Official implementations of MMD-B-Fair published in AISTATS 2023 by Namrata Deka and Danica J. Sutherland.
- Python 3.8.10
- PyTorch 1.12.1
- Torchvision 0.13.1
- Wandb 0.13.3
-
Create a .yml config file and store it in
config
(see available examples, VERY IMPORTANT). The file should specify all data, model and trainer settings/hyperparameters. -
Then execute
python main.py -v <relative-path-to-yml-from-config> --seed <seed> -w
Example:
python main.py -v eq/adult/lambda_1.yml --seed 42 -w
If you do not wish to sync to wandb while training add the option -m offline
and sync anytime later with the wandb sync
command.
If seed is not specified it will default to 0
.
Trained models are saved in the location specified in experiment.output_location
in a subfolder named as per the seed. In wandb, experiments are logged under <config file name>/<seed>
in the mmd-b-fair
workspace.
This repository heavily uses the factory design pattern for increased modularity. To add new datasets, models and/or trainers follow the steps below:
- Create new data/model/trainer class under the appropriate directories. All trainer classes must inherit
BaseTrainer
and models must inheritBaseModel
. - Create corresponding builders for new classes.
- Register all builder objects to the respective factories in
data/data.py
,model/model.py
andtrainer/trainer.py
. data.data_key
,model.model_key
andmodel.trainer_key
in the config files must match the registered factory keys.- Specify class-specific arguments in the config file. Example, dataset arguements must go in
data.args
.