This project is a reproduction of the major results from two papers, SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows, and Normalizing Flows with Multi-Scale Autoregressive Priors as well as a stretch goal of implementing the idea of ProNF. The reproduction is made in JAX library. You can find the original codes for the first paper in this repository and the second paper in this repository. Both codes were implemented in PyTorch, and our repository contains a JAX implementation of them.
The below video is an oral presentation that illustrates and gives an overview of the scope of the project and the results.
oral_presentation.mp4
pip install -r requirements.txt
pip install jax==0.2.8
pip install jaxlib==0.1.56+cuda100 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Command for checkerboard
:
python experiments/toy/train_abs_unif.py --hidden_units [200,100] --dataset checkerboard --clim 0.05
Command for corners
:
python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset corners --clim 0.1 --scale_fn softplus
Command for eightgaussians
:
python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset eight_gaussians --clim 0.15 --scale_fn softplus
Command for fourcircle
:
python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset fourcircle --clim 0.2 --scale_fn softplus
Command for pool = none
python experiments/max_pooling/max_pooling_experiment.py --epochs 500 --batch_size 32 --optimizer adamax --lr 1e-4 --gamma 0.995 --eval_every 1 --check_every 10 --warmup 5000 --num_steps 12 --num_scales 2 --dequant flow --pooling none --dataset cifar10 --augmentation eta --name nonpool --model_dir ./experiments/max_pooling/checkpoints/
Command for pool = max
python experiments/max_pooling/max_pooling_experiment.py --epochs 500 --batch_size 32 --optimizer adamax --lr 1e-4 --gamma 0.995 --eval_every 1 --check_every 10 --warmup 5000 --num_steps 12 --num_scales 2 --dequant flow --pooling max --dataset cifar10 --augmentation eta --name maxpool --model_dir ./experiments/max_pooling/checkpoints/
python experiments/msar_scf/train_msar_scf.py --ckptdir "experiments/msar_scf/ckpt_sigmoid" --activation "sigmoid" --resume True --num_epochs 3000
## 16x16 => 32x32
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_32" --resume True --warmup 50000 --ms
## 8x8 => 16x16
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_16" --resume True --warmup 50000 --input_res 16 --num_layers 2 --ms --learning_rate 1e-4
## 4x4 => 8x8
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_8" --resume True --warmup 50000 --input_res 8 --num_layers 2 --ms --learning_rate 1e-4
## 4x4 unconditional
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_4" --resume True --warmup 50000 --input_res 4
## chain-up
python experiments/pro_nf/merge.py --ckptdir "experiments/pro_nf" --resume True
python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 32
python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 16 --smallest
python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 32 16 --resume