Skip to content

Commit

Permalink
add readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 4, 2023
1 parent 2f162f0 commit e9b8e56
Showing 1 changed file with 87 additions and 5 deletions.
92 changes: 87 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# MATS SAE Training
This is a quick-and-dirty codebase for training SAEs that we put together
so MATS 5.0 Scholars under Neel Nanda could do cool projects in their research sprint.
It prioritizes development speed over having an intelligent structure.
If you want something that's better structured (but wasn't yet ready when we were making this), check out [this repo](https://github.com/ai-safety-foundation/sparse_autoencoder).

Lots of this code also copies from [Arthur's repo](https://github.com/ArthurConmy/sae).
It prioritizes development speed over structure.

## Set Up

Expand All @@ -14,4 +11,89 @@ conda create --name mats_sae_training python=3.11 -y
conda activate mats_sae_training
pip install -r requirements.txt
```
```

## Training a Sparse Autoencoder on a Language Model

(warning, config may go out of data!)

```python
import torch

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

# Data Generating Function (Model + Training Distibuion)
model_name = "gelu-2l",
hook_point = "blocks.0.hook_mlp_out",
hook_point_layer = 0,
d_in = 512,
dataset_path = "NeelNanda/c4-tokenized-2b",
is_dataset_tokenized=True,

# SAE Parameters
expansion_factor = 64, # determines the dimension of the SAE.

# Training Parameters
lr = 1e-4,
l1_coefficient = 3e-4,
train_batch_size = 4096,
context_size = 128,

# Activation Store Parameters
n_batches_in_buffer = 24,
total_training_tokens = 5_000_00 * 100, # 15 minutes on an A100
store_batch_size = 32,

# Resampling protocol
feature_sampling_method = 'l2',
feature_sampling_window = 1000, # would fire ~5 times on 500 million tokens
feature_reinit_scale = 0.2,
dead_feature_threshold = 1e-8,

# WANDB
log_to_wandb = True,
wandb_project= "mats_sae_training_language_models",
wandb_entity = None,

# Misc
device = "cuda",
seed = 42,
n_checkpoints = 5,
checkpoint_path = "checkpoints",
dtype = torch.float32,
)

sparse_autoencoder = language_model_sae_runner(cfg)

```


## Loading a Pretrained Language Model

```python

path ="path/to/sparse_autoencoder.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
path
)

```



## Citations and References:

Research:
- [Towards Monosemanticy](https://transformer-circuits.pub/2023/monosemantic-features)
- [Sparse Autoencoders Find Highly Interpretable Features in Language Model](https://arxiv.org/abs/2309.08600)



Reference Implementations:
- [Neel Nanda](https://github.com/neelnanda-io/1L-Sparse-Autoencoder)
- [AI-Safety-Foundation](https://github.com/ai-safety-foundation/sparse_autoencoder).
- [Arthur Conmy](https://github.com/ArthurConmy/sae).
- [Callum McDougall](https://github.com/callummcdougall/sae-exercises-mats/tree/main)

0 comments on commit e9b8e56

Please sign in to comment.