Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Model #2302

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Jan 27, 2025

Context

This adds the main Flux flow-matching model to TorchTune.

NOTE: @pbontrager had mostly finished an implementation of this model before going on leave, so this implementation is temporary to unblock the rest of the Flux PRs. I just copied the code from the official Flux repo with some minimal changes. We can replace it with @pbontrager 's version when he returns.

More Flux PRs coming soon.

Changelog

  • Flux flow model implementation and builders
  • LoRA builders
  • Unit tests
  • Removed input image shape assertion from autoencoder (the autoencoder can be used with many image resolutions)
  • Some utility functions for predicting the noise in an image latent using the flow model
  • Added checkpoint loading/saving logic for the flow model (which is currently a no-op since the temporary torchtune implementation is the same as the huggingface one, see note above)

Usage

tune download black-forest-labs/FLUX.1-dev --output-dir /tmp/flux
from torchtune.models.flux import flux_1_dev_flow_model
from torchtune.training.checkpointing import FullModelHFCheckpointer

model = flux_1_dev_flow_model()
checkpointer = FullModelHFCheckpointer(
    "/tmp/flux",
    ["flux1-dev.safetensors"],
    "FLUX",
    "/tmp/flux_",
)
sd = checkpointer.load_checkpoint()["model"]
model.load_state_dict(sd)

Test plan

Manual testing:

  • verified that the torchtune implemention of the flow model has perfect parity with the official implementation and same performance
  • verified that the LoRA trains properly

Checklist:

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test

Copy link

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2302

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 6 Cancelled Jobs

As of commit 9d90451 with merge base 90fd2d3 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 27, 2025
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm excited to see this PR! This is just a first pass, most of my comments are around the Flux model components. I don't have a ton to say about the noise prediction yet (I may just need to do more research/see it in action). Still need to look at the single/double stream blocks later, so will probably come back with a few more comments

@@ -0,0 +1,165 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but let's name this _utils.py just for consistency with other similar files

self.proj = nn.Linear(dim, dim)

def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to do the QKV unpacking in forward? Seems to me like we should just separate Q, K, and V as part of the checkpoint loading

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, I may be missing something.. is there a gap in our existing MultiHeadAttention that necessitates the use of these? I think we should support Q norm, K norm, RoPE, all of that there. But admittedly I'm kind unfamiliar with some of the Flux details

Comment on lines +38 to +50
in_channels: int,
out_channels: int,
vec_in_dim: int,
context_in_dim: int,
hidden_size: int,
mlp_ratio: float,
num_heads: int,
depth: int,
depth_single_blocks: int,
axes_dim: list[int],
theta: int,
qkv_bias: bool,
use_guidance: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to compose this of lower-level modules like we do in TransformerDecoder and elsewhere? E.g.

def __init__(
	self,
	pe_embedder: nn.Module,
	img_in: nn.Module,
	time_in: nn.Module,
	vector_in: nn.Module,
	guidance_in: nn.Module,
	txt_in: nn.Module,
	double_blocks: nn.ModuleList,
	single_blocks: nn.ModuleList,
	final_layer: nn.Module
):

I think one could make a case that the blocks may need some constraints, but I think most other components are single-input, single-output. This would make it a lot easier for someone to swap out e.g. MLPEmbedder with something else.


# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should timestep_embedding be parametrized as well?

txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think | is only available from Python 3.10 and we still support 3.9 for now. Would just stick with Union for the time being

return self.out_layer(self.silu(self.in_layer(x)))


class RMSNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.. we already have this, right?

axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
use_guidance=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: the only diff between dev and schnell is that dev uses guidance and schnell doesn't?

return model


def _replace_linear_with_lora(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we've talked about this one for a bit anyways, so great to see it finally materializing. Maybe let's put in modules/peft/_utils.py or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other thing we'll have to watch out for here.. for NF4-quantized layers we also add state dict hooks to prevent higher memory on checkpoint save. It may be worth including that logic in this utility

Comment on lines +721 to +722
elif self._model_type == ModelType.FLUX:
pass # the torchtune Flux model state dict is identical to the huggingface one
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a problem for this PR, but in this big if/elif/.../elif/else I think this (no weight conversion) should actually be the else, not Llama2 (for whatever reason)

Comment on lines +14 to +15
PATCH_HEIGHT, PATCH_WIDTH = 2, 2
POSITION_DIM = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: these are constant in all cases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants