Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][feat] add Mochi T2V #223

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

[WIP][feat] add Mochi T2V #223

wants to merge 7 commits into from

Conversation

sayakpaul
Copy link
Collaborator

@sayakpaul sayakpaul commented Jan 15, 2025

Fixes #90 #217

This PR is not at all in a complete state. I am opening this PR in a minimal capacity to discuss some implementation details.

Once those are sorted, I will ground it to complementation and for a more rigorous review.

My questions are commented in the changes I am introducing.

@sayakpaul sayakpaul requested a review from a-r-r-o-w January 15, 2025 11:17
transformer = MochiTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
)
transformer = cast_dit(transformer, torch.bfloat16)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do decide to go forward with this (the original implementation does it), this should be guarded with is_training, I think. Because when not training, we don't need to do this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since load_diffusion_models is only called for training purposes, there's no need to guard here I think. Only initialize_pipeline needs to be guarded, which is specific to layerwise fp8 upcasting case

Comment on lines +66 to +68
text_encoder_dtype: torch.dtype = torch.float32,
transformer_dtype: torch.dtype = torch.float32,
vae_dtype: torch.dtype = torch.float32,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything in FP32, following the original implementation and relying on autocast instead.

return {"latents": denoised_latents}


def validation(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this, Sayak.

I've made some suggestions about how we can handle the reversed timesteps. I'm not sure if they'd work 100% out of the box, so we can only know by trying some runs. Maybe we don't need any changes to the flow matching algorithm itself, but just what the prediction returns and the reversed timestep to condition on.

Most of the changes look good. I have a few concerns though:

  • Personally never prefer using torch.autocast anywhere. Okay to keep if you strongly believe it's required to make lora training work as expected
  • Would default the text encoder and vae to torch.bfloat16 if we can test that training works with that. I'm not sure text encoder and vae precision has significant impact on Mochi quality, but I'd like to believe it does not (atleast for bf16). This is because the default options should lead to "okay" memory usage. Since we allow configuration of these dtypes via the command line, they can be made fp32 if user wants
  • cast_dit is okay to keep, but would maybe train two loras to test if it is required by default (which may marginally increase memory requirements) - one with transformer in pure bf16 and one with cast_dit that keeps some layers in fp32

Comment on lines +712 to +715
# TODO: This is for Mochi-1. Only the negation is the change here. Since that is a one-liner, I wonder
# if it's fine to do: `if is_mochi: ...`. Or would it be better to have something like:
# `prepare_timesteps()`?
# timesteps = (1 - sigmas) * self.scheduler.config.num_train_timesteps
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pass sigmas/timesteps as-is, and do the negation in the implementation of forward_pass. This is because we should not modify the flow-match objective for models that handled timestep conditioning in the reverse manner

Comment on lines 97 to 98
# It doesn't rely on `sigmas` configured in the scheduler. To handle that, should
# Mochi implement its own `prepare_sigmas()` similar to how `calculate_noisy_latents()` is implemented?
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are equivalent with weighting_scheme="none", I think but need to crosscheck, so we don't need any changes maybe

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I emulated something similar:

from diffusers import FlowMatchEulerDiscreteScheduler
from finetrainers.utils.diffusion_utils import get_scheduler_sigmas, prepare_sigmas
import torch 

batch_size = 2

scheduler = FlowMatchEulerDiscreteScheduler()
scheduler_sigmas = get_scheduler_sigmas(scheduler)
sigmas = prepare_sigmas(
    scheduler=scheduler,
    sigmas=scheduler_sigmas,
    batch_size=batch_size,
    num_train_timesteps=scheduler.config.num_train_timesteps,
    flow_weighting_scheme=None,
    device="cpu",
    generator=torch.manual_seed(0),
)
print(sigmas)

print(torch.rand(2, generator=torch.manual_seed(0)))

Prints:

tensor([0.5040, 0.2320])
tensor([0.4963, 0.7682])

In case of "none" weighting scheme, we apply uniform weighting through torch.randn() and then compute the sigmas with indexing. But this is a bit different I guess.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we apply uniform weighting through torch.randn, no? We apply it via torch.rand too here.

From the uniformly sampled random values, we create indices on the linspace sigmas, so I think it should be the same effect

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. I meant torch.rand(), only.

The effect should be the same but I want to be sure so I will first try out "weighting_scheme=None" and then try out what's originally done in Mochi.

finetrainers/utils/diffusion_utils.py Outdated Show resolved Hide resolved
finetrainers/models/mochi/lora.py Show resolved Hide resolved
finetrainers/models/mochi/lora.py Outdated Show resolved Hide resolved
finetrainers/models/mochi/lora.py Outdated Show resolved Hide resolved
transformer = MochiTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
)
transformer = cast_dit(transformer, torch.bfloat16)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since load_diffusion_models is only called for training purposes, there's no need to guard here I think. Only initialize_pipeline needs to be guarded, which is specific to layerwise fp8 upcasting case

@sayakpaul
Copy link
Collaborator Author

This is perfect feedback. Will get to them tomorrow, thanks Aryan!

# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656
# In short, Mochi operates on reversed targets which is why we need to negate
# the predictions.
denoised_latents = -denoised_latents
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our reference:

import torch 
import torch.nn.functional as F

latents = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(0))
noise = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(1))
denoised_latents = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(2))

# mochi way
targets = latents - noise 
print("Mochi Way Loss:", F.mse_loss(denoised_latents.float(), targets.float()))

# our way
target = noise - latents  # Correctly set the target for "our way"
denoised_latents = -denoised_latents  # Negate the denoised_latents
print("Our Way Loss:", F.mse_loss(denoised_latents.float(), target.float()))  # Compare to the correct target

Yields same results.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect 👨‍🍳

@sayakpaul
Copy link
Collaborator Author

Some more investigations:

from diffusers import MochiTransformer3DModel, FlowMatchEulerDiscreteScheduler
from finetrainers.utils.diffusion_utils import prepare_sigmas
import torch


transformer = MochiTransformer3DModel.from_pretrained(
    "genmo/mochi-1-preview", subfolder="transformer"
).to("cuda")
scheduler = FlowMatchEulerDiscreteScheduler()

latents = torch.randn(1, 12, 4, 60, 106, generator=torch.manual_seed(0)).to("cuda")
prompt_embeds = torch.randn(1, 256, 4096, generator=torch.manual_seed(0)).to("cuda")
prompt_attention_mask = torch.randint(0, 2, size=(1, 256), generator=torch.manual_seed(0)).to("cuda")

# Original way.
# We are initializing sigma on CPU and then moving to CUDA.
original_sigma = torch.rand(latents.shape[0], generator=torch.manual_seed(0)).to("cuda")
timesteps = ((1 - original_sigma) * scheduler.config.num_train_timesteps).long()

with torch.no_grad():
    model_pred = transformer(
        hidden_states=latents,
        encoder_hidden_states=prompt_embeds,
        encoder_attention_mask=prompt_attention_mask,
        timestep=timesteps,
        return_dict=False,
    )[0]
    print(model_pred.shape)


# Our way.
scheduler_sigmas = scheduler.sigmas.clone()
sigmas = prepare_sigmas(
    scheduler=scheduler,
    sigmas=scheduler_sigmas,
    batch_size=1,
    num_train_timesteps=scheduler.config.num_train_timesteps,
    flow_weighting_scheme="none",
    device="cpu",
    generator=torch.manual_seed(0),
)
timesteps = (sigmas * scheduler.config.num_train_timesteps).long()
timesteps = timesteps.to("cuda")

with torch.no_grad():
    our_model_pred = transformer(
        hidden_states=latents,
        encoder_hidden_states=prompt_embeds,
        encoder_attention_mask=prompt_attention_mask,
        timestep=scheduler.config.num_train_timesteps - timesteps,
        return_dict=False,
    )[0]
    print(our_model_pred.shape)

try:
    print(torch.testing.assert_close(model_pred, our_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
    print("Nope :()")

# If we combine original way of computing sigma and our way of getting timesteps?
timesteps = (original_sigma * scheduler.config.num_train_timesteps).long()
with torch.no_grad():
    hybrid_model_pred = transformer(
        hidden_states=latents,
        encoder_hidden_states=prompt_embeds,
        encoder_attention_mask=prompt_attention_mask,
        timestep=scheduler.config.num_train_timesteps - timesteps,
        return_dict=False,
    )[0]
    print(our_model_pred.shape)

try:
    print(torch.testing.assert_close(model_pred, hybrid_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
    print("model_pred, hybrid_model_pred didn't.")

try:
    print(torch.testing.assert_close(our_model_pred, hybrid_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
    print("our_model_pred, hybrid_model_pred didn't.")

None of the assertions pass here.

I investigated the timesteps calculations a bit more:

import torch

# original way
sigma = torch.rand(1, generator=torch.manual_seed(0))
timestep = ((1 - sigma) * 1000).long()

# our way
a_timestep = ((sigma) * 1000).long()
a_timestep = 1000 - a_timestep

print(timestep) # prints 503
print(a_timestep) # prints 504

I know the effects are some but there is a difference that is worth noting here.

@a-r-r-o-w
Copy link
Owner

@sayakpaul To address this comment, I think the timestep calculation is okay to not match 1:1 because they seem to be the same logically. It's just a question of where the rounding from .long() is applied so not really concerning imo

@sayakpaul
Copy link
Collaborator Author

Yes, I am not worried about it but noting here in case things go south hahaha.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants