From 8cba16a0a781541743d67f4c4600e0d7f363758a Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 20 Sep 2024 08:52:04 +0200 Subject: [PATCH 1/9] update --- .../train_cogvideox_image_to_video_lora.py | 1569 +++++++++++++++++ examples/cogvideo/train_cogvideox_lora.py | 9 +- 2 files changed, 1576 insertions(+), 2 deletions(-) create mode 100644 examples/cogvideo/train_cogvideox_image_to_video_lora.py diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py new file mode 100644 index 000000000000..650609f64daa --- /dev/null +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -0,0 +1,1569 @@ +# Copyright 2024 The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer + +import diffusers +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDPMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.optimization import get_scheduler +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid +from diffusers.training_utils import ( + cast_training_params, + clear_objs_and_retain_memory, +) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script for CogVideoX.") + + # Model information + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + # Dataset information + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_root", + type=str, + default=None, + help=("A folder containing the training data."), + ) + parser.add_argument( + "--video_column", + type=str, + default="video", + help="The column of the dataset containing videos. Or, the name of the file in `--instance_data_root` folder containing the line-separated path to video data.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--instance_data_root` folder containing the line-separated instance prompts.", + ) + parser.add_argument( + "--id_token", type=str, default=None, help="Identifier token appended to the start of each prompt if provided." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + # Validation + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", + ) + parser.add_argument( + "--validation_prompt_separator", + type=str, + default=":::", + help="String that separates multiple validation prompts", + ) + parser.add_argument( + "--num_validation_videos", + type=int, + default=1, + help="Number of videos that should be generated during validation per `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run validation every X epochs. Validation consists of running the prompt `args.validation_prompt` multiple times: `args.num_validation_videos`." + ), + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6, + help="The guidance scale to use while sampling validation videos.", + ) + parser.add_argument( + "--use_dynamic_cfg", + action="store_true", + default=False, + help="Whether or not to use the default cosine dynamic guidance schedule when sampling validation videos.", + ) + + # Training information + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=128, + help=("The scaling factor to scale LoRA weight update. The actual scaling factor is `lora_alpha / rank`"), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="cogvideox-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="All input videos are resized to this height.", + ) + parser.add_argument( + "--width", + type=int, + default=720, + help="All input videos are resized to this width.", + ) + parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") + parser.add_argument( + "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." + ) + parser.add_argument( + "--skip_frames_start", + type=int, + default=0, + help="Number of frames to skip from the beginning of each input video. Useful if training data contains intro sequences.", + ) + parser.add_argument( + "--skip_frames_end", + type=int, + default=0, + help="Number of frames to skip from the end of each input video. Useful if training data contains outro sequences.", + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip videos horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) + + # Optimizer + parser.add_argument( + "--optimizer", + type=lambda s: s.lower(), + default="adam", + choices=["adam", "adamw", "prodigy"], + help=("The optimizer type to use."), + ) + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--prodigy_use_bias_correction", action="store_true", help="Turn on Adam's bias correction.") + parser.add_argument( + "--prodigy_safeguard_warmup", + action="store_true", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) + + # Other information + parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help="Directory where logs are stored.", + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default=None, + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + return parser.parse_args() + + +class VideoDataset(Dataset): + def __init__( + self, + instance_data_root: Optional[str] = None, + dataset_name: Optional[str] = None, + dataset_config_name: Optional[str] = None, + caption_column: str = "text", + video_column: str = "video", + height: int = 480, + width: int = 720, + fps: int = 8, + max_num_frames: int = 49, + skip_frames_start: int = 0, + skip_frames_end: int = 0, + cache_dir: Optional[str] = None, + id_token: Optional[str] = None, + ) -> None: + super().__init__() + + self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.caption_column = caption_column + self.video_column = video_column + self.height = height + self.width = width + self.fps = fps + self.max_num_frames = max_num_frames + self.skip_frames_start = skip_frames_start + self.skip_frames_end = skip_frames_end + self.cache_dir = cache_dir + self.id_token = id_token or "" + + if dataset_name is not None: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() + else: + self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + + self.num_instance_videos = len(self.instance_video_paths) + if self.num_instance_videos != len(self.instance_prompts): + raise ValueError( + f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." + ) + + self.instance_videos = self._preprocess_data() + + def __len__(self): + return self.num_instance_videos + + def __getitem__(self, index): + return { + "instance_prompt": self.id_token + self.instance_prompts[index], + "instance_video": self.instance_videos[index], + } + + def _load_dataset_from_hub(self): + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_root instead." + ) + + # Downloading and loading a dataset from the hub. See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + self.dataset_name, + self.dataset_config_name, + cache_dir=self.cache_dir, + ) + column_names = dataset["train"].column_names + + if self.video_column is None: + video_column = column_names[0] + logger.info(f"`video_column` defaulting to {video_column}") + else: + video_column = self.video_column + if video_column not in column_names: + raise ValueError( + f"`--video_column` value '{video_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if self.caption_column is None: + caption_column = column_names[1] + logger.info(f"`caption_column` defaulting to {caption_column}") + else: + caption_column = self.caption_column + if self.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{self.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + instance_prompts = dataset["train"][caption_column] + instance_videos = [Path(self.instance_data_root, filepath) for filepath in dataset["train"][video_column]] + + return instance_prompts, instance_videos + + def _load_dataset_from_local_path(self): + if not self.instance_data_root.exists(): + raise ValueError("Instance videos root folder does not exist") + + prompt_path = self.instance_data_root.joinpath(self.caption_column) + video_path = self.instance_data_root.joinpath(self.video_column) + + if not prompt_path.exists() or not prompt_path.is_file(): + raise ValueError( + "Expected `--caption_column` to be path to a file in `--instance_data_root` containing line-separated text prompts." + ) + if not video_path.exists() or not video_path.is_file(): + raise ValueError( + "Expected `--video_column` to be path to a file in `--instance_data_root` containing line-separated paths to video data in the same directory." + ) + + with open(prompt_path, "r", encoding="utf-8") as file: + instance_prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] + with open(video_path, "r", encoding="utf-8") as file: + instance_videos = [ + self.instance_data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0 + ] + + if any(not path.is_file() for path in instance_videos): + raise ValueError( + "Expected '--video_column' to be a path to a file in `--instance_data_root` containing line-separated paths to video data but found atleast one path that is not a valid file." + ) + + return instance_prompts, instance_videos + + def _preprocess_data(self): + try: + import decord + except ImportError: + raise ImportError( + "The `decord` package is required for loading the video dataset. Install with `pip install decord`" + ) + + decord.bridge.set_bridge("torch") + + videos = [] + train_transforms = transforms.Compose( + [ + transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), + ] + ) + + for filename in self.instance_video_paths: + video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_num_frames = len(video_reader) + + start_frame = min(self.skip_frames_start, video_num_frames) + end_frame = max(0, video_num_frames - self.skip_frames_end) + if end_frame <= start_frame: + frames = video_reader.get_batch([start_frame]) + elif end_frame - start_frame <= self.max_num_frames: + frames = video_reader.get_batch(list(range(start_frame, end_frame))) + else: + indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) + frames = video_reader.get_batch(indices) + + # Ensure that we don't go over the limit + frames = frames[: self.max_num_frames] + selected_num_frames = frames.shape[0] + + # Choose first (4k + 1) frames as this is how many is required by the VAE + remainder = (3 + (selected_num_frames % 4)) % 4 + if remainder != 0: + frames = frames[:-remainder] + selected_num_frames = frames.shape[0] + + assert (selected_num_frames - 1) % 4 == 0 + + # Training transforms + frames = frames.float() + frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] + + return videos + + +def save_model_card( + repo_id: str, + videos=None, + base_model: str = None, + validation_prompt=None, + repo_folder=None, + fps=8, +): + widget_dict = [] + if videos is not None: + for i, video in enumerate(videos): + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} + ) + + model_description = f""" +# CogVideoX LoRA - {repo_id} + + + +## Model description + +These are {repo_id} LoRA weights for {base_model}. + +The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). + +Was LoRA for the text encoder enabled? No. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import CogVideoXPipeline +import torch + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"]) + +# The LoRA adapter weights are determined by what was used for training. +# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. +# It can be made lower or higher from what was used in training to decrease or amplify the effect +# of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. +pipe.set_adapters(["cogvideox-lora"], [32 / 64]) + +video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/THUDM/CogVideoX-5b-I2V/blob/main/LICENSE). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=validation_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "image-to-video", + "diffusers-training", + "diffusers", + "lora", + "cogvideox", + "cogvideox-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipe, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation: bool = False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_videos} videos with prompt: {pipeline_args['prompt']}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipe.scheduler.config: + variance_type = pipe.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args) + pipe = pipe.to(accelerator.device) + # pipe.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + videos = [] + for _ in range(args.num_validation_videos): + video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0] + videos.append(video) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + + tracker.log( + { + phase_name: [ + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) + ] + } + ) + + clear_objs_and_retain_memory([pipe]) + + return videos + + +def _get_t5_prompt_embeds( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + +def encode_prompt( + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = _get_t5_prompt_embeds( + tokenizer, + text_encoder, + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + text_input_ids=text_input_ids, + ) + return prompt_embeds + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): + # Use DeepSpeed optimzer + if use_deepspeed: + from accelerate.utils import DummyOptim + + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + # Optimizer creation + supported_optimizers = ["adam", "adamw", "prodigy"] + if args.optimizer not in supported_optimizers: + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]: + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + if args.optimizer.lower() == "adamw": + optimizer_class = bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "adam": + optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + return optimizer + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prepare models and scheduler + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 + transformer = CogVideoXTransformer3DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, + ) + + vae = AutoencoderKLCogVideoX.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + + scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + + # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.float16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + init_lora_weights=True, + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + CogVideoXImageToVideoPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"Unexpected save model: {model.__class__}") + + lora_state_dict = CogVideoXImageToVideoPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer_]) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params([transformer], dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) + + # Dataset and DataLoader + train_dataset = VideoDataset( + instance_data_root=args.instance_data_root, + dataset_name=args.dataset_name, + dataset_config_name=args.dataset_config_name, + caption_column=args.caption_column, + video_column=args.video_column, + height=args.height, + width=args.width, + fps=args.fps, + max_num_frames=args.max_num_frames, + skip_frames_start=args.skip_frames_start, + skip_frames_end=args.skip_frames_end, + cache_dir=args.cache_dir, + id_token=args.id_token, + ) + + def encode_video(video): + video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) + video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] + image = video[:, :, :1].clone() + latent_dist = vae.encode(video).latent_dist + image_latent_dist = vae.encode(image).latent_dist + return latent_dist, image_latent_dist + + train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] + + def collate_fn(examples): + videos = [] + for example in examples: + latent_dist, image_latent_dist = example["instance_video"] + + video_latents = latent_dist.sample() * vae.config.scaling_factor + image_latents = image_latent_dist.sample() * vae.config.scaling_factor + + padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) + latent_padding = image_latents.zeros_like(padding_shape) + image_latents = torch.cat([image_latents, latent_padding], dim=1) + + latents = torch.cat([video_latents, image_latents], dim=2) + videos.append(latents) + + prompts = [example["instance_prompt"] for example in examples] + + videos = torch.cat(videos) + videos = videos.to(memory_format=torch.contiguous_format).float() + + return { + "videos": videos, + "prompts": prompts, + } + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + + lr_scheduler = DummyScheduler( + name=args.lr_scheduler, + optimizer=optimizer, + total_num_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = args.tracker_name or "cogvideox-lora" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) + + logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if not args.resume_from_checkpoint: + initial_global_step = 0 + else: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + + # Delete VAE to save memory + clear_objs_and_retain_memory([vae]) + + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + + with accelerator.accumulate(models_to_accumulate): + model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] + prompts = batch["prompts"] + + # encode prompts + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + model_config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + + # Sample noise that will be added to the latents + noise = torch.randn_like(model_input) + batch_size, num_frames, num_channels, height, width = model_input.shape + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device + ) + timesteps = timesteps.long() + + # Prepare rotary embeds + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=args.height, + width=args.width, + num_frames=num_frames, + vae_scale_factor_spatial=vae_scale_factor_spatial, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, + device=accelerator.device, + ) + if model_config.use_rotary_positional_embeddings + else None + ) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_output = transformer( + hidden_states=noisy_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timesteps, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + + alphas_cumprod = scheduler.alphas_cumprod[timesteps] + weights = 1 / (1 - alphas_cumprod) + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + + target = model_input + + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + if accelerator.state.deepspeed_plugin is None: + optimizer.step() + optimizer.zero_grad() + + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"Removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: + # Create pipeline + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + text_encoder=unwrap_model(text_encoder), + scheduler=scheduler, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + validation_outputs = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + ) + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + dtype = ( + torch.float16 + if args.mixed_precision == "fp16" + else torch.bfloat16 + if args.mixed_precision == "bf16" + else torch.float32 + ) + transformer = transformer.to(dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + CogVideoXImageToVideoPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + ) + + # Cleanup trained models to save memory + clear_objs_and_retain_memory([transformer]) + + # Final test inference + pipe = CogVideoXImageToVideoPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + + # Load LoRA weights + lora_scaling = args.lora_alpha / args.rank + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") + pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + + # Run inference + validation_outputs = [] + if args.validation_prompt and args.num_validation_videos > 0: + validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) + for validation_prompt in validation_prompts: + pipeline_args = { + "prompt": validation_prompt, + "guidance_scale": args.guidance_scale, + "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, + } + + video = log_validation( + pipe=pipe, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + is_final_validation=True, + ) + validation_outputs.extend(video) + + if args.push_to_hub: + save_model_card( + repo_id, + videos=validation_outputs, + base_model=args.pretrained_model_name_or_path, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + fps=args.fps, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 137f3222f6d9..1235064b020f 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -870,7 +870,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): ) args.optimizer = "adamw" - if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): + if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]: logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" @@ -1305,6 +1305,9 @@ def collate_fn(examples): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + # Delete VAE to save memory + clear_objs_and_retain_memory([vae]) + # For DeepSpeed training model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config @@ -1434,7 +1437,6 @@ def collate_fn(examples): args.pretrained_model_name_or_path, transformer=unwrap_model(transformer), text_encoder=unwrap_model(text_encoder), - vae=unwrap_model(vae), scheduler=scheduler, revision=args.revision, variant=args.variant, @@ -1478,6 +1480,9 @@ def collate_fn(examples): transformer_lora_layers=transformer_lora_layers, ) + # Cleanup trained models to save memory + clear_objs_and_retain_memory([transformer]) + # Final test inference pipe = CogVideoXPipeline.from_pretrained( args.pretrained_model_name_or_path, From f79eb1de0a326816986ff0d773e7c37117da7048 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Sep 2024 02:33:57 +0200 Subject: [PATCH 2/9] update --- .../train_cogvideox_image_to_video_lora.py | 134 ++++++++++++------ examples/cogvideo/train_cogvideox_lora.py | 3 +- .../pipeline_cogvideox_image2video.py | 16 ++- 3 files changed, 109 insertions(+), 44 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 650609f64daa..3a5723f516ad 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -17,6 +17,7 @@ import logging import math import os +import random import shutil from pathlib import Path from typing import List, Optional, Tuple, Union @@ -47,7 +48,13 @@ cast_training_params, clear_objs_and_retain_memory, ) -from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + export_to_video, + is_wandb_available, + load_image, +) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -146,6 +153,12 @@ def get_args(): default=None, help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", ) + parser.add_argument( + "--validation_images", + type=str, + default=None, + help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", + ) parser.add_argument( "--validation_prompt_separator", type=str, @@ -207,7 +220,7 @@ def get_args(): parser.add_argument( "--output_dir", type=str, - default="cogvideox-lora", + default="cogvideox-i2v-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( @@ -332,6 +345,12 @@ def get_args(): default=False, help="Whether or not to use VAE tiling for saving memory.", ) + parser.add_argument( + "--noised_image_dropout", + type=float, + default=0.05, + help="Image condition dropout probability.", + ) # Optimizer parser.add_argument( @@ -449,6 +468,8 @@ def __init__( else: self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() + self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] + self.num_instance_videos = len(self.instance_video_paths) if self.num_instance_videos != len(self.instance_prompts): raise ValueError( @@ -462,7 +483,7 @@ def __len__(self): def __getitem__(self, index): return { - "instance_prompt": self.id_token + self.instance_prompts[index], + "instance_prompt": self.instance_prompts[index], "instance_video": self.instance_videos[index], } @@ -602,9 +623,10 @@ def save_model_card( widget_dict = [] if videos is not None: for i, video in enumerate(videos): - export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) + video_path = f"final_video_{i}.mp4" + export_to_video(video, os.path.join(repo_folder, video_path, fps=fps)) widget_dict.append( - {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} + {"text": validation_prompt if validation_prompt else " ", "output": {"url": video_path}}, ) model_description = f""" @@ -616,7 +638,7 @@ def save_model_card( These are {repo_id} LoRA weights for {base_model}. -The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). +The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_image_to_video_lora.py). Was LoRA for the text encoder enabled? No. @@ -627,19 +649,22 @@ def save_model_card( ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) ```py -from diffusers import CogVideoXPipeline import torch +from diffusers import CogVideoXImageToVideoPipeline +from diffusers.utils import load_image, export_to_video -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") -pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"]) +pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda") +pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-i2v-lora"]) # The LoRA adapter weights are determined by what was used for training. # In this case, we assume `--lora_alpha` is 32 and `--rank` is 64. # It can be made lower or higher from what was used in training to decrease or amplify the effect # of the LoRA upto a tolerance, beyond which one might notice no effect at all or overflows. -pipe.set_adapters(["cogvideox-lora"], [32 / 64]) +pipe.set_adapters(["cogvideox-i2v-lora"], [32 / 64]) -video = pipe("{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +image = load_image("/path/to/image") +video = pipe(image=image, "{validation_prompt}", guidance_scale=6, use_dynamic_cfg=True).frames[0] +export_to_video(video, "output.mp4", fps=8) ``` For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) @@ -1190,34 +1215,61 @@ def encode_video(video): video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] image = video[:, :, :1].clone() + latent_dist = vae.encode(video).latent_dist - image_latent_dist = vae.encode(image).latent_dist + + image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) + image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) + noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] + image_latent_dist = vae.encode(noisy_image).latent_dist + return latent_dist, image_latent_dist + train_dataset.instance_prompts = [ + compute_prompt_embeddings( + tokenizer, + text_encoder, + [prompt], + transformer.config.max_text_seq_length, + accelerator.device, + weight_dtype, + requires_grad=False, + ) + for prompt in train_dataset.instance_prompts + ] train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] def collate_fn(examples): videos = [] + images = [] for example in examples: latent_dist, image_latent_dist = example["instance_video"] video_latents = latent_dist.sample() * vae.config.scaling_factor image_latents = image_latent_dist.sample() * vae.config.scaling_factor + video_latents = video_latents.permute(0, 2, 1, 3, 4) + image_latents = image_latents.permute(0, 2, 1, 3, 4) padding_shape = (video_latents.shape[0], video_latents.shape[1] - 1, *video_latents.shape[2:]) - latent_padding = image_latents.zeros_like(padding_shape) + latent_padding = image_latents.new_zeros(padding_shape) image_latents = torch.cat([image_latents, latent_padding], dim=1) - latents = torch.cat([video_latents, image_latents], dim=2) - videos.append(latents) + if random.random() < args.noised_image_dropout: + image_latents = torch.zeros_like(image_latents) - prompts = [example["instance_prompt"] for example in examples] + videos.append(video_latents) + images.append(image_latents) videos = torch.cat(videos) + images = torch.cat(images) videos = videos.to(memory_format=torch.contiguous_format).float() + images = images.to(memory_format=torch.contiguous_format).float() + + prompts = [example["instance_prompt"] for example in examples] + prompts = torch.cat(prompts) return { - "videos": videos, + "videos": (videos, images), "prompts": prompts, } @@ -1270,7 +1322,7 @@ def collate_fn(examples): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = args.tracker_name or "cogvideox-lora" + tracker_name = args.tracker_name or "cogvideox-i2v-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1325,8 +1377,8 @@ def collate_fn(examples): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - # Delete VAE to save memory - clear_objs_and_retain_memory([vae]) + # Delete VAE and Text Encoder to save memory + clear_objs_and_retain_memory([vae, text_encoder]) # For DeepSpeed training model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config @@ -1338,19 +1390,12 @@ def collate_fn(examples): models_to_accumulate = [transformer] with accelerator.accumulate(models_to_accumulate): - model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] - prompts = batch["prompts"] - - # encode prompts - prompt_embeds = compute_prompt_embeddings( - tokenizer, - text_encoder, - prompts, - model_config.max_text_seq_length, - accelerator.device, - weight_dtype, - requires_grad=False, - ) + video_latents, image_latents = batch["videos"] + video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W] + image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W] + model_input = torch.cat([video_latents, image_latents], dim=2) + + prompt_embeds = batch["prompts"] # Sample noise that will be added to the latents noise = torch.randn_like(model_input) @@ -1389,14 +1434,14 @@ def collate_fn(examples): image_rotary_emb=image_rotary_emb, return_dict=False, )[0] - model_pred = scheduler.get_velocity(model_output, noisy_model_input, timesteps) + model_pred = scheduler.get_velocity(model_output, noisy_model_input.chunk(2, dim=2)[0], timesteps) alphas_cumprod = scheduler.alphas_cumprod[timesteps] weights = 1 / (1 - alphas_cumprod) while len(weights.shape) < len(model_pred.shape): weights = weights.unsqueeze(-1) - target = model_input + target = video_latents loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) loss = loss.mean() @@ -1456,7 +1501,6 @@ def collate_fn(examples): pipe = CogVideoXImageToVideoPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=unwrap_model(transformer), - text_encoder=unwrap_model(text_encoder), scheduler=scheduler, revision=args.revision, variant=args.variant, @@ -1464,8 +1508,11 @@ def collate_fn(examples): ) validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - for validation_prompt in validation_prompts: + validation_images = args.validation_images.split(args.validation_prompt_separator) + + for validation_image, validation_prompt in zip(validation_images, validation_prompts): pipeline_args = { + "image": load_image(validation_image), "prompt": validation_prompt, "guidance_scale": args.guidance_scale, "use_dynamic_cfg": args.use_dynamic_cfg, @@ -1519,15 +1566,18 @@ def collate_fn(examples): # Load LoRA weights lora_scaling = args.lora_alpha / args.rank - pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora") - pipe.set_adapters(["cogvideox-lora"], [lora_scaling]) + pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-i2v-lora") + pipe.set_adapters(["cogvideox-i2v-lora"], [lora_scaling]) # Run inference validation_outputs = [] if args.validation_prompt and args.num_validation_videos > 0: validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) - for validation_prompt in validation_prompts: + validation_images = args.validation_images.split(args.validation_prompt_separator) + + for validation_image, validation_prompt in zip(validation_images, validation_prompts): pipeline_args = { + "image": load_image(validation_image), "prompt": validation_prompt, "guidance_scale": args.guidance_scale, "use_dynamic_cfg": args.use_dynamic_cfg, @@ -1546,11 +1596,13 @@ def collate_fn(examples): validation_outputs.extend(video) if args.push_to_hub: + validation_prompt = args.validation_prompt or "" + validation_prompt = validation_prompt.split(args.validation_prompt_separator)[0] save_model_card( repo_id, videos=validation_outputs, base_model=args.pretrained_model_name_or_path, - validation_prompt=args.validation_prompt, + validation_prompt=validation_prompt, repo_folder=args.output_dir, fps=args.fps, ) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 1235064b020f..0d8e958a7b62 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1194,6 +1194,7 @@ def collate_fn(examples): prompts = [example["instance_prompt"] for example in examples] videos = torch.cat(videos) + videos = videos.permute(0, 2, 1, 3, 4) videos = videos.to(memory_format=torch.contiguous_format).float() return { @@ -1318,7 +1319,7 @@ def collate_fn(examples): models_to_accumulate = [transformer] with accelerator.accumulate(models_to_accumulate): - model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] + model_input = batch["videos"].to(dtype=weight_dtype) # [B, F, C, H, W] prompts = batch["prompts"] # encode prompts diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a1576be97977..6f3073f6bc1e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL import torch @@ -23,6 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput +from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline @@ -152,7 +153,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class CogVideoXImageToVideoPipeline(DiffusionPipeline): +class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): r""" Pipeline for image-to-video generation using CogVideoX. @@ -547,6 +548,10 @@ def guidance_scale(self): def num_timesteps(self): return self._num_timesteps + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def interrupt(self): return self._interrupt @@ -573,6 +578,7 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -636,6 +642,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -681,6 +691,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 2. Default call parameters @@ -770,6 +781,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_pred.float() From ed45568a698c58dfa853777f6fad17ca445faf34 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 21 Sep 2024 07:18:39 +0200 Subject: [PATCH 3/9] update --- .../train_cogvideox_image_to_video_lora.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 3a5723f516ad..9bd3b5272633 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -19,6 +19,7 @@ import os import random import shutil +from datetime import timedelta from pathlib import Path from typing import List, Optional, Tuple, Union @@ -26,7 +27,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torch.utils.data import DataLoader, Dataset @@ -426,6 +427,7 @@ def get_args(): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) + parser.add_argument("--nccl_timeout", type=int, default=600, help="NCCL backend timeout in seconds.") return parser.parse_args() @@ -976,13 +978,14 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + init_kwargs = InitProcessGroupKwargs(backend="nccl", timeout=timedelta(seconds=args.nccl_timeout)) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, - kwargs_handlers=[kwargs], + kwargs_handlers=[ddp_kwargs, init_kwargs], ) # Disable AMP for MPS. @@ -1391,22 +1394,27 @@ def collate_fn(examples): with accelerator.accumulate(models_to_accumulate): video_latents, image_latents = batch["videos"] + prompt_embeds = batch["prompts"] + video_latents = video_latents.to(dtype=weight_dtype) # [B, F, C, H, W] image_latents = image_latents.to(dtype=weight_dtype) # [B, F, C, H, W] - model_input = torch.cat([video_latents, image_latents], dim=2) - prompt_embeds = batch["prompts"] - - # Sample noise that will be added to the latents - noise = torch.randn_like(model_input) - batch_size, num_frames, num_channels, height, width = model_input.shape + batch_size, num_frames, num_channels, height, width = video_latents.shape # Sample a random timestep for each image timesteps = torch.randint( - 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device + 0, scheduler.config.num_train_timesteps, (batch_size,), device=video_latents.device ) timesteps = timesteps.long() + # Sample noise that will be added to the latents + noise = torch.randn_like(video_latents) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_video_latents = scheduler.add_noise(video_latents, noise, timesteps) + noisy_model_input = torch.cat([noisy_video_latents, image_latents], dim=2) + # Prepare rotary embeds image_rotary_emb = ( prepare_rotary_positional_embeddings( @@ -1422,10 +1430,6 @@ def collate_fn(examples): else None ) - # Add noise to the model input according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) - # Predict the noise residual model_output = transformer( hidden_states=noisy_model_input, @@ -1434,7 +1438,7 @@ def collate_fn(examples): image_rotary_emb=image_rotary_emb, return_dict=False, )[0] - model_pred = scheduler.get_velocity(model_output, noisy_model_input.chunk(2, dim=2)[0], timesteps) + model_pred = scheduler.get_velocity(model_output, noisy_video_latents, timesteps) alphas_cumprod = scheduler.alphas_cumprod[timesteps] weights = 1 / (1 - alphas_cumprod) From f8107be94caf1d269ab151470ded8df1b89c0881 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Oct 2024 12:29:07 +0200 Subject: [PATCH 4/9] update --- .../train_cogvideox_image_to_video_lora.py | 14 +++++--------- examples/cogvideo/train_cogvideox_lora.py | 6 ++---- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 9bd3b5272633..ef82342469cc 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -45,10 +45,7 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid -from diffusers.training_utils import ( - cast_training_params, - clear_objs_and_retain_memory, -) +from diffusers.training_utils import cast_training_params, free_memory from diffusers.utils import ( check_min_version, convert_unet_state_dict_to_peft, @@ -758,7 +755,8 @@ def log_validation( } ) - clear_objs_and_retain_memory([pipe]) + del pipe + free_memory() return videos @@ -1380,9 +1378,6 @@ def collate_fn(examples): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - # Delete VAE and Text Encoder to save memory - clear_objs_and_retain_memory([vae, text_encoder]) - # For DeepSpeed training model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config @@ -1552,7 +1547,8 @@ def collate_fn(examples): ) # Cleanup trained models to save memory - clear_objs_and_retain_memory([transformer]) + del transformer + free_memory() # Final test inference pipe = CogVideoXImageToVideoPipeline.from_pretrained( diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 4d1273bc3083..9e4e87c91ce4 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1304,9 +1304,6 @@ def collate_fn(examples): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) - # Delete VAE to save memory - clear_objs_and_retain_memory([vae]) - # For DeepSpeed training model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config @@ -1480,7 +1477,8 @@ def collate_fn(examples): ) # Cleanup trained models to save memory - clear_objs_and_retain_memory([transformer]) + del transformer + free_memory() # Final test inference pipe = CogVideoXPipeline.from_pretrained( From d519d7553fdf85c4aadf0f668e3680d6896444ec Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 8 Oct 2024 09:27:42 +0200 Subject: [PATCH 5/9] update --- examples/cogvideo/train_cogvideox_image_to_video_lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index ef82342469cc..af9ef54a03e1 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -25,7 +25,7 @@ import torch import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, InitProcessGroupKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1190,7 +1190,7 @@ def load_model_hook(models, input_dir): ) use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None - and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config ) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) @@ -1221,7 +1221,7 @@ def encode_video(video): image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device) image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype) - noisy_image = torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] + noisy_image = image + torch.randn_like(image) * image_noise_sigma[:, None, None, None, None] image_latent_dist = vae.encode(noisy_image).latent_dist return latent_dist, image_latent_dist @@ -1494,7 +1494,7 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: # Create pipeline pipe = CogVideoXImageToVideoPipeline.from_pretrained( From 79caae374918202582e1ff008de71c0c76809ecd Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 8 Oct 2024 09:30:43 +0200 Subject: [PATCH 6/9] add coauthor Co-Authored-By: yuan-shenghai <963658029@qq.com> From abc5ec48c0d5fec4fd87907fd14e5e0f0c222f19 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 8 Oct 2024 09:31:03 +0200 Subject: [PATCH 7/9] add coauthor Co-Authored-By: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com> From 0a78cf16de0be9b594a7f019bce9ec09244b966d Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 8 Oct 2024 10:11:53 +0200 Subject: [PATCH 8/9] update Co-Authored-By: yuan-shenghai <963658029@qq.com> --- examples/cogvideo/train_cogvideox_image_to_video_lora.py | 4 ++-- examples/cogvideo/train_cogvideox_lora.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index af9ef54a03e1..0fdca2850784 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -1461,7 +1461,7 @@ def collate_fn(examples): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1494,7 +1494,7 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break - if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if accelerator.is_main_process: if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: # Create pipeline pipe = CogVideoXImageToVideoPipeline.from_pretrained( diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index fd6851caf1e7..ece2228147e2 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -25,7 +25,7 @@ import torch import torchvision.transforms as TT import transformers -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder @@ -1211,7 +1211,7 @@ def load_model_hook(models, input_dir): ) use_deepspeed_scheduler = ( accelerator.state.deepspeed_plugin is not None - and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config ) optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) @@ -1456,7 +1456,7 @@ def collate_fn(examples): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: From e0beb56744bf85a82a32d3f990e5da5ba92fbdb4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 15 Oct 2024 02:47:14 +0200 Subject: [PATCH 9/9] update --- examples/cogvideo/README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md index 8e57c2b19188..02887faeaa74 100644 --- a/examples/cogvideo/README.md +++ b/examples/cogvideo/README.md @@ -10,6 +10,11 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b). +> [!NOTE] +> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve. +> +> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team. + ## Data Preparation The training scripts accepts data in two formats. @@ -132,6 +137,8 @@ Assuming you are training on 50 videos of a similar concept, we have found 1500- - 1500 steps on 50 videos would correspond to `30` training epochs - 4000 steps on 100 videos would correspond to `40` training epochs +The following bash script launches training for text-to-video lora. + ```bash #!/bin/bash @@ -172,6 +179,8 @@ accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \ --report_to wandb ``` +For launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work. + To better track our training experiments, we're using the following flags in the command above: * `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. * `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. @@ -197,8 +206,6 @@ Note that setting the `` is not necessary. From some limited experimen > > Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data. - - ## Inference Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`. @@ -227,3 +234,5 @@ prompt = ( frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0] export_to_video(frames, "output.mp4", fps=8) ``` + +If you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation.