Skip to content

Commit

Permalink
another modification to karras unet3d for @QuantPrincess medical imag…
Browse files Browse the repository at this point in the history
…ing work lucidrains#295
  • Loading branch information
lucidrains committed Feb 26, 2024
1 parent cd85c6d commit 318ac92
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 21 deletions.
121 changes: 101 additions & 20 deletions denoising_diffusion_pytorch/karras_unet_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
from math import sqrt, ceil
from functools import partial
from typing import Optional, Union, Tuple

import torch
from torch import nn, einsum
Expand Down Expand Up @@ -207,12 +208,15 @@ def __init__(
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
downsample = False
downsample = False,
downsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
super().__init__()
dim_out = default(dim_out, dim)

self.downsample = downsample
self.downsample_config = downsample_config

self.downsample_conv = None

curr_dim = dim
Expand Down Expand Up @@ -259,7 +263,10 @@ def forward(
):
if self.downsample:
t, h, w = x.shape[-3:]
x = F.interpolate(x, (t // 2, h // 2, w // 2), mode = 'trilinear')
resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config)
interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors))

x = F.interpolate(x, interpolate_shape, mode = 'trilinear')
x = self.downsample_conv(x)

x = self.pixel_norm(x)
Expand Down Expand Up @@ -294,12 +301,15 @@ def __init__(
attn_dim_head = 64,
attn_res_mp_add_t = 0.3,
attn_flash = False,
upsample = False
upsample = False,
upsample_config: Tuple[bool, bool, bool] = (True, True, True)
):
super().__init__()
dim_out = default(dim_out, dim)

self.upsample = upsample
self.upsample_config = upsample_config

self.needs_skip = not upsample

self.to_emb = None
Expand Down Expand Up @@ -341,7 +351,10 @@ def forward(
):
if self.upsample:
t, h, w = x.shape[-3:]
x = F.interpolate(x, (t * 2, h * 2, w * 2), mode = 'trilinear')
resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config)
interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors))

x = F.interpolate(x, interpolate_shape, mode = 'trilinear')

res = self.res_conv(x)

Expand Down Expand Up @@ -416,12 +429,14 @@ def __init__(
self,
*,
image_size,
frames,
dim = 192,
dim_max = 768, # channels will double every downsample and cap out to this value
num_classes = None, # in paper, they do 1000 classes for a popular benchmark
channels = 4, # 4 channels in paper for some reason, must be alpha channel?
num_downsamples = 3,
num_blocks_per_stage = 4,
num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4,
downsample_types: Optional[Tuple[str, ...]] = None,
attn_res = (16, 8),
fourier_dim = 16,
attn_dim_head = 64,
Expand All @@ -440,7 +455,9 @@ def __init__(
# determine dimensions

self.channels = channels
self.frames = frames
self.image_size = image_size

input_channels = channels * (2 if self_condition else 1)

# input and output blocks
Expand Down Expand Up @@ -478,6 +495,25 @@ def __init__(

self.num_downsamples = num_downsamples

# specifying downsample types (either image, frames, or both)

downsample_types = default(downsample_types, 'all')
downsample_types = cast_tuple(downsample_types, num_downsamples)

assert len(downsample_types) == num_downsamples
assert all([t in {'all', 'frame', 'image'} for t in downsample_types])

# number of blocks per downsample

num_blocks_per_stage = cast_tuple(num_blocks_per_stage, num_downsamples)

if len(num_blocks_per_stage) == num_downsamples:
first, *_ = num_blocks_per_stage
num_blocks_per_stage = (first, *num_blocks_per_stage)

assert len(num_blocks_per_stage) == (num_downsamples + 1)
assert all([num_blocks >= 1 for num_blocks in num_blocks_per_stage])

# attention

attn_res = set(cast_tuple(attn_res))
Expand All @@ -498,17 +534,18 @@ def __init__(
self.ups = ModuleList([])

curr_dim = dim
curr_res = image_size
curr_image_res = image_size
curr_frame_res = frames

self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)

# take care of skip connection for initial input block and first three encoder blocks

prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))

assert num_blocks_per_stage >= 1
init_num_blocks_per_stage, *rest_num_blocks_per_stage = num_blocks_per_stage

for _ in range(num_blocks_per_stage):
for _ in range(init_num_blocks_per_stage):
enc = Encoder(curr_dim, curr_dim, **block_kwargs)
dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)

Expand All @@ -517,20 +554,53 @@ def __init__(

# stages

for _ in range(self.num_downsamples):
for _, layer_num_blocks_per_stage, layer_downsample_type in zip(range(self.num_downsamples), rest_num_blocks_per_stage, downsample_types):

dim_out = min(dim_max, curr_dim * 2)
upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)

curr_res //= 2
has_attn = curr_res in attn_res
downsample_image = layer_downsample_type in {'all', 'image'}
downsample_frame = layer_downsample_type in {'all', 'frame'}

downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)
assert not (downsample_image and not divisible_by(curr_image_res, 2))
assert not (downsample_frame and not divisible_by(curr_frame_res, 2))

down_and_upsample_config = (
downsample_frame,
downsample_image,
downsample_image
)

upsample = Decoder(
dim_out,
curr_dim,
has_attn = curr_image_res in attn_res,
upsample = True,
upsample_config = down_and_upsample_config,
**block_kwargs
)

if downsample_image:
curr_image_res //= 2

if downsample_frame:
curr_frame_res //= 2

has_attn = curr_image_res in attn_res

downsample = Encoder(
curr_dim,
dim_out,
downsample = True,
downsample_config = down_and_upsample_config,
has_attn = has_attn,
**block_kwargs
)

append(self.downs, downsample)
prepend(self.ups, upsample)
prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))

for _ in range(num_blocks_per_stage):
for _ in range(layer_num_blocks_per_stage):
enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)

Expand All @@ -541,7 +611,7 @@ def __init__(

# take care of the two middle decoders

mid_has_attn = curr_res in attn_res
mid_has_attn = curr_image_res in attn_res

self.mids = ModuleList([
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
Expand All @@ -563,7 +633,7 @@ def forward(
):
# validate image shape

assert x.shape[1:] == (self.channels, self.image_size, self.image_size, self.image_size)
assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size)

# self conditioning

Expand Down Expand Up @@ -689,19 +759,30 @@ def forward(self, x):
# example

if __name__ == '__main__':

unet = KarrasUnet3D(
frames = 32,
image_size = 64,
dim = 192,
dim = 8,
dim_max = 768,
num_downsamples = 6,
num_blocks_per_stage = (4, 3, 2, 2, 2, 2),
downsample_types = (
'image',
'frame',
'image',
'frame',
'image',
'frame',
),
attn_dim_head = 8,
num_classes = 1000,
)

images = torch.randn(2, 4, 64, 64, 64)
images = torch.randn(2, 4, 32, 64, 64)

denoised_images = unet(
images,
time = torch.ones(2,),
class_labels = torch.randint(0, 1000, (2,))
)

assert denoised_images.shape == images.shape
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.15'
__version__ = '1.10.16'

0 comments on commit 318ac92

Please sign in to comment.