diff --git a/pyproject.toml b/pyproject.toml index 00b9f3a..26c8bda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "titans-pytorch" -version = "0.1.9" +version = "0.1.10" description = "Titans" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -26,7 +26,7 @@ classifiers=[ dependencies = [ "accelerated-scan>=0.2.0", - "axial_positional_embedding>=0.3.5", + "axial_positional_embedding>=0.3.6", "einops>=0.8.0", "einx>=0.3.0", "hyper-connections>=0.1.8", diff --git a/titans_pytorch/mac_transformer.py b/titans_pytorch/mac_transformer.py index 5651108..51fef1e 100644 --- a/titans_pytorch/mac_transformer.py +++ b/titans_pytorch/mac_transformer.py @@ -593,11 +593,9 @@ def forward( # apply axial positional embedding # so intra and inter segment can be more easily discerned by the network - neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len) + pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,)) - pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True) - - x = x + pos_emb[:seq_len_with_mem] + x = x + pos_emb # prep flex attention