Skip to content

Commit

Permalink
trust Tero Karras as well as lucidrains/meshgpt-pytorch#64 and start …
Browse files Browse the repository at this point in the history
…removing groupnorms from all repos
  • Loading branch information
lucidrains committed May 3, 2024
1 parent 9c9e403 commit 5999fc1
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 91 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ diffusion = GaussianDiffusion1D(
)

training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1
dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below

loss = diffusion(training_seq)
loss.backward()

# Or using trainer

dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below

trainer = Trainer1D(
diffusion,
dataset = dataset,
Expand Down
27 changes: 12 additions & 15 deletions denoising_diffusion_pytorch/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def forward(self, x):
# building block modules

class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
def __init__(self, dim, dim_out):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()

def forward(self, x, scale_shift = None):
Expand All @@ -166,15 +166,15 @@ def forward(self, x, scale_shift = None):
return x

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None, groups = 8):
def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(int(time_emb_dim) + int(classes_emb_dim), dim_out * 2)
) if exists(time_emb_dim) or exists(classes_emb_dim) else None

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None, class_emb = None):
Expand Down Expand Up @@ -258,7 +258,6 @@ def __init__(
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
Expand All @@ -283,8 +282,6 @@ def __init__(
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

block_klass = partial(ResnetBlock, groups = resnet_block_groups)

# time embeddings

time_dim = dim * 4
Expand Down Expand Up @@ -328,31 +325,31 @@ def __init__(
is_last = ind >= (num_resolutions - 1)

self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))

mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)

self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))

default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

def forward_with_cond_scale(
Expand Down
61 changes: 30 additions & 31 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import torch
from torch import nn, einsum
from torch.cuda.amp import autocast
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.cuda.amp import autocast
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
Expand Down Expand Up @@ -98,17 +99,18 @@ def Downsample(dim, dim_out = None):
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)

class RMSNorm(nn.Module):
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

def forward(self, x):
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
return F.normalize(x, dim = 1) * self.g * self.scale

# sinusoidal positional embeds

class SinusoidalPosEmb(nn.Module):
class SinusoidalPosEmb(Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.dim = dim
Expand All @@ -123,7 +125,7 @@ def forward(self, x):
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

class RandomOrLearnedSinusoidalPosEmb(nn.Module):
class RandomOrLearnedSinusoidalPosEmb(Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

Expand All @@ -142,11 +144,11 @@ def forward(self, x):

# building block modules

class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
class Block(Module):
def __init__(self, dim, dim_out):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()

def forward(self, x, scale_shift = None):
Expand All @@ -160,16 +162,16 @@ def forward(self, x, scale_shift = None):
x = self.act(x)
return x

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
class ResnetBlock(Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None):
Expand All @@ -186,7 +188,7 @@ def forward(self, x, time_emb = None):

return h + self.res_conv(x)

class LinearAttention(nn.Module):
class LinearAttention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -231,7 +233,7 @@ def forward(self, x):
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -269,7 +271,7 @@ def forward(self, x):

# model

class Unet(nn.Module):
class Unet(Module):
def __init__(
self,
dim,
Expand All @@ -278,7 +280,6 @@ def __init__(
dim_mults = (1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
Expand All @@ -303,8 +304,6 @@ def __init__(
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

block_klass = partial(ResnetBlock, groups = resnet_block_groups)

# time embeddings

time_dim = dim * 4
Expand Down Expand Up @@ -341,43 +340,43 @@ def __init__(

# layers

self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
self.downs = ModuleList([])
self.ups = ModuleList([])
num_resolutions = len(in_out)

for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
is_last = ind >= (num_resolutions - 1)

attn_klass = FullAttention if layer_full_attn else LinearAttention

self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
self.downs.append(ModuleList([
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim),
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))

mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim)

for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
is_last = ind == (len(in_out) - 1)

attn_klass = FullAttention if layer_full_attn else LinearAttention

self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
self.ups.append(ModuleList([
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))

default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

@property
Expand Down Expand Up @@ -470,7 +469,7 @@ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)

class GaussianDiffusion(nn.Module):
class GaussianDiffusion(Module):
def __init__(
self,
model,
Expand Down Expand Up @@ -856,7 +855,7 @@ def __getitem__(self, index):

# trainer class

class Trainer(object):
class Trainer:
def __init__(
self,
diffusion_model,
Expand Down
27 changes: 12 additions & 15 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def forward(self, x):
# building block modules

class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
def __init__(self, dim, dim_out):
super().__init__()
self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()

def forward(self, x, scale_shift = None):
Expand All @@ -173,15 +173,15 @@ def forward(self, x, scale_shift = None):
return x

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
def __init__(self, dim, dim_out, *, time_emb_dim = None):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)
self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None):
Expand Down Expand Up @@ -262,7 +262,6 @@ def __init__(
dim_mults=(1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
Expand All @@ -285,8 +284,6 @@ def __init__(
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

block_klass = partial(ResnetBlock, groups = resnet_block_groups)

# time embeddings

time_dim = dim * 4
Expand Down Expand Up @@ -317,31 +314,31 @@ def __init__(
is_last = ind >= (num_resolutions - 1)

self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
]))

mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)

self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1)
]))

default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv1d(dim, self.out_dim, 1)

def forward(self, x, time, x_self_cond = None):
Expand Down
Loading

0 comments on commit 5999fc1

Please sign in to comment.