Skip to content

Commit

Permalink
cleanup axial pos emb
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 21, 2025
1 parent 8ba1245 commit 9bd87ac
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.9"
version = "0.1.10"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand All @@ -26,7 +26,7 @@ classifiers=[

dependencies = [
"accelerated-scan>=0.2.0",
"axial_positional_embedding>=0.3.5",
"axial_positional_embedding>=0.3.6",
"einops>=0.8.0",
"einx>=0.3.0",
"hyper-connections>=0.1.8",
Expand Down
6 changes: 2 additions & 4 deletions titans_pytorch/mac_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,9 @@ def forward(
# apply axial positional embedding
# so intra and inter segment can be more easily discerned by the network

neural_mem_windows = ceil(seq_len_with_mem / neural_mem_segment_len)
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))

pos_emb = self.axial_pos_emb((neural_mem_windows, neural_mem_segment_len), flatten = True)

x = x + pos_emb[:seq_len_with_mem]
x = x + pos_emb

# prep flex attention

Expand Down

0 comments on commit 9bd87ac

Please sign in to comment.