Skip to content

Commit

Permalink
moving encoder/decoder construction to the model builder
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinpelletier committed Jan 8, 2025
1 parent 346c51e commit c78f485
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 58 deletions.
4 changes: 2 additions & 2 deletions tests/torchtune/models/flux/test_flux_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from torchtune.models.flux._autoencoder import FluxAutoencoder
from torchtune.models.flux import flux_1_autoencoder
from torchtune.training.seed import set_seed

BSZ = 32
Expand All @@ -26,7 +26,7 @@ def random():
class TestFluxAutoencoder:
@pytest.fixture
def model(self):
model = FluxAutoencoder(
model = flux_1_autoencoder(
resolution=RESOLUTION,
ch_in=CH_IN,
ch_out=3,
Expand Down
55 changes: 11 additions & 44 deletions torchtune/models/flux/_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from typing import List, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -19,54 +19,21 @@ class FluxAutoencoder(nn.Module):
The image autoencoder for Flux diffusion models.
Args:
resolution (int): The height/width of the square input image.
ch_in (int): The number of channels of the input image.
ch_out (int): The number of channels of the output image.
ch_base (int): The base number of channels.
This gets multiplied by `ch_mult` values to get the number of inner channels during downsampling/upsampling.
ch_mults (List[int]): The channel multiple per downsample/upsample block.
This gets multiplied by `ch_base` to get the number of inner channels during downsampling/upsampling.
ch_z (int): The number of latent channels (dimension of the latent vector `z`).
n_layers_per_resample_block (int): Number of resnet layers per downsample/upsample block.
scale_factor (float): Constant for scaling `z`.
shift_factor (float): Constant for shifting `z`.
img_shape (Tuple[int, int, int]): The shape of the input image (without the batch dimension).
encoder (nn.Module): The encoder module.
decoder (nn.Module): The decoder module.
"""

def __init__(
self,
resolution: int,
ch_in: int,
ch_out: int,
ch_base: int,
ch_mults: List[int],
ch_z: int,
n_layers_per_resample_block: int,
scale_factor: float,
shift_factor: float,
img_shape: Tuple[int, int, int],
encoder: nn.Module,
decoder: nn.Module,
):
super().__init__()
self.img_shape = (ch_in, resolution, resolution)

channels = [ch_base * mult for mult in ch_mults]

self.encoder = FluxEncoder(
ch_in=ch_in,
ch_z=ch_z,
channels=channels,
n_layers_per_down_block=n_layers_per_resample_block,
scale_factor=scale_factor,
shift_factor=shift_factor,
)

self.decoder = FluxDecoder(
ch_out=ch_out,
ch_z=ch_z,
channels=list(reversed(channels)),
# decoder gets one more layer per up block than the encoder's down blocks
n_layers_per_up_block=n_layers_per_resample_block + 1,
scale_factor=scale_factor,
shift_factor=shift_factor,
)
self._img_shape = img_shape
self.encoder = encoder
self.decoder = decoder

def forward(self, x: Tensor) -> Tensor:
"""
Expand All @@ -88,7 +55,7 @@ def encode(self, x: Tensor) -> Tensor:
Returns:
Tensor: latent encodings (shape = [bsz, ch_z, latent resolution, latent resolution])
"""
assert x.shape[1:] == self.img_shape
assert x.shape[1:] == self._img_shape
return self.encoder(x)

def decode(self, z: Tensor) -> Tensor:
Expand Down
65 changes: 53 additions & 12 deletions torchtune/models/flux/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torchtune.models.flux._autoencoder import FluxAutoencoder
from typing import List

from torchtune.models.flux._autoencoder import FluxAutoencoder, FluxDecoder, FluxEncoder

def flux_1_autoencoder() -> FluxAutoencoder:

def flux_1_autoencoder(
resolution: int = 256,
ch_in: int = 3,
ch_out: int = 3,
ch_base: int = 128,
ch_mults: List[int] = [1, 2, 4, 4],
ch_z: int = 16,
n_layers_per_resample_block: int = 2,
scale_factor: float = 0.3611,
shift_factor: float = 0.1159,
) -> FluxAutoencoder:
"""
The image autoencoder for all current Flux diffusion models:
- FLUX.1-dev
Expand All @@ -15,18 +27,47 @@ def flux_1_autoencoder() -> FluxAutoencoder:
- FLUX.1-Depth-dev
- FLUX.1-Fill-dev
ch = number of channels (size of the channel dimension)
Args:
resolution (int): The height/width of the square input image.
ch_in (int): The number of channels of the input image.
ch_out (int): The number of channels of the output image.
ch_base (int): The base number of channels.
This gets multiplied by `ch_mult` values to get the number of inner channels during downsampling/upsampling.
ch_mults (List[int]): The channel multiple per downsample/upsample block.
This gets multiplied by `ch_base` to get the number of inner channels during downsampling/upsampling.
ch_z (int): The number of latent channels (dimension of the latent vector `z`).
n_layers_per_resample_block (int): Number of resnet layers per downsample/upsample block.
scale_factor (float): Constant for scaling `z`.
shift_factor (float): Constant for shifting `z`.
Returns:
FluxAutoencoder
"""
# ch = number of channels (size of the channel dimension)
channels = [ch_base * mult for mult in ch_mults]

encoder = FluxEncoder(
ch_in=ch_in,
ch_z=ch_z,
channels=channels,
n_layers_per_down_block=n_layers_per_resample_block,
scale_factor=scale_factor,
shift_factor=shift_factor,
)

decoder = FluxDecoder(
ch_out=ch_out,
ch_z=ch_z,
channels=list(reversed(channels)),
# decoder gets one more layer per up block than the encoder's down blocks
n_layers_per_up_block=n_layers_per_resample_block + 1,
scale_factor=scale_factor,
shift_factor=shift_factor,
)

return FluxAutoencoder(
resolution=256,
ch_in=3,
ch_out=3,
ch_base=128,
ch_mults=[1, 2, 4, 4],
ch_z=16,
n_layers_per_resample_block=2,
scale_factor=0.3611,
shift_factor=0.1159,
img_shape=(ch_in, resolution, resolution),
encoder=encoder,
decoder=decoder,
)

0 comments on commit c78f485

Please sign in to comment.