This is the official implementation of ACFlow.
refer to requirements.txt.
download CelebA, CIFAR10, MNIST and Omniglot to your local workspace. You might need to change the path for each dataset in datasets folder accordingly.
MNIST and CIFAR10 can be downloaded by torchvision. Links for CelebA and Omniglot are provided here. Please cite their work if you use this repo.
You can train your own model by the scripts provided below. Or you can download our pretrained weights form here.
- Train with Gaussian base likelihood
python scripts/train.py --cfg_file=./exp/celeba/rnvp/params.json- Train with autoregressive likelihood
python scripts/train_tan.py --cfg_file=./exp/celeba/tan/params.json- Compute log likelihood on testset and compute the PSNR and PRD scores using samples.
python scripts/test.py --cfg_file=./exp/celeba/rnvp/params.jsonNOTE: you can run this script for multiple times with different random seed to get mean score and standard deviation.
- Compute joint likelihood p(x).
python scripts/test_joint.py --cfg_file=./exp/celeba/rnvp/params.json- Sample from arbitrary conditional distribution p(x_u | x_o) for multiple imputation.
python scripts/sample.py --cfg_file=./exp/celeba/rnvp/params.json- Sample the 'Best Guess' single imputation.
python scripts/sample_single.py --cfg_file=./exp/celeba/rnvp/params.json- Sample from joint distribution p(x).
python scripts/sample_joint.py --cfg_file=./exp/celeba/rnvp/params.json- Gibbs sampling
python scripts/gibbs_sampling.py --cfg_file=./exp/celeba/rnvp/params.jsonSample the upper and lower half condition on the remaining half.
similar commands can be run. Config files are provided in exp/mnist folder.
similar commands can be run. Config files are provided in exp/omniglot folder.
similar commands can be run. Config files are provided in exp/cifar folder.
Code for evaluating FID and PRD are adapted from their public implementations. Please cite their work if you use this repo.