Skip to content

Commit

Permalink
last cleanup for the day
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2025
1 parent 5d8aa54 commit 28eddd6
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions titans_pytorch/mac_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def identity(t):
def round_up_multiple(seq, mult):
return ceil(seq / mult) * mult

def pack_with_inverse(t, pattern):
packed, packed_shape = pack(t, pattern)

def inverse(out, inv_pattern = None):
return unpack(out, packed_shape, default(inv_pattern, pattern))

return packed, inverse

def pad_at_dim(t, pad, dim = -1, value = 0.):
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = ((0, 0) * dims_from_right)
Expand Down Expand Up @@ -576,7 +584,7 @@ def forward(
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)

mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
x, mem_ps = pack((x, mems), 'b * d')
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')

x = inverse_segment(x)

Expand Down Expand Up @@ -634,7 +642,7 @@ def forward(

x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)

x, _ = unpack(x, mem_ps, 'b * d')
x, _ = inverse_pack_mems(x)

x = inverse_segment(x)

Expand Down

0 comments on commit 28eddd6

Please sign in to comment.