From f1ee575ff6a99ab3afe887cb4ea9420f6d525271 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Tue, 24 Sep 2024 13:12:55 -0700 Subject: [PATCH 1/3] adds ValueError if both d_sae and expansion_factor set --- sae_lens/config.py | 12 +- tests/unit/helpers.py | 107 +++++++++++------- tests/unit/test_evals.py | 5 - tests/unit/training/test_activations_store.py | 5 - tests/unit/training/test_config.py | 12 ++ tests/unit/training/test_sae_basic.py | 8 +- .../unit/training/test_sae_initialization.py | 2 +- tests/unit/training/test_sae_trainer.py | 1 - tests/unit/training/test_sae_training.py | 5 - 9 files changed, 93 insertions(+), 64 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index a74b7565..3fc4ac75 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -131,7 +131,9 @@ class LanguageModelSAERunnerConfig: d_in: int = 512 d_sae: Optional[int] = None b_dec_init_method: str = "geometric_median" - expansion_factor: int = 4 + expansion_factor: Optional[int] = ( + None # defaults to 4 if d_sae and expansion_factor is None + ) activation_fn: str = "relu" # relu, tanh-relu, topk activation_fn_kwargs: dict[str, Any] = field(default_factory=dict) # for topk normalize_sae_decoder: bool = True @@ -246,7 +248,13 @@ def __post_init__(self): self.hook_head_index, ) - if not isinstance(self.expansion_factor, list): + if self.d_sae is not None and self.expansion_factor is not None: + raise ValueError("You can't set both d_sae and expansion_factor.") + + if self.d_sae is None and self.expansion_factor is None: + self.expansion_factor = 4 + + if self.d_sae is None and self.expansion_factor is not None: self.d_sae = self.d_in * self.expansion_factor self.tokens_per_buffer = ( self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index cf11cc9e..1ddb76a1 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -1,5 +1,5 @@ import copy -from typing import Any +from typing import Any, Optional, TypedDict from transformer_lens import HookedTransformer @@ -9,51 +9,78 @@ TINYSTORIES_DATASET = "roneneldan/TinyStories" +class ConfigKwargsType(TypedDict, total=False): + model_name: str + hook_name: str + hook_layer: int + hook_head_index: Optional[int] + dataset_path: str + dataset_trust_remote_code: bool + is_dataset_tokenized: bool + use_cached_activations: bool + d_in: int + l1_coefficient: float + lp_norm: float + lr: float + train_batch_size_tokens: int + context_size: int + feature_sampling_window: int + dead_feature_threshold: float + dead_feature_window: int + n_batches_in_buffer: int + training_tokens: int + store_batch_size_prompts: int + log_to_wandb: bool + wandb_project: str + wandb_entity: str + wandb_log_frequency: int + device: str + seed: int + checkpoint_path: str + dtype: str + prepend_bos: bool + + def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: """ Helper to create a mock instance of LanguageModelSAERunnerConfig. """ - # Create a mock object with the necessary attributes - mock_config = LanguageModelSAERunnerConfig( - model_name=TINYSTORIES_MODEL, - hook_name="blocks.0.hook_mlp_out", - hook_layer=0, - hook_head_index=None, - dataset_path=TINYSTORIES_DATASET, - dataset_trust_remote_code=True, - is_dataset_tokenized=False, - use_cached_activations=False, - d_in=64, - expansion_factor=2, - l1_coefficient=2e-3, - lp_norm=1, - lr=2e-4, - train_batch_size_tokens=4, - context_size=6, - feature_sampling_window=50, - dead_feature_threshold=1e-7, - dead_feature_window=1000, - n_batches_in_buffer=2, - training_tokens=1_000_000, - store_batch_size_prompts=4, - log_to_wandb=False, - wandb_project="test_project", - wandb_entity="test_entity", - wandb_log_frequency=10, - device="cpu", - seed=24, - checkpoint_path="test/checkpoints", - dtype="float32", - prepend_bos=True, - ) + config_kwargs: ConfigKwargsType = { + "model_name": TINYSTORIES_MODEL, + "hook_name": "blocks.0.hook_mlp_out", + "hook_layer": 0, + "hook_head_index": None, + "dataset_path": TINYSTORIES_DATASET, + "dataset_trust_remote_code": True, + "is_dataset_tokenized": False, + "use_cached_activations": False, + "d_in": 64, + "l1_coefficient": 2e-3, + "lp_norm": 1, + "lr": 2e-4, + "train_batch_size_tokens": 4, + "context_size": 6, + "feature_sampling_window": 50, + "dead_feature_threshold": 1e-7, + "dead_feature_window": 1000, + "n_batches_in_buffer": 2, + "training_tokens": 1_000_000, + "store_batch_size_prompts": 4, + "log_to_wandb": False, + "wandb_project": "test_project", + "wandb_entity": "test_entity", + "wandb_log_frequency": 10, + "device": "cpu", + "seed": 24, + "checkpoint_path": "test/checkpoints", + "dtype": "float32", + "prepend_bos": True, + } - for key, val in kwargs.items(): - setattr(mock_config, key, val) + for key, value in kwargs.items(): + config_kwargs[key] = value - # Call the post-init method to set any derived attributes - # useful for checking the correctness of the configuration - # in the tests. - mock_config.__post_init__() + mock_config = LanguageModelSAERunnerConfig(**config_kwargs) # reset checkpoint path (as we add an id to each each time) mock_config.checkpoint_path = ( diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index 9d85e41b..1f9bb835 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -33,7 +33,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -41,7 +40,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -51,7 +49,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -59,7 +56,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.attn.hook_z", "hook_layer": 1, "d_in": 16 * 4, @@ -67,7 +63,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.attn.hook_q", "hook_layer": 1, "d_in": 16 * 4, diff --git a/tests/unit/training/test_activations_store.py b/tests/unit/training/test_activations_store.py index 20598d31..9a7ae432 100644 --- a/tests/unit/training/test_activations_store.py +++ b/tests/unit/training/test_activations_store.py @@ -28,7 +28,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -37,7 +36,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.attn.hook_z", "hook_layer": 1, "d_in": 64, @@ -45,7 +43,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: { "model_name": "gelu-2l", "dataset_path": "NeelNanda/c4-tokenized-2b", - "tokenized": True, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 512, @@ -54,7 +51,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: { "model_name": "gpt2", "dataset_path": "apollo-research/Skylion007-openwebtext-tokenizer-gpt2", - "tokenized": True, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 768, @@ -63,7 +59,6 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: { "model_name": "gpt2", "dataset_path": "Skylion007/openwebtext", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 768, diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index e4cc461a..ca0f154a 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -79,3 +79,15 @@ def test_sae_training_runner_config_raises_error_if_resume_true(): with pytest.raises(ValueError): _ = LanguageModelSAERunnerConfig(resume=True) assert True + + +def test_sae_training_runner_config_raises_error_if_d_sae_and_expansion_factor_not_none(): + with pytest.raises(ValueError): + _ = LanguageModelSAERunnerConfig(d_sae=128, expansion_factor=4) + assert True + + +def test_sae_training_runner_config_expansion_factor(): + cfg = LanguageModelSAERunnerConfig() + + assert cfg.expansion_factor == 4 diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 60dfaddb..6264b2ba 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -16,7 +16,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -24,7 +23,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -34,7 +32,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -42,7 +39,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.attn.hook_z", "hook_layer": 1, "d_in": 64, @@ -202,7 +198,9 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: cfg = build_sae_cfg( - activation_fn_str="topk", activation_fn_kwargs={"k": 30}, device="cpu" + # activation_fn_str="topk", activation_fn_kwargs={"k": 30}, device="cpu" + activation_fn_kwargs={"k": 30}, + device="cpu", ) model_path = str(tmp_path) sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) diff --git a/tests/unit/training/test_sae_initialization.py b/tests/unit/training/test_sae_initialization.py index fe73ee23..e0db877d 100644 --- a/tests/unit/training/test_sae_initialization.py +++ b/tests/unit/training/test_sae_initialization.py @@ -65,7 +65,7 @@ def test_SparseAutoencoder_initialization_gated(): def test_SparseAutoencoder_initialization_orthogonal_enc_dec(): - cfg = build_sae_cfg(decoder_orthogonal_init=True) + cfg = build_sae_cfg(decoder_orthogonal_init=True, expansion_factor=2) sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) projections = sae.W_dec.T @ sae.W_dec diff --git a/tests/unit/training/test_sae_trainer.py b/tests/unit/training/test_sae_trainer.py index 22c6f0a5..a39cb93e 100644 --- a/tests/unit/training/test_sae_trainer.py +++ b/tests/unit/training/test_sae_trainer.py @@ -204,7 +204,6 @@ def test_train_sae_group_on_language_model__runs( checkpoint_dir = tmp_path / "checkpoint" cfg = build_sae_cfg( checkpoint_path=str(checkpoint_dir), - train_batch_size=32, training_tokens=100, context_size=8, ) diff --git a/tests/unit/training/test_sae_training.py b/tests/unit/training/test_sae_training.py index c334da4d..692a78b3 100644 --- a/tests/unit/training/test_sae_training.py +++ b/tests/unit/training/test_sae_training.py @@ -19,7 +19,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -27,7 +26,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -37,7 +35,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -45,7 +42,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "tokenized": False, "hook_name": "blocks.1.hook_resid_pre", "hook_layer": 1, "d_in": 64, @@ -54,7 +50,6 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "tokenized": False, "hook_name": "blocks.1.attn.hook_z", "hook_layer": 1, "d_in": 64, From d77c0ca8505b8ec3136fb2b70fb7a8dea7ca09e8 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Tue, 24 Sep 2024 13:26:39 -0700 Subject: [PATCH 2/3] renames class --- tests/unit/helpers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 1ddb76a1..78d04054 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -9,7 +9,7 @@ TINYSTORIES_DATASET = "roneneldan/TinyStories" -class ConfigKwargsType(TypedDict, total=False): +class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str hook_name: str hook_layer: int @@ -45,7 +45,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: """ Helper to create a mock instance of LanguageModelSAERunnerConfig. """ - config_kwargs: ConfigKwargsType = { + mock_config_dict: LanguageModelSAERunnerConfigDict = { "model_name": TINYSTORIES_MODEL, "hook_name": "blocks.0.hook_mlp_out", "hook_layer": 0, @@ -78,9 +78,9 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: } for key, value in kwargs.items(): - config_kwargs[key] = value + mock_config_dict[key] = value - mock_config = LanguageModelSAERunnerConfig(**config_kwargs) + mock_config = LanguageModelSAERunnerConfig(**mock_config_dict) # reset checkpoint path (as we add an id to each each time) mock_config.checkpoint_path = ( From 3d1b263a0a9d776b076dcbedce5f4cf2d4edd881 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Tue, 24 Sep 2024 13:32:36 -0700 Subject: [PATCH 3/3] removes commented out line --- tests/unit/training/test_sae_basic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 6264b2ba..55428bb9 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -198,7 +198,6 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: cfg = build_sae_cfg( - # activation_fn_str="topk", activation_fn_kwargs={"k": 30}, device="cpu" activation_fn_kwargs={"k": 30}, device="cpu", )