This codebase is a fork of the OpenLTH codebase created by Facebook. For details on basic usage of the codebase, see https://github.com/facebookresearch/open_lth.
To create a network to prune, use the train option for OpenLTH:
git clone https://github.com/sahibsin/Pruning.git
python open_lth.py train --default_hparams=mnist_lenet_300_100
python open_lth.py train --default_hparams=cifar_resnet_20
If you wish to explore pruning at steps other than the beginning and end of training, you can add an additional flag to save the weights at other steps:
python open_lth.py train --default_hparams=cifar_resnet_20 --weight_save_steps=1000it,2000it,3000it
The above command will save the weights at iterations 1000, 2000, and 3000 for later use.
To prune a network, we use the branch functionality of OpenLTH. (We have refactored this functionality slightly to make it possible to create branches of training jobs.)
We have created a branch called oneshot that can be found in training/branch/oneshot_experiments.py.
This branch makes it possible to prune the network to various sparsities using each of the pruning methods.
For example, the command
LeNet:
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=magnitude --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=random --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=snip10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=grasp10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=graspabs10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=synflow --prune_fraction=0.75 --prune_iterations=100
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=magnitude --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=random --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=snip10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=grasp10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=graspabs10 --prune_fraction=0.75
python open_lth.py branch train oneshot --default_hparams=cifar_resnet_20 --strategy=synflow --prune_fraction=0.75 --prune_iterations=100
python open_lth.py branch train oneshot --default_hparams=mnist_lenet_300_100 --strategy=synflow --prune_fraction=0.75 --prune_iterations=10
will prune the ResNet-20 we created in (1) to 75% sparsity using magnitude pruning at initialization. It will then train the network normally from there.
The available values for the strategy flag include:
randommagnitudesnipN(usesNexamples per class to compute the scores)graspN(usesNexamples per class to compute the scores)graspabsN(usesNexamples per class to compute the scores)synflow
By default, all of these methods will use one-shot pruning. To make the method iterative, set the --prune_iterations flag to the desired number of pruning iterations (e.g., 100 for SynFlow).
By default, this branch will always prune scores with the lowest values. For GraSP, this is the incorrect behavior. To prune the scores with the highest scores (or to invert a pruning method where appropriate), set the --prune_highest flag.
To prune using the state of the network at a different iteration, set the --prune_step and --state_step flags to the desired iteration (e.g., 1000it). You can only use the state of the network if you saved it in (1). Step 0 and the last step of training save by default.
To perform lottery ticket rewinding, set --prune_step to the last step of training and set --state_step to the desired rewinding iteration.
Set the --randomize_layerwise flag.
Set the --reinitialize flag.
At both (1) and here, add the flag --model_init=standard_normal
mnist_lenet_300_100cifar_resnet_20cifar_vgg_16imagenet_resnet_50tinyimagenet_resnet_18tinyimagenet_modifiedresnet_18
mnistcifar10tinyimagenet(the version we use in the main body; need to download, install according todatasets/tinyimagenet.py, and add toplatforms/local.py)tinyimagenet2(the version we use for Modified ResNet-18; need to download, install according todatasets/tinyimagenet.py, and add toplatforms/local.py)imagenet(need to download, install according todatasets/tinyimagenet.py, and add toplatforms/local.py)
rm -rf /Users/sahib/open_lth_data2/*