Skip to content

Commit

Permalink
improve readmen
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Feb 18, 2024
1 parent 22e415d commit f3fe937
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ Some other folders:

I've been commiting my research code to the `Research` folder but am not expecting other people use or look at that.

## Loading Sparse Autoencoders from Huggingface

[Previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream).

```python
import torch
from sae_training.utils import LMSparseAutoencoderSessionloader
from huggingface_hub import hf_hub_download

layer = 8 # pick a layer you want.
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
path = path
)
sparse_autoencoder.eval()
```

You can also load the feature sparsity from huggingface.

```python
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_sparsity = torch.load(path, map_location=sparse_autoencoder.cfg.device)

```


## Training a Sparse Autoencoder on a Language Model

Sparse Autoencoders can be intimidating at first but it's fairly simple to train one once you know what each part of the config does. I've created a config class which you instantiate and pass to the runner which will complete your training run and log it's progress to wandb.
Expand Down Expand Up @@ -87,7 +116,6 @@ cfg = LanguageModelSAERunnerConfig(

# Dead Neurons and Sparsity
use_ghost_grads=True,
feature_sampling_method = None,
feature_sampling_window = 1000,
dead_feature_window=5000,
dead_feature_threshold = 1e-6,
Expand Down

0 comments on commit f3fe937

Please sign in to comment.