-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zhifeng Kong
authored and
Zhifeng Kong
committed
Jun 2, 2021
1 parent
1cee44e
commit 6d82ebf
Showing
11 changed files
with
1,649 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
import torch | ||
import argparse | ||
|
||
from pytorch_fid.fid_score import calculate_fid_given_paths | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
# dataset and model | ||
parser.add_argument('-name', '--name', type=str, choices=["cifar10", "lsun_bedroom", "celeba64"], | ||
help='Name of experiment') | ||
parser.add_argument('-ema', '--ema', action='store_true', help='Whether use ema') | ||
|
||
# fast generation parameters | ||
parser.add_argument('-approxdiff', '--approxdiff', type=str, choices=['STD', 'STEP', 'VAR'], help='approximate diffusion process') | ||
parser.add_argument('-kappa', '--kappa', type=float, default=1.0, help='factor to be multiplied to sigma') | ||
parser.add_argument('-S', '--S', type=int, default=50, help='number of steps') | ||
parser.add_argument('-schedule', '--schedule', type=str, choices=['linear', 'quadratic'], help='noise level schedules') | ||
|
||
parser.add_argument('-gpu', '--gpu', type=int, default=0, help='gpu device') | ||
|
||
args = parser.parse_args() | ||
|
||
kwargs = {'batch_size': 50, 'device': torch.device('cuda:{}'.format(args.gpu)), 'dims': 2048} | ||
|
||
if args.approxdiff == 'STD': | ||
variance_schedule = '1000' | ||
else: | ||
variance_schedule = '{}{}'.format(args.S, args.schedule) | ||
folder = '{}{}_{}{}_kappa{}'.format('ema_' if args.ema else '', | ||
args.name, | ||
args.approxdiff, | ||
variance_schedule, | ||
args.kappa) | ||
if folder not in os.listdir('generated'): | ||
raise Exception('folder not found') | ||
|
||
paths = ['./generated/{}'.format(folder), | ||
'./pytorch_fid/{}_train_stat.npy'.format(args.name)] | ||
fid = calculate_fid_given_paths(paths=paths, **kwargs) | ||
print('{}: FID = {}'.format(folder, fid)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". | ||
FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. See paper via [this link](https://arxiv.org/abs/2106.00132). | ||
|
||
# Pretrained models | ||
Download checkpoints from [this link](https://heibox.uni-heidelberg.de/d/01207c3f6b8441779abf/) and [this link](https://drive.google.com/file/d/1R_H-fJYXSH79wfSKs9D-fuKQVan5L-GR/view?usp=sharing). Put them under ```checkpoints\ema_diffusion_${dataset_name}_model\model.ckpt```, where ```${dataset_name}``` is ```cifar10```, ```celeba64```, ```lsun_bedroom```, ``lsun_church```, or ``lsun_cat```. | ||
|
||
# Usage | ||
- General command: ```python generate.py -ema -name ${dataset_name} -approxdiff ${approximate_diffusion_process} -kappa ${kappa} -S ${FastDPM_length} -schedule ${noise_level_schedule} -n ${number_to_generate} -bs ${batchsize} -gpu ${gpu_index}``` | ||
|
||
## CIFAR-10 | ||
Below are commands to generate CIFAR-10 images. | ||
- Standard DDPM generation: ```python generate.py -ema -name cifar10 -approxdiff STD -n 16 -bs 16``` | ||
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16``` | ||
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff STEP -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16``` | ||
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 1.0 -S 50 -schedule quadratic -n 16 -bs 16``` | ||
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name cifar10 -approxdiff VAR -kappa 0.0 -S 50 -schedule quadratic -n 16 -bs 16``` | ||
|
||
## CelebA | ||
Below are commands to generate CelebA images. | ||
- Standard DDPM generation: ```python generate.py -ema -name celeba64 -approxdiff STD -n 16 -bs 16``` | ||
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16``` | ||
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16``` | ||
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 16 -bs 16``` | ||
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name celeba64 -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 16 -bs 16``` | ||
|
||
## LSUN_bedroom | ||
Below are commands to generate LSUN bedroom images. | ||
- Standard DDPM generation: ```python generate.py -ema -name lsun_bedroom -approxdiff STD -n 8 -bs 8``` | ||
- FastDPM generation (STEP + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8``` | ||
- FastDPM generation (STEP + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff STEP -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8``` | ||
- FastDPM generation (VAR + DDPM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 1.0 -S 50 -schedule linear -n 8 -bs 8``` | ||
- FastDPM generation (VAR + DDIM-rev): ```python generate.py -ema -name lsun_bedroom -approxdiff VAR -kappa 0.0 -S 50 -schedule linear -n 8 -bs 8``` | ||
|
||
## Note | ||
To generate 50K samples, set ```-n 50000``` and batchsize (```-bs```) divisible by 50K. | ||
|
||
# Compute FID | ||
To compute FID of generated samples, first make sure there are 50K images, and then run | ||
- ```python FID.py -ema -name cifar10 -approxdiff STEP -kappa 1.0 -S 50 -schedule quadratic``` | ||
|
||
# Code References | ||
- [DDPM TensorFlow official](https://github.com/hojonathanho/diffusion) | ||
- [DDPM PyTorch](https://github.com/pesser/pytorch_diffusion) | ||
- [DDPM CelebA-HQ](https://github.com/FengNiMa/pytorch_diffusion_model_celebahq) | ||
- [DDIM PyTorch](https://github.com/ermongroup/ddim) | ||
- [FID PyTorch](https://github.com/mseitzer/pytorch-fid) | ||
- [DiffWave PyTorch 1](https://github.com/lmnt-com/diffwave) | ||
- [DiffWave PyTorch 2](https://github.com/philsyn/DiffWave-Vocoder) | ||
- [DiffWave PyTorch 3](https://github.com/philsyn/DiffWave-unconditional) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
cifar10_cfg = { | ||
"resolution": 32, | ||
"in_channels": 3, | ||
"out_ch": 3, | ||
"ch": 128, | ||
"ch_mult": (1,2,2,2), | ||
"num_res_blocks": 2, | ||
"attn_resolutions": (16,), | ||
"dropout": 0.1, | ||
} | ||
|
||
lsun_cfg = { | ||
"resolution": 256, | ||
"in_channels": 3, | ||
"out_ch": 3, | ||
"ch": 128, | ||
"ch_mult": (1,1,2,2,4,4), | ||
"num_res_blocks": 2, | ||
"attn_resolutions": (16,), | ||
"dropout": 0.0, | ||
} | ||
|
||
celeba64_cfg = { | ||
"resolution": 64, | ||
"in_channels": 3, | ||
"out_ch": 3, | ||
"ch": 128, | ||
"ch_mult": (1,2,2,2,4), | ||
"num_res_blocks": 2, | ||
"attn_resolutions": (16,), | ||
"dropout": 0.1, | ||
} | ||
|
||
model_config_map = { | ||
"cifar10": cifar10_cfg, | ||
"lsun_bedroom": lsun_cfg, | ||
"lsun_cat": lsun_cfg, | ||
"lsun_church": lsun_cfg, | ||
"celeba64": celeba64_cfg | ||
} | ||
|
||
diffusion_config = { | ||
"beta_0": 0.0001, | ||
"beta_T": 0.02, | ||
"T": 1000, | ||
} | ||
|
||
model_var_type_map = { | ||
"cifar10": "fixedlarge", | ||
"lsun_bedroom": "fixedsmall", | ||
"lsun_cat": "fixedsmall", | ||
"lsun_church": "fixedsmall", | ||
} |
Oops, something went wrong.