Skip to content

Commit

Permalink
feat: Add experimental Gemma Scope embedding SAEs (#299)
Browse files Browse the repository at this point in the history
* add experimental embedding gemmascope SAEs

* format and lint
  • Loading branch information
jbloomAus authored Sep 23, 2024
1 parent a708220 commit bb9ebbc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
4 changes: 3 additions & 1 deletion sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,9 @@ def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str):

# Save individual JSON files
for result in eval_results:
json_filename = f"{result['unique_id']}_{result['eval_cfg']['context_size']}_{result['eval_cfg']['dataset'].replace('/', '_')}.json"
json_filename = f"{result['unique_id']}_{result['eval_cfg']['context_size']}_{result['eval_cfg']['dataset']}.json".replace(
"/", "_"
)
json_path = output_path / json_filename
with open(json_path, "w") as f:
json.dump(result, f, indent=2)
Expand Down
18 changes: 16 additions & 2 deletions sae_lens/pretrained_saes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,14 @@ gemma-scope-2b-pt-res:
model: gemma-2-2b
conversion_func: gemma_2
saes:
- id: embedding/width_4k/average_l0_6
path: embedding/width_4k/average_l0_6
- id: embedding/width_4k/average_l0_44
path: embedding/width_4k/average_l0_44
- id: embedding/width_4k/average_l0_21
path: embedding/width_4k/average_l0_21
- id: embedding/width_4k/average_l0_111
path: embedding/width_4k/average_l0_111
- id: layer_0/width_16k/average_l0_105
path: layer_0/width_16k/average_l0_105
l0: 105
Expand Down Expand Up @@ -3728,6 +3736,14 @@ gemma-scope-9b-pt-res:
model: gemma-2-9b
conversion_func: gemma_2
saes:
- id: embedding/width_4k/average_l0_14
path: embedding/width_4k/average_l0_14
- id: embedding/width_4k/average_l0_22
path: embedding/width_4k/average_l0_22
- id: embedding/width_4k/average_l0_7
path: embedding/width_4k/average_l0_7
- id: embedding/width_4k/average_l0_80
path: embedding/width_4k/average_l0_80
- id: layer_0/width_131k/average_l0_11
path: layer_0/width_131k/average_l0_11
l0: 11
Expand Down Expand Up @@ -5687,8 +5703,6 @@ gemma-scope-9b-pt-res-canonical:
path: layer_20/width_65k/average_l0_55
neuronpedia: gemma-2-9b/20-gemmascope-res-65k
l0: 55


gemma-scope-9b-pt-att:
repo_id: google/gemma-scope-9b-pt-att
model: gemma-2-9b
Expand Down
16 changes: 14 additions & 2 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def get_gemma_2_config(
) -> Dict[str, Any]:
# Detect width from folder_name
width_map = {
"width_4k": 4096,
"width_16k": 16384,
"width_32k": 32768,
"width_65k": 65536,
Expand All @@ -291,7 +292,10 @@ def get_gemma_2_config(
match = re.search(r"layer_(\d+)", folder_name)
layer = int(match.group(1)) if match else layer_override
if layer is None:
raise ValueError("Layer not found in folder_name and no override provided.")
if "embedding" in folder_name:
layer = 0
else:
raise ValueError("Layer not found in folder_name and no override provided.")

# Model specific parameters
model_params = {
Expand All @@ -311,8 +315,10 @@ def get_gemma_2_config(
model_name, d_in = model_info["name"], model_info["d_in"]

# Hook specific parameters
if "res" in repo_id:
if "res" in repo_id and "embedding" not in folder_name:
hook_name = f"blocks.{layer}.hook_resid_post"
elif "res" in repo_id and "embedding" in folder_name:
hook_name = "hook_embed"
elif "mlp" in repo_id:
hook_name = f"blocks.{layer}.hook_mlp_out"
elif "att" in repo_id:
Expand Down Expand Up @@ -398,6 +404,12 @@ def gemma_2_sae_loader(
# No sparsity tensor for Gemma 2 SAEs
log_sparsity = None

# if it is an embedding SAE, then we need to adjust for the scale of d_model because of how they trained it
if "embedding" in folder_name:
print("Adjusting for d_model in embedding SAE")
state_dict["W_enc"].data = state_dict["W_enc"].data / np.sqrt(cfg_dict["d_in"])
state_dict["W_dec"].data = state_dict["W_dec"].data * np.sqrt(cfg_dict["d_in"])

return cfg_dict, state_dict, log_sparsity


Expand Down

0 comments on commit bb9ebbc

Please sign in to comment.