From 80f0a665f4802087a1f6300ee9dd932cec73dc87 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 21 Jan 2025 08:17:27 -0800 Subject: [PATCH] move assoc scan related stuff outside of store_memories, so it is easier to read --- titans_pytorch/titans.py | 77 +++++++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/titans_pytorch/titans.py b/titans_pytorch/titans.py index 17f5e32..9cf76cc 100644 --- a/titans_pytorch/titans.py +++ b/titans_pytorch/titans.py @@ -301,6 +301,45 @@ def forward(self, x): return out +# associative scan wrapper + +class AssocScan(Module): + def __init__( + self, + use_accelerated = False + ): + super().__init__() + self.use_accelerated = use_accelerated + + def forward(self, gates, inputs): + + if not self.use_accelerated: + _, outputs = associative_scan(binary_operator, (gates, inputs)) + return outputs + + from accelerated_scan.triton import scan as triton_scan + from accelerated_scan.warp import scan as warp_scan + + scan = triton_scan if gates.is_cuda else warp_scan + + def accelerate_scan_fn(gates, inputs): + gates = gates.expand_as(inputs) + gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs)) + + seq_len = gates.shape[-1] + next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len)))) + + gates = F.pad(gates, (0, next_power_two_seq_len - seq_len)) + inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len)) + + outputs = scan(gates.contiguous(), inputs.contiguous()) + + outputs = outputs[..., :seq_len] + outputs = rearrange(outputs, 'b d n -> b n d') + return outputs + + return accelerate_scan_fn(gates, inputs) + # main neural memory def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2): @@ -339,6 +378,10 @@ def __init__( self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size) + # associative scan + + self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan) + # norms self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity() @@ -566,36 +609,6 @@ def store_memories( # determine scan function - def default_associative_scan(gates, inputs): - _, outputs = associative_scan(binary_operator, (gates, inputs)) - return outputs - - if self.use_accelerated_scan: - from accelerated_scan.triton import scan as triton_scan - from accelerated_scan.warp import scan as warp_scan - - scan = triton_scan if seq.is_cuda else warp_scan - - def accelerate_scan_fn(gates, inputs): - gates = gates.expand_as(inputs) - gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs)) - - seq_len = gates.shape[-1] - next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len)))) - - gates = F.pad(gates, (0, next_power_two_seq_len - seq_len)) - inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len)) - - outputs = scan(gates.contiguous(), inputs.contiguous()) - - outputs = outputs[..., :seq_len] - outputs = rearrange(outputs, 'b d n -> b n d') - return outputs - - scan_fn = accelerate_scan_fn - else: - scan_fn = default_associative_scan - # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates next_momentum = TensorDict() if has_momentum else None @@ -610,12 +623,12 @@ def accelerate_scan_fn(gates, inputs): # derive momentum with associative scan - eq (10) if has_momentum: - update = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper + update = self.assoc_scan(adaptive_momentum, surprise) # momentum is S / surprise in the paper momentum = update # use associative scan again for learned forgetting (weight decay) - eq (13) - update = scan_fn(1. - decay_factor, update) + update = self.assoc_scan(1. - decay_factor, update) updates[param_name] = inverse_pack(update)