From cf4ebcdfe93d96270da2ed108f37a5c8d9d97c75 Mon Sep 17 00:00:00 2001 From: Josh Engels Date: Thu, 4 Jul 2024 15:19:55 -0400 Subject: [PATCH] Making changes in response to comments --- sae_lens/evals.py | 293 ++++++++++++++++++------------ sae_lens/training/sae_trainer.py | 12 +- tests/unit/training/test_evals.py | 54 +++++- 3 files changed, 232 insertions(+), 127 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index ba0288b2..aab8e6e4 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -1,7 +1,8 @@ import argparse import re +from dataclasses import dataclass from functools import partial -from typing import Any, Mapping, Tuple +from typing import Any, Mapping import einops import pandas as pd @@ -15,19 +16,55 @@ from sae_lens.training.activations_store import ActivationsStore +# Everything by default is false so the user can just set the ones they want to true +@dataclass +class EvalConfig: + batch_size_prompts: int | None = None + + # Reconstruction metrics + n_eval_reconstruction_batches: int = 10 + compute_kl: bool = False + compute_ce_loss: bool = False + + # Sparsity and variance metrics + n_eval_sparsity_variance_batches: int = 1 + compute_l2_norms: bool = False + compute_sparsity_metrics: bool = False + compute_variance_metrics: bool = False + + +def get_eval_everything_config( + batch_size_prompts: int | None = None, + n_eval_reconstruction_batches: int = 10, + n_eval_sparsity_variance_batches: int = 1, +) -> EvalConfig: + """ + Returns an EvalConfig object with all metrics set to True, so that when passed to run_evals all available metrics will be run. + """ + return EvalConfig( + batch_size_prompts=batch_size_prompts, + n_eval_reconstruction_batches=n_eval_reconstruction_batches, + compute_kl=True, + compute_ce_loss=True, + compute_l2_norms=True, + n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches, + compute_sparsity_metrics=True, + compute_variance_metrics=True, + ) + + @torch.no_grad() def run_evals( sae: SAE, activation_store: ActivationsStore, model: HookedRootModule, - n_eval_batches: int = 10, - eval_batch_size_prompts: int | None = None, + eval_config: EvalConfig = EvalConfig(), model_kwargs: Mapping[str, Any] = {}, ) -> dict[str, Any]: hook_name = sae.cfg.hook_name actual_batch_size = ( - eval_batch_size_prompts or activation_store.store_batch_size_prompts + eval_config.batch_size_prompts or activation_store.store_batch_size_prompts ) # TODO: Come up with a cleaner long term strategy here for SAEs that do reshaping. @@ -38,24 +75,44 @@ def run_evals( else: previous_hook_z_reshaping_mode = None - metrics = get_downstream_reconstruction_metrics( - sae, - model, - activation_store, - n_batches=n_eval_batches, - eval_batch_size_prompts=actual_batch_size, - ) + metrics = {} - activation_store.reset_input_dataset() + if eval_config.compute_kl or eval_config.compute_ce_loss: + assert eval_config.n_eval_reconstruction_batches > 0 + metrics |= get_downstream_reconstruction_metrics( + sae, + model, + activation_store, + compute_kl=eval_config.compute_kl, + compute_ce_loss=eval_config.compute_ce_loss, + n_batches=eval_config.n_eval_reconstruction_batches, + eval_batch_size_prompts=actual_batch_size, + ) - metrics |= get_sparsity_and_variance_metrics( - sae, - model, - activation_store, - n_batches=n_eval_batches, - eval_batch_size_prompts=actual_batch_size, - model_kwargs=model_kwargs, - ) + activation_store.reset_input_dataset() + + if ( + eval_config.compute_l2_norms + or eval_config.compute_sparsity_metrics + or eval_config.compute_variance_metrics + ): + assert eval_config.n_eval_sparsity_variance_batches > 0 + metrics |= get_sparsity_and_variance_metrics( + sae, + model, + activation_store, + compute_l2_norms=eval_config.compute_l2_norms, + compute_sparsity_metrics=eval_config.compute_sparsity_metrics, + compute_variance_metrics=eval_config.compute_variance_metrics, + n_batches=eval_config.n_eval_sparsity_variance_batches, + eval_batch_size_prompts=actual_batch_size, + model_kwargs=model_kwargs, + ) + + if len(metrics) == 0: + raise ValueError( + "No metrics were computed, please set at least one metric to True." + ) # restore previous hook z reshaping mode if necessary if "hook_z" in hook_name: @@ -65,7 +122,9 @@ def run_evals( sae.turn_off_forward_pass_hook_z_reshaping() total_tokens_evaluated = ( - activation_store.context_size * n_eval_batches * actual_batch_size + activation_store.context_size + * eval_config.n_eval_reconstruction_batches + * actual_batch_size ) metrics["metrics/total_tokens_evaluated"] = total_tokens_evaluated @@ -76,57 +135,50 @@ def get_downstream_reconstruction_metrics( sae: SAE, model: HookedRootModule, activation_store: ActivationsStore, + compute_kl: bool, + compute_ce_loss: bool, n_batches: int, eval_batch_size_prompts: int, ): - metrics = [] + metrics_dict = {} + if compute_kl: + metrics_dict["kl_div_with_sae"] = [] + metrics_dict["kl_div_with_ablation"] = [] + if compute_ce_loss: + metrics_dict["ce_loss_with_sae"] = [] + metrics_dict["ce_loss_without_sae"] = [] + metrics_dict["ce_loss_with_ablation"] = [] + for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) - ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) = get_recons_loss( + for metric_name, metric_value in get_recons_loss( sae, model, batch_tokens, activation_store, - ) + compute_kl=compute_kl, + compute_ce_loss=compute_ce_loss, + ).items(): + metrics_dict[metric_name].append(metric_value) - metrics.append( - ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) + metrics: dict[str, float] = {} + for metric_name, metric_values in metrics_dict.items(): + metrics[f"metrics/{metric_name}"] = torch.stack(metric_values).mean().item() + + if compute_kl: + metrics["metrics/kl_div_score"] = ( + metrics["metrics/kl_div_with_ablation"] - metrics["metrics/kl_div_with_sae"] + ) / metrics["metrics/kl_div_with_ablation"] + + if compute_ce_loss: + metrics["metrics/ce_loss_score"] = ( + metrics["metrics/ce_loss_with_ablation"] + - metrics["metrics/ce_loss_with_sae"] + ) / ( + metrics["metrics/ce_loss_with_ablation"] + - metrics["metrics/ce_loss_without_sae"] ) - recons_kl_div = torch.stack([metric[0] for metric in metrics]).mean() - zero_abl_kl_div = torch.stack([metric[1] for metric in metrics]).mean() - kl_div_score = (zero_abl_kl_div - recons_kl_div) / zero_abl_kl_div - - zero_abl_ce_loss = torch.stack([metric[4] for metric in metrics]).mean() - recons_ce_loss = torch.stack([metric[3] for metric in metrics]).mean() - original_ce_loss = torch.stack([metric[2] for metric in metrics]).mean() - ce_loss_score = (zero_abl_ce_loss - recons_ce_loss) / ( - zero_abl_ce_loss - original_ce_loss - ) - - metrics = { - "metrics/ce_loss_score": ce_loss_score.item(), - "metrics/ce_loss_without_sae": original_ce_loss.item(), - "metrics/ce_loss_with_sae": recons_ce_loss.item(), - "metrics/ce_loss_with_ablation": zero_abl_ce_loss.item(), - "metrics/kl_div_score": kl_div_score.item(), - "metrics/kl_div_without_sae": 0, - "metrics/kl_div_with_sae": recons_kl_div.item(), - "metrics/kl_div_with_ablation": zero_abl_kl_div.item(), - } - return metrics @@ -135,15 +187,28 @@ def get_sparsity_and_variance_metrics( model: HookedRootModule, activation_store: ActivationsStore, n_batches: int, + compute_l2_norms: bool, + compute_sparsity_metrics: bool, + compute_variance_metrics: bool, eval_batch_size_prompts: int, model_kwargs: Mapping[str, Any], ): - metrics_list = [] - hook_name = sae.cfg.hook_name hook_head_index = sae.cfg.hook_head_index + metric_dict = {} + if compute_l2_norms: + metric_dict["l2_norm_in"] = [] + metric_dict["l2_norm_out"] = [] + metric_dict["l2_ratio"] = [] + if compute_sparsity_metrics: + metric_dict["l0"] = [] + metric_dict["l1"] = [] + if compute_variance_metrics: + metric_dict["explained_variance"] = [] + metric_dict["mse"] = [] + for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) @@ -180,49 +245,36 @@ def get_sparsity_and_variance_metrics( ) flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") - l2_norm_in = torch.norm(flattened_sae_input, dim=-1) - l2_norm_out = torch.norm(flattened_sae_out, dim=-1) - l2_norm_in_for_div = l2_norm_in.clone() - l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 - l2_norm_ratio = l2_norm_out / l2_norm_in_for_div - - l0 = (flattened_sae_feature_acts > 0).sum(dim=-1) - l1 = flattened_sae_feature_acts.sum(dim=-1) - resid_sum_of_squares = ( - (flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1) - ) - total_sum_of_squares = ( - (flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1) - ) - explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares - - metrics_list.append( - ( - l2_norm_in, - l2_norm_out, - l2_norm_ratio, - explained_variance, - l0.float(), - l1, - resid_sum_of_squares, + if compute_l2_norms: + l2_norm_in = torch.norm(flattened_sae_input, dim=-1) + l2_norm_out = torch.norm(flattened_sae_out, dim=-1) + l2_norm_in_for_div = l2_norm_in.clone() + l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 + l2_norm_ratio = l2_norm_out / l2_norm_in_for_div + metric_dict["l2_norm_in"].append(l2_norm_in) + metric_dict["l2_norm_out"].append(l2_norm_out) + metric_dict["l2_ratio"].append(l2_norm_ratio) + + if compute_sparsity_metrics: + l0 = (flattened_sae_feature_acts > 0).sum(dim=-1).float() + l1 = flattened_sae_feature_acts.sum(dim=-1) + metric_dict["l0"].append(l0) + metric_dict["l1"].append(l1) + + if compute_variance_metrics: + resid_sum_of_squares = ( + (flattened_sae_input - flattened_sae_out).pow(2).sum(dim=-1) ) - ) + total_sum_of_squares = ( + (flattened_sae_input - flattened_sae_input.mean(dim=0)).pow(2).sum(-1) + ) + explained_variance = 1 - resid_sum_of_squares / total_sum_of_squares + metric_dict["explained_variance"].append(explained_variance) + metric_dict["mse"].append(resid_sum_of_squares) metrics: dict[str, float] = {} - for i, metric_name in enumerate( - [ - "l2_norm_in", - "l2_norm_out", - "l2_ratio", - "explained_variance", - "l0", - "l1", - "mse", - ] - ): - metrics[f"metrics/{metric_name}"] = ( - torch.stack([m[i] for m in metrics_list]).mean().item() - ) + for metric_name, metric_values in metric_dict.items(): + metrics[f"metrics/{metric_name}"] = torch.stack(metric_values).mean().item() return metrics @@ -233,8 +285,10 @@ def get_recons_loss( model: HookedRootModule, batch_tokens: torch.Tensor, activation_store: ActivationsStore, + compute_kl: bool, + compute_ce_loss: bool, model_kwargs: Mapping[str, Any] = {}, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> dict[str, Any]: hook_name = sae.cfg.hook_name head_index = sae.cfg.hook_head_index @@ -242,6 +296,8 @@ def get_recons_loss( batch_tokens, return_type="both", **model_kwargs ) + metrics = {} + # TODO(tomMcGrath): the rescaling below is a bit of a hack and could probably be tidied up def standard_replacement_hook(activations: torch.Tensor, hook: Any): @@ -335,7 +391,7 @@ def zero_ablate_hook(activations: torch.Tensor, hook: Any): **model_kwargs, ) - def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): + def kl(original_logits: torch.Tensor, new_logits: torch.Tensor): original_probs = torch.nn.functional.softmax(original_logits, dim=-1) log_original_probs = torch.log(original_probs) new_probs = torch.nn.functional.softmax(new_logits, dim=-1) @@ -344,16 +400,18 @@ def compute_kl(original_logits: torch.Tensor, new_logits: torch.Tensor): kl_div = kl_div.sum(dim=-1) return kl_div - recons_kl_div = compute_kl(original_logits, recons_logits) - zero_abl_kl_div = compute_kl(original_logits, zero_abl_logits) + if compute_kl: + recons_kl_div = kl(original_logits, recons_logits) + zero_abl_kl_div = kl(original_logits, zero_abl_logits) + metrics["kl_div_with_sae"] = recons_kl_div + metrics["kl_div_with_ablation"] = zero_abl_kl_div - return ( - recons_kl_div, - zero_abl_kl_div, - original_ce_loss, - recons_ce_loss, - zero_abl_ce_loss, - ) + if compute_ce_loss: + metrics["ce_loss_with_sae"] = recons_ce_loss + metrics["ce_loss_without_sae"] = original_ce_loss + metrics["ce_loss_with_ablation"] = zero_abl_ce_loss + + return metrics def all_loadable_saes() -> list[tuple[str, str, float, float]]: @@ -394,6 +452,12 @@ def multiple_evals( eval_results = [] + eval_config = get_eval_everything_config( + batch_size_prompts=eval_batch_size_prompts, + n_eval_reconstruction_batches=num_eval_batches, + n_eval_sparsity_variance_batches=num_eval_batches, + ) + current_model = None current_model_str = None print(filtered_saes) @@ -430,8 +494,7 @@ def multiple_evals( sae=sae, activation_store=activation_store, model=current_model, - n_eval_batches=num_eval_batches, - eval_batch_size_prompts=eval_batch_size_prompts, + eval_config=eval_config, ) eval_results.append(eval_metrics) diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 205e7282..90b0824b 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -10,7 +10,7 @@ from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.evals import run_evals +from sae_lens.evals import EvalConfig, run_evals from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import L1Scheduler, get_lr_scheduler from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput @@ -22,6 +22,13 @@ "unrotated_decoder": ["scaling_factor", "b_dec"], } +TRAINER_EVAL_CONFIG = EvalConfig( + n_eval_reconstruction_batches=10, + compute_ce_loss=True, + n_eval_sparsity_variance_batches=1, + compute_l2_norms=True, +) + def _log_feature_sparsity( feature_sparsity: torch.Tensor, eps: float = 1e-10 @@ -313,8 +320,7 @@ def _run_and_log_evals(self): sae=self.sae, activation_store=self.activation_store, model=self.model, - n_eval_batches=self.cfg.n_eval_batches, - eval_batch_size_prompts=self.cfg.eval_batch_size_prompts, + eval_config=TRAINER_EVAL_CONFIG, model_kwargs=self.cfg.model_kwargs, ) diff --git a/tests/unit/training/test_evals.py b/tests/unit/training/test_evals.py index f2781349..b853d375 100644 --- a/tests/unit/training/test_evals.py +++ b/tests/unit/training/test_evals.py @@ -3,9 +3,10 @@ from transformer_lens import HookedTransformer from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.evals import run_evals +from sae_lens.evals import get_eval_everything_config, run_evals from sae_lens.sae import SAE from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sae_trainer import TRAINER_EVAL_CONFIG from sae_lens.training.training_sae import TrainingSAE from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached @@ -93,7 +94,7 @@ def training_sae(cfg: LanguageModelSAERunnerConfig): return TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) -expected_keys = [ +all_expected_keys = [ "metrics/l2_norm_in", "metrics/l2_ratio", "metrics/l2_norm_out", @@ -106,7 +107,6 @@ def training_sae(cfg: LanguageModelSAERunnerConfig): "metrics/ce_loss_with_sae", "metrics/ce_loss_with_ablation", "metrics/kl_div_score", - "metrics/kl_div_without_sae", "metrics/kl_div_with_sae", "metrics/kl_div_with_ablation", ] @@ -122,12 +122,11 @@ def test_run_evals_base_sae( sae=base_sae, activation_store=activation_store, model=model, - n_eval_batches=2, - eval_batch_size_prompts=None, + eval_config=get_eval_everything_config(), ) # results will be garbage without a real model. - for key in expected_keys: + for key in all_expected_keys: assert key in eval_metrics @@ -141,9 +140,46 @@ def test_run_evals_training_sae( sae=training_sae, activation_store=activation_store, model=model, - n_eval_batches=10, - eval_batch_size_prompts=None, + eval_config=get_eval_everything_config(), ) - for key in expected_keys: + print(eval_metrics) + for key in all_expected_keys: assert key in eval_metrics + + +def test_run_empty_evals( + base_sae: SAE, + activation_store: ActivationsStore, + model: HookedTransformer, +): + with pytest.raises(ValueError): + run_evals(sae=base_sae, activation_store=activation_store, model=model) + + +def test_training_eval_config( + base_sae: SAE, + activation_store: ActivationsStore, + model: HookedTransformer, +): + expected_keys = [ + "metrics/l2_norm_in", + "metrics/l2_ratio", + "metrics/l2_norm_out", + "metrics/ce_loss_score", + "metrics/ce_loss_without_sae", + "metrics/ce_loss_with_sae", + "metrics/ce_loss_with_ablation", + ] + eval_config = TRAINER_EVAL_CONFIG + eval_metrics = run_evals( + sae=base_sae, + activation_store=activation_store, + model=model, + eval_config=eval_config, + ) + sorted_returned_keys = sorted(eval_metrics.keys()) + sorted_expected_keys = sorted(expected_keys) + + for i in range(len(expected_keys)): + assert sorted_returned_keys[i] == sorted_expected_keys[i]