Skip to content

Latest commit

 

History

History
263 lines (241 loc) · 14.9 KB

README_TRAIN_REWARD.md

File metadata and controls

263 lines (241 loc) · 14.9 KB

Enhance CogVideoX-Fun with Reward Backpropagation (Preference Optimization)

We explore the Reward Backpropagation technique 1 2 to optimized the generated videos by CogVideoX-Fun-V1.1 for better alignment with human preferences. We provide pre-trained models (i.e. LoRAs) along with the training script. You can use these LoRAs to enhance the corresponding base model as a plug-in or train your own reward LoRA.

Demo

CogVideoX-Fun-V1.1-5B

Prompt CogVideoX-Fun-V1.1-5B CogVideoX-Fun-V1.1-5B
HPSv2.1 Reward LoRA
CogVideoX-Fun-V1.1-5B
MPS Reward LoRA
Pig with wings flying above a diamond mountain
00000008.mp4
00000008.mp4
00000008.mp4
A dog runs through a field while a cat climbs a tree
00000002.mp4
00000002.mp4
00000002.mp4
Crystal cake shimmering beside a metal apple
00000006.mp4
00000006.mp4
00000006.mp4
Elderly artist with a white beard painting on a white canvas
00000005.mp4
00000005.mp4
00000005.mp4

CogVideoX-Fun-V1.1-2B

Prompt CogVideoX-Fun-V1.1-2B CogVideoX-Fun-V1.1-2B
HPSv2.1 Reward LoRA
CogVideoX-Fun-V1.1-2B
MPS Reward LoRA
A blue car drives past a white picket fence on a sunny day
00000002.mp4
00000002.mp4
00000002.mp4
Blue jay swooping near a red maple tree
00000003.mp4
00000003.mp4
00000003.mp4
Yellow curtains swaying near a blue sofa
00000004.mp4
00000004.mp4
00000004.mp4
White tractor plowing near a green farmhouse
00000005.mp4
00000005.mp4
00000005.mp4

Note

The above test prompts are from T2V-CompBench. All videos are generated with lora weight 0.7.

Model Zoo

Name Base Model Reward Model Hugging Face Description
CogVideoX-Fun-V1.1-5b-InP-HPS2.1.safetensors CogVideoX-Fun-V1.1-5b HPS v2.1 🤗Link Official HPS v2.1 reward LoRA (rank=128 and network_alpha=64) for CogVideoX-Fun-V1.1-5b-InP. It is trained with a batch size of 8 for 1,500 steps.
CogVideoX-Fun-V1.1-2b-InP-HPS2.1.safetensors CogVideoX-Fun-V1.1-2b HPS v2.1 🤗Link Official HPS v2.1 reward LoRA (rank=128 and network_alpha=64) for CogVideoX-Fun-V1.1-2b-InP. It is trained with a batch size of 8 for 3,000 steps.
CogVideoX-Fun-V1.1-5b-InP-MPS.safetensors CogVideoX-Fun-V1.1-5b MPS 🤗Link Official MPS reward LoRA (rank=128 and network_alpha=64) for CogVideoX-Fun-V1.1-5b-InP. It is trained with a batch size of 8 for 5,500 steps.
CogVideoX-Fun-V1.1-2b-InP-MPS.safetensors CogVideoX-Fun-V1.1-2b MPS 🤗Link Official MPS reward LoRA (rank=128 and network_alpha=64) for CogVideoX-Fun-V1.1-2b-InP. It is trained with a batch size of 8 for 16,000 steps.

Inference

We provide an example inference code to run CogVideoX-Fun-V1.1-5b-InP with its HPS2.1 reward LoRA.

import torch
from diffusers import CogVideoXDDIMScheduler

from cogvideox.models.transformer3d import CogVideoXTransformer3DModel
from cogvideox.pipeline.pipeline_cogvideox_inpaint import CogVideoX_Fun_Pipeline_Inpaint
from cogvideox.utils.lora_utils import merge_lora
from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid

model_path = "alibaba-pai/CogVideoX-Fun-V1.1-5b-InP"
lora_path = "alibaba-pai/CogVideoX-Fun-V1.1-Reward-LoRAs/CogVideoX-Fun-V1.1-5b-InP-HPS2.1.safetensors"
lora_weight = 0.7

prompt = "Pig with wings flying above a diamond mountain"
sample_size = [512, 512]
video_length = 49

transformer = CogVideoXTransformer3DModel.from_pretrained_2d(model_path, subfolder="transformer").to(torch.bfloat16)
scheduler = CogVideoXDDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
pipeline = CogVideoX_Fun_Pipeline_Inpaint.from_pretrained(
    model_path, transformer=transformer, scheduler=scheduler, torch_dtype=torch.bfloat16
)
pipeline.enable_model_cpu_offload()
pipeline = merge_lora(pipeline, lora_path, lora_weight)

generator = torch.Generator(device="cuda").manual_seed(42)
input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=sample_size)
sample = pipeline(
    prompt,
    num_frames = video_length,
    negative_prompt = "bad detailed",
    height = sample_size[0],
    width = sample_size[1],
    generator = generator,
    guidance_scale = 7.0,
    num_inference_steps = 50,
    video = input_video,
    mask_video = input_video_mask,
).videos

save_videos_grid(sample, "samples/output.mp4", fps=8)

Training

The training code is based on train_lora.py. We provide a shell script to train the HPS v2.1 reward LoRA for CogVideoX-Fun-V1.1-2b-InP, which can be trained on a single A10 with 24GB VRAM. To further reduce the VRAM requirement, please read Important Args.

Setup

Please read the quick-start section to setup the CogVideoX-Fun environment. If you're playing with HPS reward model, please run the following script to install the dependencies:

# For HPS reward model only
pip install hpsv2
site_packages=$(python -c "import site; print(site.getsitepackages()[0])")
wget -O $site_packages/hpsv2/src/open_clip/factory.py https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/package/patches/hpsv2_src_open_clip_factory_patches.py
wget -O $site_packages/hpsv2/src/open_clip/ https://github.com/tgxs002/HPSv2/raw/refs/heads/master/hpsv2/src/open_clip/bpe_simple_vocab_16e6.txt.gz

Note

Since some models will be downloaded automatically from HuggingFace, Please run HF_ENDPOINT=https://hf-mirror.com sh scripts/train_reward_lora.sh if you cannot access to huggingface.com.

Important Args

  • rank: The size of LoRA model. The higher the LoRA rank, the more parameters it has, and the more it can learn (including some unnecessary information). Bt default, we set the rank to 128. You can lower this value to reduce training GPU memory and the LoRA file size.
  • network_alpha: A scaling factor changes how the LoRA affect the base model weight. In general, it can be set to half of the rank.
  • prompt_path: The path to the prompt file (in txt format, each line is a prompt) for sampling training videos. We randomly selected 701 prompts from MovieGenBench.
  • train_sample_height and train_sample_width: The resolution of the sampled training videos. We found training at a 256x256 resolution can generalize to any other resolution. Reducing the resolution can save GPU memory during training, but it is recommended that the resolution should be equal to or greater than the image input resolution of the reward model. Due to the resize and crop preprocessing operations, we suggest using a 1:1 aspect ratio.
  • reward_fn and reward_fn_kwargs: The reward model name and its keyword arguments. All supported reward models (Aesthetic Predictor v2/v2.5, HPS v2/v2.1, PickScore and MPS) can be found in reward_fn.py. You can also customize your own reward model (e.g., combining aesthetic predictor with HPS).
  • num_decoded_latents and num_sampled_frames: The number of decoded latents (for VAE) and sampled frames (for the reward model). Since CogVideoX-Fun adopts the 3D casual VAE, we found decoding only the first latent to obtain the first frame for computing the reward not only reduces training memory usage but also prevents excessive reward optimization and maintains the dynamics of generated videos.

Limitations

  1. We observe after training to a certain extent, the reward continues to increase, but the quality of the generated videos does not further improve. The model trickly learns some shortcuts (by adding artifacts in the background, i.e., adversarial patches) to increase the reward.
  2. Currently, there is still a lack of suitable preference models for video generation. Directly using image preference models cannot evaluate preferences along the temporal dimension (such as dynamism and consistency). Further more, We find using image preference models leads to a decrease in the dynamism of generated videos. Although this can be mitigated by computing the reward using only the first frame of the decoded video, the impact still persists.

References

  1. Clark, Kevin, et al. "Directly fine-tuning diffusion models on differentiable rewards.". In ICLR 2024.
  2. Prabhudesai, Mihir, et al. "Aligning text-to-image diffusion models with reward backpropagation." arXiv preprint arXiv:2310.03739 (2023).