-
Notifications
You must be signed in to change notification settings - Fork 9
Loss Functions for Ordinal data and importance weighting #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…luded beyond a single key specifying the start of the rosetta energy terms. This is problematic when many rosetta energy functions are combined into one, or ordering of columns is not guaranteed.
…s to a warning telling the user that they are proceeding at their own risk
…s to a warning telling the user that they are proceeding at their own risk
…oss function for both MSE and CORN loss
…les.py object does not need to be passed into the transfer_model which is unserializaable in hparams
…nd divide by zero errors.
code/datamodules.py
Outdated
| type=str, default="") | ||
| parser.add_argument("--aux_input_num", | ||
| help="number of auxiliary inputs", | ||
| type=int,default=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to get the # of auxiliary inputs from the given aux_input_names, which would avoid needing a separate argument.
|
I reviewed the diff but didn't run any code. Overall, it looks good to me so far. There's a comment above about getting the number of auxiliary inputs from the given auxiliary input names instead of providing it as a separate argument. Another minor suggestion would be to follow the same coding style as the existing code / PEP8. The new code does this for the most part, but there are some places where an additional space needs to be inserted like |
|
ask arnav to do a quick skim once inference and multilibrary coral are done. |
TL;DR: This update gives the ability for metl finetuned models to have an ordinal specific loss (either through corn or coral loss), originally implemented by Sebastian Raschka here, along with other small updates. I recommend looking to the
Loss Argumentssection ofnotebooks/finetuning.ipynbfor a comprehensive view of how to run these models with ordinal specific loss.This PR is a work in progress, as coral still needs to be tested on a gpu (update May 31, 2025 - tested) and multilibrary coral loss needs to be implemented. However, these updates, given the auxiliary inputs are available to the CORAL layer, should be minimal.
Additionally,
inference.pyneeds to be tested with these functions.compute_rosetta_standardization.py
--columns2ignoreincode/compute_rosetta_standardization.py, these columns are not saved to the database. This is important for situations where columns after the chosen start column, viaenergies_start_colshould not be included.datamodules.py
--num_classes, required for all ordinal loss functions.--use_importance_weights, to get weights based on training set class balance. You can set your own importance weights throughset_importance_weights. This can be done for MSE, CORN, and CORAL loss.models.py
top_net_type), which includes class specific bias terms.top_net_output_dimis specific to corn, as there must be a task to predict the probability of a sample being in each class, takes dimension N-1, where N is the number of classes.preinit_biasis coral specific, and initializes the custom bias terms to have better convergence in practice (although not tested with dms data).parse_rosetta_data.py
int_colsallows the user to specify which columns of the rosetta energy terms are integers, as these columns must be explicitly defined due to an hdf saving error.tasks.py
Majority of updates here.
--loss_funcallows the user to specify a loss, between 'mse', 'corn', and 'coral'.--corn_coral_log_featurethis special flag only effects the logging process. All N-1 probabilities tasks, where N is the number of classes, and true predictions are saved from coral and corn. However, parity plots and other metrics such as spearman correlation and pearson correlation are calculated. This flag tells the logging process which column to use for this logging process.train_target_model.py & training_utils.py
finetuning.ipynb