Skip to content

Commit

Permalink
add feature density histogram to evals + consistent activation heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus authored and Curt Tigges committed Oct 18, 2024
1 parent c168c2b commit 9341398
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 21 deletions.
64 changes: 50 additions & 14 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class EvalConfig:
compute_l2_norms: bool = False
compute_sparsity_metrics: bool = False
compute_variance_metrics: bool = False
# compute featurewise density statistics
compute_featurewise_density_statistics: bool = False

library_version: str = field(default_factory=get_library_version)
git_hash: str = field(default_factory=get_git_hash)
Expand All @@ -85,6 +87,7 @@ def get_eval_everything_config(
n_eval_sparsity_variance_batches=n_eval_sparsity_variance_batches,
compute_sparsity_metrics=True,
compute_variance_metrics=True,
compute_featurewise_density_statistics=True,
)


Expand All @@ -96,7 +99,7 @@ def run_evals(
eval_config: EvalConfig = EvalConfig(),
model_kwargs: Mapping[str, Any] = {},
ignore_tokens: set[int | None] = set(),
) -> dict[str, Any]:
) -> tuple[dict[str, Any], dict[str, Any]]:

hook_name = sae.cfg.hook_name
actual_batch_size = (
Expand All @@ -111,11 +114,11 @@ def run_evals(
else:
previous_hook_z_reshaping_mode = None

metrics = {}
all_metrics = {}

if eval_config.compute_kl or eval_config.compute_ce_loss:
assert eval_config.n_eval_reconstruction_batches > 0
metrics |= get_downstream_reconstruction_metrics(
all_metrics |= get_downstream_reconstruction_metrics(
sae,
model,
activation_store,
Expand All @@ -134,20 +137,24 @@ def run_evals(
or eval_config.compute_variance_metrics
):
assert eval_config.n_eval_sparsity_variance_batches > 0
metrics |= get_sparsity_and_variance_metrics(
scalar_metrics, feature_metrics = get_sparsity_and_variance_metrics(
sae,
model,
activation_store,
compute_l2_norms=eval_config.compute_l2_norms,
compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
compute_variance_metrics=eval_config.compute_variance_metrics,
compute_featurewise_density_statistics=eval_config.compute_featurewise_density_statistics,
n_batches=eval_config.n_eval_sparsity_variance_batches,
eval_batch_size_prompts=actual_batch_size,
model_kwargs=model_kwargs,
ignore_tokens=ignore_tokens,
)
all_metrics |= scalar_metrics
else:
feature_metrics = {}

if len(metrics) == 0:
if len(all_metrics) == 0:
raise ValueError(
"No metrics were computed, please set at least one metric to True."
)
Expand All @@ -171,14 +178,14 @@ def run_evals(
* actual_batch_size
)

metrics["total_tokens_eval_reconstruction"] = (
all_metrics["total_tokens_eval_reconstruction"] = (
total_tokens_evaluated_eval_reconstruction
)
metrics["total_tokens_eval_sparsity_variance"] = (
all_metrics["total_tokens_eval_sparsity_variance"] = (
total_tokens_evaluated_eval_sparsity_variance
)

return metrics
return all_metrics, feature_metrics


def get_downstream_reconstruction_metrics(
Expand Down Expand Up @@ -252,15 +259,18 @@ def get_sparsity_and_variance_metrics(
compute_l2_norms: bool,
compute_sparsity_metrics: bool,
compute_variance_metrics: bool,
compute_featurewise_density_statistics: bool,
eval_batch_size_prompts: int,
model_kwargs: Mapping[str, Any],
ignore_tokens: set[int | None] = set(),
):
) -> tuple[dict[str, Any], dict[str, Any]]:

hook_name = sae.cfg.hook_name
hook_head_index = sae.cfg.hook_head_index

metric_dict = {}
feature_metric_dict = {}

if compute_l2_norms:
metric_dict["l2_norm_in"] = []
metric_dict["l2_norm_out"] = []
Expand All @@ -273,6 +283,13 @@ def get_sparsity_and_variance_metrics(
metric_dict["explained_variance"] = []
metric_dict["mse"] = []
metric_dict["cossim"] = []
if compute_featurewise_density_statistics:
feature_metric_dict["feature_density"] = []
feature_metric_dict["consistent_activation_heuristic"] = []

total_feature_acts = torch.zeros(sae.cfg.d_sae, device=sae.device)
total_feature_prompts = torch.zeros(sae.cfg.d_sae, device=sae.device)
total_tokens = 0

for _ in range(n_batches):
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
Expand Down Expand Up @@ -327,7 +344,9 @@ def get_sparsity_and_variance_metrics(
)
flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d")

# TODO: Clean this up.
# apply mask
masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1)
flattened_sae_input = flattened_sae_input[flattened_mask]
flattened_sae_feature_acts = flattened_sae_feature_acts[flattened_mask]
flattened_sae_out = flattened_sae_out[flattened_mask]
Expand Down Expand Up @@ -383,13 +402,27 @@ def get_sparsity_and_variance_metrics(
metric_dict["mse"].append(mse)
metric_dict["cossim"].append(cossim)

if compute_featurewise_density_statistics:
sae_feature_activations_bool = (masked_sae_feature_activations > 0).float()
total_feature_acts += sae_feature_activations_bool.sum(dim=1).sum(dim=0)
total_feature_prompts += (sae_feature_activations_bool.sum(dim=1) > 0).sum(
dim=0
)
total_tokens += mask.sum()

# Aggregate scalar metrics
metrics: dict[str, float] = {}
for metric_name, metric_values in metric_dict.items():
# since we're masking, we need to flatten but may not have n_ctx for all metrics
# in all batches.
metrics[f"{metric_name}"] = torch.cat(metric_values).mean().item()

return metrics
# Aggregate feature-wise metrics
feature_metrics: dict[str, torch.Tensor] = {}
feature_metrics["feature_density"] = (total_feature_acts / total_tokens).tolist()
feature_metrics["consistent_activation_heuristic"] = (
total_feature_acts / total_feature_prompts
).tolist()

return metrics, feature_metrics


@torch.no_grad()
Expand Down Expand Up @@ -645,7 +678,7 @@ def multiple_evals(
] = eval_config.library_version
eval_metrics["eval_cfg"]["git_hash"] = eval_config.git_hash

run_eval_metrics = run_evals(
scalar_metrics, feature_metrics = run_evals(
sae=sae,
activation_store=activation_store,
model=current_model,
Expand All @@ -656,7 +689,8 @@ def multiple_evals(
current_model.tokenizer.bos_token_id, # type: ignore
},
)
eval_metrics["metrics"] = run_eval_metrics
eval_metrics["metrics"] = scalar_metrics
eval_metrics["feature_metrics"] = feature_metrics

# Add SAE config
eval_metrics["sae_cfg"] = sae.cfg.to_dict()
Expand Down Expand Up @@ -766,6 +800,8 @@ def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str):

args = arg_parser.parse_args()

# poetry run python sae_lens/evals.py "sae_bench_pythia70m_sweep_standard.*" "blocks.4.*" --save_path "pythia_70m.csv"

eval_results = run_evaluations(args)
output_files = process_results(eval_results, args.output_dir)

Expand Down
49 changes: 47 additions & 2 deletions tests/benchmark/test_eval_all_loadable_saes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# import pandas as pd
# import plotly.express as px
# import numpy as np
import argparse
import json
from pathlib import Path

import pytest
import torch

from sae_lens import SAE, ActivationsStore
from sae_lens.analysis.neuronpedia_integration import open_neuronpedia_feature_dashboard
from sae_lens.evals import all_loadable_saes, get_eval_everything_config, run_evals
from sae_lens.evals import (
all_loadable_saes,
get_eval_everything_config,
process_results,
run_evals,
run_evaluations,
)
from sae_lens.toolkit.pretrained_sae_loaders import (
SAEConfigLoadOptions,
get_sae_config_from_hf,
Expand Down Expand Up @@ -149,7 +159,7 @@ def test_eval_all_loadable_saes(
eval_config = get_eval_everything_config(
batch_size_prompts=8,
n_eval_reconstruction_batches=3,
n_eval_sparsity_variance_batches=10,
n_eval_sparsity_variance_batches=100,
)

metrics = run_evals(
Expand All @@ -168,3 +178,38 @@ def test_eval_all_loadable_saes(
assert (
pytest.approx(metrics["explained_variance"], abs=0.1) == expected_var_explained
)


@pytest.fixture
def mock_evals_simple_args(tmp_path: Path):
class Args:
sae_regex_pattern = "gpt2-small-res-jb"
sae_block_pattern = "blocks.0.hook_resid_pre"
num_eval_batches = 1
eval_batch_size_prompts = 2
datasets = ["Skylion007/openwebtext"]
ctx_lens = [128]
output_dir = str(tmp_path)

return Args()


def test_run_evaluations_process_results(mock_evals_simple_args: argparse.Namespace):
"""
This test is more like an acceptance test for the evals code than a benchmark.
"""
eval_results = run_evaluations(mock_evals_simple_args)
output_files = process_results(eval_results, mock_evals_simple_args.output_dir)

print("Evaluation complete. Output files:")
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
print(f"Combined JSON: {output_files['combined_json']}")
print(f"CSV: {output_files['csv']}")

# open and validate the files
combined_json_path = output_files["combined_json"]
assert combined_json_path.exists()
with open(combined_json_path, "r") as f:
data = json.load(f)[0]
assert "metrics" in data
assert "feature_metrics" in data
10 changes: 5 additions & 5 deletions tests/unit/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_run_evals_base_sae(
model: HookedTransformer,
):

eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=base_sae,
activation_store=activation_store,
model=model,
Expand All @@ -150,7 +150,7 @@ def test_run_evals_training_sae(
model: HookedTransformer,
):

eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=training_sae,
activation_store=activation_store,
model=model,
Expand All @@ -167,7 +167,7 @@ def test_run_evals_training_sae_ignore_bos(
model: HookedTransformer,
):

eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=training_sae,
activation_store=activation_store,
model=model,
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_training_eval_config(
"relative_reconstruction_bias",
]
eval_config = TRAINER_EVAL_CONFIG
eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=base_sae,
activation_store=activation_store,
model=model,
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_training_eval_config_ignore_control_tokens(
"relative_reconstruction_bias",
]
eval_config = TRAINER_EVAL_CONFIG
eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=base_sae,
activation_store=activation_store,
model=model,
Expand Down

0 comments on commit 9341398

Please sign in to comment.