diff --git a/titans_pytorch/mac_transformer.py b/titans_pytorch/mac_transformer.py index b97d606..5651108 100644 --- a/titans_pytorch/mac_transformer.py +++ b/titans_pytorch/mac_transformer.py @@ -83,6 +83,14 @@ def identity(t): def round_up_multiple(seq, mult): return ceil(seq / mult) * mult +def pack_with_inverse(t, pattern): + packed, packed_shape = pack(t, pattern) + + def inverse(out, inv_pattern = None): + return unpack(out, packed_shape, default(inv_pattern, pattern)) + + return packed, inverse + def pad_at_dim(t, pad, dim = -1, value = 0.): dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) zeros = ((0, 0) * dims_from_right) @@ -576,7 +584,7 @@ def forward( x, inverse_segment = pad_and_segment_with_inverse(x, segment_len) mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0]) - x, mem_ps = pack((x, mems), 'b * d') + x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d') x = inverse_segment(x) @@ -634,7 +642,7 @@ def forward( x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens) - x, _ = unpack(x, mem_ps, 'b * d') + x, _ = inverse_pack_mems(x) x = inverse_segment(x)