Skip to content

Commit

Permalink
throw out yet another memory model, gated residual mlp variant
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 20, 2025
1 parent 6d6721a commit 0e4a1d4
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.5"
version = "0.1.7"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
3 changes: 2 additions & 1 deletion titans_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
NeuralMemory,
MemoryMLP,
MemoryAttention,
FactorizedMemoryMLP
FactorizedMemoryMLP,
GatedResidualMemoryMLP
)

from titans_pytorch.mac_transformer import (
Expand Down
46 changes: 45 additions & 1 deletion titans_pytorch/titans.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
from typing import Callable

import math
from functools import partial

import torch
from torch import nn, Tensor
from torch import nn, cat, Tensor
import torch.nn.functional as F
from torch.nn import Linear, Module, Parameter, ParameterList
from torch.func import functional_call, vmap, grad
Expand Down Expand Up @@ -154,6 +155,49 @@ def forward(

return x

# memory mlp, but with gated residual + final projection

class GatedResidualMemoryMLP(Module):
def __init__(
self,
dim,
depth,
expansion_factor = 2.
):
super().__init__()
dim_hidden = int(dim * expansion_factor)

self.weights = ParameterList([
ParameterList([
Parameter(torch.randn(dim, dim_hidden)),
Parameter(torch.randn(dim_hidden, dim)),
Parameter(torch.randn(dim * 2, dim)),
]) for _ in range(depth)
])

self.final_proj = Parameter(torch.randn(dim, dim))

for param in self.parameters():
nn.init.xavier_uniform_(param)

def forward(
self,
x
):
for weight1, weight2, to_gates in self.weights:
res = x

hidden = x @ weight1
hidden = F.silu(hidden)
branch_out = hidden @ weight2

# gated residual

gates = cat((branch_out, res), dim = -1) @ to_gates
x = res.lerp(branch_out, gates.sigmoid())

return x @ self.final_proj

# memory mlp with factorized weights
# so can tradeoff capacity for smaller chunk sizes

Expand Down

0 comments on commit 0e4a1d4

Please sign in to comment.