Skip to content

Commit

Permalink
feat: Adding Mistral SAEs (#178)
Browse files Browse the repository at this point in the history
Note: normalize_activations is now a string and should be either 'none', 'expected_average_only_in' (Anthropic April Update, not yet folded), 'constant_norm_rescale' (Anthropic Feb update). 

* Adding code to load mistral saes

* Black formatting

* Removing library changes that allowed forward pass normalization

* feat: support feb update style norm scaling for mistral saes

* Adding code to load mistral saes

* Black formatting

* Removing library changes that allowed forward pass normalization

* Adding code to load mistral saes

* Black formatting

* Removing library changes that allowed forward pass normalization

* feat: support feb update style norm scaling for mistral saes

* remove accidental inclusion

---------
Co-authored-by: jbloomAus <[email protected]>
  • Loading branch information
JoshEngels and jbloomAus authored Jun 14, 2024
1 parent 209696a commit 227d208
Show file tree
Hide file tree
Showing 29 changed files with 339 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/training_saes.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cfg = LanguageModelSAERunnerConfig(
scale_sparsity_penalty_by_decoder_norm=True,
decoder_heuristic_init=True,
init_encoder_as_decoder_transpose=True,
normalize_activations=True,
normalize_activations="expected_average_only_in",
# Training Parameters
lr=5e-5,
adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ automated-interpretability = "^0.0.3"
python-dotenv = "^1.0.1"
pyyaml = "^6.0.1"
pytest-profiling = "^1.7.0"
zstandard = "^0.22.0"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -70,7 +71,9 @@ reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"
reportDeprecated = "none"

ignore = [
"**/wandb/**"
]

[build-system]
requires = ["poetry-core"]
Expand Down
17 changes: 14 additions & 3 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class LanguageModelSAERunnerConfig:
finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
normalize_activations (bool): Whether to normalize activations. See Anthropic April update.
normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
device (str): The device to use. Usually cuda.
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
seed (int): The seed to use.
Expand Down Expand Up @@ -141,7 +141,9 @@ class LanguageModelSAERunnerConfig:
finetuning_tokens: int = 0
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: bool = False
normalize_activations: str = (
"none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -265,6 +267,15 @@ def __post_init__(self):
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
)

if self.normalize_activations not in [
"none",
"expected_average_only_in",
"constant_norm_rescale",
]:
raise ValueError(
f"normalize_activations must be none, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
)

if self.act_store_device == "with_model":
self.act_store_device = self.device

Expand Down Expand Up @@ -425,7 +436,7 @@ class CacheActivationsRunnerConfig:
training_tokens: int = 2_000_000
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: bool = False
normalize_activations: str = "none" # should always be none for activation caching

# Misc
device: str = "cpu"
Expand Down
14 changes: 7 additions & 7 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def run_evals(
original_act = cache[hook_name]

# normalise if necessary
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
original_act = activation_store.apply_norm_scaling_factor(original_act)

# send the (maybe normalised) activations into the SAE
Expand Down Expand Up @@ -148,20 +148,20 @@ def get_recons_loss(
# TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up
def standard_replacement_hook(activations: torch.Tensor, hook: Any):
# Handle rescaling if SAE expects it
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
activations = activation_store.apply_norm_scaling_factor(activations)

# SAE class agnost forward forward pass.
activations = sae.decode(sae.encode(activations)).to(activations.dtype)

# Unscale if activations were scaled prior to going into the SAE
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
activations = activation_store.unscale(activations)
return activations

def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
# Handle rescaling if SAE expects it
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
activations = activation_store.apply_norm_scaling_factor(activations)

# SAE class agnost forward forward pass.
Expand All @@ -174,14 +174,14 @@ def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
) # reshape to match original shape

# Unscale if activations were scaled prior to going into the SAE
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
new_activations = activation_store.unscale(new_activations)

return new_activations

def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
# Handle rescaling if SAE expects it
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
activations = activation_store.apply_norm_scaling_factor(activations)

new_activations = sae.decoder(sae.encode(activations[:, :, head_index])).to(
Expand All @@ -190,7 +190,7 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
activations[:, :, head_index] = new_activations

# Unscale if activations were scaled prior to going into the SAE
if activation_store.normalize_activations:
if activation_store.normalize_activations == "expected_average_only_in":
activations = activation_store.unscale(activations)
return activations

Expand Down
17 changes: 17 additions & 0 deletions sae_lens/pretrained_saes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,20 @@ SAE_LOOKUP:
path: "gemma_2b_blocks.12.hook_resid_post_16384"
variance_explained: -3.6
l0: 62.0
mistral-7b-res-wg:
repo_id: "JoshEngels/Mistral-7B-Residual-Stream-SAEs"
model: "mistral-7b"
conversion_func: "mistral_7b_josh_engels_loader"
saes:
- id: "blocks.8.hook_resid_pre"
path: "mistral_7b_layer_8"
variance_explained: 0.74
l0: 82
- id: "blocks.16.hook_resid_pre"
path: "mistral_7b_layer_16"
variance_explained: 0.85
l0: 74
- id: "blocks.24.hook_resid_pre"
path: "mistral_7b_layer_24"
variance_explained: 0.72
l0: 75
38 changes: 36 additions & 2 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SAEConfig:
hook_head_index: Optional[int]
prepend_bos: bool
dataset_path: str
normalize_activations: bool
normalize_activations: str

# misc
dtype: str
Expand All @@ -65,7 +65,9 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":

# use only config terms that are in the dataclass
config_dict = {
k: v for k, v in config_dict.items() if k in cls.__dataclass_fields__ # type: ignore
k: v
for k, v in config_dict.items()
if k in cls.__dataclass_fields__ # pylint: disable=no-member
}
return cls(**config_dict)

Expand Down Expand Up @@ -146,6 +148,27 @@ def __init__(
# need to default the reshape fns
self.turn_off_forward_pass_hook_z_reshaping()

# handle run time activation normalization if needed:
if self.cfg.normalize_activations == "constant_norm_rescale":

# we need to scale the norm of the input and store the scaling factor
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
x = x * self.x_norm_coeff
return x

def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
x = x / self.x_norm_coeff
del self.x_norm_coeff # prevents reusing
return x

self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out

else:
self.run_time_activation_norm_fn_in = lambda x: x
self.run_time_activation_norm_fn_out = lambda x: x

self.setup() # Required for `HookedRootModule`s

def initialize_weights_basic(self):
Expand Down Expand Up @@ -206,6 +229,9 @@ def forward(
# handle hook z reshaping if needed.
sae_in = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)

# apply b_dec_to_input if using that method.
sae_in_cent = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

Expand All @@ -218,6 +244,7 @@ def forward(
d_head=self.d_head,
)

sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)

return self.hook_sae_output(sae_out + sae_error)
Expand All @@ -237,6 +264,9 @@ def encode(
# handle hook z reshaping if needed.
x = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
x = self.run_time_activation_norm_fn_in(x)

# apply b_dec_to_input if using that method.
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

Expand All @@ -255,6 +285,10 @@ def decode(
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
)

# handle run time activation normalization if needed
# will fail if you call this twice without calling encode in between.
sae_out = self.run_time_activation_norm_fn_out(sae_out)

# handle hook z reshaping if needed.
sae_out = self.reshape_fn_out(sae_out, self.d_head) # type: ignore

Expand Down
31 changes: 29 additions & 2 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def sae_lens_loader(
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

# TODO: don't call a function at the end of another like this. Poor form.
return load_pretrained_sae_lens_sae_components(
cfg_path, sae_path, device, log_sparsity_path=log_sparsity_path
)
Expand Down Expand Up @@ -138,12 +139,26 @@ def connor_rob_hook_z_loader(
"prepend_bos": True,
"dataset_path": "apollo-research/Skylion007-openwebtext-tokenizer-gpt2",
"context_size": 128,
"normalize_activations": False,
"normalize_activations": "none",
}

return cfg_dict, weights, None


def mistral_7b_josh_engels_loader(
repo_id: str,
folder_name: str,
device: Optional[str] = None,
force_download: bool = False,
) -> tuple[dict[str, Any], dict[str, torch.Tensor], Optional[torch.Tensor]]:

cfg_dict, state_dict, log_sparsity = sae_lens_loader(
repo_id, folder_name, device, force_download
)
cfg_dict["normalize_activations"] = "constant_norm_rescale"
return cfg_dict, state_dict, log_sparsity


def load_pretrained_sae_lens_sae_components(
cfg_path: str,
weight_path: str,
Expand Down Expand Up @@ -197,8 +212,19 @@ def load_pretrained_sae_lens_sae_components(
if "activation_fn" not in cfg_dict:
cfg_dict["activation_fn_str"] = "relu"

# if missing then none.
if "normalize_activations" not in cfg_dict:
cfg_dict["normalize_activations"] = False
cfg_dict["normalize_activations"] = "none"
# if bool and True, then it's the April update method of normalizing activations and hasn't been folded in.
if "normalize_activations" in cfg_dict and isinstance(
cfg_dict["normalize_activations"], bool
):
# backwards compatibility
cfg_dict["normalize_activations"] = (
"none"
if not cfg_dict["normalize_activations"]
else "expected_average_only_in"
)

if "scaling_factor" in state_dict:
# we were adding it anyway for a period of time but are no longer doing so.
Expand Down Expand Up @@ -227,4 +253,5 @@ def load_pretrained_sae_lens_sae_components(
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeLoader] = {
"sae_lens": sae_lens_loader, # type: ignore
"connor_rob_hook_z": connor_rob_hook_z_loader, # type: ignore
"mistral_7b_josh_engels_loader": mistral_7b_josh_engels_loader, # type: ignore
}
4 changes: 2 additions & 2 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
store_batch_size_prompts: int,
train_batch_size_tokens: int,
prepend_bos: bool,
normalize_activations: bool,
normalize_activations: str,
device: torch.device,
dtype: str,
cached_activations_path: str | None = None,
Expand Down Expand Up @@ -453,7 +453,7 @@ def get_buffer(self, n_batches_in_buffer: int) -> torch.Tensor:
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

# every buffer should be normalized:
if self.normalize_activations:
if self.normalize_activations == "expected_average_only_in":
new_buffer = self.apply_norm_scaling_factor(new_buffer)

return new_buffer
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def fit(self) -> TrainingSAE:

@torch.no_grad()
def _estimate_norm_scaling_factor_if_needed(self) -> None:
if self.cfg.normalize_activations:
if self.cfg.normalize_activations == "expected_average_only_in":
self.activation_store.estimated_norm_scaling_factor = (
self.activation_store.estimate_norm_scaling_factor()
)
Expand Down
3 changes: 3 additions & 0 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def encode_with_hidden_pre(
# apply b_dec_to_input if using that method.
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input))

# handle run time activation normalization if needed
x = self.run_time_activation_norm_fn_in(x)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
hidden_pre_noised = hidden_pre + (
Expand Down
2 changes: 1 addition & 1 deletion scripts/ansible/configs_example/cache_acts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ prepend_bos: true
train_batch_size: 4096
n_batches_in_buffer: 4
store_batch_size: 128
normalize_activations: false
normalize_activations: none
shuffle_every_n_buffers: 8
n_shuffles_with_last_section: 1
n_shuffles_in_entire_dir: 1
Expand Down
2 changes: 1 addition & 1 deletion scripts/ansible/configs_example/train_sae/sweep_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ adam_beta2: 0.999
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer: 64
store_batch_size: 16
normalize_activations: False
normalize_activations: none
# Feature Store
feature_sampling_window: 1000
dead_feature_window: 1000
Expand Down
2 changes: 1 addition & 1 deletion scripts/caching_replication_how_train_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
# buffer details
n_batches_in_buffer=4,
store_batch_size_prompts=128,
normalize_activations=False,
normalize_activations="none",
#
shuffle_every_n_buffers=8,
n_shuffles_with_last_section=1,
Expand Down
2 changes: 1 addition & 1 deletion scripts/replication_how_train_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer=64,
store_batch_size_prompts=16,
normalize_activations=False,
normalize_activations="none",
# Feature Store
feature_sampling_window=1000,
dead_feature_window=1000,
Expand Down
2 changes: 1 addition & 1 deletion scripts/replication_how_train_saes_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer=64,
store_batch_size_prompts=16,
normalize_activations=False,
normalize_activations="none",
# Feature Store
feature_sampling_window=1000,
dead_feature_window=1000,
Expand Down
2 changes: 1 addition & 1 deletion scripts/sweep-gpt2-blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
# Unsure if this is enough
n_batches_in_buffer=64,
store_batch_size_prompts=32,
normalize_activations=True,
normalize_activations="none",
# Feature Store
feature_sampling_window=1000,
dead_feature_window=1000,
Expand Down
Loading

0 comments on commit 227d208

Please sign in to comment.