-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add value error if both d sae and expansion factor set #301
Add value error if both d sae and expansion factor set #301
Conversation
adfe458
to
3d1b263
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core code changes look good, but why were unrelated test values changed?
# useful for checking the correctness of the configuration | ||
# in the tests. | ||
mock_config.__post_init__() | ||
mock_config = LanguageModelSAERunnerConfig(**mock_config_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair enough! This feels less hacky than doing __post_init__()
like before 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, yeah, we shouldn't effectively call __post_init__()
twice.
@@ -28,7 +28,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: | |||
{ | |||
"model_name": "tiny-stories-1M", | |||
"dataset_path": "roneneldan/TinyStories", | |||
"tokenized": False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were these values changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because build_sae_cfg()
now takes a dict, updates another dict, and then passes it to a dataclass, whereas before it instantiated a dataclass, then looped over the dict and called setattr()
to set the dataclass attributes. So now, since "tokenized"
doesn't match any attributes of LanguageModelSAERunnerConfig
, its __init__()
method would return an error/tests would fail.
I checked throughout the codebase that there are no references to a field tokenized
for any instance of LanguageModelSAERunnerConfig
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oooh nice, so if we mistype attributes in tests now we'll get an error? 🥇
@@ -202,7 +198,8 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: | |||
|
|||
def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: | |||
cfg = build_sae_cfg( | |||
activation_fn_str="topk", activation_fn_kwargs={"k": 30}, device="cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why was this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For similar reasons as in this comment. activation_fn_str
isn't an attribute of LanguageModelSAERunnerConfig
.
@@ -204,7 +204,6 @@ def test_train_sae_group_on_language_model__runs( | |||
checkpoint_dir = tmp_path / "checkpoint" | |||
cfg = build_sae_cfg( | |||
checkpoint_path=str(checkpoint_dir), | |||
train_batch_size=32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For similar reasons as in this comment. train_batch_size
isn't an attribute of LanguageModelSAERunnerConfig
.
Awesome work with this, and thanks for fixing the issues with tests have invalid config fields set! |
…us#301) * adds ValueError if both d_sae and expansion_factor set * renames class * removes commented out line
…us#301) * adds ValueError if both d_sae and expansion_factor set * renames class * removes commented out line
Description
Changes
LanguageModelSAERunnerConfig
so ifd_sae
andexpansion_factor
are notNone
,__post_init__()
raises aValueError
and ifd_sae
andexpansion_factor
areNone
,__post_init__()
setsexpansion_factor
to4
.d_sae
andexpansion_factor
are both attributes to set the size of an SAE, andexpansion_factor
has a default value of4
, for backward compatibility.Fixes #47
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)Performance Check.
If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:
Please links to wandb dashboards with a control and test group.