-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][feat] add Mochi T2V #223
base: main
Are you sure you want to change the base?
Conversation
transformer = MochiTransformer3DModel.from_pretrained( | ||
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir | ||
) | ||
transformer = cast_dit(transformer, torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do decide to go forward with this (the original implementation does it), this should be guarded with is_training
, I think. Because when not training, we don't need to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since load_diffusion_models
is only called for training purposes, there's no need to guard here I think. Only initialize_pipeline
needs to be guarded, which is specific to layerwise fp8 upcasting case
text_encoder_dtype: torch.dtype = torch.float32, | ||
transformer_dtype: torch.dtype = torch.float32, | ||
vae_dtype: torch.dtype = torch.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything in FP32, following the original implementation and relying on autocast instead.
return {"latents": denoised_latents} | ||
|
||
|
||
def validation( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be revisited to match what we do in:
https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi#reproducing-the-results-from-the-genmo-mochi-repo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this, Sayak.
I've made some suggestions about how we can handle the reversed timesteps. I'm not sure if they'd work 100% out of the box, so we can only know by trying some runs. Maybe we don't need any changes to the flow matching algorithm itself, but just what the prediction returns and the reversed timestep to condition on.
Most of the changes look good. I have a few concerns though:
- Personally never prefer using torch.autocast anywhere. Okay to keep if you strongly believe it's required to make lora training work as expected
- Would default the text encoder and vae to
torch.bfloat16
if we can test that training works with that. I'm not sure text encoder and vae precision has significant impact on Mochi quality, but I'd like to believe it does not (atleast for bf16). This is because the default options should lead to "okay" memory usage. Since we allow configuration of these dtypes via the command line, they can be made fp32 if user wants cast_dit
is okay to keep, but would maybe train two loras to test if it is required by default (which may marginally increase memory requirements) - one with transformer in pure bf16 and one with cast_dit that keeps some layers in fp32
# TODO: This is for Mochi-1. Only the negation is the change here. Since that is a one-liner, I wonder | ||
# if it's fine to do: `if is_mochi: ...`. Or would it be better to have something like: | ||
# `prepare_timesteps()`? | ||
# timesteps = (1 - sigmas) * self.scheduler.config.num_train_timesteps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pass sigmas/timesteps as-is, and do the negation in the implementation of forward_pass
. This is because we should not modify the flow-match objective for models that handled timestep conditioning in the reverse manner
# It doesn't rely on `sigmas` configured in the scheduler. To handle that, should | ||
# Mochi implement its own `prepare_sigmas()` similar to how `calculate_noisy_latents()` is implemented? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are equivalent with weighting_scheme="none"
, I think but need to crosscheck, so we don't need any changes maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I emulated something similar:
from diffusers import FlowMatchEulerDiscreteScheduler
from finetrainers.utils.diffusion_utils import get_scheduler_sigmas, prepare_sigmas
import torch
batch_size = 2
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler_sigmas = get_scheduler_sigmas(scheduler)
sigmas = prepare_sigmas(
scheduler=scheduler,
sigmas=scheduler_sigmas,
batch_size=batch_size,
num_train_timesteps=scheduler.config.num_train_timesteps,
flow_weighting_scheme=None,
device="cpu",
generator=torch.manual_seed(0),
)
print(sigmas)
print(torch.rand(2, generator=torch.manual_seed(0)))
Prints:
tensor([0.5040, 0.2320])
tensor([0.4963, 0.7682])
In case of "none" weighting scheme, we apply uniform weighting through torch.randn()
and then compute the sigmas with indexing. But this is a bit different I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we apply uniform weighting through torch.randn
, no? We apply it via torch.rand
too here.
From the uniformly sampled random values, we create indices on the linspace sigmas, so I think it should be the same effect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. I meant torch.rand()
, only.
The effect should be the same but I want to be sure so I will first try out "weighting_scheme=None" and then try out what's originally done in Mochi.
transformer = MochiTransformer3DModel.from_pretrained( | ||
model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir | ||
) | ||
transformer = cast_dit(transformer, torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since load_diffusion_models
is only called for training purposes, there's no need to guard here I think. Only initialize_pipeline
needs to be guarded, which is specific to layerwise fp8 upcasting case
This is perfect feedback. Will get to them tomorrow, thanks Aryan! |
Co-authored-by: Aryan <[email protected]>
# https://github.com/genmoai/mochi/blob/aba74c1b5e0755b1fa3343d9e4bd22e89de77ab1/src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py#L656 | ||
# In short, Mochi operates on reversed targets which is why we need to negate | ||
# the predictions. | ||
denoised_latents = -denoised_latents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our reference:
import torch
import torch.nn.functional as F
latents = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(0))
noise = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(1))
denoised_latents = torch.randn(1, 4, 64, 64, generator=torch.manual_seed(2))
# mochi way
targets = latents - noise
print("Mochi Way Loss:", F.mse_loss(denoised_latents.float(), targets.float()))
# our way
target = noise - latents # Correctly set the target for "our way"
denoised_latents = -denoised_latents # Negate the denoised_latents
print("Our Way Loss:", F.mse_loss(denoised_latents.float(), target.float())) # Compare to the correct target
Yields same results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect 👨🍳
Some more investigations: from diffusers import MochiTransformer3DModel, FlowMatchEulerDiscreteScheduler
from finetrainers.utils.diffusion_utils import prepare_sigmas
import torch
transformer = MochiTransformer3DModel.from_pretrained(
"genmo/mochi-1-preview", subfolder="transformer"
).to("cuda")
scheduler = FlowMatchEulerDiscreteScheduler()
latents = torch.randn(1, 12, 4, 60, 106, generator=torch.manual_seed(0)).to("cuda")
prompt_embeds = torch.randn(1, 256, 4096, generator=torch.manual_seed(0)).to("cuda")
prompt_attention_mask = torch.randint(0, 2, size=(1, 256), generator=torch.manual_seed(0)).to("cuda")
# Original way.
# We are initializing sigma on CPU and then moving to CUDA.
original_sigma = torch.rand(latents.shape[0], generator=torch.manual_seed(0)).to("cuda")
timesteps = ((1 - original_sigma) * scheduler.config.num_train_timesteps).long()
with torch.no_grad():
model_pred = transformer(
hidden_states=latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timesteps,
return_dict=False,
)[0]
print(model_pred.shape)
# Our way.
scheduler_sigmas = scheduler.sigmas.clone()
sigmas = prepare_sigmas(
scheduler=scheduler,
sigmas=scheduler_sigmas,
batch_size=1,
num_train_timesteps=scheduler.config.num_train_timesteps,
flow_weighting_scheme="none",
device="cpu",
generator=torch.manual_seed(0),
)
timesteps = (sigmas * scheduler.config.num_train_timesteps).long()
timesteps = timesteps.to("cuda")
with torch.no_grad():
our_model_pred = transformer(
hidden_states=latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=scheduler.config.num_train_timesteps - timesteps,
return_dict=False,
)[0]
print(our_model_pred.shape)
try:
print(torch.testing.assert_close(model_pred, our_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
print("Nope :()")
# If we combine original way of computing sigma and our way of getting timesteps?
timesteps = (original_sigma * scheduler.config.num_train_timesteps).long()
with torch.no_grad():
hybrid_model_pred = transformer(
hidden_states=latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=scheduler.config.num_train_timesteps - timesteps,
return_dict=False,
)[0]
print(our_model_pred.shape)
try:
print(torch.testing.assert_close(model_pred, hybrid_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
print("model_pred, hybrid_model_pred didn't.")
try:
print(torch.testing.assert_close(our_model_pred, hybrid_model_pred, atol=1e-4, rtol=1e-4))
except AssertionError:
print("our_model_pred, hybrid_model_pred didn't.") None of the assertions pass here. I investigated the import torch
# original way
sigma = torch.rand(1, generator=torch.manual_seed(0))
timestep = ((1 - sigma) * 1000).long()
# our way
a_timestep = ((sigma) * 1000).long()
a_timestep = 1000 - a_timestep
print(timestep) # prints 503
print(a_timestep) # prints 504 I know the effects are some but there is a difference that is worth noting here. |
@sayakpaul To address this comment, I think the timestep calculation is okay to not match 1:1 because they seem to be the same logically. It's just a question of where the rounding from |
Yes, I am not worried about it but noting here in case things go south hahaha. |
Fixes #90 #217
This PR is not at all in a complete state. I am opening this PR in a minimal capacity to discuss some implementation details.
Once those are sorted, I will ground it to complementation and for a more rigorous review.
My questions are commented in the changes I am introducing.