This repository contains the Pytorch implementation for the paper "SLACK: Stable Learning of Augmentations with Cold-start and KL regularization", which was published in the Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) in June 2023.
If you find this work useful for your research, please cite the paper using the following BibTeX entry:
@InProceedings{Marrie_2023_CVPR,
author = {Marrie, Juliette and Arbel, Michael and Larlus, Diane and Mairal, Julien},
title = {SLACK: Stable Learning of Augmentations with Cold-start and KL regularization},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023}
}
SLACK is evaluated based on
- A pretraining phase
- A search phase, where a policy is found
- The evaluation phase, where a network is trained with the found policy on the full training data, with
$n_{train}$ independent runs (seeds).
The final reported accuracy is the average over the
Create conda environment: conda create -n slack python=3.9.7
Install dependencies: pip install -r requirements.txt
Before running the script, export ROOT=$HOME/SLACK/slack.
The data should be contained in $ROOT/data/datasets.
The search logs are saved in $ROOT/data/outputs/[SEARCH_DIR], in a SEARCH_DIR directory specific to each search experiment (and automatically created). They contain the augmentation model in different formats: as checkpoint (models.ckpt), as numpy files (in pi, mu), as .txt (in genotype) and with simple visualizations (in plt_genotype). Also, metadata.yaml reports the hyperparameters used for the search and val.txt reports validation and training metrics.
The (pre)training is performed using TrivialAugment's framework, located in $ROOT/TrivialAugment.
The (pre)training logs are saved in $ROOT/TrivialAugment/logs and can be evaluated with the $ROOT/TrivialAugment/aggregate_results.py script.
python aggregate_results.py --logdirs [DIRS] --split [train|test] --metric [top1|top5|loss] --step [STEP]The pretraining checkpoint is directly saved in $ROOT/data/outputs/[SEARCH_DIR]. The other checkpoints are saved in $ROOT/TrivialAugment/ckpt.
Evaluate our policy on SPLIT with 8 seeds:
sh $ROOT/scripts/cifar/train.sh [10|100] [40x2|28x10] [SPLIT] git-policiesEvaluate our Uniform policy with 8 seeds:
sh $ROOT/scripts/cifar/train.sh [10|100] [40x2|28x10] uniformTest accuracies on CIFAR (average over
| CIFAR10 WRN-40-2 | CIFAR10 WRN-28-10 | CIFAR100 WRN-40-2 | CIFAR100 WRN-28-10 | |
|---|---|---|---|---|
| TA (Wide) | 96.32 | 97.46 | 79.76 | 84.33 |
| Uniform policy | 96.12 | 97.26 | 78.79 | 82.82 |
| SLACK | 96.29 | 97.46 | 79.87 | 84.08 |
Best policies found for CIFAR10 (1,2) and CIFAR100 (3,4) with WRN-40-2 (1,3) and WRN-28-10 (2,4).
Class IDs for ImageNet-100 can be found here.
Evaluate our policy on SPLIT with SEED
sh $ROOT/scripts/imagenet100/train.sh [SPLIT] [SEED] git-policiesEvaluate our Uniform policy with SEED
sh $ROOT/scripts/imagenet100/train.sh uniform [SEED]Evaluate TrivialAugment with SEED
sh $ROOT/scripts/imagenet100/train-TA.sh [ta|ta_wide] [SEED]Test accuracies on ImageNet-100 (average over
| Method | ImageNet-100, ResNet-18 |
|---|---|
| TA (RA) | 85.87 |
| TA (Wide) | 86.39 |
| Uniform policy | 85.78 |
| SLACK | 86.06 |
Best policy found for ImageNet-100.
DomainNet is a dataset commonly used for domain generalization that contains 345 classes of images from six different domains: painting, clipart, sketch, infograph, quickdraw, real. It can be downloaded from the DomainBed suite.
Download
We evaluate on the six domains, with a reduced version of 50,000 training images for the two largest (real, quickdraw) and use the remaining of the data for testing. For the other domains, we isolate 20% of the data for testing. The filenames belonging to the train/test splits are stored in $ROOT/domein_net_splits/.
Use the following script to separatesthe data into training and testing folders for a DOMAIN from {painting, clipart, sketch, infograph, quidraw, real}:
python domain_net_splits/split_dataset.py --data_dir data/datasets/domain_net/[DOMAINET] --train_id domain_net_splits/npz/[DOMAIN]_train.npzEvaluation
Evaluate our policy on SPLIT with SEED
sh $ROOT/scripts/domainnet/train.sh [DOMAIN] [SPLIT] [SEED] git-policiesEvaluate our Uniform policy with SEED
sh $ROOT/scripts/domainnet/train.sh [DOMAIN] uniform [SEED]Evaluate TrivialAugment with SEED
sh $ROOT/scripts/domainnet/train-TA.sh [DOMAIN] [ta_imagenet|ta_cifar|ta_imagenet_wide|ta_cifar_wide] [SEED]Test accuracies on DomainNet (average over
| Method | Real-50k | Quickdraw-50k | Inforgraph | Sketch | Painting | Clipart | Average |
|---|---|---|---|---|---|---|---|
| DomainBed | 62.54 | 66.54 | 26.76 | 59.54 | 58.31 | 66.23 | 57.23 |
| TA (RA) ImageNet | 70.85 | 67.85 | 35.24 | 65.63 | 64.75 | 70.29 | 62.43 |
| TA (Wide) ImageNet | 71.56 | 68.60 | 35.44 | 66.21 | 65.15 | 71.19 | 63.03 |
| TA (RA) CIFAR | 70.28 | 68.35 | 33.85 | 64.13 | 64.73 | 70.33 | 61.94 |
| TA (Wide) CIFAR | 71.12 | 69.29 | 34.21 | 65.52 | 64.81 | 71.01 | 62.66 |
| Uniform policy | 70.37 | 68.27 | 34.11 | 65.22 | 63.97 | 72.26 | 62.37 |
| SLACK | 71.00 | 68.14 | 34.78 | 65.41 | 64.83 | 72.65 | 62.80 |
Best policies found for Sketch (left), Clipart (center) and Painting (right).
-
Pretrain on SPLIT
sh $ROOT/scripts/cifar/pretrain.sh [10|100] [40x2|28x10] [SPLIT]
-
Search on SPLIT
sh $ROOT/scripts/cifar/search.sh [10|100] [40x2|28x10] [SPLIT]
-
Evaluate on 4 seeds for SPLIT
sh $ROOT/scripts/cifar/train.sh [10|100] [40x2|28x10] [SPLIT]
You can also run our ablations (no-kl, warm-start, unrolled, pi-only, mu-only):
bash sh $ROOT/scripts/cifar/search.sh [10|100] [40x2|28x10] [SPLIT] [ABLATION]
-
Pretrain on SPLIT
sh $ROOT/scripts/imagenet100/pretrain.sh [SPLIT] -
Search on SPLIT
sh $ROOT/scripts/imagenet100/search.sh [SPLIT] -
Evaluate on SPLIT with SEED
sh $ROOT/scripts/imagenet100/train.sh [SPLIT] [SEED]
-
Pretrain on SPLIT
sh $ROOT/scripts/domainnet/pretrain.sh [DOMAIN] [SPLIT] -
Search on SPLIT
sh $ROOT/scripts/domainnet/search.sh [DOMAIN] [SPLIT] -
Evaluate on SPLIT with SEED
sh $ROOT/scripts/domainnet/train.sh [DOMAIN] [SPLIT] [SEED]