Skip to content

Commit

Permalink
fix: modify duplicate neuronpedia ids in config.yml, add test. (#265)
Browse files Browse the repository at this point in the history
* fix duplicate ids

* fix test that had mistake
  • Loading branch information
jbloomAus authored Aug 23, 2024
1 parent 5c2d391 commit 0555178
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
10 changes: 5 additions & 5 deletions sae_lens/pretrained_saes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ gpt2-small-res-jb:
l0: 31.0
- id: blocks.5.hook_resid_pre
path: blocks.5.hook_resid_pre
neuronpedia: gpt2-small/0-res-jb
neuronpedia: gpt2-small/5-res-jb
variance_explained: 0.9
l0: 41.0
- id: blocks.6.hook_resid_pre
Expand Down Expand Up @@ -6305,18 +6305,18 @@ gemma-scope-9b-it-res:
l0: 88
gemma-scope-9b-it-res-canonical:
repo_id: google/gemma-scope-9b-it-res
model: gemma-2-9b
model: gemma-2-9b-it
conversion_func: gemma_2
saes:
- id: layer_9/width_131k/canonical
path: layer_9/width_131k/average_l0_121/params.npz
neuronpedia: gemma-2-9b/9-gemmascope-res-131k
neuronpedia: gemma-2-9b-it/9-gemmascope-it-res-131k
- id: layer_20/width_131k/canonical
path: layer_20/width_131k/average_l0_81/params.npz
neuronpedia: gemma-2-9b/20-gemmascope-res-131k
neuronpedia: gemma-2-9b-it/20-gemmascope-it-res-131k
- id: layer_31/width_131k/canonical
path: layer_31/width_131k/average_l0_109/params.npz
neuronpedia: gemma-2-9b/31-gemmascope-res-131k
neuronpedia: gemma-2-9b-it/31-gemmascope-it-res-131k
gemma-scope-27b-pt-res:
repo_id: google/gemma-scope-27b-pt-res
model: gemma-2-2b
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/toolkit/test_pretrained_saes_directory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pandas as pd

from sae_lens.toolkit.pretrained_saes_directory import (
PretrainedSAELookup,
get_pretrained_saes_directory,
Expand Down Expand Up @@ -64,7 +66,7 @@ def test_get_pretrained_saes_directory():
"blocks.2.hook_resid_pre": "gpt2-small/2-res-jb",
"blocks.3.hook_resid_pre": "gpt2-small/3-res-jb",
"blocks.4.hook_resid_pre": "gpt2-small/4-res-jb",
"blocks.5.hook_resid_pre": "gpt2-small/0-res-jb",
"blocks.5.hook_resid_pre": "gpt2-small/5-res-jb",
"blocks.6.hook_resid_pre": "gpt2-small/6-res-jb",
"blocks.7.hook_resid_pre": "gpt2-small/7-res-jb",
"blocks.8.hook_resid_pre": "gpt2-small/8-res-jb",
Expand All @@ -76,3 +78,39 @@ def test_get_pretrained_saes_directory():
)

assert sae_directory["gpt2-small-res-jb"] == expected_result


def test_get_pretrained_saes_directory_unique_np_ids():

# ideally this code should be elsewhere but as a stop-gap we'll leave it here.
df = pd.DataFrame.from_records(
{k: v.__dict__ for k, v in get_pretrained_saes_directory().items()}
).T
df.drop(
columns=[
"repo_id",
"saes_map",
"expected_var_explained",
"expected_l0",
"config_overrides",
"conversion_func",
],
inplace=True,
)
df["neuronpedia_id_list"] = df["neuronpedia_id"].apply(lambda x: list(x.items()))
df_exploded = df.explode("neuronpedia_id_list")
df_exploded[["sae_lens_id", "neuronpedia_id"]] = pd.DataFrame(
df_exploded["neuronpedia_id_list"].tolist(), index=df_exploded.index
)
df_exploded = df_exploded.drop(columns=["neuronpedia_id_list"])
df_exploded = df_exploded.reset_index(drop=True)
df_exploded["neuronpedia_set"] = df_exploded["neuronpedia_id"].apply(
lambda x: "-".join(x.split("/")[-1].split("-")[1:]) if x is not None else None
)

duplicate_ids = df_exploded.groupby("neuronpedia_id").sae_lens_id.apply(
lambda x: len(x)
)
assert (
duplicate_ids.max() == 1
), f"Duplicate IDs found: {duplicate_ids[duplicate_ids > 1]}"

0 comments on commit 0555178

Please sign in to comment.