import math import os import random import threading import time import argparse import cv2 import tempfile import imageio_ffmpeg import gradio as gr import torch from PIL import Image from diffusers import ( CogVideoXPipeline, CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, ) from diffusers.utils import export_to_video, load_video, load_image from datetime import datetime, timedelta from diffusers.image_processor import VaeImageProcessor from openai import OpenAI import moviepy.editor as mp import utils from rife_model import load_rife_model, rife_inference_with_latents from huggingface_hub import hf_hub_download, snapshot_download import gc import platform # Add imports for quantization from transformers import T5EncoderModel from diffusers import AutoencoderKLCogVideoX def is_bf16_supported(): if torch.cuda.is_available(): return torch.cuda.is_bf16_supported() return False if is_bf16_supported(): default_dtype = torch.bfloat16 print("Using bfloat16 precision") else: default_dtype = torch.float16 print("Using float16 precision") def open_folder(folder_path): if platform.system() == "Windows": os.startfile(folder_path) elif platform.system() == "Linux": os.system(f'xdg-open "{folder_path}"') elif platform.system() == "Darwin": # macOS os.system(f'open "{folder_path}"') try: from torchao.quantization import quantize_, int8_weight_only, int8_dynamic_activation_int8_weight TORCHAO_AVAILABLE = True except ImportError: TORCHAO_AVAILABLE = False device = "cuda" if torch.cuda.is_available() else "cpu" hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran") snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=default_dtype).to("cpu") pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") i2v_transformer = CogVideoXTransformer3DModel.from_pretrained( "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=default_dtype ) os.makedirs("./outputs", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) upscale_model = utils.load_sd_upscale("model_real_esran/RealESRGAN_x4.pth", device) frame_interpolation_model = load_rife_model("model_rife") def load_and_quantize_model(quantization_type): text_encoder = T5EncoderModel.from_pretrained("THUDM/CogVideoX-5b-I2V", subfolder="text_encoder", torch_dtype=default_dtype) transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=default_dtype) vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-5b-I2V", subfolder="vae", torch_dtype=default_dtype) if quantization_type == "int8" and TORCHAO_AVAILABLE: quantize_(text_encoder, int8_weight_only()) quantize_(transformer, int8_weight_only()) quantize_(vae, int8_weight_only()) elif quantization_type == "fp8": # Check if GPU supports FP8 text_encoder = text_encoder.to(torch.float8_e4m3fn) transformer = transformer.to(torch.float8_e4m3fn) vae = vae.to(torch.float8_e4m3fn) return text_encoder, transformer, vae def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)): width, height = get_video_dimensions(input_video) if width == 720 and height == 480: processed_video = input_video else: processed_video = center_crop_resize(input_video) return processed_video def get_video_dimensions(input_video_path): reader = imageio_ffmpeg.read_frames(input_video_path) metadata = next(reader) return metadata["size"] def center_crop_resize(input_video_path, target_width=720, target_height=480): cap = cv2.VideoCapture(input_video_path) orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) orig_fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width_factor = target_width / orig_width height_factor = target_height / orig_height resize_factor = max(width_factor, height_factor) inter_width = int(orig_width * resize_factor) inter_height = int(orig_height * resize_factor) target_fps = 8 ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1) skip = min(5, ideal_skip) # Cap at 5 while (total_frames / (skip + 1)) < 49 and skip > 0: skip -= 1 processed_frames = [] frame_count = 0 total_read = 0 while frame_count < 49 and total_read < total_frames: ret, frame = cap.read() if not ret: break if total_read % (skip + 1) == 0: resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA) start_x = (inter_width - target_width) // 2 start_y = (inter_height - target_height) // 2 cropped = resized[start_y : start_y + target_height, start_x : start_x + target_width] processed_frames.append(cropped) frame_count += 1 total_read += 1 cap.release() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: temp_video_path = temp_file.name fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(temp_video_path, fourcc, target_fps, (target_width, target_height)) for frame in processed_frames: out.write(frame) out.release() return temp_video_path def infer( prompt: str, image_input: str, video_input: str, video_strenght: float, num_inference_steps: int, guidance_scale: float, seed: int = -1, use_cpu_offload: bool = True, use_slicing: bool = True, use_tiling: bool = True, quantization_type: str = "none", progress=gr.Progress(track_tqdm=True), ): if seed == -1: seed = random.randint(0, 2**8 - 1) text_encoder, transformer, vae = load_and_quantize_model(quantization_type) if video_input is not None: video = load_video(video_input)[:49] # Limit to 49 frames pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained( "THUDM/CogVideoX-5b", transformer=transformer, vae=vae, scheduler=pipe.scheduler, tokenizer=pipe.tokenizer, text_encoder=text_encoder, torch_dtype=default_dtype, ).to(device) if use_cpu_offload: pipe_video.enable_sequential_cpu_offload() if use_slicing: pipe_video.vae.enable_slicing() if use_tiling: pipe_video.vae.enable_tiling() video_pt = pipe_video( video=video, prompt=prompt, num_inference_steps=num_inference_steps, num_videos_per_prompt=1, strength=video_strenght, use_dynamic_cfg=True, output_type="pt", guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed), ).frames gc.collect() torch.cuda.empty_cache() elif image_input is not None: pipe_image = CogVideoXImageToVideoPipeline.from_pretrained( "THUDM/CogVideoX-5b-I2V", transformer=transformer, vae=vae, scheduler=pipe.scheduler, tokenizer=pipe.tokenizer, text_encoder=text_encoder, torch_dtype=default_dtype, ).to(device) if use_cpu_offload: pipe_image.enable_sequential_cpu_offload() if use_slicing: pipe_image.vae.enable_slicing() if use_tiling: pipe_image.vae.enable_tiling() image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL image = load_image(image_input) video_pt = pipe_image( image=image, prompt=prompt, num_inference_steps=num_inference_steps, num_videos_per_prompt=1, use_dynamic_cfg=True, output_type="pt", guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed), ).frames gc.collect() torch.cuda.empty_cache() else: pipe.to(device) pipe.transformer = transformer pipe.vae = vae pipe.text_encoder = text_encoder if use_cpu_offload: pipe.enable_sequential_cpu_offload() if use_slicing: pipe.vae.enable_slicing() if use_tiling: pipe.vae.enable_tiling() video_pt = pipe( prompt=prompt, num_videos_per_prompt=1, num_inference_steps=num_inference_steps, num_frames=49, use_dynamic_cfg=True, output_type="pt", guidance_scale=guidance_scale, generator=torch.Generator(device="cpu").manual_seed(seed), ).frames gc.collect() return (video_pt, seed) def get_unique_filename(base_path, extension): directory = os.path.dirname(base_path) filename = os.path.basename(base_path) name, ext = os.path.splitext(filename) counter = 0 while True: if counter == 0: new_filename = f"{name}{extension}" else: new_filename = f"{name}_{counter:04d}{extension}" new_path = os.path.join(directory, new_filename) if not os.path.exists(new_path): return new_path counter += 1 def delete_old_files(): while True: now = datetime.now() cutoff = now - timedelta(minutes=10) directories = ["./outputs", "./gradio_tmp"] for directory in directories: for filename in os.listdir(directory): file_path = os.path.join(directory, filename) if os.path.isfile(file_path): file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) if file_mtime < cutoff: os.remove(file_path) time.sleep(600) threading.Thread(target=delete_old_files, daemon=True).start() def generate( prompt, image_input, video_input, video_strength, seed_value, num_inference_steps, guidance_scale, scale_status, rife_status, use_cpu_offload, use_slicing, use_tiling, quantization_type, num_generations, progress=gr.Progress(track_tqdm=True) ): all_video_paths = [] all_gif_paths = [] all_seeds = [] for i in range(num_generations): latents, seed = infer( prompt, image_input, video_input, video_strength, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed_value if i == 0 else -1, # Use provided seed only for first generation use_cpu_offload=use_cpu_offload, use_slicing=use_slicing, use_tiling=use_tiling, quantization_type=quantization_type, progress=progress, ) if rife_status: latents = rife_inference_with_latents(frame_interpolation_model, latents) if scale_status: latents = utils.upscale_batch_and_concatenate(upscale_model, latents, device) batch_size = latents.shape[0] batch_video_frames = [] for batch_idx in range(batch_size): pt_image = latents[batch_idx] pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])]) image_np = VaeImageProcessor.pt_to_numpy(pt_image) image_pil = VaeImageProcessor.numpy_to_pil(image_np) batch_video_frames.append(image_pil) base_filename = "output_" if video_input is None else os.path.splitext(os.path.basename(video_input))[0] video_path = get_unique_filename(os.path.join("outputs", f"{base_filename}.mp4"), ".mp4") utils.save_video(batch_video_frames[0], fps=math.ceil((len(batch_video_frames[0]) - 1) / 6), output_path=video_path) gif_path = get_unique_filename(video_path.replace(".mp4", ".gif"), ".gif") clip = mp.VideoFileClip(video_path) clip = clip.set_fps(8) clip = clip.resize(height=240) clip.write_gif(gif_path, fps=8) all_video_paths.append(video_path) all_gif_paths.append(gif_path) all_seeds.append(seed) # Return only the last generated video for display video_update = gr.update(visible=True, value=all_video_paths[-1]) gif_update = gr.update(visible=True, value=all_gif_paths[-1]) seed_update = gr.update(visible=True, value=all_seeds[-1]) return all_video_paths[-1], video_update, gif_update, seed_update with gr.Blocks() as demo: gr.Markdown("""