Skip to content

Commit

Permalink
use full attention at the center of the unet
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2022
1 parent eb6e1b5 commit f461559
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
25 changes: 24 additions & 1 deletion denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,29 @@ def forward(self, x):
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)

class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)

def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q * self.scale

sim = einsum('b h d i, b h d j -> b h i j', q, k)
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)

out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)

# model

class Unet(nn.Module):
Expand Down Expand Up @@ -220,7 +243,7 @@ def __init__(

mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.7.1',
version = '0.8.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f461559

Please sign in to comment.