Skip to content

hazan-lab/flash-stu

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

⚡️ Flash STU ⚡️

Flash STU Logo

Table of Contents

  1. Introduction
  2. Features
  3. Installation
  4. Usage
  5. Configuration
  6. Contributing
  7. License
  8. Acknowledgments

Introduction

This repository complements the Flash STU: Fast Spectral Transform Units paper and contains an optimized, open-source PyTorch implementation of the Spectral Transform Unit (STU) as proposed in Spectral State Space Models by Agarwal et al. (2024).

The STU module is a fast and flexible building block that can be adapted into a wide range of neural network architectures, especially those that aim to solve tasks with long-range dependencies.

Features

  • ⚡️ Fast convolutions using Flash FFT
  • 🚀 Fast, local attention using (sliding window) Flash Attention
  • 🌐 Support for distributed training using DDP and FSDP

Installation

Note: CUDA is required to run code from this repository.

This repository was tested with:

  • Python 3.12.5
  • PyTorch 2.4.1
  • Triton 3.0.0
  • CUDA 12.4

and may be incompatible with other versions.

  1. Install PyTorch with CUDA support:

    pip install torch --index-url https://download.pytorch.org/whl/cu124
  2. Install required packages:

    pip install -e .
  3. Install Flash Attention:

    MAX_JOBS=4 pip install flash-attn --no-build-isolation
  4. Install Flash FFT:

     pip install git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
     pip install git+https://github.com/HazyResearch/flash-fft-conv.git

Or from source:

pip install git+https://github.com/windsornguyen/flash-stu.git

Usage

Using Flash STU

Here is an example of how to import and use Flash STU:

from flash_stu import FlashSTU, FlashSTUConfig, get_spectral_filters
import torch

device = torch.device('cuda') # Flash STU requires CUDA

config = FlashSTUConfig(
   MODIFY_YOUR_ARGS_HERE,
)

phi = get_spectral_filters(
   config.seq_len, 
   config.num_eigh, 
   config.use_hankel_L, 
   device, 
   config.torch_dtype
)

model = FlashSTU(
   config, 
   phi
)

y = model(x)

Training

An example LLM pretraining script is provided in example.py for you to test out the repository.

If your compute cluster does not have internet access, you will need to pre-download the entire dataset before running the example training script.

To download the dataset, run:

cd training
python data.py

Note: The FineWeb-Edu 10B-token sample is a relatively large dataset. It can be swapped out for something smaller, e.g. TinyStories (476.6M tokens).

To begin training, make sure you are in the training directory and run the following command in your terminal:

torchrun example.py

If you are in a compute cluster that uses Slurm and environment modules, you can submit a job using the following command:

sbatch job.slurm

Model configurations can be adjusted as needed in config.json. Be sure to adjust the configurations of the Slurm job based on your cluster's constraints.

Note: PyTorch's torch.compile currently does not have great support for distributed wrapper modules like DDP or FSDP. If you encounter errors during training, try disabling torch.compile. For more information on torch.compile, see this informal manual.

Contributing

Contributions are welcomed! Writing performant distributed code is always tricky. We welcome contributors to:

  • Submit pull requests
  • Report issues
  • Help improve the project overall

License

Apache 2.0 License

You can freely use, modify, and distribute the software, even in proprietary products, as long as you:

  • Include proper attribution
  • Include a copy of the license
  • Mention any changes made

It also provides an express grant of patent rights from contributors.

See the LICENSE file for more details.

Acknowledgments

Special thanks to (in no particular order):

  • Elad Hazan and the authors of the Spectral State Space Models paper
  • Isabel Liu, Yagiz Devre, Evan Dogariu
  • The Flash Attention team
  • The Flash FFT team
  • The PyTorch team
  • Princeton Research Computing and Princeton Language and Intelligence, for supplying compute
  • Andrej Karpathy, for his awesome NanoGPT repository

Citation

If you use this repository, or otherwise find our work valuable, please cite Flash STU:

@article{flashstu,
  title={Flash STU: Fast Spectral Transform Units},
  author={Y. Isabel Liu, Windsor Nguyen, Yagiz Devre, Evan Dogariu, Anirudha Majumdar, Elad Hazan},
  journal={arXiv preprint arXiv:2409.10489},
  year={2024},
  url={https://arxiv.org/abs/2409.10489}
}

About

PyTorch implementation of the Flash Spectral Transform Unit.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published