Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Model #2302

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/torchtune/models/flux/test_flux_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class TestFluxAutoencoder:
@pytest.fixture
def model(self):
model = flux_1_autoencoder(
resolution=RESOLUTION,
ch_in=CH_IN,
ch_out=3,
ch_base=32,
Expand Down
142 changes: 142 additions & 0 deletions tests/torchtune/models/flux/test_flux_flow_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy

import pytest
import torch

from torchtune.models.flux._flow_model import FluxFlowModel
from torchtune.models.flux._model_builders import _replace_linear_with_lora
from torchtune.models.flux._util import predict_noise
from torchtune.training.seed import set_seed

BSZ = 32
CH = 4
RES = 8
Y_DIM = 16
TXT_DIM = 8

# model inputs/outputs are sequences of 2x2 latent patches
MODEL_CH = CH * 4


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestFluxFlowModel:
@pytest.fixture
def model(self):
model = FluxFlowModel(
in_channels=MODEL_CH,
out_channels=MODEL_CH,
vec_in_dim=Y_DIM,
context_in_dim=TXT_DIM,
hidden_size=16,
mlp_ratio=2.0,
num_heads=2,
depth=1,
depth_single_blocks=1,
axes_dim=[2, 2, 4],
theta=10_000,
qkv_bias=True,
use_guidance=True,
)

for param in model.parameters():
param.data.uniform_(0, 0.1)

return model

@pytest.fixture
def latents(self):
return torch.randn(BSZ, CH, RES, RES)

@pytest.fixture
def clip_encodings(self):
return torch.randn(BSZ, Y_DIM)

@pytest.fixture
def t5_encodings(self):
return torch.randn(BSZ, 8, TXT_DIM)

@pytest.fixture
def timesteps(self):
return torch.rand(BSZ)

@pytest.fixture
def guidance(self):
return torch.rand(BSZ) * 3 + 1

def test_forward(
self, model, latents, clip_encodings, t5_encodings, timesteps, guidance
):
actual = predict_noise(
model, latents, clip_encodings, t5_encodings, timesteps, guidance
)
assert actual.shape == (BSZ, CH, RES, RES)

actual = torch.mean(actual, dim=(0, 2, 3))
print(actual)
expected = torch.tensor([1.9532, 2.0414, 2.2768, 2.2754])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(
self, model, latents, clip_encodings, t5_encodings, timesteps, guidance
):
pred = predict_noise(
model, latents, clip_encodings, t5_encodings, timesteps, guidance
)
loss = pred.mean()
loss.backward()

def test_lora(
self, model, latents, clip_encodings, t5_encodings, timesteps, guidance
):
# Setup LoRA model
lora_model = deepcopy(model)
_replace_linear_with_lora(
lora_model,
rank=2,
alpha=2,
dropout=0.0,
quantize_base=False,
use_dora=False,
)
lora_model.load_state_dict(model.state_dict(), strict=False)
lora_model.requires_grad_(False)
_lora_enable_grad(lora_model)

# Check param counts
total_params = sum(p.numel() for p in lora_model.parameters())
trainable_params = sum(
p.numel() for p in lora_model.parameters() if p.requires_grad
)
assert total_params == 24416
assert trainable_params == 3280

# Check parity with original model
pred = predict_noise(
model, latents, clip_encodings, t5_encodings, timesteps, guidance
)
lora_pred = predict_noise(
lora_model, latents, clip_encodings, t5_encodings, timesteps, guidance
)
torch.testing.assert_close(pred, lora_pred, atol=1e-4, rtol=1e-4)

# Check backward pass works
loss = lora_pred.mean()
loss.backward()


def _lora_enable_grad(module):
for name, child in module.named_children():
if name.startswith("lora_"):
child.requires_grad_(True)
else:
_lora_enable_grad(child)
12 changes: 11 additions & 1 deletion torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from ._model_builders import flux_1_autoencoder
from ._model_builders import (
flux_1_autoencoder,
flux_1_dev_flow_model,
flux_1_schnell_flow_model,
lora_flux_1_dev_flow_model,
lora_flux_1_schnell_flow_model,
)

__all__ = [
"flux_1_autoencoder",
"flux_1_dev_flow_model",
"flux_1_schnell_flow_model",
"lora_flux_1_dev_flow_model",
"lora_flux_1_schnell_flow_model",
]
6 changes: 1 addition & 5 deletions torchtune/models/flux/_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Tuple
from typing import List

import torch
import torch.nn.functional as F
Expand All @@ -19,19 +19,16 @@ class FluxAutoencoder(nn.Module):
The image autoencoder for Flux diffusion models.

Args:
img_shape (Tuple[int, int, int]): The shape of the input image (without the batch dimension).
encoder (nn.Module): The encoder module.
decoder (nn.Module): The decoder module.
"""

def __init__(
self,
img_shape: Tuple[int, int, int],
encoder: nn.Module,
decoder: nn.Module,
):
super().__init__()
self._img_shape = img_shape
self.encoder = encoder
self.decoder = decoder

Expand All @@ -55,7 +52,6 @@ def encode(self, x: Tensor) -> Tensor:
Returns:
Tensor: latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution])
"""
assert x.shape[1:] == self._img_shape
return self.encoder(x)

def decode(self, z: Tensor) -> Tensor:
Expand Down
Loading
Loading