From 2476afbffad41406840ebd5492c04acf90a0e62c Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Wed, 26 Jun 2024 17:33:13 -0400 Subject: [PATCH 1/7] First round of evals --- sae_lens/evals.py | 266 +++++++++++++++++-------- sae_lens/training/activations_store.py | 33 ++- sae_lens/training/sae_trainer.py | 10 + tests/run_gpt2_evals.py | 77 +++++++ 4 files changed, 300 insertions(+), 86 deletions(-) create mode 100644 tests/run_gpt2_evals.py diff --git a/sae_lens/evals.py b/sae_lens/evals.py index e4df54ec..571686a9 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -1,7 +1,7 @@ from functools import partial -from typing import Any, Mapping, cast +from typing import Any, Mapping, Tuple -import pandas as pd +import einops import torch from transformer_lens.hook_points import HookedRootModule @@ -17,12 +17,12 @@ def run_evals( n_eval_batches: int = 10, eval_batch_size_prompts: int | None = None, model_kwargs: Mapping[str, Any] = {}, -) -> Mapping[str, Any]: +) -> dict[str, Any]: hook_name = sae.cfg.hook_name - hook_head_index = sae.cfg.hook_head_index - ### Evals - eval_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) + actual_batch_size = ( + eval_batch_size_prompts or activation_store.store_batch_size_prompts + ) # TODO: Come up with a cleaner long term strategy here for SAEs that do reshaping. # turn off hook_z reshaping mode if it's on, and restore it after evals @@ -32,65 +32,24 @@ def run_evals( else: previous_hook_z_reshaping_mode = None - # Get Reconstruction Score - losses_df = recons_loss_batched( + metrics = get_downstream_reconstruction_metrics( sae, model, activation_store, n_batches=n_eval_batches, - eval_batch_size_prompts=eval_batch_size_prompts, + eval_batch_size_prompts=actual_batch_size, ) - recons_score = losses_df["score"].mean() - ntp_loss = losses_df["loss"].mean() - recons_loss = losses_df["recons_loss"].mean() - zero_abl_loss = losses_df["zero_abl_loss"].mean() - - # get cache - _, cache = model.run_with_cache( - eval_tokens, - prepend_bos=False, - names_filter=[hook_name], - **model_kwargs, - ) + activation_store.reset_input_dataset() - # we would include hook z, except that we now have base SAE's - # which will do their own reshaping for hook z. - has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] - if hook_head_index is not None: - original_act = cache[hook_name][:, :, hook_head_index] - elif any(substring in hook_name for substring in has_head_dim_key_substrings): - original_act = cache[hook_name].flatten(-2, -1) - else: - original_act = cache[hook_name] - - # normalise if necessary - if activation_store.normalize_activations == "expected_average_only_in": - original_act = activation_store.apply_norm_scaling_factor(original_act) - - # send the (maybe normalised) activations into the SAE - sae_out = sae.decode(sae.encode(original_act.to(sae.device))).to( - original_act.device + metrics |= get_sparsity_and_variance_metrics( + sae, + model, + activation_store, + n_batches=n_eval_batches, + eval_batch_size_prompts=actual_batch_size, + model_kwargs=model_kwargs, ) - del cache - - l2_norm_in = torch.norm(original_act, dim=-1) - l2_norm_out = torch.norm(sae_out, dim=-1) - l2_norm_in_for_div = l2_norm_in.clone() - l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 - l2_norm_ratio = l2_norm_out / l2_norm_in_for_div - - metrics = { - # l2 norms - "metrics/l2_norm": l2_norm_out.mean().item(), - "metrics/l2_ratio": l2_norm_ratio.mean().item(), - "metrics/l2_norm_in": l2_norm_in.mean().item(), - # CE Loss - "metrics/CE_loss_score": recons_score, - "metrics/ce_loss_without_sae": ntp_loss, - "metrics/ce_loss_with_sae": recons_loss, - "metrics/ce_loss_with_ablation": zero_abl_loss, - } # restore previous hook z reshaping mode if necessary if "hook_z" in hook_name: @@ -99,39 +58,167 @@ def run_evals( elif not previous_hook_z_reshaping_mode and sae.hook_z_reshaping_mode: sae.turn_off_forward_pass_hook_z_reshaping() + total_tokens_evaluated = ( + activation_store.context_size * n_eval_batches * actual_batch_size + ) + metrics["metrics/total_tokens_evaluated"] = total_tokens_evaluated + return metrics -def recons_loss_batched( +def get_downstream_reconstruction_metrics( sae: SAE, model: HookedRootModule, activation_store: ActivationsStore, - n_batches: int = 100, - eval_batch_size_prompts: int | None = None, + n_batches: int, + eval_batch_size_prompts: int, ): - losses = [] + metrics = [] for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) - score, loss, recons_loss, zero_abl_loss = get_recons_loss( + ( + recons_kl_div, + zero_abl_kl_div, + original_ce_loss, + recons_ce_loss, + zero_abl_ce_loss, + ) = get_recons_loss( sae, model, batch_tokens, activation_store, ) - losses.append( + + metrics.append( ( - score.mean().item(), - loss.mean().item(), - recons_loss.mean().item(), - zero_abl_loss.mean().item(), + recons_kl_div, + zero_abl_kl_div, + original_ce_loss, + recons_ce_loss, + zero_abl_ce_loss, ) ) - losses = pd.DataFrame( - losses, columns=cast(Any, ["score", "loss", "recons_loss", "zero_abl_loss"]) + recons_kl_div = torch.stack([metric[0] for metric in metrics]).mean() + zero_abl_kl_div = torch.stack([metric[1] for metric in metrics]).mean() + kl_div_score = (zero_abl_kl_div - recons_kl_div) / zero_abl_kl_div + + zero_abl_ce_loss = torch.stack([metric[4] for metric in metrics]).mean() + recons_ce_loss = torch.stack([metric[3] for metric in metrics]).mean() + original_ce_loss = torch.stack([metric[2] for metric in metrics]).mean() + ce_loss_score = (zero_abl_ce_loss - recons_ce_loss) / ( + zero_abl_ce_loss - original_ce_loss ) - return losses + metrics = { + "metrics/ce_loss_score": ce_loss_score.item(), + "metrics/ce_loss_without_sae": original_ce_loss.item(), + "metrics/ce_loss_with_sae": recons_ce_loss.item(), + "metrics/ce_loss_with_ablation": zero_abl_ce_loss.item(), + "metrics/kl_div_score": kl_div_score.item(), + "metrics/kl_div_without_sae": 0, + "metrics/kl_div_with_sae": recons_kl_div.item(), + "metrics/kl_div_with_ablation": zero_abl_kl_div.item(), + } + + return metrics + + +def get_sparsity_and_variance_metrics( + sae: SAE, + model: HookedRootModule, + activation_store: ActivationsStore, + n_batches: int, + eval_batch_size_prompts: int, + model_kwargs: Mapping[str, Any], +): + + metrics_list = [] + + hook_name = sae.cfg.hook_name + hook_head_index = sae.cfg.hook_head_index + + for _ in range(n_batches): + batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) + + # get cache + _, cache = model.run_with_cache( + batch_tokens, + prepend_bos=False, + names_filter=[hook_name], + **model_kwargs, + ) + + # we would include hook z, except that we now have base SAE's + # which will do their own reshaping for hook z. + has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] + if hook_head_index is not None: + original_act = cache[hook_name][:, :, hook_head_index] + elif any(substring in hook_name for substring in has_head_dim_key_substrings): + original_act = cache[hook_name].flatten(-2, -1) + else: + original_act = cache[hook_name] + + # normalise if necessary + if activation_store.normalize_activations == "expected_average_only_in": + original_act = activation_store.apply_norm_scaling_factor(original_act) + + # send the (maybe normalised) activations into the SAE + sae_feature_activations = sae.encode(original_act.to(sae.device)) + sae_out = sae.decode(sae_feature_activations).to(original_act.device) + del cache + + flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d") + flattened_sae_feature_acts = einops.rearrange( + sae_feature_activations, "b ctx d -> (b ctx) d" + ) + flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") + + l2_norm_in = torch.norm(flattened_sae_input, dim=-1) + l2_norm_out = torch.norm(flattened_sae_out, dim=-1) + l2_norm_in_for_div = l2_norm_in.clone() + l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 + l2_norm_ratio = l2_norm_out / l2_norm_in_for_div + + l0 = (flattened_sae_feature_acts > 0).sum(dim=-1) + l1 = flattened_sae_feature_acts.sum(dim=-1) + resid_sum_of_squares = ( + (flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1) + ) + total_sum_of_squares = ( + (flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1) + ) + explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares + + metrics_list.append( + ( + l2_norm_in, + l2_norm_out, + l2_norm_ratio, + explained_variance, + l0.float(), + l1, + resid_sum_of_squares, + ) + ) + + metrics: dict[str, float] = {} + for i, metric_name in enumerate( + [ + "l2_norm_in", + "l2_norm_out", + "l2_ratio", + "explained_variance", + "l0", + "l1", + "mse", + ] + ): + metrics[f"metrics/{metric_name}"] = ( + torch.stack([m[i] for m in metrics_list]).mean().item() + ) + + return metrics @torch.no_grad() @@ -141,11 +228,13 @@ def get_recons_loss( batch_tokens: torch.Tensor, activation_store: ActivationsStore, model_kwargs: Mapping[str, Any] = {}, -): +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hook_name = sae.cfg.hook_name head_index = sae.cfg.hook_head_index - loss = model(batch_tokens, return_type="loss", **model_kwargs) + original_logits, original_ce_loss = model( + batch_tokens, return_type="both", **model_kwargs + ) # TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up def standard_replacement_hook(activations: torch.Tensor, hook: Any): @@ -226,23 +315,36 @@ def zero_ablate_hook(activations: torch.Tensor, hook: Any): else: replacement_hook = standard_replacement_hook - recons_loss = model.run_with_hooks( + recons_logits, recons_ce_loss = model.run_with_hooks( batch_tokens, - return_type="loss", + return_type="both", fwd_hooks=[(hook_name, partial(replacement_hook))], **model_kwargs, ) - zero_abl_loss = model.run_with_hooks( + zero_abl_logits, zero_abl_ce_loss = model.run_with_hooks( batch_tokens, - return_type="loss", + return_type="both", fwd_hooks=[(hook_name, zero_ablate_hook)], **model_kwargs, ) - div_val = zero_abl_loss - loss - div_val[torch.abs(div_val) < 0.0001] = 1.0 - - score = (zero_abl_loss - recons_loss) / div_val - - return score, loss, recons_loss, zero_abl_loss + def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): + original_probs = torch.nn.functional.softmax(original_logits, dim=-1) + log_original_probs = torch.log(original_probs) + new_probs = torch.nn.functional.softmax(new_logits, dim=-1) + log_new_probs = torch.log(new_probs) + kl_div = original_probs * (log_original_probs - log_new_probs) + kl_div = kl_div.sum(dim=-1) + return kl_div + + recons_kl_div = compute_kl(original_logits, recons_logits) + zero_abl_kl_div = compute_kl(original_logits, zero_abl_logits) + + return ( + recons_kl_div, + zero_abl_kl_div, + original_ce_loss, + recons_ce_loss, + zero_abl_ce_loss, + ) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 79cfed46..71e2f630 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -39,7 +39,7 @@ class ActivationsStore: model: HookedRootModule dataset: HfDataset cached_activations_path: str | None - tokens_column: Literal["tokens", "input_ids", "text"] + tokens_column: Literal["tokens", "input_ids", "text", "problem"] hook_name: str hook_layer: int hook_head_index: int | None @@ -89,6 +89,8 @@ def from_sae( cls, model: HookedRootModule, sae: SAE, + context_size: int | None = None, + dataset: HfDataset | str | None = None, streaming: bool = True, store_batch_size_prompts: int = 8, n_batches_in_buffer: int = 8, @@ -99,12 +101,12 @@ def from_sae( return cls( model=model, - dataset=sae.cfg.dataset_path, + dataset=sae.cfg.dataset_path if dataset is None else dataset, d_in=sae.cfg.d_in, hook_name=sae.cfg.hook_name, hook_layer=sae.cfg.hook_layer, hook_head_index=sae.cfg.hook_head_index, - context_size=sae.cfg.context_size, + context_size=sae.cfg.context_size if context_size is None else context_size, prepend_bos=sae.cfg.prepend_bos, streaming=streaming, store_batch_size_prompts=store_batch_size_prompts, @@ -188,9 +190,12 @@ def __init__( elif "text" in dataset_sample.keys(): self.is_dataset_tokenized = False self.tokens_column = "text" + elif "problem" in dataset_sample.keys(): + self.is_dataset_tokenized = False + self.tokens_column = "problem" else: raise ValueError( - "Dataset must have a 'tokens', 'input_ids', or 'text' column." + "Dataset must have a 'tokens', 'input_ids', 'text', or 'problem' column." ) self.iterable_dataset = iter(self.dataset) # Reset iterator after checking @@ -248,6 +253,26 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e return scaling_factor + def shuffle_input_dataset(self, seed: int, buffer_size: int = 1): + """ + This applies a shuffle to the huggingface dataset that is the input to the activations store. This + also shuffles the shards of the dataset, which is especially useful for evaluating on different + sections of very large streaming datasets. Buffer size is only relevant for streaming datasets. + The default buffer_size of 1 means that only the shard will be shuffled; larger buffer sizes will + additionally shuffle individual elements within the shard. + """ + if type(self.dataset) == IterableDataset: + self.dataset = self.dataset.shuffle(seed=seed, buffer_size=buffer_size) + else: + self.dataset = self.dataset.shuffle(seed=seed) + self.iterable_dataset = iter(self.dataset) + + def reset_input_dataset(self): + """ + Resets the input dataset iterator to the beginning. + """ + self.iterable_dataset = iter(self.dataset) + @property def storage_buffer(self) -> torch.Tensor: if self._storage_buffer is None: diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 34c7c0d5..205e7282 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -318,6 +318,16 @@ def _run_and_log_evals(self): model_kwargs=self.cfg.model_kwargs, ) + # Remove eval metrics that are already logged during training + eval_metrics.pop("metrics/explained_variance", None) + eval_metrics.pop("metrics/explained_variance_std", None) + eval_metrics.pop("metrics/l0", None) + eval_metrics.pop("metrics/l1", None) + eval_metrics.pop("metrics/mse", None) + + # Remove metrics that are not useful for wandb logging + eval_metrics.pop("metrics/total_tokens_evaluated", None) + W_dec_norm_dist = self.sae.W_dec.norm(dim=1).detach().cpu().numpy() eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore diff --git a/tests/run_gpt2_evals.py b/tests/run_gpt2_evals.py new file mode 100644 index 00000000..c56db592 --- /dev/null +++ b/tests/run_gpt2_evals.py @@ -0,0 +1,77 @@ +import pandas as pd +from tqdm import tqdm +from transformer_lens import HookedTransformer + +from sae_lens import SAE +from sae_lens.evals import run_evals +from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory +from sae_lens.training.activations_store import ActivationsStore + + +def all_loadable_saes() -> list[tuple[str, str, float, float]]: + all_loadable_saes = [] + saes_directory = get_pretrained_saes_directory() + for release, lookup in saes_directory.items(): + for sae_name in lookup.saes_map.keys(): + expected_var_explained = lookup.expected_var_explained[sae_name] + expected_l0 = lookup.expected_l0[sae_name] + all_loadable_saes.append( + (release, sae_name, expected_var_explained, expected_l0) + ) + + return all_loadable_saes + + +def eval_all_loadable_gpt2_saes( + num_eval_batches: int = 10, + eval_batch_size_prompts: int = 8, + datasets: list[str] = ["Skylion007/openwebtext", "lighteval/MATH"], + ctx_lens: list[int] = [64, 128, 256, 512], +) -> pd.DataFrame: + all_saes = all_loadable_saes() + gpt2_saes = [sae for sae in all_saes if "gpt2-small" in sae[0]] + + device = "cuda:0" + + model = HookedTransformer.from_pretrained("gpt2-small", device=device) + + data = [] + + for sae_name, sae_block, _, _ in tqdm(gpt2_saes): + + sae = SAE.from_pretrained( + release=sae_name, # see other options in sae_lens/pretrained_saes.yaml + sae_id=sae_block, # won't always be a hook point + device=device, + )[0] + + for ctx_len in ctx_lens: + for dataset in datasets: + activation_store = ActivationsStore.from_sae( + model, sae, context_size=ctx_len, dataset=dataset + ) + activation_store.shuffle_input_dataset(seed=42) + + eval_metrics = {} + eval_metrics["sae_id"] = f"{sae_name}-{sae_block}" + eval_metrics["context_size"] = ctx_len + eval_metrics["dataset"] = dataset + + eval_metrics |= run_evals( + sae=sae, + activation_store=activation_store, + model=model, + n_eval_batches=10, + eval_batch_size_prompts=8, + ) + + data.append(eval_metrics) + + return pd.DataFrame(data) + + +# %% + +if __name__ == "__main__": + df = eval_all_loadable_gpt2_saes() + df.to_csv("gpt2_saes_evals.csv", index=False) From 4be50115b8b2c43448557ee54ff8f0afe692d111 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Wed, 26 Jun 2024 18:20:29 -0400 Subject: [PATCH 2/7] Moving file --- tests/{ => benchmark}/run_gpt2_evals.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename tests/{ => benchmark}/run_gpt2_evals.py (95%) diff --git a/tests/run_gpt2_evals.py b/tests/benchmark/run_gpt2_evals.py similarity index 95% rename from tests/run_gpt2_evals.py rename to tests/benchmark/run_gpt2_evals.py index c56db592..ce5d2619 100644 --- a/tests/run_gpt2_evals.py +++ b/tests/benchmark/run_gpt2_evals.py @@ -61,8 +61,8 @@ def eval_all_loadable_gpt2_saes( sae=sae, activation_store=activation_store, model=model, - n_eval_batches=10, - eval_batch_size_prompts=8, + n_eval_batches=num_eval_batches, + eval_batch_size_prompts=eval_batch_size_prompts, ) data.append(eval_metrics) From f9aa2ddd20c1f8c26b9181e685f04c7638511bc1 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Wed, 26 Jun 2024 18:52:38 -0400 Subject: [PATCH 3/7] Adding script to evals.py --- sae_lens/evals.py | 153 ++++++++++++++++++ tests/benchmark/run_gpt2_evals.py | 77 --------- .../benchmark/test_eval_all_loadable_saes.py | 19 +-- 3 files changed, 154 insertions(+), 95 deletions(-) delete mode 100644 tests/benchmark/run_gpt2_evals.py diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 571686a9..05e56b79 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -1,11 +1,17 @@ +import argparse +import re from functools import partial from typing import Any, Mapping, Tuple import einops +import pandas as pd import torch +from tqdm import tqdm +from transformer_lens import HookedTransformer from transformer_lens.hook_points import HookedRootModule from sae_lens.sae import SAE +from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory from sae_lens.training.activations_store import ActivationsStore @@ -348,3 +354,150 @@ def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): recons_ce_loss, zero_abl_ce_loss, ) + + +def all_loadable_saes() -> list[tuple[str, str, float, float]]: + all_loadable_saes = [] + saes_directory = get_pretrained_saes_directory() + for release, lookup in saes_directory.items(): + for sae_name in lookup.saes_map.keys(): + expected_var_explained = lookup.expected_var_explained[sae_name] + expected_l0 = lookup.expected_l0[sae_name] + all_loadable_saes.append( + (release, sae_name, expected_var_explained, expected_l0) + ) + + return all_loadable_saes + + +def multiple_evals( + sae_regex_pattern: str, + sae_block_pattern: str, + num_eval_batches: int = 10, + eval_batch_size_prompts: int = 8, + datasets: list[str] = ["Skylion007/openwebtext", "lighteval/MATH"], + ctx_lens: list[int] = [64, 128, 256, 512], +) -> pd.DataFrame: + + device = "cuda" if torch.cuda.is_available() else "cpu" + + sae_regex_compiled = re.compile(sae_regex_pattern) + sae_block_compiled = re.compile(sae_block_pattern) + all_saes = all_loadable_saes() + filtered_saes = [ + sae + for sae in all_saes + if sae_regex_compiled.fullmatch(sae[0]) and sae_block_compiled.fullmatch(sae[1]) + ] + + assert len(filtered_saes) > 0, "No SAEs matched the given regex patterns" + + eval_results = [] + + current_model = None + current_model_str = None + print(filtered_saes) + for sae_name, sae_block, _, _ in tqdm(filtered_saes): + + sae = SAE.from_pretrained( + release=sae_name, # see other options in sae_lens/pretrained_saes.yaml + sae_id=sae_block, # won't always be a hook point + device=device, + )[0] + + if current_model_str != sae.cfg.model_name: + del current_model # potentially saves GPU memory + current_model_str = sae.cfg.model_name + current_model = HookedTransformer.from_pretrained( + current_model_str, device=device + ) + assert current_model is not None + + for ctx_len in ctx_lens: + for dataset in datasets: + + activation_store = ActivationsStore.from_sae( + current_model, sae, context_size=ctx_len, dataset=dataset + ) + activation_store.shuffle_input_dataset(seed=42) + + eval_metrics = {} + eval_metrics["sae_id"] = f"{sae_name}-{sae_block}" + eval_metrics["context_size"] = ctx_len + eval_metrics["dataset"] = dataset + + eval_metrics |= run_evals( + sae=sae, + activation_store=activation_store, + model=current_model, + n_eval_batches=num_eval_batches, + eval_batch_size_prompts=eval_batch_size_prompts, + ) + + eval_results.append(eval_metrics) + + return pd.DataFrame(eval_results) + + +if __name__ == "__main__": + + # Example commands: + # python sae_lens/evals.py "gpt2-small-res-jb.*" "blocks.8.hook_resid_pre" --save_path "gpt2_small_jb_layer8_resid_pre_eval_results.csv" + # python sae_lens/evals.py "gpt2-small.*" "blocks.8.hook_resid_pre" --save_path "gpt2_small_layer8_resid_pre_eval_results.csv" + # python sae_lens/evals.py "gpt2-small.*" ".*" --save_path "gpt2_small_eval_results.csv" + # python sae_lens/evals.py "mistral.*" ".*" --save_path "mistral_eval_results.csv" + + arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs") + arg_parser.add_argument( + "sae_regex_pattern", + type=str, + help="Regex pattern to match SAE names. Can be an entire SAE name to match a specific SAE.", + ) + arg_parser.add_argument( + "sae_block_pattern", + type=str, + help="Regex pattern to match SAE block names. Can be an entire block name to match a specific block.", + ) + arg_parser.add_argument( + "--num_eval_batches", + type=int, + default=10, + help="Number of evaluation batches to run.", + ) + arg_parser.add_argument( + "--eval_batch_size_prompts", + type=int, + default=8, + help="Batch size for evaluation prompts.", + ) + arg_parser.add_argument( + "--datasets", + nargs="+", + default=["Skylion007/openwebtext", "lighteval/MATH"], + help="Datasets to evaluate on.", + ) + arg_parser.add_argument( + "--ctx_lens", + nargs="+", + default=[64, 128, 256, 512], + help="Context lengths to evaluate on.", + ) + arg_parser.add_argument( + "--save_path", + type=str, + default="eval_results.csv", + help="Path to save evaluation results to.", + ) + + args = arg_parser.parse_args() + + eval_results = multiple_evals( + sae_regex_pattern=args.sae_regex_pattern, + sae_block_pattern=args.sae_block_pattern, + num_eval_batches=args.num_eval_batches, + eval_batch_size_prompts=args.eval_batch_size_prompts, + datasets=args.datasets, + ctx_lens=args.ctx_lens, + ) + + eval_results.to_csv(args.save_path, index=False) diff --git a/tests/benchmark/run_gpt2_evals.py b/tests/benchmark/run_gpt2_evals.py deleted file mode 100644 index ce5d2619..00000000 --- a/tests/benchmark/run_gpt2_evals.py +++ /dev/null @@ -1,77 +0,0 @@ -import pandas as pd -from tqdm import tqdm -from transformer_lens import HookedTransformer - -from sae_lens import SAE -from sae_lens.evals import run_evals -from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory -from sae_lens.training.activations_store import ActivationsStore - - -def all_loadable_saes() -> list[tuple[str, str, float, float]]: - all_loadable_saes = [] - saes_directory = get_pretrained_saes_directory() - for release, lookup in saes_directory.items(): - for sae_name in lookup.saes_map.keys(): - expected_var_explained = lookup.expected_var_explained[sae_name] - expected_l0 = lookup.expected_l0[sae_name] - all_loadable_saes.append( - (release, sae_name, expected_var_explained, expected_l0) - ) - - return all_loadable_saes - - -def eval_all_loadable_gpt2_saes( - num_eval_batches: int = 10, - eval_batch_size_prompts: int = 8, - datasets: list[str] = ["Skylion007/openwebtext", "lighteval/MATH"], - ctx_lens: list[int] = [64, 128, 256, 512], -) -> pd.DataFrame: - all_saes = all_loadable_saes() - gpt2_saes = [sae for sae in all_saes if "gpt2-small" in sae[0]] - - device = "cuda:0" - - model = HookedTransformer.from_pretrained("gpt2-small", device=device) - - data = [] - - for sae_name, sae_block, _, _ in tqdm(gpt2_saes): - - sae = SAE.from_pretrained( - release=sae_name, # see other options in sae_lens/pretrained_saes.yaml - sae_id=sae_block, # won't always be a hook point - device=device, - )[0] - - for ctx_len in ctx_lens: - for dataset in datasets: - activation_store = ActivationsStore.from_sae( - model, sae, context_size=ctx_len, dataset=dataset - ) - activation_store.shuffle_input_dataset(seed=42) - - eval_metrics = {} - eval_metrics["sae_id"] = f"{sae_name}-{sae_block}" - eval_metrics["context_size"] = ctx_len - eval_metrics["dataset"] = dataset - - eval_metrics |= run_evals( - sae=sae, - activation_store=activation_store, - model=model, - n_eval_batches=num_eval_batches, - eval_batch_size_prompts=eval_batch_size_prompts, - ) - - data.append(eval_metrics) - - return pd.DataFrame(data) - - -# %% - -if __name__ == "__main__": - df = eval_all_loadable_gpt2_saes() - df.to_csv("gpt2_saes_evals.csv", index=False) diff --git a/tests/benchmark/test_eval_all_loadable_saes.py b/tests/benchmark/test_eval_all_loadable_saes.py index d6fa584b..ded6a842 100644 --- a/tests/benchmark/test_eval_all_loadable_saes.py +++ b/tests/benchmark/test_eval_all_loadable_saes.py @@ -3,11 +3,9 @@ # import numpy as np import pytest import torch -from tqdm import tqdm -# from sae_lens.training.evals import run_evals +from sae_lens.evals import all_loadable_saes from sae_lens.sae import SAE -from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory from sae_lens.training.activations_store import ActivationsStore from tests.unit.helpers import load_model_cached @@ -26,21 +24,6 @@ """ -# @pytest.fixture -def all_loadable_saes() -> list[tuple[str, str]]: - all_loadable_saes = [] - saes_directory = get_pretrained_saes_directory() - for release, lookup in tqdm(saes_directory.items()): - for sae_name in lookup.saes_map.keys(): - expected_var_explained = lookup.expected_var_explained[sae_name] - expected_l0 = lookup.expected_l0[sae_name] - all_loadable_saes.append( - (release, sae_name, expected_var_explained, expected_l0) - ) - - return all_loadable_saes - - @pytest.mark.parametrize( "release, sae_name, expected_var_explained, expected_l0", all_loadable_saes() ) From 389a15924345c17442937e98f45c8d2eb9c92b21 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Wed, 26 Jun 2024 19:08:28 -0400 Subject: [PATCH 4/7] Fixing test --- tests/unit/training/test_evals.py | 39 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/unit/training/test_evals.py b/tests/unit/training/test_evals.py index 3062e87d..f2781349 100644 --- a/tests/unit/training/test_evals.py +++ b/tests/unit/training/test_evals.py @@ -93,6 +93,25 @@ def training_sae(cfg: LanguageModelSAERunnerConfig): return TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) +expected_keys = [ + "metrics/l2_norm_in", + "metrics/l2_ratio", + "metrics/l2_norm_out", + "metrics/explained_variance", + "metrics/l0", + "metrics/l1", + "metrics/mse", + "metrics/ce_loss_score", + "metrics/ce_loss_without_sae", + "metrics/ce_loss_with_sae", + "metrics/ce_loss_with_ablation", + "metrics/kl_div_score", + "metrics/kl_div_without_sae", + "metrics/kl_div_with_sae", + "metrics/kl_div_with_ablation", +] + + def test_run_evals_base_sae( base_sae: SAE, activation_store: ActivationsStore, @@ -107,16 +126,6 @@ def test_run_evals_base_sae( eval_batch_size_prompts=None, ) - expected_keys = [ - "metrics/l2_norm", - "metrics/l2_ratio", - "metrics/l2_norm_in", - "metrics/CE_loss_score", - "metrics/ce_loss_without_sae", - "metrics/ce_loss_with_sae", - "metrics/ce_loss_with_ablation", - ] - # results will be garbage without a real model. for key in expected_keys: assert key in eval_metrics @@ -136,15 +145,5 @@ def test_run_evals_training_sae( eval_batch_size_prompts=None, ) - expected_keys = [ - "metrics/l2_norm", - "metrics/l2_ratio", - "metrics/l2_norm_in", - "metrics/CE_loss_score", - "metrics/ce_loss_without_sae", - "metrics/ce_loss_with_sae", - "metrics/ce_loss_with_ablation", - ] - for key in expected_keys: assert key in eval_metrics From 265687c09ba3c6ae090cf5a97e7f70251c0cf66c Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Wed, 26 Jun 2024 19:14:21 -0400 Subject: [PATCH 5/7] Updating example commands --- sae_lens/evals.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 05e56b79..ba0288b2 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -442,8 +442,8 @@ def multiple_evals( if __name__ == "__main__": # Example commands: - # python sae_lens/evals.py "gpt2-small-res-jb.*" "blocks.8.hook_resid_pre" --save_path "gpt2_small_jb_layer8_resid_pre_eval_results.csv" - # python sae_lens/evals.py "gpt2-small.*" "blocks.8.hook_resid_pre" --save_path "gpt2_small_layer8_resid_pre_eval_results.csv" + # python sae_lens/evals.py "gpt2-small-res-jb.*" "blocks.8.hook_resid_pre.*" --save_path "gpt2_small_jb_layer8_resid_pre_eval_results.csv" + # python sae_lens/evals.py "gpt2-small.*" "blocks.8.hook_resid_pre.*" --save_path "gpt2_small_layer8_resid_pre_eval_results.csv" # python sae_lens/evals.py "gpt2-small.*" ".*" --save_path "gpt2_small_eval_results.csv" # python sae_lens/evals.py "mistral.*" ".*" --save_path "mistral_eval_results.csv" From cf4ebcdfe93d96270da2ed108f37a5c8d9d97c75 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Thu, 4 Jul 2024 15:19:55 -0400 Subject: [PATCH 6/7] Making changes in response to comments --- sae_lens/evals.py | 293 ++++++++++++++++++------------ sae_lens/training/sae_trainer.py | 12 +- tests/unit/training/test_evals.py | 54 +++++- 3 files changed, 232 insertions(+), 127 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index ba0288b2..aab8e6e4 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -1,7 +1,8 @@ import argparse import re +from dataclasses import dataclass from functools import partial -from typing import Any, Mapping, Tuple +from typing import Any, Mapping import einops import pandas as pd @@ -15,19 +16,55 @@ from sae_lens.training.activations_store import ActivationsStore +# Everything by default is false so the user can just set the ones they want to true +@dataclass +class EvalConfig: + batch_size_prompts: int | None = None + + # Reconstruction metrics + n_eval_reconstruction_batches: int = 10 + compute_kl: bool = False + compute_ce_loss: bool = False + + # Sparsity and variance metrics + n_eval_sparsity_variance_batches: int = 1 + compute_l2_norms: bool = False + compute_sparsity_metrics: bool = False + compute_variance_metrics: bool = False + + +def get_eval_everything_config( + batch_size_prompts: int | None = None, + n_eval_reconstruction_batches: int = 10, + n_eval_sparsity_variance_batches: int = 1, +) -> EvalConfig: + """ + Returns an EvalConfig object with all metrics set to True, so that when passed to run_evals all available metrics will be run. + """ + return EvalConfig( + batch_size_prompts=batch_size_prompts, + n_eval_reconstruction_batches=n_eval_reconstruction_batches, + compute_kl=True, + compute_ce_loss=True, + compute_l2_norms=True, + n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches, + compute_sparsity_metrics=True, + compute_variance_metrics=True, + ) + + @torch.no_grad() def run_evals( sae: SAE, activation_store: ActivationsStore, model: HookedRootModule, - n_eval_batches: int = 10, - eval_batch_size_prompts: int | None = None, + eval_config: EvalConfig = EvalConfig(), model_kwargs: Mapping[str, Any] = {}, ) -> dict[str, Any]: hook_name = sae.cfg.hook_name actual_batch_size = ( - eval_batch_size_prompts or activation_store.store_batch_size_prompts + eval_config.batch_size_prompts or activation_store.store_batch_size_prompts ) # TODO: Come up with a cleaner long term strategy here for SAEs that do reshaping. @@ -38,24 +75,44 @@ def run_evals( else: previous_hook_z_reshaping_mode = None - metrics = get_downstream_reconstruction_metrics( - sae, - model, - activation_store, - n_batches=n_eval_batches, - eval_batch_size_prompts=actual_batch_size, - ) + metrics = {} - activation_store.reset_input_dataset() + if eval_config.compute_kl or eval_config.compute_ce_loss: + assert eval_config.n_eval_reconstruction_batches > 0 + metrics |= get_downstream_reconstruction_metrics( + sae, + model, + activation_store, + compute_kl=eval_config.compute_kl, + compute_ce_loss=eval_config.compute_ce_loss, + n_batches=eval_config.n_eval_reconstruction_batches, + eval_batch_size_prompts=actual_batch_size, + ) - metrics |= get_sparsity_and_variance_metrics( - sae, - model, - activation_store, - n_batches=n_eval_batches, - eval_batch_size_prompts=actual_batch_size, - model_kwargs=model_kwargs, - ) + activation_store.reset_input_dataset() + + if ( + eval_config.compute_l2_norms + or eval_config.compute_sparsity_metrics + or eval_config.compute_variance_metrics + ): + assert eval_config.n_eval_sparsity_variance_batches > 0 + metrics |= get_sparsity_and_variance_metrics( + sae, + model, + activation_store, + compute_l2_norms=eval_config.compute_l2_norms, + compute_sparsity_metrics=eval_config.compute_sparsity_metrics, + compute_variance_metrics=eval_config.compute_variance_metrics, + n_batches=eval_config.n_eval_sparsity_variance_batches, + eval_batch_size_prompts=actual_batch_size, + model_kwargs=model_kwargs, + ) + + if len(metrics) == 0: + raise ValueError( + "No metrics were computed, please set at least one metric to True." + ) # restore previous hook z reshaping mode if necessary if "hook_z" in hook_name: @@ -65,7 +122,9 @@ def run_evals( sae.turn_off_forward_pass_hook_z_reshaping() total_tokens_evaluated = ( - activation_store.context_size * n_eval_batches * actual_batch_size + activation_store.context_size + * eval_config.n_eval_reconstruction_batches + * actual_batch_size ) metrics["metrics/total_tokens_evaluated"] = total_tokens_evaluated @@ -76,57 +135,50 @@ def get_downstream_reconstruction_metrics( sae: SAE, model: HookedRootModule, activation_store: ActivationsStore, + compute_kl: bool, + compute_ce_loss: bool, n_batches: int, eval_batch_size_prompts: int, ): - metrics = [] + metrics_dict = {} + if compute_kl: + metrics_dict["kl_div_with_sae"] = [] + metrics_dict["kl_div_with_ablation"] = [] + if compute_ce_loss: + metrics_dict["ce_loss_with_sae"] = [] + metrics_dict["ce_loss_without_sae"] = [] + metrics_dict["ce_loss_with_ablation"] = [] + for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) - ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) = get_recons_loss( + for metric_name, metric_value in get_recons_loss( sae, model, batch_tokens, activation_store, - ) + compute_kl=compute_kl, + compute_ce_loss=compute_ce_loss, + ).items(): + metrics_dict[metric_name].append(metric_value) - metrics.append( - ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) + metrics: dict[str, float] = {} + for metric_name, metric_values in metrics_dict.items(): + metrics[f"metrics/{metric_name}"] = torch.stack(metric_values).mean().item() + + if compute_kl: + metrics["metrics/kl_div_score"] = ( + metrics["metrics/kl_div_with_ablation"] - metrics["metrics/kl_div_with_sae"] + ) / metrics["metrics/kl_div_with_ablation"] + + if compute_ce_loss: + metrics["metrics/ce_loss_score"] = ( + metrics["metrics/ce_loss_with_ablation"] + - metrics["metrics/ce_loss_with_sae"] + ) / ( + metrics["metrics/ce_loss_with_ablation"] + - metrics["metrics/ce_loss_without_sae"] ) - recons_kl_div = torch.stack([metric[0] for metric in metrics]).mean() - zero_abl_kl_div = torch.stack([metric[1] for metric in metrics]).mean() - kl_div_score = (zero_abl_kl_div - recons_kl_div) / zero_abl_kl_div - - zero_abl_ce_loss = torch.stack([metric[4] for metric in metrics]).mean() - recons_ce_loss = torch.stack([metric[3] for metric in metrics]).mean() - original_ce_loss = torch.stack([metric[2] for metric in metrics]).mean() - ce_loss_score = (zero_abl_ce_loss - recons_ce_loss) / ( - zero_abl_ce_loss - original_ce_loss - ) - - metrics = { - "metrics/ce_loss_score": ce_loss_score.item(), - "metrics/ce_loss_without_sae": original_ce_loss.item(), - "metrics/ce_loss_with_sae": recons_ce_loss.item(), - "metrics/ce_loss_with_ablation": zero_abl_ce_loss.item(), - "metrics/kl_div_score": kl_div_score.item(), - "metrics/kl_div_without_sae": 0, - "metrics/kl_div_with_sae": recons_kl_div.item(), - "metrics/kl_div_with_ablation": zero_abl_kl_div.item(), - } - return metrics @@ -135,15 +187,28 @@ def get_sparsity_and_variance_metrics( model: HookedRootModule, activation_store: ActivationsStore, n_batches: int, + compute_l2_norms: bool, + compute_sparsity_metrics: bool, + compute_variance_metrics: bool, eval_batch_size_prompts: int, model_kwargs: Mapping[str, Any], ): - metrics_list = [] - hook_name = sae.cfg.hook_name hook_head_index = sae.cfg.hook_head_index + metric_dict = {} + if compute_l2_norms: + metric_dict["l2_norm_in"] = [] + metric_dict["l2_norm_out"] = [] + metric_dict["l2_ratio"] = [] + if compute_sparsity_metrics: + metric_dict["l0"] = [] + metric_dict["l1"] = [] + if compute_variance_metrics: + metric_dict["explained_variance"] = [] + metric_dict["mse"] = [] + for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) @@ -180,49 +245,36 @@ def get_sparsity_and_variance_metrics( ) flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") - l2_norm_in = torch.norm(flattened_sae_input, dim=-1) - l2_norm_out = torch.norm(flattened_sae_out, dim=-1) - l2_norm_in_for_div = l2_norm_in.clone() - l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 - l2_norm_ratio = l2_norm_out / l2_norm_in_for_div - - l0 = (flattened_sae_feature_acts > 0).sum(dim=-1) - l1 = flattened_sae_feature_acts.sum(dim=-1) - resid_sum_of_squares = ( - (flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1) - ) - total_sum_of_squares = ( - (flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1) - ) - explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares - - metrics_list.append( - ( - l2_norm_in, - l2_norm_out, - l2_norm_ratio, - explained_variance, - l0.float(), - l1, - resid_sum_of_squares, + if compute_l2_norms: + l2_norm_in = torch.norm(flattened_sae_input, dim=-1) + l2_norm_out = torch.norm(flattened_sae_out, dim=-1) + l2_norm_in_for_div = l2_norm_in.clone() + l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 + l2_norm_ratio = l2_norm_out / l2_norm_in_for_div + metric_dict["l2_norm_in"].append(l2_norm_in) + metric_dict["l2_norm_out"].append(l2_norm_out) + metric_dict["l2_ratio"].append(l2_norm_ratio) + + if compute_sparsity_metrics: + l0 = (flattened_sae_feature_acts > 0).sum(dim=-1).float() + l1 = flattened_sae_feature_acts.sum(dim=-1) + metric_dict["l0"].append(l0) + metric_dict["l1"].append(l1) + + if compute_variance_metrics: + resid_sum_of_squares = ( + (flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1) ) - ) + total_sum_of_squares = ( + (flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1) + ) + explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares + metric_dict["explained_variance"].append(explained_variance) + metric_dict["mse"].append(resid_sum_of_squares) metrics: dict[str, float] = {} - for i, metric_name in enumerate( - [ - "l2_norm_in", - "l2_norm_out", - "l2_ratio", - "explained_variance", - "l0", - "l1", - "mse", - ] - ): - metrics[f"metrics/{metric_name}"] = ( - torch.stack([m[i] for m in metrics_list]).mean().item() - ) + for metric_name, metric_values in metric_dict.items(): + metrics[f"metrics/{metric_name}"] = torch.stack(metric_values).mean().item() return metrics @@ -233,8 +285,10 @@ def get_recons_loss( model: HookedRootModule, batch_tokens: torch.Tensor, activation_store: ActivationsStore, + compute_kl: bool, + compute_ce_loss: bool, model_kwargs: Mapping[str, Any] = {}, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> dict[str, Any]: hook_name = sae.cfg.hook_name head_index = sae.cfg.hook_head_index @@ -242,6 +296,8 @@ def get_recons_loss( batch_tokens, return_type="both", **model_kwargs ) + metrics = {} + # TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up def standard_replacement_hook(activations: torch.Tensor, hook: Any): @@ -335,7 +391,7 @@ def zero_ablate_hook(activations: torch.Tensor, hook: Any): **model_kwargs, ) - def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): + def kl(original_logits: torch.Tensor, new_logits: torch.Tensor): original_probs = torch.nn.functional.softmax(original_logits, dim=-1) log_original_probs = torch.log(original_probs) new_probs = torch.nn.functional.softmax(new_logits, dim=-1) @@ -344,16 +400,18 @@ def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): kl_div = kl_div.sum(dim=-1) return kl_div - recons_kl_div = compute_kl(original_logits, recons_logits) - zero_abl_kl_div = compute_kl(original_logits, zero_abl_logits) + if compute_kl: + recons_kl_div = kl(original_logits, recons_logits) + zero_abl_kl_div = kl(original_logits, zero_abl_logits) + metrics["kl_div_with_sae"] = recons_kl_div + metrics["kl_div_with_ablation"] = zero_abl_kl_div - return ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) + if compute_ce_loss: + metrics["ce_loss_with_sae"] = recons_ce_loss + metrics["ce_loss_without_sae"] = original_ce_loss + metrics["ce_loss_with_ablation"] = zero_abl_ce_loss + + return metrics def all_loadable_saes() -> list[tuple[str, str, float, float]]: @@ -394,6 +452,12 @@ def multiple_evals( eval_results = [] + eval_config = get_eval_everything_config( + batch_size_prompts=eval_batch_size_prompts, + n_eval_reconstruction_batches=num_eval_batches, + n_eval_sparsity_variance_batches=num_eval_batches, + ) + current_model = None current_model_str = None print(filtered_saes) @@ -430,8 +494,7 @@ def multiple_evals( sae=sae, activation_store=activation_store, model=current_model, - n_eval_batches=num_eval_batches, - eval_batch_size_prompts=eval_batch_size_prompts, + eval_config=eval_config, ) eval_results.append(eval_metrics) diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 205e7282..90b0824b 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -10,7 +10,7 @@ from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.evals import run_evals +from sae_lens.evals import EvalConfig, run_evals from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import L1Scheduler, get_lr_scheduler from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput @@ -22,6 +22,13 @@ "unrotated_decoder": ["scaling_factor", "b_dec"], } +TRAINER_EVAL_CONFIG = EvalConfig( + n_eval_reconstruction_batches=10, + compute_ce_loss=True, + n_eval_sparsity_variance_batches=1, + compute_l2_norms=True, +) + def _log_feature_sparsity( feature_sparsity: torch.Tensor, eps: float = 1e-10 @@ -313,8 +320,7 @@ def _run_and_log_evals(self): sae=self.sae, activation_store=self.activation_store, model=self.model, - n_eval_batches=self.cfg.n_eval_batches, - eval_batch_size_prompts=self.cfg.eval_batch_size_prompts, + eval_config=TRAINER_EVAL_CONFIG, model_kwargs=self.cfg.model_kwargs, ) diff --git a/tests/unit/training/test_evals.py b/tests/unit/training/test_evals.py index f2781349..b853d375 100644 --- a/tests/unit/training/test_evals.py +++ b/tests/unit/training/test_evals.py @@ -3,9 +3,10 @@ from transformer_lens import HookedTransformer from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.evals import run_evals +from sae_lens.evals import get_eval_everything_config, run_evals from sae_lens.sae import SAE from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sae_trainer import TRAINER_EVAL_CONFIG from sae_lens.training.training_sae import TrainingSAE from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached @@ -93,7 +94,7 @@ def training_sae(cfg: LanguageModelSAERunnerConfig): return TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) -expected_keys = [ +all_expected_keys = [ "metrics/l2_norm_in", "metrics/l2_ratio", "metrics/l2_norm_out", @@ -106,7 +107,6 @@ def training_sae(cfg: LanguageModelSAERunnerConfig): "metrics/ce_loss_with_sae", "metrics/ce_loss_with_ablation", "metrics/kl_div_score", - "metrics/kl_div_without_sae", "metrics/kl_div_with_sae", "metrics/kl_div_with_ablation", ] @@ -122,12 +122,11 @@ def test_run_evals_base_sae( sae=base_sae, activation_store=activation_store, model=model, - n_eval_batches=2, - eval_batch_size_prompts=None, + eval_config=get_eval_everything_config(), ) # results will be garbage without a real model. - for key in expected_keys: + for key in all_expected_keys: assert key in eval_metrics @@ -141,9 +140,46 @@ def test_run_evals_training_sae( sae=training_sae, activation_store=activation_store, model=model, - n_eval_batches=10, - eval_batch_size_prompts=None, + eval_config=get_eval_everything_config(), ) - for key in expected_keys: + print(eval_metrics) + for key in all_expected_keys: assert key in eval_metrics + + +def test_run_empty_evals( + base_sae: SAE, + activation_store: ActivationsStore, + model: HookedTransformer, +): + with pytest.raises(ValueError): + run_evals(sae=base_sae, activation_store=activation_store, model=model) + + +def test_training_eval_config( + base_sae: SAE, + activation_store: ActivationsStore, + model: HookedTransformer, +): + expected_keys = [ + "metrics/l2_norm_in", + "metrics/l2_ratio", + "metrics/l2_norm_out", + "metrics/ce_loss_score", + "metrics/ce_loss_without_sae", + "metrics/ce_loss_with_sae", + "metrics/ce_loss_with_ablation", + ] + eval_config = TRAINER_EVAL_CONFIG + eval_metrics = run_evals( + sae=base_sae, + activation_store=activation_store, + model=model, + eval_config=eval_config, + ) + sorted_returned_keys = sorted(eval_metrics.keys()) + sorted_expected_keys = sorted(expected_keys) + + for i in range(len(expected_keys)): + assert sorted_returned_keys[i] == sorted_expected_keys[i] From 5da6a13df27678d59e0d233b51dbf8758e190e34 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Thu, 4 Jul 2024 15:43:32 -0400 Subject: [PATCH 7/7] Adding type hint --- sae_lens/training/activations_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 714f654a..97550944 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -6,7 +6,7 @@ import numpy as np import torch -from datasets import Dataset, DatasetDict, load_dataset +from datasets import Dataset, DatasetDict, IterableDataset, load_dataset from safetensors import safe_open from safetensors.torch import save_file from torch.utils.data import DataLoader