Skip to content

Commit

Permalink
Merge branch 'main' into faster-ci
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Apr 11, 2024
2 parents 9e3863c + 8784c74 commit 89e1568
Show file tree
Hide file tree
Showing 25 changed files with 600 additions and 482 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mkdocs-autorefs = "^1.0.1"
mkdocs-section-index = "^0.3.8"
mkdocstrings = "^0.24.1"
mkdocstrings-python = "^1.9.0"
safetensors = "^0.4.2"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -67,4 +68,4 @@ version_variables = [
"pyproject.toml:version",
]
branch = "main"
build_command = "pip install poetry && poetry build"
build_command = "pip install poetry && poetry build"
4 changes: 2 additions & 2 deletions sae_lens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .training.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig
from .training.evals import run_evals
from .training.lm_runner import language_model_sae_runner
from .training.sae_group import SAEGroup
from .training.sae_group import SparseAutoencoderDictionary
from .training.session_loader import LMSparseAutoencoderSessionloader
from .training.sparse_autoencoder import SparseAutoencoder
from .training.train_sae_on_language_model import train_sae_group_on_language_model
Expand All @@ -15,7 +15,7 @@
"CacheActivationsRunnerConfig",
"LMSparseAutoencoderSessionloader",
"SparseAutoencoder",
"SAEGroup",
"SparseAutoencoderDictionary",
"run_evals",
"language_model_sae_runner",
"cache_activations_runner",
Expand Down
6 changes: 3 additions & 3 deletions sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -24,6 +23,7 @@
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


Expand Down Expand Up @@ -134,9 +134,9 @@ def init_sae_session(self):
self.model,
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path)
# TODO: handle multiple autoencoders
self.sparse_autoencoder = sae_group.autoencoders[0]
self.sparse_autoencoder = next(iter(sae_group))[1]

def get_tokens(
self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6
Expand Down
11 changes: 11 additions & 0 deletions sae_lens/analysis/neuronpedia_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import webbrowser


def open_neuronpedia(
feature_id: int, layer: int = 0, model: str = "gpt2-small", dataset: str = "res-jb"
):

path_to_html = f"https://www.neuronpedia.org/{model}/{layer}-{dataset}/{feature_id}"

print(f"Feature {feature_id}")
webbrowser.open_new_tab(path_to_html)
4 changes: 2 additions & 2 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def init_sae_session(self):
self.model,
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path)
# TODO: handle multiple autoencoders
self.sparse_autoencoder = sae_group.autoencoders[0]
self.sparse_autoencoder = next(iter(sae_group))[1]

def get_tokens(
self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6
Expand Down
51 changes: 0 additions & 51 deletions sae_lens/analysis/toolkit.py

This file was deleted.

126 changes: 126 additions & 0 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import json
import os

import torch
from huggingface_hub import hf_hub_download, list_files_info
from safetensors import safe_open
from tqdm import tqdm

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


def load_sparsity(path: str) -> torch.Tensor:
sparsity_path = os.path.join(path, "sparsity.safetensors")
with safe_open(sparsity_path, framework="pt", device="cpu") as f: # type: ignore
sparsity = f.get_tensor("sparsity")
return sparsity


def get_gpt2_res_jb_saes() -> (
tuple[dict[str, SparseAutoencoder], dict[str, torch.Tensor]]
):

GPT2_SMALL_RESIDUAL_SAES_REPO_ID = "jbloom/GPT2-Small-SAEs-Reformatted"
GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS = [
f"blocks.{layer}.hook_resid_pre" for layer in range(12)
] + ["blocks.11.hook_resid_post"]

saes = {}
sparsities = {}
for hook_point in tqdm(GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS):
# download the files required:
FILENAME = f"{hook_point}/cfg.json"
hf_hub_download(repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME)

FILENAME = f"{hook_point}/sae_weights.safetensors"
path = hf_hub_download(
repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME
)

FILENAME = f"{hook_point}/sparsity.safetensors"
path = hf_hub_download(
repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME
)

# Then use our function to download the files
folder_path = os.path.dirname(path)
sae = SparseAutoencoder.load_from_pretrained(folder_path)
sparsity = load_sparsity(folder_path)
saes[hook_point] = sae
sparsities[hook_point] = sparsity

return saes, sparsities


def convert_connor_rob_sae_to_our_saelens_format(
state_dict: dict[str, torch.Tensor],
config: dict[str, int | str],
device: str = "cpu",
):
"""
# can get session like so.
model, ae_alt, activation_store = LMSparseAutoencoderSessionloader(
cfg
).load_sae_training_group_session()
next(iter(ae_alt))[1].load_state_dict(state_dict)
return model, ae_alt, activation_store
"""

expansion_factor = int(config["dict_size"]) // int(config["act_size"])

cfg = LanguageModelSAERunnerConfig(
model_name=config["model_name"], # type: ignore
hook_point=config["act_name"], # type: ignore
hook_point_layer=config["layer"], # type: ignore
# data
# dataset_path = "/share/data/datasets/pile/the-eye.eu/public/AI/pile/train", # Training set of The Pile
dataset_path="NeelNanda/openwebtext-tokenized-9b",
is_dataset_tokenized=True,
d_in=config["act_size"], # type: ignore
expansion_factor=expansion_factor,
context_size=config["seq_len"], # type: ignore
device=device,
store_batch_size=32,
n_batches_in_buffer=10,
prepend_bos=False,
verbose=False,
dtype=torch.float32,
)

ae_alt = SparseAutoencoder(cfg)
ae_alt.load_state_dict(state_dict)
return ae_alt


def get_gpt2_small_ckrk_attn_out_saes() -> dict[str, SparseAutoencoder]:

REPO_ID = "ckkissane/attn-saes-gpt2-small-all-layers"

# list all files in repo
saes_weights = {}
sae_configs = {}
repo_files = list_files_info(REPO_ID)
for i in tqdm(repo_files):
file_name = i.path
if file_name.endswith(".pt"):
# print(f"Downloading {file_name}")
path = hf_hub_download(REPO_ID, file_name)
name = path.split("/")[-1].split(".pt")[0]
saes_weights[name] = torch.load(path, map_location="mps")
elif file_name.endswith(".json"):
# print(f"Downloading {file_name}")
config_path = hf_hub_download(REPO_ID, file_name)
name = config_path.split("/")[-1].split("_cfg.json")[0]
sae_configs[name] = json.load(open(config_path, "r"))

saes = {}
for name, config in sae_configs.items():
print(f"Loading {name}")
saes[name] = convert_connor_rob_sae_to_our_saelens_format(
saes_weights[name], config
)

return saes
32 changes: 19 additions & 13 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ class ActivationsStore:
cached_activations_path: str | None
tokens_column: Literal["tokens", "input_ids", "text"]
hook_point_head_index: int | None
_dataloader: Iterator[Any] | None = None
_storage_buffer: torch.Tensor | None = None

@classmethod
def from_config(
cls,
model: HookedTransformer,
cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
) -> "ActivationsStore":
cached_activations_path = cfg.cached_activations_path
# set cached_activations_path to None if we're not using cached activations
Expand All @@ -65,7 +66,6 @@ def from_config(
device=cfg.device,
dtype=cfg.dtype,
cached_activations_path=cached_activations_path,
create_dataloader=create_dataloader,
)

def __init__(
Expand All @@ -83,9 +83,8 @@ def __init__(
train_batch_size: int,
prepend_bos: bool,
device: str | torch.device,
dtype: torch.dtype,
dtype: str | torch.dtype,
cached_activations_path: str | None = None,
create_dataloader: bool = True,
):
self.model = model
self.dataset = (
Expand Down Expand Up @@ -151,10 +150,17 @@ def __init__(

# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()
@property
def storage_buffer(self) -> torch.Tensor:
if self._storage_buffer is None:
self._storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
return self._storage_buffer

@property
def dataloader(self) -> Iterator[Any]:
if self._dataloader is None:
self._dataloader = self.get_data_loader()
return self._dataloader

def get_batch_tokens(self):
"""
Expand Down Expand Up @@ -259,7 +265,7 @@ def get_activations(self, batch_tokens: torch.Tensor):

return stacked_activations

def get_buffer(self, n_batches_in_buffer: int):
def get_buffer(self, n_batches_in_buffer: int) -> torch.Tensor:
context_size = self.context_size
batch_size = self.store_batch_size
d_in = self.d_in
Expand All @@ -272,7 +278,7 @@ def get_buffer(self, n_batches_in_buffer: int):
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
dtype=self.dtype,
dtype=self.dtype, # type: ignore
device=self.device,
)
n_tokens_filled = 0
Expand Down Expand Up @@ -323,7 +329,7 @@ def get_buffer(self, n_batches_in_buffer: int):
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
dtype=self.dtype,
dtype=self.dtype, # type: ignore
device=self.device,
)

Expand Down Expand Up @@ -363,7 +369,7 @@ def get_data_loader(
mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

# 3. put other 50 % in a dataloader
dataloader = iter(
Expand All @@ -387,7 +393,7 @@ def next_batch(self):
return next(self.dataloader)
except StopIteration:
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
self._dataloader = self.get_data_loader()
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
Expand Down
1 change: 0 additions & 1 deletion sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
activations_store = ActivationsStore.from_config(
model,
cfg,
create_dataloader=False,
)

# if the activations directory exists and has files in it, raise an exception
Expand Down
Loading

0 comments on commit 89e1568

Please sign in to comment.