Skip to content

Commit

Permalink
fix: Improve error message for Gemma Scope non-canonical ID not found (
Browse files Browse the repository at this point in the history
…#288)

* Update sae.py as a nicer Gemma Scope error encouraging canonical

* Update sae.py

* Update sae.py

* format

---------

Co-authored-by: jbloomAus <[email protected]>
  • Loading branch information
ArthurConmy and jbloomAus authored Sep 13, 2024
1 parent 9ce40c2 commit 9d34598
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,34 @@ def from_pretrained(
f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
)
elif sae_id not in sae_directory[release].saes_map:
# If using Gemma Scope and not the canonical release, give a hint to use it
if (
"gemma-scope" in release
and "canonical" not in release
and f"{release}-canonical" in sae_directory
):
canonical_ids = list(
sae_directory[release + "-canonical"].saes_map.keys()
)
# Shorten the lengthy string of valid IDs
if len(canonical_ids) > 5:
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
else:
str_canonical_ids = str(canonical_ids)
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
else:
value_suffix = ""

valid_ids = list(sae_directory[release].saes_map.keys())
# Shorten the lengthy string of valid IDs
if len(valid_ids) > 5:
str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
else:
str_valid_ids = str(valid_ids)

raise ValueError(
f"ID {sae_id} not found in release {release}. Valid IDs are {sae_directory[release].saes_map.keys()}"
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
+ value_suffix
)
sae_info = sae_directory.get(release, None)
hf_repo_id = sae_info.repo_id if sae_info is not None else release
Expand Down

0 comments on commit 9d34598

Please sign in to comment.