-
Notifications
You must be signed in to change notification settings - Fork 509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux Model #2302
base: main
Are you sure you want to change the base?
Flux Model #2302
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2302
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 6 Cancelled JobsAs of commit 9d90451 with merge base 90fd2d3 (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm excited to see this PR! This is just a first pass, most of my comments are around the Flux model components. I don't have a ton to say about the noise prediction yet (I may just need to do more research/see it in action). Still need to look at the single/double stream blocks later, so will probably come back with a few more comments
@@ -0,0 +1,165 @@ | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but let's name this _utils.py
just for consistency with other similar files
self.proj = nn.Linear(dim, dim) | ||
|
||
def forward(self, x: Tensor, pe: Tensor) -> Tensor: | ||
qkv = self.qkv(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to do the QKV unpacking in forward? Seems to me like we should just separate Q, K, and V as part of the checkpoint loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, I may be missing something.. is there a gap in our existing MultiHeadAttention
that necessitates the use of these? I think we should support Q norm, K norm, RoPE, all of that there. But admittedly I'm kind unfamiliar with some of the Flux details
in_channels: int, | ||
out_channels: int, | ||
vec_in_dim: int, | ||
context_in_dim: int, | ||
hidden_size: int, | ||
mlp_ratio: float, | ||
num_heads: int, | ||
depth: int, | ||
depth_single_blocks: int, | ||
axes_dim: list[int], | ||
theta: int, | ||
qkv_bias: bool, | ||
use_guidance: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to compose this of lower-level modules like we do in TransformerDecoder
and elsewhere? E.g.
def __init__(
self,
pe_embedder: nn.Module,
img_in: nn.Module,
time_in: nn.Module,
vector_in: nn.Module,
guidance_in: nn.Module,
txt_in: nn.Module,
double_blocks: nn.ModuleList,
single_blocks: nn.ModuleList,
final_layer: nn.Module
):
I think one could make a case that the blocks may need some constraints, but I think most other components are single-input, single-output. This would make it a lot easier for someone to swap out e.g. MLPEmbedder
with something else.
|
||
# running on sequences img | ||
img = self.img_in(img) | ||
vec = self.time_in(timestep_embedding(timesteps, 256)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should timestep_embedding be parametrized as well?
txt_ids: Tensor, | ||
timesteps: Tensor, | ||
y: Tensor, | ||
guidance: Tensor | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think |
is only available from Python 3.10 and we still support 3.9 for now. Would just stick with Union for the time being
return self.out_layer(self.silu(self.in_layer(x))) | ||
|
||
|
||
class RMSNorm(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.. we already have this, right?
axes_dim=[16, 56, 56], | ||
theta=10_000, | ||
qkv_bias=True, | ||
use_guidance=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: the only diff between dev and schnell is that dev uses guidance and schnell doesn't?
return model | ||
|
||
|
||
def _replace_linear_with_lora( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we've talked about this one for a bit anyways, so great to see it finally materializing. Maybe let's put in modules/peft/_utils.py
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One other thing we'll have to watch out for here.. for NF4-quantized layers we also add state dict hooks to prevent higher memory on checkpoint save. It may be worth including that logic in this utility
elif self._model_type == ModelType.FLUX: | ||
pass # the torchtune Flux model state dict is identical to the huggingface one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a problem for this PR, but in this big if/elif/.../elif/else I think this (no weight conversion) should actually be the else, not Llama2 (for whatever reason)
PATCH_HEIGHT, PATCH_WIDTH = 2, 2 | ||
POSITION_DIM = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: these are constant in all cases?
Context
This adds the main Flux flow-matching model to TorchTune.
NOTE: @pbontrager had mostly finished an implementation of this model before going on leave, so this implementation is temporary to unblock the rest of the Flux PRs. I just copied the code from the official Flux repo with some minimal changes. We can replace it with @pbontrager 's version when he returns.
More Flux PRs coming soon.
Changelog
Usage
Test plan
Manual testing:
Checklist:
pre-commit install
)pytest tests
pytest tests -m integration_test