diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 772f91fb..48fe9d36 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -64,6 +64,8 @@ class EvalConfig: compute_l2_norms: bool = False compute_sparsity_metrics: bool = False compute_variance_metrics: bool = False + # compute featurewise density statistics + compute_featurewise_density_statistics: bool = False library_version: str = field(default_factory=get_library_version) git_hash: str = field(default_factory=get_git_hash) @@ -85,6 +87,7 @@ def get_eval_everything_config( n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches, compute_sparsity_metrics=True, compute_variance_metrics=True, + compute_featurewise_density_statistics=True, ) @@ -96,7 +99,7 @@ def run_evals( eval_config: EvalConfig = EvalConfig(), model_kwargs: Mapping[str, Any] = {}, ignore_tokens: set[int | None] = set(), -) -> dict[str, Any]: +) -> tuple[dict[str, Any], dict[str, Any]]: hook_name = sae.cfg.hook_name actual_batch_size = ( @@ -111,11 +114,11 @@ def run_evals( else: previous_hook_z_reshaping_mode = None - metrics = {} + all_metrics = {} if eval_config.compute_kl or eval_config.compute_ce_loss: assert eval_config.n_eval_reconstruction_batches > 0 - metrics |= get_downstream_reconstruction_metrics( + all_metrics |= get_downstream_reconstruction_metrics( sae, model, activation_store, @@ -134,20 +137,24 @@ def run_evals( or eval_config.compute_variance_metrics ): assert eval_config.n_eval_sparsity_variance_batches > 0 - metrics |= get_sparsity_and_variance_metrics( + scalar_metrics, feature_metrics = get_sparsity_and_variance_metrics( sae, model, activation_store, compute_l2_norms=eval_config.compute_l2_norms, compute_sparsity_metrics=eval_config.compute_sparsity_metrics, compute_variance_metrics=eval_config.compute_variance_metrics, + compute_featurewise_density_statistics=eval_config.compute_featurewise_density_statistics, n_batches=eval_config.n_eval_sparsity_variance_batches, eval_batch_size_prompts=actual_batch_size, model_kwargs=model_kwargs, ignore_tokens=ignore_tokens, ) + all_metrics |= scalar_metrics + else: + feature_metrics = {} - if len(metrics) == 0: + if len(all_metrics) == 0: raise ValueError( "No metrics were computed, please set at least one metric to True." ) @@ -171,14 +178,14 @@ def run_evals( * actual_batch_size ) - metrics["total_tokens_eval_reconstruction"] = ( + all_metrics["total_tokens_eval_reconstruction"] = ( total_tokens_evaluated_eval_reconstruction ) - metrics["total_tokens_eval_sparsity_variance"] = ( + all_metrics["total_tokens_eval_sparsity_variance"] = ( total_tokens_evaluated_eval_sparsity_variance ) - return metrics + return all_metrics, feature_metrics def get_downstream_reconstruction_metrics( @@ -252,15 +259,18 @@ def get_sparsity_and_variance_metrics( compute_l2_norms: bool, compute_sparsity_metrics: bool, compute_variance_metrics: bool, + compute_featurewise_density_statistics: bool, eval_batch_size_prompts: int, model_kwargs: Mapping[str, Any], ignore_tokens: set[int | None] = set(), -): +) -> tuple[dict[str, Any], dict[str, Any]]: hook_name = sae.cfg.hook_name hook_head_index = sae.cfg.hook_head_index metric_dict = {} + feature_metric_dict = {} + if compute_l2_norms: metric_dict["l2_norm_in"] = [] metric_dict["l2_norm_out"] = [] @@ -273,6 +283,13 @@ def get_sparsity_and_variance_metrics( metric_dict["explained_variance"] = [] metric_dict["mse"] = [] metric_dict["cossim"] = [] + if compute_featurewise_density_statistics: + feature_metric_dict["feature_density"] = [] + feature_metric_dict["consistent_activation_heuristic"] = [] + + total_feature_acts = torch.zeros(sae.cfg.d_sae, device=sae.device) + total_feature_prompts = torch.zeros(sae.cfg.d_sae, device=sae.device) + total_tokens = 0 for _ in range(n_batches): batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts) @@ -327,7 +344,9 @@ def get_sparsity_and_variance_metrics( ) flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") + # TODO: Clean this up. # apply mask + masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1) flattened_sae_input = flattened_sae_input[flattened_mask] flattened_sae_feature_acts = flattened_sae_feature_acts[flattened_mask] flattened_sae_out = flattened_sae_out[flattened_mask] @@ -383,13 +402,27 @@ def get_sparsity_and_variance_metrics( metric_dict["mse"].append(mse) metric_dict["cossim"].append(cossim) + if compute_featurewise_density_statistics: + sae_feature_activations_bool = (masked_sae_feature_activations > 0).float() + total_feature_acts += sae_feature_activations_bool.sum(dim=1).sum(dim=0) + total_feature_prompts += (sae_feature_activations_bool.sum(dim=1) > 0).sum( + dim=0 + ) + total_tokens += mask.sum() + + # Aggregate scalar metrics metrics: dict[str, float] = {} for metric_name, metric_values in metric_dict.items(): - # since we're masking, we need to flatten but may not have n_ctx for all metrics - # in all batches. metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item() - return metrics + # Aggregate feature-wise metrics + feature_metrics: dict[str, torch.Tensor] = {} + feature_metrics["feature_density"] = (total_feature_acts / total_tokens).tolist() + feature_metrics["consistent_activation_heuristic"] = ( + total_feature_acts / total_feature_prompts + ).tolist() + + return metrics, feature_metrics @torch.no_grad() @@ -645,7 +678,7 @@ def multiple_evals( ] = eval_config.library_version eval_metrics["eval_cfg"]["git_hash"] = eval_config.git_hash - run_eval_metrics = run_evals( + scalar_metrics, feature_metrics = run_evals( sae=sae, activation_store=activation_store, model=current_model, @@ -656,7 +689,8 @@ def multiple_evals( current_model.tokenizer.bos_token_id, # type: ignore }, ) - eval_metrics["metrics"] = run_eval_metrics + eval_metrics["metrics"] = scalar_metrics + eval_metrics["feature_metrics"] = feature_metrics # Add SAE config eval_metrics["sae_cfg"] = sae.cfg.to_dict() @@ -766,6 +800,8 @@ def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str): args = arg_parser.parse_args() + # poetry run python sae_lens/evals.py "sae_bench_pythia70m_sweep_standard.*" "blocks.4.*" --save_path "pythia_70m.csv" + eval_results = run_evaluations(args) output_files = process_results(eval_results, args.output_dir) diff --git a/tests/benchmark/test_eval_all_loadable_saes.py b/tests/benchmark/test_eval_all_loadable_saes.py index 4e33a908..04a79235 100644 --- a/tests/benchmark/test_eval_all_loadable_saes.py +++ b/tests/benchmark/test_eval_all_loadable_saes.py @@ -1,12 +1,22 @@ # import pandas as pd # import plotly.express as px # import numpy as np +import argparse +import json +from pathlib import Path + import pytest import torch from sae_lens import SAE, ActivationsStore from sae_lens.analysis.neuronpedia_integration import open_neuronpedia_feature_dashboard -from sae_lens.evals import all_loadable_saes, get_eval_everything_config, run_evals +from sae_lens.evals import ( + all_loadable_saes, + get_eval_everything_config, + process_results, + run_evals, + run_evaluations, +) from sae_lens.toolkit.pretrained_sae_loaders import ( SAEConfigLoadOptions, get_sae_config_from_hf, @@ -149,7 +159,7 @@ def test_eval_all_loadable_saes( eval_config = get_eval_everything_config( batch_size_prompts=8, n_eval_reconstruction_batches=3, - n_eval_sparsity_variance_batches=10, + n_eval_sparsity_variance_batches=100, ) metrics = run_evals( @@ -168,3 +178,38 @@ def test_eval_all_loadable_saes( assert ( pytest.approx(metrics["explained_variance"], abs=0.1) == expected_var_explained ) + + +@pytest.fixture +def mock_evals_simple_args(tmp_path: Path): + class Args: + sae_regex_pattern = "gpt2-small-res-jb" + sae_block_pattern = "blocks.0.hook_resid_pre" + num_eval_batches = 1 + eval_batch_size_prompts = 2 + datasets = ["Skylion007/openwebtext"] + ctx_lens = [128] + output_dir = str(tmp_path) + + return Args() + + +def test_run_evaluations_process_results(mock_evals_simple_args: argparse.Namespace): + """ + This test is more like an acceptance test for the evals code than a benchmark. + """ + eval_results = run_evaluations(mock_evals_simple_args) + output_files = process_results(eval_results, mock_evals_simple_args.output_dir) + + print("Evaluation complete. Output files:") + print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore + print(f"Combined JSON: {output_files['combined_json']}") + print(f"CSV: {output_files['csv']}") + + # open and validate the files + combined_json_path = output_files["combined_json"] + assert combined_json_path.exists() + with open(combined_json_path, "r") as f: + data = json.load(f)[0] + assert "metrics" in data + assert "feature_metrics" in data diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index ce2f0325..ee4e7805 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -132,7 +132,7 @@ def test_run_evals_base_sae( model: HookedTransformer, ): - eval_metrics = run_evals( + eval_metrics, _ = run_evals( sae=base_sae, activation_store=activation_store, model=model, @@ -150,7 +150,7 @@ def test_run_evals_training_sae( model: HookedTransformer, ): - eval_metrics = run_evals( + eval_metrics, _ = run_evals( sae=training_sae, activation_store=activation_store, model=model, @@ -167,7 +167,7 @@ def test_run_evals_training_sae_ignore_bos( model: HookedTransformer, ): - eval_metrics = run_evals( + eval_metrics, _ = run_evals( sae=training_sae, activation_store=activation_store, model=model, @@ -209,7 +209,7 @@ def test_training_eval_config( "relative_reconstruction_bias", ] eval_config = TRAINER_EVAL_CONFIG - eval_metrics = run_evals( + eval_metrics, _ = run_evals( sae=base_sae, activation_store=activation_store, model=model, @@ -238,7 +238,7 @@ def test_training_eval_config_ignore_control_tokens( "relative_reconstruction_bias", ] eval_config = TRAINER_EVAL_CONFIG - eval_metrics = run_evals( + eval_metrics, _ = run_evals( sae=base_sae, activation_store=activation_store, model=model,