This repository contains the implementation of the Titans architecture, a next-generation framework for scalable sequence modeling introduced in the paper "Titans: Learning to Memorize at Test Time". Titans redefine memory management in deep learning, seamlessly integrating short-term and long-term memory modules to handle large context windows efficiently and effectively.
- Memory as Context (MAC): Combines input sequences with long-term and persistent memory, using attention mechanisms to dynamically decide the relevance of historical data.
- Memory as Gate (MAG): Employs sliding-window attention for short-term memory and a gating mechanism to blend long-term context effectively.
- Memory as Layer (MAL): Treats the memory module as an independent layer, compressing past and current information before attention mechanisms.
PersistentMemory
: Provides static task-specific knowledge.LongTermMemory
: Encodes historical patterns for effective retrieval.SlidingWindowAttention
: Processes short-term memory with a focus on recent context.- MAC/MAG/MAL Implementations: Three architectural variants tailored for different sequence modeling tasks.
titans_memory_architectures.py
: Core implementation of the Titans architecture, including MAC, MAG, and MAL variants.train.py
: Script for training the Titans model.evaluate.py
: Script for evaluating the model on specific datasets.datasets.py
: Preprocessing and loading scripts for various datasets.
# Import the MAC, MAG, and MAL architectures
from titans_memory_architectures import MemoryAsContext, MemoryAsGate, MemoryAsLayer
# Initialize models
mac = MemoryAsContext(feature_dim=128, memory_size=10)
mag = MemoryAsGate(feature_dim=128)
mal = MemoryAsLayer(feature_dim=128)
# Input data
inputs = torch.randn(8, 32, 128) # Batch size: 8, Sequence length: 32, Feature dimension: 128
# Forward pass
output_mac = mac(inputs)
output_mag = mag(inputs)
output_mal = mal(inputs)
Clone this repository:
git clone https://github.com/yourusername/titans-memory.git
cd titans-memory
Install the required dependencies:
pip install -r requirements.txt
- WikiText-103: For language modeling.
- PIQA, HellaSwag: For commonsense reasoning.
- ETTh/ETTm: For time-series forecasting.
Use the datasets.py
script to preprocess your dataset. Example:
python datasets.py --dataset wikitext --output_dir ./processed_data
Train the Titans model using train.py
:
python train.py --model mac --dataset ./processed_data --epochs 10 --batch_size 16
Evaluate the model using evaluate.py
:
python evaluate.py --model_path ./checkpoints/best_model.pt --dataset ./processed_data
- Language Modeling: Achieved state-of-the-art perplexity on WikiText-103.
- Commonsense Reasoning: Outperformed GPT-4 and Llama 3.1 on PIQA and HellaSwag.
- Time-Series Forecasting: Showcased exceptional ability to model long-term dependencies.
Contributions are welcome! Feel free to submit issues or pull requests.
This repository is licensed under the MIT License.