Skip to content

Commit

Permalink
add dctinit for video editing
Browse files Browse the repository at this point in the history
  • Loading branch information
maxin-cn committed Aug 9, 2024
1 parent d0e9956 commit 0ad617a
Showing 1 changed file with 43 additions and 7 deletions.
50 changes: 43 additions & 7 deletions pipelines/video_editing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from datasets import video_transforms
from torchvision import transforms
from models.unet import UNet3DConditionModel
from einops import repeat
from utils import dct_low_pass_filter, exchanged_mixed_dct_freq

def prepare_image(path, vae, transform_video, device, dtype=torch.float16):
with open(path, 'rb') as f:
Expand Down Expand Up @@ -142,7 +144,8 @@ def main(args):


# video_path = './video_editing/A_man_walking_on_the_beach.mp4'
video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
# video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4'
video_path = './video_editing/test_03.mp4'


video_reader = DecordInit()
Expand All @@ -157,7 +160,11 @@ def main(args):

# image_path = "./video_editing/a_man_walking_in_the_park.png"
image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png"
edit_content = prepare_image(image_path, vae, transform_video, device, dtype=torch.float16).to(device)

if args.use_dct:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device)
else:
edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device)

if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)
Expand All @@ -177,7 +184,38 @@ def main(args):
output_type="latent").video

# prompt = 'a man walking in the park'
prompt = 'a corgi walking in the park at sunrise, oil painting style'
# prompt = 'a corgi walking in the park at sunrise, oil painting style'
prompt = 'A girl is playing the guitar in her room'

if args.use_dct:
# filter params
print("Using DCT!")
edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous()

# define filter
freq_filter = dct_low_pass_filter(dct_coefficients=edit_content,
percentage=0.23)

noise = latents.to(dtype=torch.float64)

# add noise to base_content
diffuse_timesteps = torch.full((1,),int(985))
diffuse_timesteps = diffuse_timesteps.long()

# 3d content
edit_content_noise = scheduler.add_noise(
original_samples=edit_content_repeat.to(device),
noise=noise,
timesteps=diffuse_timesteps.to(device))

# 3d content
latents = exchanged_mixed_dct_freq(noise=noise,
base_content=edit_content_noise,
LPF_3d=freq_filter).to(dtype=torch.float16)

latents = latents.to(dtype=torch.float16)
edit_content = edit_content.to(dtype=torch.float16)

videos = videogen_pipeline(prompt,
latents=latents,
base_content=edit_content,
Expand All @@ -186,7 +224,7 @@ def main(args):
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=1.0,
# guidance_scale=args.guidance_scale,
# guidance_scale=args.guidance_scale,
motion_bucket_id=args.motion_bucket_id,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0
Expand All @@ -197,6 +235,4 @@ def main(args):
parser.add_argument("--config", type=str, default="./configs/sample.yaml")
args = parser.parse_args()

main(OmegaConf.load(args.config))


main(OmegaConf.load(args.config))

0 comments on commit 0ad617a

Please sign in to comment.