diff --git a/denoising_diffusion_pytorch/karras_unet_3d.py b/denoising_diffusion_pytorch/karras_unet_3d.py index 717ea90b2..4c224d7d7 100644 --- a/denoising_diffusion_pytorch/karras_unet_3d.py +++ b/denoising_diffusion_pytorch/karras_unet_3d.py @@ -5,6 +5,7 @@ import math from math import sqrt, ceil from functools import partial +from typing import Optional, Union, Tuple import torch from torch import nn, einsum @@ -207,12 +208,15 @@ def __init__( attn_dim_head = 64, attn_res_mp_add_t = 0.3, attn_flash = False, - downsample = False + downsample = False, + downsample_config: Tuple[bool, bool, bool] = (True, True, True) ): super().__init__() dim_out = default(dim_out, dim) self.downsample = downsample + self.downsample_config = downsample_config + self.downsample_conv = None curr_dim = dim @@ -259,7 +263,10 @@ def forward( ): if self.downsample: t, h, w = x.shape[-3:] - x = F.interpolate(x, (t // 2, h // 2, w // 2), mode = 'trilinear') + resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config) + interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors)) + + x = F.interpolate(x, interpolate_shape, mode = 'trilinear') x = self.downsample_conv(x) x = self.pixel_norm(x) @@ -294,12 +301,15 @@ def __init__( attn_dim_head = 64, attn_res_mp_add_t = 0.3, attn_flash = False, - upsample = False + upsample = False, + upsample_config: Tuple[bool, bool, bool] = (True, True, True) ): super().__init__() dim_out = default(dim_out, dim) self.upsample = upsample + self.upsample_config = upsample_config + self.needs_skip = not upsample self.to_emb = None @@ -341,7 +351,10 @@ def forward( ): if self.upsample: t, h, w = x.shape[-3:] - x = F.interpolate(x, (t * 2, h * 2, w * 2), mode = 'trilinear') + resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config) + interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors)) + + x = F.interpolate(x, interpolate_shape, mode = 'trilinear') res = self.res_conv(x) @@ -416,12 +429,14 @@ def __init__( self, *, image_size, + frames, dim = 192, dim_max = 768, # channels will double every downsample and cap out to this value num_classes = None, # in paper, they do 1000 classes for a popular benchmark channels = 4, # 4 channels in paper for some reason, must be alpha channel? num_downsamples = 3, - num_blocks_per_stage = 4, + num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4, + downsample_types: Optional[Tuple[str, ...]] = None, attn_res = (16, 8), fourier_dim = 16, attn_dim_head = 64, @@ -440,7 +455,9 @@ def __init__( # determine dimensions self.channels = channels + self.frames = frames self.image_size = image_size + input_channels = channels * (2 if self_condition else 1) # input and output blocks @@ -478,6 +495,25 @@ def __init__( self.num_downsamples = num_downsamples + # specifying downsample types (either image, frames, or both) + + downsample_types = default(downsample_types, 'all') + downsample_types = cast_tuple(downsample_types, num_downsamples) + + assert len(downsample_types) == num_downsamples + assert all([t in {'all', 'frame', 'image'} for t in downsample_types]) + + # number of blocks per downsample + + num_blocks_per_stage = cast_tuple(num_blocks_per_stage, num_downsamples) + + if len(num_blocks_per_stage) == num_downsamples: + first, *_ = num_blocks_per_stage + num_blocks_per_stage = (first, *num_blocks_per_stage) + + assert len(num_blocks_per_stage) == (num_downsamples + 1) + assert all([num_blocks >= 1 for num_blocks in num_blocks_per_stage]) + # attention attn_res = set(cast_tuple(attn_res)) @@ -498,7 +534,8 @@ def __init__( self.ups = ModuleList([]) curr_dim = dim - curr_res = image_size + curr_image_res = image_size + curr_frame_res = frames self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1) @@ -506,9 +543,9 @@ def __init__( prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs)) - assert num_blocks_per_stage >= 1 + init_num_blocks_per_stage, *rest_num_blocks_per_stage = num_blocks_per_stage - for _ in range(num_blocks_per_stage): + for _ in range(init_num_blocks_per_stage): enc = Encoder(curr_dim, curr_dim, **block_kwargs) dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs) @@ -517,20 +554,53 @@ def __init__( # stages - for _ in range(self.num_downsamples): + for _, layer_num_blocks_per_stage, layer_downsample_type in zip(range(self.num_downsamples), rest_num_blocks_per_stage, downsample_types): + dim_out = min(dim_max, curr_dim * 2) - upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs) - curr_res //= 2 - has_attn = curr_res in attn_res + downsample_image = layer_downsample_type in {'all', 'image'} + downsample_frame = layer_downsample_type in {'all', 'frame'} - downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs) + assert not (downsample_image and not divisible_by(curr_image_res, 2)) + assert not (downsample_frame and not divisible_by(curr_frame_res, 2)) + + down_and_upsample_config = ( + downsample_frame, + downsample_image, + downsample_image + ) + + upsample = Decoder( + dim_out, + curr_dim, + has_attn = curr_image_res in attn_res, + upsample = True, + upsample_config = down_and_upsample_config, + **block_kwargs + ) + + if downsample_image: + curr_image_res //= 2 + + if downsample_frame: + curr_frame_res //= 2 + + has_attn = curr_image_res in attn_res + + downsample = Encoder( + curr_dim, + dim_out, + downsample = True, + downsample_config = down_and_upsample_config, + has_attn = has_attn, + **block_kwargs + ) append(self.downs, downsample) prepend(self.ups, upsample) prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)) - for _ in range(num_blocks_per_stage): + for _ in range(layer_num_blocks_per_stage): enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs) dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs) @@ -541,7 +611,7 @@ def __init__( # take care of the two middle decoders - mid_has_attn = curr_res in attn_res + mid_has_attn = curr_image_res in attn_res self.mids = ModuleList([ Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs), @@ -563,7 +633,7 @@ def forward( ): # validate image shape - assert x.shape[1:] == (self.channels, self.image_size, self.image_size, self.image_size) + assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size) # self conditioning @@ -689,19 +759,30 @@ def forward(self, x): # example if __name__ == '__main__': + unet = KarrasUnet3D( + frames = 32, image_size = 64, - dim = 192, + dim = 8, dim_max = 768, + num_downsamples = 6, + num_blocks_per_stage = (4, 3, 2, 2, 2, 2), + downsample_types = ( + 'image', + 'frame', + 'image', + 'frame', + 'image', + 'frame', + ), + attn_dim_head = 8, num_classes = 1000, ) - images = torch.randn(2, 4, 64, 64, 64) + images = torch.randn(2, 4, 32, 64, 64) denoised_images = unet( images, time = torch.ones(2,), class_labels = torch.randint(0, 1000, (2,)) ) - - assert denoised_images.shape == images.shape diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index edf68c3f9..0c191203f 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.15' +__version__ = '1.10.16'