-
Notifications
You must be signed in to change notification settings - Fork 322
/
Copy pathinference.py
481 lines (419 loc) · 20.9 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
import random
import time
from pathlib import Path
import numpy as np
import torch
# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
from .diffusion.pipeline import StableDiffusionPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
from peft import LoraConfig
class Resolution:
def __init__(self, width, height):
self.width = width
self.height = height
def __str__(self):
return f'{self.height}x{self.width}'
class ResolutionGroup:
def __init__(self):
self.data = [
Resolution(1024, 1024), # 1:1
Resolution(1280, 1280), # 1:1
Resolution(1024, 768), # 4:3
Resolution(1152, 864), # 4:3
Resolution(1280, 960), # 4:3
Resolution(768, 1024), # 3:4
Resolution(864, 1152), # 3:4
Resolution(960, 1280), # 3:4
Resolution(1280, 768), # 16:9
Resolution(768, 1280), # 9:16
]
self.supported_sizes = set([(r.width, r.height) for r in self.data])
def is_valid(self, width, height):
return (width, height) in self.supported_sizes
STANDARD_RATIO = np.array([
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
])
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1280, 960)], # 4:3
[(960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [
np.array([w * h for w, h in shapes])
for shapes in STANDARD_SHAPE
]
def get_standard_shape(target_width, target_height):
"""
Map image size to standard size.
"""
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
embedder_t5, infer_mode, sampler=None):
"""
Get scheduler and pipeline for sampling. The sampler and pipeline are both
based on diffusers and make some modifications.
Returns
-------
pipeline: StableDiffusionPipeline
sampler_name: str
"""
sampler = sampler or args.sampler
# Load sampler from factory
kwargs = SAMPLER_FACTORY[sampler]['kwargs']
scheduler = SAMPLER_FACTORY[sampler]['scheduler']
# Update sampler according to the arguments
kwargs['beta_schedule'] = args.noise_schedule
kwargs['beta_start'] = args.beta_start
kwargs['beta_end'] = args.beta_end
kwargs['prediction_type'] = args.predict_type
# Build scheduler according to the sampler.
scheduler_class = getattr(schedulers, scheduler)
scheduler = scheduler_class(**kwargs)
logger.debug(f"Using sampler: {sampler} with scheduler: {scheduler}")
# Set timesteps for inference steps.
scheduler.set_timesteps(args.infer_steps, device)
# Only enable progress bar for rank 0
progress_bar_config = {} if rank == 0 else {'disable': True}
pipeline = StableDiffusionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=model,
scheduler=scheduler,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
progress_bar_config=progress_bar_config,
embedder_t5=embedder_t5,
infer_mode=infer_mode,
)
pipeline = pipeline.to(device)
return pipeline, sampler
class End2End(object):
def __init__(self, args, models_root_path):
self.args = args
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel('ERROR')
# ========================================================================
logger.info(f"Loading CLIP Text Encoder...")
text_encoder_path = self.root / "clip_text_encoder"
self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
logger.info(f"Loading CLIP Text Encoder finished")
# ========================================================================
logger.info(f"Loading CLIP Tokenizer...")
tokenizer_path = self.root / "tokenizer"
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
logger.info(f"Loading CLIP Tokenizer finished")
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / 'mt5'
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
self.embedder_t5 = embedder_t5
self.embedder_t5.model.to(self.device) # Only move encoder to device
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
# ========================================================================
logger.info(f"Loading VAE...")
vae_path = self.root / "sdxl-vae-fp16-fix"
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
logger.info(f"Loading VAE finished")
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
self.patch_size = model_config['patch_size']
self.head_size = model_config['hidden_size'] // model_config['num_heads']
self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
self.image_size = _to_tuple(self.args.image_size)
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
self.infer_mode = self.args.infer_mode
if self.infer_mode in ['fa', 'torch']:
# Build model structure
self.model = HunYuanDiT(self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
).half().to(self.device) # Force to use fp16
# Load model checkpoint
self.load_torch_weights()
lora_ckpt = args.lora_ckpt
if lora_ckpt is not None and lora_ckpt != "":
logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
self.model.load_adapter(lora_ckpt)
self.model.merge_and_unload()
self.model.eval()
logger.info(f"Loading torch model finished")
elif self.infer_mode == 'trt':
from .modules.trt.hcf_model import TRTModel
trt_dir = self.root / "model_trt"
engine_dir = trt_dir / "engine"
plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
model_name = "model_onnx"
logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
self.model = TRTModel(model_name=model_name,
engine_dir=str(engine_dir),
image_height=TRT_MAX_HEIGHT,
image_width=TRT_MAX_WIDTH,
text_maxlen=args.text_len,
embedding_dim=args.text_states_dim,
plugin_path=str(plugin_path),
max_batch_size=TRT_MAX_BATCH_SIZE,
)
logger.info(f"Loading TensorRT model finished")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Build inference pipeline. We use a customized StableDiffusionPipeline.
logger.info(f"Loading inference pipeline...")
self.pipeline, self.sampler = self.load_sampler()
logger.info(f'Loading pipeline finished')
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_torch_weights(self):
load_key = self.args.load_key
if self.args.dit_weight is not None:
dit_weight = Path(self.args.dit_weight)
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith('pytorch_model_'):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith('_model_states.pt') for f in files):
files = [f for f in files if str(f).endswith('_model_states.pt')]
model_path = files[0]
if len(files) > 1:
logger.warning(f"Multiple model weights found in {dit_weight}, using {model_path}")
bare_model = False
else:
raise ValueError(f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file.")
elif dit_weight.is_file():
model_path = dit_weight
bare_model = 'unknown'
else:
raise ValueError(f"Invalid model path: {dit_weight}")
else:
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{load_key}.pt"
bare_model = True
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
if model_path.suffix == '.safetensors':
raise NotImplementedError(f"Loading safetensors is not supported yet.")
else:
# Assume it's a single weight file in the *.pt format.
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
if bare_model == 'unknown' and ('ema' in state_dict or 'module' in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}.")
if 'style_embedder.weight' in state_dict and not hasattr(self.model, 'style_embedder'):
raise ValueError(f"You might be attempting to load the weights of HunYuanDiT version <= 1.1. You need "
f"to set `--use-style-cond --size-cond 1024 1024 --beta-end 0.03` to adapt to these weights."
f"Alternatively, you can use weights of version >= 1.2, which no longer depend on "
f"these two parameters.")
if 'style_embedder.weight' not in state_dict and hasattr(self.model, 'style_embedder'):
raise ValueError(f"You might be attempting to load the weights of HunYuanDiT version >= 1.2. You need "
f"to remove `--use-style-cond` and `--size-cond 1024 1024` to adapt to these weights.")
# Don't set strict=False. Always explicitly check the state_dict.
self.model.load_state_dict(state_dict, strict=True)
def load_sampler(self, sampler=None):
pipeline, sampler = get_pipeline(self.args,
self.vae,
self.clip_text_encoder,
self.tokenizer,
self.model,
device=self.device,
rank=0,
embedder_t5=self.embedder_t5,
infer_mode=self.infer_mode,
sampler=sampler,
)
return pipeline, sampler
def calc_rope(self, height, width):
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
return rope
def standard_shapes(self):
resolutions = ResolutionGroup()
freqs_cis_img = {}
for reso in resolutions.data:
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
return resolutions, freqs_cis_img
def predict(self,
user_prompt,
height=1024,
width=1024,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=100,
guidance_scale=6,
batch_size=1,
src_size_cond=(1024, 1024),
sampler=None,
use_style_cond=False,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = set_seeds(seed, device=self.device)
# ========================================================================
# Arguments: target_width, target_height
# ========================================================================
if width <= 0 or height <= 0:
raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
logger.info(f"Input (height, width) = ({height}, {width})")
if self.infer_mode in ['fa', 'torch']:
# We must force height and width to align to 16 and to be an integer.
target_height = int((height // 16) * 16)
target_width = int((width // 16) * 16)
logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
elif self.infer_mode == 'trt':
target_width, target_height = get_standard_shape(width, height)
logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is None or negative_prompt == '':
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
# ========================================================================
# Arguments: style. (A fixed argument. Don't Change it.)
# ========================================================================
if use_style_cond:
# Only for hydit <= 1.1
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
else:
style = None
# ========================================================================
# Inner arguments: image_meta_size (Please refer to SDXL.)
# ========================================================================
if src_size_cond is None:
size_cond = None
image_meta_size = None
else:
# Only for hydit <= 1.1
if isinstance(src_size_cond, int):
src_size_cond = [src_size_cond, src_size_cond]
if not isinstance(src_size_cond, (list, tuple)):
raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
if len(src_size_cond) != 2:
raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
# ========================================================================
start_time = time.time()
logger.debug(f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
(height, width): {(target_height, target_width)}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
image_meta_size: {size_cond}
""")
reso = f'{target_height}x{target_width}'
if reso in self.freqs_cis_img:
freqs_cis_img = self.freqs_cis_img[reso]
else:
freqs_cis_img = self.calc_rope(target_height, target_width)
if sampler is not None and sampler != self.sampler:
self.pipeline, self.sampler = self.load_sampler(sampler)
samples = self.pipeline(
height=target_height,
width=target_width,
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
image_meta_size=image_meta_size,
style=style,
return_dict=False,
generator=generator,
freqs_cis_img=freqs_cis_img,
use_fp16=self.args.use_fp16,
learn_sigma=self.args.learn_sigma,
)[0]
gen_time = time.time() - start_time
logger.debug(f"Success, time: {gen_time}")
return {
'images': samples,
'seed': seed,
}