diff --git a/README.md b/README.md index 3af3d41c..13bcc59a 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,35 @@ Some other folders: I've been commiting my research code to the `Research` folder but am not expecting other people use or look at that. +## Loading Sparse Autoencoders from Huggingface + +[Previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream). + +```python +import torch +from sae_training.utils import LMSparseAutoencoderSessionloader +from huggingface_hub import hf_hub_download + +layer = 8 # pick a layer you want. +REPO_ID = "jbloom/GPT2-Small-SAEs" +FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt" +path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) +model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained( + path = path +) +sparse_autoencoder.eval() +``` + +You can also load the feature sparsity from huggingface. + +```python +FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt" +path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME) +log_feature_sparsity = torch.load(path, map_location=sparse_autoencoder.cfg.device) + +``` + + ## Training a Sparse Autoencoder on a Language Model Sparse Autoencoders can be intimidating at first but it's fairly simple to train one once you know what each part of the config does. I've created a config class which you instantiate and pass to the runner which will complete your training run and log it's progress to wandb. @@ -87,7 +116,6 @@ cfg = LanguageModelSAERunnerConfig( # Dead Neurons and Sparsity use_ghost_grads=True, - feature_sampling_method = None, feature_sampling_window = 1000, dead_feature_window=5000, dead_feature_threshold = 1e-6,