From f4615599bcbc884dcb7efa0dd70840f011a0fbb9 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 4 Apr 2022 09:03:41 -0700 Subject: [PATCH] use full attention at the center of the unet --- .../denoising_diffusion_pytorch.py | 25 ++++++++++++++++++- setup.py | 2 +- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 5fafe76e5..ae9b17a08 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -175,6 +175,29 @@ def forward(self, x): out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) return self.to_out(out) +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + q = q * self.scale + + sim = einsum('b h d i, b h d j -> b h i j', q, k) + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h d j -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + # model class Unet(nn.Module): @@ -220,7 +243,7 @@ def __init__( mid_dim = dims[-1] self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) - self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): diff --git a/setup.py b/setup.py index caf64391c..3db1e3852 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'denoising-diffusion-pytorch', packages = find_packages(), - version = '0.7.1', + version = '0.8.0', license='MIT', description = 'Denoising Diffusion Probabilistic Models - Pytorch', author = 'Phil Wang',