This repository complements the Flash STU: Fast Spectral Transform Units paper and contains an optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU) as proposed in Spectral State Space Models by Agarwal et al. (2024).
The STU module is a fast and flexible building block that can be adapted into a wide range of neural network architectures, especially those that aim to solve tasks with long-range dependencies.
- ⚡️ Fast convolutions using Flash FFT
- 🚀 Fast, local attention using (sliding window) Flash Attention
- 🌐 Support for distributed training using DDP and FSDP
Note: CUDA is required to run code from this repository.
This repository was tested with:
- Python 3.12.5
- PyTorch 2.4.1
- Triton 3.0.0
- CUDA 12.4
and may be incompatible with other versions.
-
Install PyTorch with CUDA support:
pip install torch --index-url https://download.pytorch.org/whl/cu124
-
Install required packages:
pip install -e .
-
Install Flash Attention:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
-
Install Flash FFT:
pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv pip install git+https://github.com/HazyResearch/flash-fft-conv.git
Or from source:
pip install git+https://github.com/windsornguyen/flash-stu.git
Here is an example of how to import and use Flash STU:
from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters
import torch
device = torch.device('cuda') # Flash STU requires CUDA
config = FlashSTUConfig(
MODIFY_YOUR_ARGS_HERE,
)
phi = get_spectral_filters(
config.seq_len,
config.num_eigh,
config.use_hankel_L,
device,
config.torch_dtype
)
model = FlashSTU(
config,
phi
)
y = model(x)
An example LLM pretraining script is provided in example.py
for you to test out the repository.
If your compute cluster does not have internet access, you will need to pre-download the entire dataset before running the example training script.
To download the dataset, run:
cd training
python data.py
Note: The FineWeb-Edu 10B-token sample is a relatively large dataset. It can be swapped out for something smaller, e.g. TinyStories (476.6M tokens).
To begin training, make sure you are in the training
directory and run the following command in your terminal:
torchrun example.py
If you are in a compute cluster that uses Slurm and environment modules, you can submit a job using the following command:
sbatch job.slurm
Model configurations can be adjusted as needed in config.json
. Be sure to adjust the configurations of the Slurm job based on your cluster's constraints.
Note: PyTorch's
torch.compile
currently does not have great support for distributed wrapper modules like DDP or FSDP. If you encounter errors during training, try disablingtorch.compile
. For more information ontorch.compile
, see this informal manual.
Contributions are welcomed! Writing performant distributed code is always tricky. We welcome contributors to:
- Submit pull requests
- Report issues
- Help improve the project overall
Apache 2.0 License
You can freely use, modify, and distribute the software, even in proprietary products, as long as you:
- Include proper attribution
- Include a copy of the license
- Mention any changes made
It also provides an express grant of patent rights from contributors.
See the LICENSE file for more details.
Special thanks to (in no particular order):
- Elad Hazan and the authors of the Spectral State Space Models paper
- Isabel Liu, Yagiz Devre, Evan Dogariu
- The Flash Attention team
- The Flash FFT team
- The PyTorch team
- Princeton Research Computing and Princeton Language and Intelligence, for supplying compute
- Andrej Karpathy, for his awesome NanoGPT repository
If you use this repository, or otherwise find our work valuable, please cite Flash STU:
@article{flashstu,
title={Flash STU: Fast Spectral Transform Units},
author={Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan},
journal={arXiv preprint arXiv:2409.10489},
year={2024},
url={https://arxiv.org/abs/2409.10489}
}