Skip to content

Commit

Permalink
just give knn attention its own relative positional bias
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 23, 2022
1 parent aaf9a0a commit 83fa147
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def __init__(
super().__init__()
self.heads = heads
self.scale = nn.Parameter(torch.ones(heads, 1, 1) * math.log(attn_scale_init))
self.local_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))
self.knn_attn_bias = nn.Parameter(torch.zeros(heads, 1, 1))

inner_dim = heads * dim_head
self.xl_max_memories = xl_max_memories
Expand Down Expand Up @@ -242,7 +240,6 @@ def forward(
if exists(rel_pos_bias):
sim = rel_pos_bias[..., -i:, -j:] + sim

sim = sim + self.local_attn_bias
mask_value = -torch.finfo(sim.dtype).max

causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
Expand All @@ -254,7 +251,6 @@ def forward(
mem_k, mem_v = mem_kv.unbind(dim = -2)

sim_mem = einsum('b h i d, b h i j d -> b h i j', q, mem_k) * scale
sim_mem = sim_mem + self.knn_attn_bias
sim_mem = sim_mem.masked_fill(~mem_mask, mask_value)

# calculate new XL memories, as well as memories to be discarded
Expand Down Expand Up @@ -356,6 +352,7 @@ def __init__(
# relative positional bias

self.rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)
self.knn_rel_pos_bias = T5RelativePositionBias(scale = dim_head ** 0.5, heads = heads)

# layers

Expand Down Expand Up @@ -481,6 +478,7 @@ def forward(
max_context_len = max([seq_len, *map(lambda t: (t.shape[-3] if exists(t) else 0) + seq_len, xl_memories)])

rel_pos_bias = self.rel_pos_bias(seq_len, max_context_len, device = device)
knn_rel_pos_bias = self.knn_rel_pos_bias(seq_len, max_context_len, device = device)

# keep track of new xl memories

Expand All @@ -494,7 +492,7 @@ def forward(
is_memorizing_layer = layer_num in self.memorizing_layers
is_xl_memory_layer = layer_num in self.xl_memory_layers

attn_kwargs = dict(rel_pos_bias = rel_pos_bias)
attn_kwargs = dict(rel_pos_bias = rel_pos_bias if not is_memorizing_layer else knn_rel_pos_bias)

if is_memorizing_layer:
attn_kwargs = {**attn_kwargs, 'knn_memory': next(knn_memories_iter), 'add_knn_memory': add_knn_memory}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memorizing-transformers-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.6',
version = '0.3.7',
license='MIT',
description = 'Memorizing Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 83fa147

Please sign in to comment.