Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhifeng Kong authored and Zhifeng Kong committed Jun 2, 2021
1 parent 1cee44e commit 6d82ebf
Show file tree
Hide file tree
Showing 11 changed files with 1,649 additions and 0 deletions.
41 changes: 41 additions & 0 deletions FID.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import torch
import argparse

from pytorch_fid.fid_score import calculate_fid_given_paths

if __name__ == '__main__':
parser = argparse.ArgumentParser()
# dataset and model
parser.add_argument('-name', '--name', type=str, choices=["cifar10", "lsun_bedroom", "celeba64"],
help='Name of experiment')
parser.add_argument('-ema', '--ema', action='store_true', help='Whether use ema')

# fast generation parameters
parser.add_argument('-approxdiff', '--approxdiff', type=str, choices=['STD', 'STEP', 'VAR'], help='approximate diffusion process')
parser.add_argument('-kappa', '--kappa', type=float, default=1.0, help='factor to be multiplied to sigma')
parser.add_argument('-S', '--S', type=int, default=50, help='number of steps')
parser.add_argument('-schedule', '--schedule', type=str, choices=['linear', 'quadratic'], help='noise level schedules')

parser.add_argument('-gpu', '--gpu', type=int, default=0, help='gpu device')

args = parser.parse_args()

kwargs = {'batch_size': 50, 'device': torch.device('cuda:{}'.format(args.gpu)), 'dims': 2048}

if args.approxdiff == 'STD':
variance_schedule = '1000'
else:
variance_schedule = '{}{}'.format(args.S, args.schedule)
folder = '{}{}_{}{}_kappa{}'.format('ema_' if args.ema else '',
args.name,
args.approxdiff,
variance_schedule,
args.kappa)
if folder not in os.listdir('generated'):
raise Exception('folder not found')

paths = ['./generated/{}'.format(folder),
'./pytorch_fid/{}_train_stat.npy'.format(args.name)]
fid = calculate_fid_given_paths(paths=paths, **kwargs)
print('{}: FID = {}'.format(folder, fid))
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models".
FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. See paper via [this link](https://arxiv.org/abs/2106.00132).

# Pretrained models
Download checkpoints from [this link](https://heibox.uni-heidelberg.de/d/01207c3f6b8441779abf/) and [this link](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing). Put them under ```checkpoints\ema_diffusion_${dataset_name}_model\model.ckpt```, where ```${dataset_name}``` is ```cifar10```, ```celeba64```, ```lsun_bedroom```, ``lsun_church```, or ``lsun_cat```.

# Usage
- General command: ```python generate.py -ema -name ${dataset_name} -approxdiff ${approximate_diffusion_process} -kappa ${kappa} -S ${FastDPM_length} -schedule ${noise_level_schedule} -n ${number_to_generate} -bs ${batchsize} -gpu ${gpu_index}```

## CIFAR-10
Below are commands to generate CIFAR-10 images.
- Standard DDPM generation: ```python generate.py -ema -name cifar10 -approxdiff STD -n 16 -bs 16```
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16```
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16```
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16```
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16```

## CelebA
Below are commands to generate CelebA images.
- Standard DDPM generation: ```python generate.py -ema -name celeba64 -approxdiff STD -n 16 -bs 16```
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16```
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16```
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16```
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16```

## LSUN_bedroom
Below are commands to generate LSUN bedroom images.
- Standard DDPM generation: ```python generate.py -ema -name lsun_bedroom -approxdiff STD -n 8 -bs 8```
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8```
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8```
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8```
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8```

## Note
To generate 50K samples, set ```-n 50000``` and batchsize (```-bs```) divisible by 50K.

# Compute FID
To compute FID of generated samples, first make sure there are 50K images, and then run
- ```python FID.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic```

# Code References
- [DDPM TensorFlow official](https://github.com/hojonathanho/diffusion)
- [DDPM PyTorch](https://github.com/pesser/pytorch_diffusion)
- [DDPM CelebA-HQ](https://github.com/FengNiMa/pytorch_diffusion_model_celebahq)
- [DDIM PyTorch](https://github.com/ermongroup/ddim)
- [FID PyTorch](https://github.com/mseitzer/pytorch-fid)
- [DiffWave PyTorch 1](https://github.com/lmnt-com/diffwave)
- [DiffWave PyTorch 2](https://github.com/philsyn/DiffWave-Vocoder)
- [DiffWave PyTorch 3](https://github.com/philsyn/DiffWave-unconditional)
53 changes: 53 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
cifar10_cfg = {
"resolution": 32,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": (1,2,2,2),
"num_res_blocks": 2,
"attn_resolutions": (16,),
"dropout": 0.1,
}

lsun_cfg = {
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": (1,1,2,2,4,4),
"num_res_blocks": 2,
"attn_resolutions": (16,),
"dropout": 0.0,
}

celeba64_cfg = {
"resolution": 64,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": (1,2,2,2,4),
"num_res_blocks": 2,
"attn_resolutions": (16,),
"dropout": 0.1,
}

model_config_map = {
"cifar10": cifar10_cfg,
"lsun_bedroom": lsun_cfg,
"lsun_cat": lsun_cfg,
"lsun_church": lsun_cfg,
"celeba64": celeba64_cfg
}

diffusion_config = {
"beta_0": 0.0001,
"beta_T": 0.02,
"T": 1000,
}

model_var_type_map = {
"cifar10": "fixedlarge",
"lsun_bedroom": "fixedsmall",
"lsun_cat": "fixedsmall",
"lsun_church": "fixedsmall",
}
Loading

0 comments on commit 6d82ebf

Please sign in to comment.