Skip to content

Commit

Permalink
move assoc scan related stuff outside of store_memories, so it is eas…
Browse files Browse the repository at this point in the history
…ier to read
  • Loading branch information
lucidrains committed Jan 21, 2025
1 parent e6e86fd commit 80f0a66
Showing 1 changed file with 45 additions and 32 deletions.
77 changes: 45 additions & 32 deletions titans_pytorch/titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,45 @@ def forward(self, x):

return out

# associative scan wrapper

class AssocScan(Module):
def __init__(
self,
use_accelerated = False
):
super().__init__()
self.use_accelerated = use_accelerated

def forward(self, gates, inputs):

if not self.use_accelerated:
_, outputs = associative_scan(binary_operator, (gates, inputs))
return outputs

from accelerated_scan.triton import scan as triton_scan
from accelerated_scan.warp import scan as warp_scan

scan = triton_scan if gates.is_cuda else warp_scan

def accelerate_scan_fn(gates, inputs):
gates = gates.expand_as(inputs)
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))

seq_len = gates.shape[-1]
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))

gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))

outputs = scan(gates.contiguous(), inputs.contiguous())

outputs = outputs[..., :seq_len]
outputs = rearrange(outputs, 'b d n -> b n d')
return outputs

return accelerate_scan_fn(gates, inputs)

# main neural memory

def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
Expand Down Expand Up @@ -339,6 +378,10 @@ def __init__(

self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)

# associative scan

self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)

# norms

self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
Expand Down Expand Up @@ -566,36 +609,6 @@ def store_memories(

# determine scan function

def default_associative_scan(gates, inputs):
_, outputs = associative_scan(binary_operator, (gates, inputs))
return outputs

if self.use_accelerated_scan:
from accelerated_scan.triton import scan as triton_scan
from accelerated_scan.warp import scan as warp_scan

scan = triton_scan if seq.is_cuda else warp_scan

def accelerate_scan_fn(gates, inputs):
gates = gates.expand_as(inputs)
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))

seq_len = gates.shape[-1]
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))

gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))

outputs = scan(gates.contiguous(), inputs.contiguous())

outputs = outputs[..., :seq_len]
outputs = rearrange(outputs, 'b d n -> b n d')
return outputs

scan_fn = accelerate_scan_fn
else:
scan_fn = default_associative_scan

# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates

next_momentum = TensorDict() if has_momentum else None
Expand All @@ -610,12 +623,12 @@ def accelerate_scan_fn(gates, inputs):
# derive momentum with associative scan - eq (10)

if has_momentum:
update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper
momentum = update

# use associative scan again for learned forgetting (weight decay) - eq (13)

update = scan_fn(1. - decay_factor, update)
update = self.assoc_scan(1. - decay_factor, update)

updates[param_name] = inverse_pack(update)

Expand Down

0 comments on commit 80f0a66

Please sign in to comment.