Skip to content

Commit

Permalink
update (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w authored Oct 6, 2024
1 parent d81209b commit d792cbb
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 52 deletions.
2 changes: 2 additions & 0 deletions train_text_to_video_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ for learning_rate in "${LEARNING_RATES[@]}"; do
--height_buckets 480 \
--width_buckets 720 \
--frame_buckets 49 \
--dataloader_num_workers 8 \
--pin_memory \
--validation_prompt \"BW_STYLE A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions:::BW_STYLE A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
Expand Down
2 changes: 2 additions & 0 deletions train_text_to_video_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ for learning_rate in "${LEARNING_RATES[@]}"; do
--height_buckets 480 \
--width_buckets 720 \
--frame_buckets 49 \
--dataloader_num_workers 8 \
--pin_memory \
--validation_prompt \"Tom, the mischievous gray cat, is sprawled out on a vibrant red pillow, his body relaxed and his eyes half-closed, as if he's just woken up or is about to doze off. His white paws are stretched out in front of him, and his tail is casually draped over the edge of the pillow. The setting appears to be a cozy corner of a room, with a warm yellow wall in the background and a hint of a wooden floor. The scene captures a rare moment of tranquility for Tom, contrasting with his usual energetic and playful demeanor:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance\" \
--validation_prompt_separator ::: \
--num_validation_videos 1 \
Expand Down
5 changes: 5 additions & 0 deletions training/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
default=0,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Whether or not to use the pinned memory setting in pytorch dataloader.",
)


def _get_validation_args(parser: argparse.ArgumentParser) -> None:
Expand Down
46 changes: 20 additions & 26 deletions training/cogvideox_text_to_video_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import diffusers
import torch
import transformers
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
Expand All @@ -51,8 +52,6 @@
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel

import wandb


from args import get_args # isort:skip
from dataset import BucketSampler, VideoDatasetWithResizing # isort:skip
Expand Down Expand Up @@ -475,32 +474,14 @@ def load_model_hook(models, input_dir):
random_flip=args.random_flip,
)

def collate_fn_without_pre_encoding(data):
def collate_fn(data):
prompts = [x["prompt"] for x in data[0]]

videos = [x["video"] for x in data[0]]
videos = torch.stack(videos)
videos = videos.to(accelerator.device, dtype=weight_dtype)
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(videos).latent_dist
videos = latent_dist.sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()

return {
"videos": videos,
"prompts": prompts,
}

def collate_fn_with_pre_encoding(data):
prompts = [x["prompt"] for x in data[0]]
prompts = torch.stack(prompts).to(accelerator.device, dtype=weight_dtype)
if args.load_tensors:
prompts = torch.stack(prompts).to(dtype=weight_dtype, non_blocking=True)

videos = [x["video"] for x in data[0]]
videos = torch.stack(videos).to(accelerator.device, dtype=weight_dtype)
videos = DiagonalGaussianDistribution(videos).sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)

return {
"videos": videos,
Expand All @@ -511,8 +492,9 @@ def collate_fn_with_pre_encoding(data):
train_dataset,
batch_size=1,
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
collate_fn=collate_fn_with_pre_encoding if args.load_tensors else collate_fn_without_pre_encoding,
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
pin_memory=args.pin_memory,
)

# Scheduler and math around the number of training steps.
Expand Down Expand Up @@ -637,9 +619,21 @@ def collate_fn_with_pre_encoding(data):
models_to_accumulate = [transformer]

with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"]
videos = batch["videos"].to(accelerator.device, non_blocking=True)
prompts = batch["prompts"]

# Encode videos
if not args.load_tensors:
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(videos).latent_dist
else:
latent_dist = DiagonalGaussianDistribution(videos)

videos = latent_dist.sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()
model_input = videos

# Encode prompts
if not args.load_tensors:
prompt_embeds = compute_prompt_embeddings(
Expand Down
46 changes: 20 additions & 26 deletions training/cogvideox_text_to_video_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import diffusers
import torch
import transformers
import wandb
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import (
Expand All @@ -50,8 +51,6 @@
from tqdm.auto import tqdm
from transformers import AutoTokenizer, T5EncoderModel

import wandb


from args import get_args # isort:skip
from dataset import BucketSampler, VideoDatasetWithResizing # isort:skip
Expand Down Expand Up @@ -406,32 +405,14 @@ def load_model_hook(models, input_dir):
random_flip=args.random_flip,
)

def collate_fn_without_pre_encoding(data):
def collate_fn(data):
prompts = [x["prompt"] for x in data[0]]

videos = [x["video"] for x in data[0]]
videos = torch.stack(videos)
videos = videos.to(accelerator.device, dtype=weight_dtype)
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(videos).latent_dist
videos = latent_dist.sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()

return {
"videos": videos,
"prompts": prompts,
}

def collate_fn_with_pre_encoding(data):
prompts = [x["prompt"] for x in data[0]]
prompts = torch.stack(prompts).to(accelerator.device, dtype=weight_dtype)
if args.load_tensors:
prompts = torch.stack(prompts).to(dtype=weight_dtype, non_blocking=True)

videos = [x["video"] for x in data[0]]
videos = torch.stack(videos).to(accelerator.device, dtype=weight_dtype)
videos = DiagonalGaussianDistribution(videos).sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()
videos = torch.stack(videos).to(dtype=weight_dtype, non_blocking=True)

return {
"videos": videos,
Expand All @@ -442,8 +423,9 @@ def collate_fn_with_pre_encoding(data):
train_dataset,
batch_size=1,
sampler=BucketSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True),
collate_fn=collate_fn_with_pre_encoding if args.load_tensors else collate_fn_without_pre_encoding,
collate_fn=collate_fn,
num_workers=args.dataloader_num_workers,
pin_memory=args.pin_memory,
)

# Scheduler and math around the number of training steps.
Expand Down Expand Up @@ -568,9 +550,21 @@ def collate_fn_with_pre_encoding(data):
models_to_accumulate = [transformer]

with accelerator.accumulate(models_to_accumulate):
model_input = batch["videos"]
videos = batch["videos"].to(accelerator.device, non_blocking=True)
prompts = batch["prompts"]

# Encode videos
if not args.load_tensors:
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(videos).latent_dist
else:
latent_dist = DiagonalGaussianDistribution(videos)

videos = latent_dist.sample() * VAE_SCALING_FACTOR
videos = videos.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
videos = videos.to(memory_format=torch.contiguous_format).float()
model_input = videos

# Encode prompts
if not args.load_tensors:
prompt_embeds = compute_prompt_embeddings(
Expand Down

0 comments on commit d792cbb

Please sign in to comment.