Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add value error if both d sae and expansion factor set #301

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ class LanguageModelSAERunnerConfig:
d_in: int = 512
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
expansion_factor: int = 4
expansion_factor: Optional[int] = (
None # defaults to 4 if d_sae and expansion_factor is None
)
activation_fn: str = "relu" # relu, tanh-relu, topk
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) # for topk
normalize_sae_decoder: bool = True
Expand Down Expand Up @@ -246,7 +248,13 @@ def __post_init__(self):
self.hook_head_index,
)

if not isinstance(self.expansion_factor, list):
if self.d_sae is not None and self.expansion_factor is not None:
raise ValueError("You can't set both d_sae and expansion_factor.")

if self.d_sae is None and self.expansion_factor is None:
self.expansion_factor = 4

if self.d_sae is None and self.expansion_factor is not None:
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = (
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
Expand Down
107 changes: 67 additions & 40 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any
from typing import Any, Optional, TypedDict

from transformer_lens import HookedTransformer

Expand All @@ -9,51 +9,78 @@
TINYSTORIES_DATASET = "roneneldan/TinyStories"


class LanguageModelSAERunnerConfigDict(TypedDict, total=False):
model_name: str
hook_name: str
hook_layer: int
hook_head_index: Optional[int]
dataset_path: str
dataset_trust_remote_code: bool
is_dataset_tokenized: bool
use_cached_activations: bool
d_in: int
l1_coefficient: float
lp_norm: float
lr: float
train_batch_size_tokens: int
context_size: int
feature_sampling_window: int
dead_feature_threshold: float
dead_feature_window: int
n_batches_in_buffer: int
training_tokens: int
store_batch_size_prompts: int
log_to_wandb: bool
wandb_project: str
wandb_entity: str
wandb_log_frequency: int
device: str
seed: int
checkpoint_path: str
dtype: str
prepend_bos: bool


def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
"""
Helper to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
mock_config = LanguageModelSAERunnerConfig(
model_name=TINYSTORIES_MODEL,
hook_name="blocks.0.hook_mlp_out",
hook_layer=0,
hook_head_index=None,
dataset_path=TINYSTORIES_DATASET,
dataset_trust_remote_code=True,
is_dataset_tokenized=False,
use_cached_activations=False,
d_in=64,
expansion_factor=2,
l1_coefficient=2e-3,
lp_norm=1,
lr=2e-4,
train_batch_size_tokens=4,
context_size=6,
feature_sampling_window=50,
dead_feature_threshold=1e-7,
dead_feature_window=1000,
n_batches_in_buffer=2,
training_tokens=1_000_000,
store_batch_size_prompts=4,
log_to_wandb=False,
wandb_project="test_project",
wandb_entity="test_entity",
wandb_log_frequency=10,
device="cpu",
seed=24,
checkpoint_path="test/checkpoints",
dtype="float32",
prepend_bos=True,
)
mock_config_dict: LanguageModelSAERunnerConfigDict = {
"model_name": TINYSTORIES_MODEL,
"hook_name": "blocks.0.hook_mlp_out",
"hook_layer": 0,
"hook_head_index": None,
"dataset_path": TINYSTORIES_DATASET,
"dataset_trust_remote_code": True,
"is_dataset_tokenized": False,
"use_cached_activations": False,
"d_in": 64,
"l1_coefficient": 2e-3,
"lp_norm": 1,
"lr": 2e-4,
"train_batch_size_tokens": 4,
"context_size": 6,
"feature_sampling_window": 50,
"dead_feature_threshold": 1e-7,
"dead_feature_window": 1000,
"n_batches_in_buffer": 2,
"training_tokens": 1_000_000,
"store_batch_size_prompts": 4,
"log_to_wandb": False,
"wandb_project": "test_project",
"wandb_entity": "test_entity",
"wandb_log_frequency": 10,
"device": "cpu",
"seed": 24,
"checkpoint_path": "test/checkpoints",
"dtype": "float32",
"prepend_bos": True,
}

for key, val in kwargs.items():
setattr(mock_config, key, val)
for key, value in kwargs.items():
mock_config_dict[key] = value

# Call the post-init method to set any derived attributes
# useful for checking the correctness of the configuration
# in the tests.
mock_config.__post_init__()
mock_config = LanguageModelSAERunnerConfig(**mock_config_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough! This feels less hacky than doing __post_init__() like before 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, yeah, we shouldn't effectively call __post_init__() twice.


# reset checkpoint path (as we add an id to each each time)
mock_config.checkpoint_path = (
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
Expand All @@ -51,23 +49,20 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.attn.hook_z",
"hook_layer": 1,
"d_in": 16 * 4,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.attn.hook_q",
"hook_layer": 1,
"d_in": 16 * 4,
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/training/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these values changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because build_sae_cfg() now takes a dict, updates another dict, and then passes it to a dataclass, whereas before it instantiated a dataclass, then looped over the dict and called setattr() to set the dataclass attributes. So now, since "tokenized" doesn't match any attributes of LanguageModelSAERunnerConfig, its __init__() method would return an error/tests would fail.

I checked throughout the codebase that there are no references to a field tokenized for any instance of LanguageModelSAERunnerConfig.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooh nice, so if we mistype attributes in tests now we'll get an error? 🥇

"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
Expand All @@ -37,15 +36,13 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.attn.hook_z",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "gelu-2l",
"dataset_path": "NeelNanda/c4-tokenized-2b",
"tokenized": True,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 512,
Expand All @@ -54,7 +51,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
{
"model_name": "gpt2",
"dataset_path": "apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
"tokenized": True,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 768,
Expand All @@ -63,7 +59,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
{
"model_name": "gpt2",
"dataset_path": "Skylion007/openwebtext",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 768,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ def test_sae_training_runner_config_raises_error_if_resume_true():
with pytest.raises(ValueError):
_ = LanguageModelSAERunnerConfig(resume=True)
assert True


def test_sae_training_runner_config_raises_error_if_d_sae_and_expansion_factor_not_none():
with pytest.raises(ValueError):
_ = LanguageModelSAERunnerConfig(d_sae=128, expansion_factor=4)
assert True


def test_sae_training_runner_config_expansion_factor():
cfg = LanguageModelSAERunnerConfig()

assert cfg.expansion_factor == 4
7 changes: 2 additions & 5 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
Expand All @@ -34,15 +32,13 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.attn.hook_z",
"hook_layer": 1,
"d_in": 64,
Expand Down Expand Up @@ -202,7 +198,8 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None:

def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
cfg = build_sae_cfg(
activation_fn_str="topk", activation_fn_kwargs={"k": 30}, device="cpu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this changed?

Copy link
Contributor Author

@anthonyduong9 anthonyduong9 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For similar reasons as in this comment. activation_fn_str isn't an attribute of LanguageModelSAERunnerConfig.

activation_fn_kwargs={"k": 30},
device="cpu",
)
model_path = str(tmp_path)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/training/test_sae_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_SparseAutoencoder_initialization_gated():


def test_SparseAutoencoder_initialization_orthogonal_enc_dec():
cfg = build_sae_cfg(decoder_orthogonal_init=True)
cfg = build_sae_cfg(decoder_orthogonal_init=True, expansion_factor=2)

sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
projections = sae.W_dec.T @ sae.W_dec
Expand Down
1 change: 0 additions & 1 deletion tests/unit/training/test_sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_train_sae_group_on_language_model__runs(
checkpoint_dir = tmp_path / "checkpoint"
cfg = build_sae_cfg(
checkpoint_path=str(checkpoint_dir),
train_batch_size=32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed?

Copy link
Contributor Author

@anthonyduong9 anthonyduong9 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For similar reasons as in this comment. train_batch_size isn't an attribute of LanguageModelSAERunnerConfig.

training_tokens=100,
context_size=8,
)
Expand Down
5 changes: 0 additions & 5 deletions tests/unit/training/test_sae_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
Expand All @@ -37,15 +35,13 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2",
"tokenized": False,
"hook_name": "blocks.1.hook_resid_pre",
"hook_layer": 1,
"d_in": 64,
Expand All @@ -54,7 +50,6 @@
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"tokenized": False,
"hook_name": "blocks.1.attn.hook_z",
"hook_layer": 1,
"d_in": 64,
Expand Down
Loading