Skip to content

Commit

Permalink
feat: add w_dec_norm folding (#167)
Browse files Browse the repository at this point in the history
* feat: add w_dec_norm folding

* format
  • Loading branch information
jbloomAus authored May 29, 2024
1 parent 4850b16 commit f1908a3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
7 changes: 7 additions & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ def decode(

return sae_out

@torch.no_grad()
def fold_W_dec_norm(self):
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):

if not os.path.exists(path):
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from copy import deepcopy
from pathlib import Path

import pytest
Expand Down Expand Up @@ -73,6 +74,41 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig):
assert sae.b_dec.shape == (cfg.d_in,)


def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig):
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here.
assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6)
sae2 = deepcopy(sae)
sae2.fold_W_dec_norm()

W_dec_norms = sae.W_dec.norm(dim=-1).unsqueeze(1)
assert torch.allclose(sae2.W_dec.data, sae.W_dec.data / W_dec_norms)
assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * W_dec_norms.T)
assert torch.allclose(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze())

# fold_W_dec_norm should normalize W_dec to have unit norm.
assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6)

# we expect activations of features to differ by W_dec norm weights.
activations = torch.randn(10, 4, cfg.d_in, device=cfg.device)
feature_activations_1 = sae.encode(activations)
feature_activations_2 = sae2.encode(activations)

assert torch.allclose(
feature_activations_1.nonzero(),
feature_activations_2.nonzero(),
)

expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1)
torch.testing.assert_close(feature_activations_2, expected_feature_activations_2)

sae_out_1 = sae.decode(feature_activations_1)
sae_out_2 = sae2.decode(feature_activations_2)

# but actual outputs should be the same
torch.testing.assert_close(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None:
cfg = build_sae_cfg(device="cpu")
model_path = str(tmp_path)
Expand Down

0 comments on commit f1908a3

Please sign in to comment.