diff --git a/pipelines/video_editing.py b/pipelines/video_editing.py index 1f2d290..593cbd3 100644 --- a/pipelines/video_editing.py +++ b/pipelines/video_editing.py @@ -24,6 +24,8 @@ from datasets import video_transforms from torchvision import transforms from models.unet import UNet3DConditionModel +from einops import repeat +from utils import dct_low_pass_filter, exchanged_mixed_dct_freq def prepare_image(path, vae, transform_video, device, dtype=torch.float16): with open(path, 'rb') as f: @@ -142,7 +144,8 @@ def main(args): # video_path = './video_editing/A_man_walking_on_the_beach.mp4' - video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4' + # video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4' + video_path = './video_editing/test_03.mp4' video_reader = DecordInit() @@ -157,7 +160,11 @@ def main(args): # image_path = "./video_editing/a_man_walking_in_the_park.png" image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png" - edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device) + + if args.use_dct: + edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device) + else: + edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device) if not os.path.exists(args.save_img_path): os.makedirs(args.save_img_path) @@ -177,7 +184,38 @@ def main(args): output_type="latent").video # prompt = 'a man walking in the park' - prompt = 'a corgi walking in the park at sunrise, oil painting style' + # prompt = 'a corgi walking in the park at sunrise, oil painting style' + prompt = 'A girl is playing the guitar in her room' + + if args.use_dct: + # filter params + print("Using DCT!") + edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous() + + # define filter + freq_filter = dct_low_pass_filter(dct_coefficients=edit_content, + percentage=0.23) + + noise = latents.to(dtype=torch.float64) + + # add noise to base_content + diffuse_timesteps = torch.full((1,),int(985)) + diffuse_timesteps = diffuse_timesteps.long() + + # 3d content + edit_content_noise = scheduler.add_noise( + original_samples=edit_content_repeat.to(device), + noise=noise, + timesteps=diffuse_timesteps.to(device)) + + # 3d content + latents = exchanged_mixed_dct_freq(noise=noise, + base_content=edit_content_noise, + LPF_3d=freq_filter).to(dtype=torch.float16) + + latents = latents.to(dtype=torch.float16) + edit_content = edit_content.to(dtype=torch.float16) + videos = videogen_pipeline(prompt, latents=latents, base_content=edit_content, @@ -186,7 +224,7 @@ def main(args): width=args.image_size[1], num_inference_steps=args.num_sampling_steps, guidance_scale=1.0, - # guidance_scale=args.guidance_scale, + # guidance_scale=args.guidance_scale, motion_bucket_id=args.motion_bucket_id, enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0 @@ -197,6 +235,4 @@ def main(args): parser.add_argument("--config", type=str, default="./configs/sample.yaml") args = parser.parse_args() - main(OmegaConf.load(args.config)) - - + main(OmegaConf.load(args.config)) \ No newline at end of file