From 832642b9ff717a5e831235840ccef602e93af80f Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 23 Jan 2025 13:50:38 -0800 Subject: [PATCH 1/9] checkpointer --- torchtune/training/checkpointing/_checkpointer.py | 15 +++++++++++++-- torchtune/training/checkpointing/_utils.py | 4 ++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 7c8d2b0bed..f2ea92e650 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -436,8 +436,9 @@ def __init__( self._weight_map: Dict[str, str] = None # the config.json file contains model params needed for state dict conversion - self._config = json.loads( - Path.joinpath(self._checkpoint_dir, "config.json").read_text() + config_path = Path.joinpath(self._checkpoint_dir, "config.json") + self._config = ( + json.loads(config_path.read_text()) if config_path.exists() else {} ) # repo_id is necessary for when saving an adapter config, so its compatible with HF. @@ -604,6 +605,14 @@ def load_checkpoint(self) -> Dict[str, Any]: converted_state_dict[training.MODEL_KEY] = t5_encoder_hf_to_tune( merged_state_dict, ) + elif self._model_type == ModelType.FLUX_AUTOENCODER: + from torchtune.models.flux._convert_weights import flux_ae_hf_to_tune + + converted_state_dict[training.MODEL_KEY] = flux_ae_hf_to_tune( + merged_state_dict, + ) + elif self._model_type == ModelType.FLUX_FLOW: + converted_state_dict[training.MODEL_KEY] = merged_state_dict else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -709,6 +718,8 @@ def save_checkpoint( dim=self._config["hidden_size"], head_dim=self._config.get("head_dim", None), ) + elif self._model_type == ModelType.FLUX_FLOW: + pass else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 0366b6d2b7..c4c8628def 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -93,6 +93,8 @@ class ModelType(Enum): QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` T5_ENCODER (str): T5 text encoder. See :func:`~torchtune.models.t5.t5_v1_1_xxl_encoder` + FLUX_AUTOENCODER (str): Flux autoencoder. See :func:`~torchtune.models.flux.flux_1_autoencoder` + FLUX_FLOW (str): Flux flow model. See :func:`~torchtune.models.flux.flux_1_dev_flow_model` Example: >>> # Usage in a checkpointer class @@ -114,6 +116,8 @@ class ModelType(Enum): QWEN2: str = "qwen2" CLIP_TEXT: str = "clip_text" T5_ENCODER: str = "t5_encoder" + FLUX_AUTOENCODER: str = "flux_autoencoder" + FLUX_FLOW: str = "flux_flow" class FormattedCheckpointFiles: From 07d519db68430086ec8587efb68a9d9e63df21bf Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 23 Jan 2025 13:52:42 -0800 Subject: [PATCH 2/9] flux flow model builders --- torchtune/models/flux/__init__.py | 12 +- torchtune/models/flux/_autoencoder.py | 6 +- torchtune/models/flux/_model_builders.py | 155 ++++++++++++++++++++++- 3 files changed, 164 insertions(+), 9 deletions(-) diff --git a/torchtune/models/flux/__init__.py b/torchtune/models/flux/__init__.py index 3d08ac24fc..8a1acae5c1 100644 --- a/torchtune/models/flux/__init__.py +++ b/torchtune/models/flux/__init__.py @@ -3,8 +3,18 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._model_builders import flux_1_autoencoder +from ._model_builders import ( + flux_1_autoencoder, + flux_1_dev_flow_model, + flux_1_schnell_flow_model, + lora_flux_1_dev_flow_model, + lora_flux_1_schnell_flow_model, +) __all__ = [ "flux_1_autoencoder", + "flux_1_dev_flow_model", + "flux_1_schnell_flow_model", + "lora_flux_1_dev_flow_model", + "lora_flux_1_schnell_flow_model", ] diff --git a/torchtune/models/flux/_autoencoder.py b/torchtune/models/flux/_autoencoder.py index 666178d1d8..0e010f9b27 100644 --- a/torchtune/models/flux/_autoencoder.py +++ b/torchtune/models/flux/_autoencoder.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Tuple +from typing import List import torch import torch.nn.functional as F @@ -19,19 +19,16 @@ class FluxAutoencoder(nn.Module): The image autoencoder for Flux diffusion models. Args: - img_shape (Tuple[int, int, int]): The shape of the input image (without the batch dimension). encoder (nn.Module): The encoder module. decoder (nn.Module): The decoder module. """ def __init__( self, - img_shape: Tuple[int, int, int], encoder: nn.Module, decoder: nn.Module, ): super().__init__() - self._img_shape = img_shape self.encoder = encoder self.decoder = decoder @@ -55,7 +52,6 @@ def encode(self, x: Tensor) -> Tensor: Returns: Tensor: latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution]) """ - assert x.shape[1:] == self._img_shape return self.encoder(x) def decode(self, z: Tensor) -> Tensor: diff --git a/torchtune/models/flux/_model_builders.py b/torchtune/models/flux/_model_builders.py index 84ed08c060..b88314fc37 100644 --- a/torchtune/models/flux/_model_builders.py +++ b/torchtune/models/flux/_model_builders.py @@ -5,11 +5,14 @@ # LICENSE file in the root directory of this source tree. from typing import List +from torch import nn + from torchtune.models.flux._autoencoder import FluxAutoencoder, FluxDecoder, FluxEncoder +from torchtune.models.flux._flow_model import FluxFlowModel +from torchtune.modules.peft import DoRALinear, LoRALinear def flux_1_autoencoder( - resolution: int = 256, ch_in: int = 3, ch_out: int = 3, ch_base: int = 128, @@ -30,7 +33,6 @@ def flux_1_autoencoder( ch = number of channels (size of the channel dimension) Args: - resolution (int): The height/width of the square input image. ch_in (int): The number of channels of the input image. ch_out (int): The number of channels of the output image. ch_base (int): The base number of channels. @@ -67,7 +69,154 @@ def flux_1_autoencoder( ) return FluxAutoencoder( - img_shape=(ch_in, resolution, resolution), encoder=encoder, decoder=decoder, ) + + +def flux_1_dev_flow_model(): + """ + Flow-matching model for FLUX.1-dev + + Returns: + FluxFlowModel + """ + return FluxFlowModel( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + use_guidance=True, + ) + + +def flux_1_schnell_flow_model(): + """ + Flow-matching model for FLUX.1-schnell + + Returns: + FluxFlowModel + """ + return FluxFlowModel( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + use_guidance=False, + ) + + +def lora_flux_1_dev_flow_model( + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + quantize_base: bool = False, + use_dora: bool = False, +): + """ + Flow-matching model for FLUX.1-dev with linear layers replaced with LoRA modules. + + Args: + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + quantize_base (bool): Whether to quantize base model weights. Default: False + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + Default: False + + Returns: + FluxFlowModel + """ + model = flux_1_dev_flow_model() + _replace_linear_with_lora( + model, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + use_dora=use_dora, + ) + return model + + +def lora_flux_1_schnell_flow_model( + lora_rank: int = 16, + lora_alpha: float = 16.0, + lora_dropout: float = 0.0, + quantize_base: bool = False, + use_dora: bool = False, +): + """ + Flow-matching model for FLUX.1-schnell with linear layers replaced with LoRA modules. + + Args: + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + quantize_base (bool): Whether to quantize base model weights. Default: False + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + Default: False + + Returns: + FluxFlowModel + """ + model = flux_1_schnell_flow_model() + _replace_linear_with_lora( + model, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + use_dora=use_dora, + ) + return model + + +def _replace_linear_with_lora( + module: nn.Module, + rank: int, + alpha: float, + dropout: float, + quantize_base: bool, + use_dora: bool, +) -> None: + lora_cls = DoRALinear if use_dora else LoRALinear + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + new_child = lora_cls( + in_dim=child.in_features, + out_dim=child.out_features, + rank=rank, + alpha=alpha, + dropout=dropout, + use_bias=child.bias is not None, + quantize_base=quantize_base, + ) + setattr(module, name, new_child) + else: + _replace_linear_with_lora( + child, + rank=rank, + alpha=alpha, + dropout=dropout, + quantize_base=quantize_base, + use_dora=use_dora, + ) From 5f3f1f0fde494c36ecf8792bfd807bdbfad0405a Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 23 Jan 2025 14:03:15 -0800 Subject: [PATCH 3/9] flux flow model --- torchtune/models/flux/_flow_model.py | 443 +++++++++++++++++++++++++++ 1 file changed, 443 insertions(+) create mode 100644 torchtune/models/flux/_flow_model.py diff --git a/torchtune/models/flux/_flow_model.py b/torchtune/models/flux/_flow_model.py new file mode 100644 index 0000000000..b996f8359b --- /dev/null +++ b/torchtune/models/flux/_flow_model.py @@ -0,0 +1,443 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import math +from dataclasses import dataclass + +import torch +from torch import nn, Tensor + + +class FluxFlowModel(nn.Module): + """ + Flow-matching model for Flux v1 from https://github.com/black-forest-labs/flux + + Args: + in_channels (int): The number of channels of the "img" input (the latent from the autoencoder). + out_channels (int): The number of channels of the output. + vec_in_dim (int): The number of dimensions of the "y" input (the CLIP embedding). + context_in_dim (int): The number of dimensions of the "txt" input (the T5 text encoding). + hidden_size (int): The inner dimension of the model. + mlp_ratio (float): The size of the dimension relative to the hidden size of the MLPs in the attention blocks. + num_heads (int): The number of attention heads. + depth (int): The number of double stream blocks. + depth_single_blocks (int): The number of single stream blocks. + axes_dim (list[int]): The dimension of the rope embeddings. + theta (int): The theta of the rope embeddings. + qkv_bias (bool): Enable bias for the QKV projections. + use_guidance (bool): Enable guidance (True for dev models and False for schnell). + + Raises: + ValueError: If hidden_size, num_heads, axes_dim, or pe_dim is invalid. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + vec_in_dim: int, + context_in_dim: int, + hidden_size: int, + mlp_ratio: float, + num_heads: int, + depth: int, + depth_single_blocks: int, + axes_dim: list[int], + theta: int, + qkv_bias: bool, + use_guidance: bool, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + pe_dim = hidden_size // num_heads + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if use_guidance + else None + ) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + ) + for _ in range(depth) + ] + ) + + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio) + for _ in range(depth_single_blocks) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def forward( + self, + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor | None = None, + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + # running on sequences img + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.guidance_in is not None: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = torch.cat((txt, img), 1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + return img + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + t = time_factor * t + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(t.device) + + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + if torch.is_floating_point(t): + embedding = embedding.to(t) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms).to(dtype=x_dtype) * self.scale + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor, pe: Tensor) -> Tensor: + qkv = self.qkv(x) + q, k, v = unpack_qkv(qkv, self.num_heads) + q, k = self.norm(q, k, v) + x = attention(q, k, v, pe=pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( + self.multiplier, dim=-1 + ) + + return ( + ModulationOut(*out[:3]), + ModulationOut(*out[3:]) if self.is_double else None, + ) + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approximate="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = unpack_qkv(img_qkv, self.num_heads) + img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = unpack_qkv(txt_qkv, self.num_heads) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) + + # run actual attention + q = torch.cat((txt_q, img_q), dim=2) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + attn = attention(q, k, v, pe=pe) + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + return img, txt + + +class SingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers as described in + https://arxiv.org/abs/2302.05442 and adapted modulation interface. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + qkv, mlp = torch.split( + self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + q, k, v = unpack_qkv(qkv, self.num_heads) + q, k = self.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe) + # compute activation in mlp stream, cat again and run second linear layer + output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + return x + mod.gate * output + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: + q, k = apply_rope(q, k, pe) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + + # B H L D -> B L (H D) + B, H, L, D = x.shape # noqa: N806 + x = x.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return x + + +def rope(pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + out = torch.stack( + [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 + ) + + # B N D 4 -> B N D 2 2 + B, N, D, _ = out.shape # noqa: N806 + out = out.reshape(B, N, D, 2, 2) + + return out.float() + + +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +def unpack_qkv(qkv, num_heads): + # B L (3 H D) -> 3 B H L D + B, L, _ = qkv.shape # noqa: N806 + q, k, v = qkv.reshape(B, L, 3, num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0) + return q, k, v From 9374a4f0e5921e40fd0bc64be172fbafddd949d8 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 23 Jan 2025 22:26:56 -0800 Subject: [PATCH 4/9] flow model unit test --- .../models/flux/test_flux_flow_model.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/torchtune/models/flux/test_flux_flow_model.py diff --git a/tests/torchtune/models/flux/test_flux_flow_model.py b/tests/torchtune/models/flux/test_flux_flow_model.py new file mode 100644 index 0000000000..83ade8b271 --- /dev/null +++ b/tests/torchtune/models/flux/test_flux_flow_model.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.flux._flow_model import FluxFlowModel +from torchtune.models.flux._util import predict_noise +from torchtune.training.seed import set_seed + +BSZ = 32 +CH = 4 +RES = 8 +Y_DIM = 16 +TXT_DIM = 8 + +# model inputs/outputs are sequences of 2x2 latent patches +MODEL_CH = CH * 4 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestFluxFlowModel: + @pytest.fixture + def model(self): + model = FluxFlowModel( + in_channels=MODEL_CH, + out_channels=MODEL_CH, + vec_in_dim=Y_DIM, + context_in_dim=TXT_DIM, + hidden_size=16, + mlp_ratio=2.0, + num_heads=2, + depth=1, + depth_single_blocks=1, + axes_dim=[2, 2, 4], + theta=10_000, + qkv_bias=True, + use_guidance=True, + ) + + for param in model.parameters(): + param.data.uniform_(0, 0.1) + + return model + + @pytest.fixture + def latents(self): + return torch.randn(BSZ, CH, RES, RES) + + @pytest.fixture + def clip_encodings(self): + return torch.randn(BSZ, Y_DIM) + + @pytest.fixture + def t5_encodings(self): + return torch.randn(BSZ, 8, TXT_DIM) + + @pytest.fixture + def timesteps(self): + return torch.rand(BSZ) + + @pytest.fixture + def guidance(self): + return torch.rand(BSZ) * 3 + 1 + + def test_forward( + self, model, latents, clip_encodings, t5_encodings, timesteps, guidance + ): + actual = predict_noise( + model, latents, clip_encodings, t5_encodings, timesteps, guidance + ) + assert actual.shape == (BSZ, CH, RES, RES) + + actual = torch.mean(actual, dim=(0, 2, 3)) + print(actual) + expected = torch.tensor([1.9532, 2.0414, 2.2768, 2.2754]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_backward( + self, model, latents, clip_encodings, t5_encodings, timesteps, guidance + ): + pred = predict_noise( + model, latents, clip_encodings, t5_encodings, timesteps, guidance + ) + loss = pred.mean() + loss.backward() From 7e2c81b41a7f8342c382450901b1ee7e18c6db73 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 23 Jan 2025 22:58:38 -0800 Subject: [PATCH 5/9] util --- torchtune/models/flux/_util.py | 165 +++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 torchtune/models/flux/_util.py diff --git a/torchtune/models/flux/_util.py b/torchtune/models/flux/_util.py new file mode 100644 index 0000000000..40c3c05d17 --- /dev/null +++ b/torchtune/models/flux/_util.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import Tensor + +from torchtune.models.flux._flow_model import FluxFlowModel + +PATCH_HEIGHT, PATCH_WIDTH = 2, 2 +POSITION_DIM = 3 + + +def predict_noise( + model: FluxFlowModel, + latents: Tensor, + clip_encodings: Tensor, + t5_encodings: Tensor, + timesteps: Tensor, + guidance: Optional[Tensor] = None, +) -> Tensor: + """ + Use Flux's flow-matching model to predict the noise in image latents. + + Args: + model (FluxFlowModel): The Flux flow model. + latents (Tensor): Image encodings from the Flux autoencoder. + Shape: [bsz, 16, latent height, latent width] + clip_encodings (Tensor): CLIP text encodings. + Shape: [bsz, 768] + t5_encodings (Tensor): T5 text encodings. + Shape: [bsz, sequence length, 256 or 512] + timesteps (Tensor): The amount of noise (0 to 1). + Shape: [bsz] + guidance (Optional[Tensor]): The guidance value (1.5 to 4) if guidance-enabled model. + Shape: [bsz] + Default: None + + Returns: + Tensor: The noise prediction. + Shape: [bsz, 16, latent height, latent width] + """ + bsz, _, latent_height, latent_width = latents.shape + + # Create positional encodings + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width + ) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM) + + # Convert latent into a sequence of patches + latents = pack_latents(latents) + + # Predict noise + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc.to(latents), + txt=t5_encodings.to(latents), + txt_ids=text_pos_enc.to(latents), + y=clip_encodings.to(latents), + timesteps=timesteps.to(latents), + guidance=guidance.to(latents) if guidance is not None else None, + ) + + # Convert sequence of patches to latent shape + latent_noise_pred = unpack_latents(latent_noise_pred, latent_height, latent_width) + + return latent_noise_pred + + +def create_position_encoding_for_latents( + bsz: int, latent_height: int, latent_width: int +) -> Tensor: + """ + Create the packed latents' position encodings for the Flux flow model. + + Args: + bsz (int): The batch size. + latent_height (int): The height of the latent. + latent_width (int): The width of the latent. + + Returns: + Tensor: The position encodings. + Shape: [bsz, (latent_height // PATCH_HEIGHT) * (latent_width // PATCH_WIDTH), POSITION_DIM) + """ + height = latent_height // PATCH_HEIGHT + width = latent_width // PATCH_WIDTH + + position_encoding = torch.zeros(height, width, POSITION_DIM) + + row_indices = torch.arange(height) + position_encoding[:, :, 1] = row_indices.unsqueeze(1) + + col_indices = torch.arange(width) + position_encoding[:, :, 2] = col_indices.unsqueeze(0) + + # Flatten and repeat for the full batch + # [height, width, 3] -> [bsz, height * width, 3] + position_encoding = position_encoding.view(1, height * width, POSITION_DIM) + position_encoding = position_encoding.repeat(bsz, 1, 1) + + return position_encoding + + +def pack_latents(x: Tensor) -> Tensor: + """ + Rearrange latents from an image-like format into a sequence of patches. + + Equivalent to `einops.rearrange("b c (h ph) (w pw) -> b (h w) (c ph pw)")`. + + Args: + x (Tensor): The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + + Returns: + Tensor: The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + """ + b, c, latent_height, latent_width = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + + # [b, c, h*ph, w*ph] -> [b, c, h, w, ph, pw] + x = x.unfold(2, PATCH_HEIGHT, PATCH_HEIGHT).unfold(3, PATCH_WIDTH, PATCH_WIDTH) + + # [b, c, h, w, ph, PW] -> [b, h, w, c, ph, PW] + x = x.permute(0, 2, 3, 1, 4, 5) + + # [b, h, w, c, ph, PW] -> [b, h*w, c*ph*PW] + return x.reshape(b, h * w, c * PATCH_HEIGHT * PATCH_WIDTH) + + +def unpack_latents(x: Tensor, latent_height: int, latent_width: int) -> Tensor: + """ + Rearrange latents from a sequence of patches into an image-like format. + + Equivalent to `einops.rearrange("b (h w) (c ph pw) -> b c (h ph) (w pw)")`. + + Args: + x (Tensor): The packed latents. + Shape: (bsz, (latent_height // ph) * (latent_width // pw), ch * ph * pw) + latent_height (int): The height of the unpacked latents. + latent_width (int): The width of the unpacked latents. + + Returns: + Tensor: The unpacked latents. + Shape: [bsz, ch, latent height, latent width] + """ + b, _, c_ph_pw = x.shape + h = latent_height // PATCH_HEIGHT + w = latent_width // PATCH_WIDTH + c = c_ph_pw // (PATCH_HEIGHT * PATCH_WIDTH) + + # [b, h*w, c*ph*pw] -> [b, h, w, c, ph, pw] + x = x.reshape(b, h, w, c, PATCH_HEIGHT, PATCH_WIDTH) + + # [b, h, w, c, ph, pw] -> [b, c, h, ph, w, pw] + x = x.permute(0, 3, 1, 4, 2, 5) + + # [b, c, h, ph, w, pw] -> [b, c, h*ph, w*pw] + return x.reshape(b, c, h * PATCH_HEIGHT, w * PATCH_WIDTH) From 923a17101cbfe96fc29d618ae17d571fb13403f8 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Sun, 26 Jan 2025 23:24:58 -0800 Subject: [PATCH 6/9] model type --- torchtune/training/checkpointing/_checkpointer.py | 2 +- torchtune/training/checkpointing/_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index f2ea92e650..2581784d01 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -611,7 +611,7 @@ def load_checkpoint(self) -> Dict[str, Any]: converted_state_dict[training.MODEL_KEY] = flux_ae_hf_to_tune( merged_state_dict, ) - elif self._model_type == ModelType.FLUX_FLOW: + elif self._model_type == ModelType.FLUX: converted_state_dict[training.MODEL_KEY] = merged_state_dict else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index c4c8628def..9129e786b1 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -94,7 +94,7 @@ class ModelType(Enum): CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` T5_ENCODER (str): T5 text encoder. See :func:`~torchtune.models.t5.t5_v1_1_xxl_encoder` FLUX_AUTOENCODER (str): Flux autoencoder. See :func:`~torchtune.models.flux.flux_1_autoencoder` - FLUX_FLOW (str): Flux flow model. See :func:`~torchtune.models.flux.flux_1_dev_flow_model` + FLUX (str): Main Flux model. See :func:`~torchtune.models.flux.flux_1_dev_flow_model` Example: >>> # Usage in a checkpointer class @@ -117,7 +117,7 @@ class ModelType(Enum): CLIP_TEXT: str = "clip_text" T5_ENCODER: str = "t5_encoder" FLUX_AUTOENCODER: str = "flux_autoencoder" - FLUX_FLOW: str = "flux_flow" + FLUX: str = "flux" class FormattedCheckpointFiles: From 78d29b5b9a61b77ff84aba41e6b536f86153720f Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 27 Jan 2025 15:20:14 -0800 Subject: [PATCH 7/9] lora unit test --- .../models/flux/test_flux_flow_model.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/torchtune/models/flux/test_flux_flow_model.py b/tests/torchtune/models/flux/test_flux_flow_model.py index 83ade8b271..feaa25105d 100644 --- a/tests/torchtune/models/flux/test_flux_flow_model.py +++ b/tests/torchtune/models/flux/test_flux_flow_model.py @@ -4,10 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy + import pytest import torch from torchtune.models.flux._flow_model import FluxFlowModel +from torchtune.models.flux._model_builders import _replace_linear_with_lora from torchtune.models.flux._util import predict_noise from torchtune.training.seed import set_seed @@ -91,3 +94,49 @@ def test_backward( ) loss = pred.mean() loss.backward() + + def test_lora( + self, model, latents, clip_encodings, t5_encodings, timesteps, guidance + ): + # Setup LoRA model + lora_model = deepcopy(model) + _replace_linear_with_lora( + lora_model, + rank=2, + alpha=2, + dropout=0.0, + quantize_base=False, + use_dora=False, + ) + lora_model.load_state_dict(model.state_dict(), strict=False) + lora_model.requires_grad_(False) + _lora_enable_grad(lora_model) + + # Check param counts + total_params = sum(p.numel() for p in lora_model.parameters()) + trainable_params = sum( + p.numel() for p in lora_model.parameters() if p.requires_grad + ) + assert total_params == 24416 + assert trainable_params == 3280 + + # Check parity with original model + pred = predict_noise( + model, latents, clip_encodings, t5_encodings, timesteps, guidance + ) + lora_pred = predict_noise( + lora_model, latents, clip_encodings, t5_encodings, timesteps, guidance + ) + torch.testing.assert_close(pred, lora_pred, atol=1e-4, rtol=1e-4) + + # Check backward pass works + loss = lora_pred.mean() + loss.backward() + + +def _lora_enable_grad(module): + for name, child in module.named_children(): + if name.startswith("lora_"): + child.requires_grad_(True) + else: + _lora_enable_grad(child) From 213bde5c1e69541caab6b0ad39ad2d72f0eb7f4e Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 27 Jan 2025 15:26:51 -0800 Subject: [PATCH 8/9] added comment to checkpointer --- torchtune/training/checkpointing/_checkpointer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 2581784d01..4d1701aced 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -718,8 +718,8 @@ def save_checkpoint( dim=self._config["hidden_size"], head_dim=self._config.get("head_dim", None), ) - elif self._model_type == ModelType.FLUX_FLOW: - pass + elif self._model_type == ModelType.FLUX: + pass # the torchtune Flux model state dict is identical to the huggingface one else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], From 9d9045108371552c4558d663a1dc5e5ceef6783b Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 28 Jan 2025 12:36:21 -0800 Subject: [PATCH 9/9] fix unit test --- tests/torchtune/models/flux/test_flux_autoencoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/torchtune/models/flux/test_flux_autoencoder.py b/tests/torchtune/models/flux/test_flux_autoencoder.py index bb385dbc94..5189566de4 100644 --- a/tests/torchtune/models/flux/test_flux_autoencoder.py +++ b/tests/torchtune/models/flux/test_flux_autoencoder.py @@ -27,7 +27,6 @@ class TestFluxAutoencoder: @pytest.fixture def model(self): model = flux_1_autoencoder( - resolution=RESOLUTION, ch_in=CH_IN, ch_out=3, ch_base=32,