diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 8bca4451..2e60c24f 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -255,6 +255,13 @@ def decode( return sae_out + @torch.no_grad() + def fold_W_dec_norm(self): + W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1) + self.W_dec.data = self.W_dec.data / W_dec_norms + self.W_enc.data = self.W_enc.data * W_dec_norms.T + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): if not os.path.exists(path): diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index a8a5274d..72fadcff 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -1,4 +1,5 @@ import os +from copy import deepcopy from pathlib import Path import pytest @@ -73,6 +74,41 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): assert sae.b_dec.shape == (cfg.d_in,) +def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + sae2.fold_W_dec_norm() + + W_dec_norms = sae.W_dec.norm(dim=-1).unsqueeze(1) + assert torch.allclose(sae2.W_dec.data, sae.W_dec.data / W_dec_norms) + assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * W_dec_norms.T) + assert torch.allclose(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze()) + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1) + torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: cfg = build_sae_cfg(device="cpu") model_path = str(tmp_path)