-
Notifications
You must be signed in to change notification settings - Fork 546
/
sample_video.py
58 lines (49 loc) · 1.99 KB
/
sample_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import time
from pathlib import Path
from loguru import logger
from datetime import datetime
from hyvideo.utils.file_utils import save_videos_grid
from hyvideo.config import parse_args
from hyvideo.inference import HunyuanVideoSampler
def main():
args = parse_args()
print(args)
models_root_path = Path(args.model_base)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Create save folder to save the samples
save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
if not os.path.exists(args.save_path):
os.makedirs(save_path, exist_ok=True)
# Load models
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
# Get the updated args
args = hunyuan_video_sampler.args
# Start sampling
# TODO: batch inference check
outputs = hunyuan_video_sampler.predict(
prompt=args.prompt,
height=args.video_size[0],
width=args.video_size[1],
video_length=args.video_length,
seed=args.seed,
negative_prompt=args.neg_prompt,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
num_videos_per_prompt=args.num_videos,
flow_shift=args.flow_shift,
batch_size=args.batch_size,
embedded_guidance_scale=args.embedded_cfg_scale
)
samples = outputs['samples']
# Save samples
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
save_videos_grid(sample, save_path, fps=24)
logger.info(f'Sample save to: {save_path}')
if __name__ == "__main__":
main()