-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathsample_fitv2_ddp.py
272 lines (239 loc) · 12.2 KB
/
sample_fitv2_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Samples a large number of images from a pre-trained DiT model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
For a simple single-GPU/CPU sampling script, see sample.py.
"""
import os
import sys
import math
import torch
import argparse
import numpy as np
import torch.distributed as dist
import re
from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from diffusers.models import AutoencoderKL
from fit.scheduler.transport import create_transport, Sampler
from fit.utils.eval_utils import create_npz_from_sample_folder, init_from_ckpt
from fit.utils.utils import instantiate_from_config
from fit.utils.sit_eval_utils import parse_sde_args, parse_ode_args
def ntk_scaled_init(head_dim, base=10000, alpha=8):
#The method is just these two lines
dim_h = head_dim // 2 # for x and y
base = base * alpha ** (dim_h / (dim_h-2)) #Base change formula
return base
def main(args):
"""
Run sampling.
"""
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
torch.set_grad_enabled(False)
# Setup DDP:
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
if args.mixed == "fp32":
weight_dtype = torch.float32
elif args.mixed == "bf16":
weight_dtype = torch.bfloat16
if args.cfgdir == "":
args.cfgdir = os.path.join(args.ckpt.split("/")[0], args.ckpt.split("/")[1], "configs/config.yaml")
print("config dir: ",args.cfgdir)
config = OmegaConf.load(args.cfgdir)
config_diffusion = config.diffusion
H, W = args.image_height // 8, args.image_width // 8
patch_size = config_diffusion.network_config.params.patch_size
n_patch_h, n_patch_w = H // patch_size, W // patch_size
if args.interpolation != 'no':
if args.interpolation == 'linear': # 这个就是positional index interpolation,原来叫normal,现在叫linear
config_diffusion.network_config.params['custom_freqs'] = 'linear'
elif args.interpolation == 'dynntk': # 这个就是ntk-aware
config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware'
elif args.interpolation == 'ntkpro1':
config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro1'
elif args.interpolation == 'ntkpro2':
config_diffusion.network_config.params['custom_freqs'] = 'ntk-aware-pro2'
elif args.interpolation == 'partntk': # 这个就是ntk-by-parts
config_diffusion.network_config.params['custom_freqs'] = 'ntk-by-parts'
elif args.interpolation == 'yarn':
config_diffusion.network_config.params['custom_freqs'] = 'yarn'
else:
raise NotImplementedError
config_diffusion.network_config.params['max_pe_len_h'] = n_patch_h
config_diffusion.network_config.params['max_pe_len_w'] = n_patch_w
config_diffusion.network_config.params['decouple'] = args.decouple
config_diffusion.network_config.params['ori_max_pe_len'] = int(args.ori_max_pe_len)
config_diffusion.network_config.params['online_rope'] = False
else: # there is no need to do interpolation!
config_diffusion.network_config.params['custom_freqs'] = 'normal'
config_diffusion.network_config.params['online_rope'] = False
model = instantiate_from_config(config_diffusion.network_config).to(device, dtype=weight_dtype)
init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True)
model.eval() # important
# prepare first stage model
if args.vae_decoder == 'sd-ft-mse':
vae_model = 'stabilityai/sd-vae-ft-mse'
elif args.vae_decoder == 'sd-ft-ema':
vae_model = 'stabilityai/sd-vae-ft-ema'
vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(device, dtype=weight_dtype)
vae.eval() # important
# prepare transport
transport = create_transport(**OmegaConf.to_container(config_diffusion.transport)) # default: velocity;
sampler = Sampler(transport)
sampler_mode = args.sampler_mode
if sampler_mode == "ODE":
if args.likelihood:
assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
sample_fn = sampler.sample_ode_likelihood(
sampling_method=args.ode_sampling_method,
num_steps=args.num_sampling_steps,
atol=args.atol,
rtol=args.rtol,
)
else:
sample_fn = sampler.sample_ode(
sampling_method=args.ode_sampling_method,
num_steps=args.num_sampling_steps,
atol=args.atol,
rtol=args.rtol,
reverse=args.reverse
)
elif sampler_mode == "SDE":
sample_fn = sampler.sample_sde(
sampling_method=args.sde_sampling_method,
diffusion_form=args.diffusion_form,
diffusion_norm=args.diffusion_norm,
last_step=args.last_step,
last_step_size=args.last_step_size,
num_steps=args.num_sampling_steps,
)
else:
raise NotImplementedError
workdir_name = 'official_fit'
folder_name = f'{args.ckpt.split("/")[-1].split(".")[0]}'
sample_folder_dir = f"{args.sample_dir}/{workdir_name}/{folder_name}"
if rank == 0:
os.makedirs(os.path.join(args.sample_dir, workdir_name), exist_ok=True)
os.makedirs(sample_folder_dir, exist_ok=True)
print(f"Saving .png samples at {sample_folder_dir}")
dist.barrier()
args.cfg_scale = float(args.cfg_scale)
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
using_cfg = args.cfg_scale > 1.0
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
n = args.per_proc_batch_size
global_batch_size = n * dist.get_world_size()
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
if rank == 0:
print(f"Total number of images that will be sampled: {total_samples}")
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
iterations = int(samples_needed_this_gpu // n)
pbar = range(iterations)
pbar = tqdm(pbar) if rank == 0 else pbar
total = 0
index = 0
all_images = []
while len(all_images) * n < int(args.num_fid_samples):
print(device, "device: ", index, flush=True)
index+=1
# Sample inputs:
z = torch.randn(
(n, n_patch_h*n_patch_w, (patch_size**2)*model.in_channels)
).to(device=device, dtype=weight_dtype)
y = torch.randint(0, args.num_classes, (n,), device=device)
# prepare for x
grid_h = torch.arange(n_patch_h, dtype=torch.long)
grid_w = torch.arange(n_patch_w, dtype=torch.long)
grid = torch.meshgrid(grid_w, grid_h, indexing='xy')
grid = torch.cat(
[grid[0].reshape(1,-1), grid[1].reshape(1,-1)], dim=0
).repeat(n,1,1).to(device=device, dtype=torch.long)
mask = torch.ones(n, n_patch_h*n_patch_w).to(device=device, dtype=weight_dtype)
size = torch.tensor((n_patch_h, n_patch_w)).repeat(n,1).to(device=device, dtype=torch.long)
size = size[:, None, :]
# Setup classifier-free guidance:
if using_cfg:
z = torch.cat([z, z], 0) # (B, N, patch_size**2 * C) -> (2B, N, patch_size**2 * C)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0) # (B,) -> (2B, )
grid = torch.cat([grid, grid], 0) # (B, 2, N) -> (2B, 2, N)
mask = torch.cat([mask, mask], 0) # (B, N) -> (2B, N)
size = torch.cat([size, size], 0)
model_kwargs = dict(y=y, grid=grid, mask=mask, size=size, cfg_scale=args.cfg_scale, scale_pow=args.scale_pow)
model_fn = model.forward_with_cfg
else:
model_kwargs = dict(y=y, grid=grid, mask=mask, size=size)
model_fn = model.forward
# Sample images:
samples = sample_fn(z, model_fn, **model_kwargs)[-1]
if using_cfg:
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = samples[..., : n_patch_h*n_patch_w]
samples = model.unpatchify(samples, (H, W))
samples = vae.decode(samples / vae.config.scaling_factor).sample
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous()
# gather samples
gathered_samples = [torch.zeros_like(samples) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, samples) # gather not supported with NCCL
all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
# Save samples to disk as individual .png files
for i, sample in enumerate(samples.cpu().numpy()):
index = i * dist.get_world_size() + rank + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += global_batch_size
if rank == 0:
pbar.update()
# Make sure all processes have finished saving their samples before attempting to convert to .npz
dist.barrier()
if rank == 0:
import time
time.sleep(20)
arr = np.concatenate(all_images, axis=0)
arr = arr[: int(args.num_fid_samples)]
npz_path = f"{sample_folder_dir}.npz"
np.savez(npz_path, arr_0=arr)
print(f"Saved .npz file to {npz_path} [shape={arr.shape}].")
print("Done.")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cfgdir", type=str, default="")
parser.add_argument("--ckpt", type=str, default="")
parser.add_argument("--sample-dir", type=str, default="workdir/eval")
parser.add_argument("--per-proc-batch-size", type=int, default=32)
parser.add_argument("--num-fid-samples", type=int, default=50_000)
parser.add_argument("--image-height", type=int, default=256)
parser.add_argument("--image-width", type=int, default=256)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--vae-decoder", type=str, choices=['sd-ft-mse', 'sd-ft-ema'], default='sd-ft-ema')
parser.add_argument("--cfg-scale", type=str, default='1.0')
parser.add_argument("--scale-pow", type=float, default=0.0)
parser.add_argument("--num-sampling-steps", type=int, default=250)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'yarn', 'dynntk', 'partntk', 'ntkpro1', 'ntkpro2'], default='no') # interpolation
parser.add_argument("--ori-max-pe-len", default=None, type=int)
parser.add_argument("--decouple", default=False, action="store_true") # interpolation
parser.add_argument("--sampler-mode", default='SDE', choices=['SDE', 'ODE'])
parser.add_argument("--tf32", action='store_true', default=True)
parser.add_argument("--mixed", type=str, default="fp32")
parser.add_argument("--save-images", action='store_true', default=False)
parse_ode_args(parser)
parse_sde_args(parser)
args = parser.parse_args()
main(args)