Skip to content

Commit

Permalink
Organized basic eval metrics and eliminated NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
Curt Tigges authored and Curt Tigges committed Oct 18, 2024
1 parent 97622b5 commit f6be1a6
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 48 deletions.
92 changes: 74 additions & 18 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import json
import math
import re
import subprocess
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any, Mapping
from typing import Any, Dict, List, Mapping, Union

import einops
import pandas as pd
Expand Down Expand Up @@ -120,11 +121,18 @@ def run_evals(
else:
previous_hook_z_reshaping_mode = None

all_metrics = {}
all_metrics = {
"model_behavior_preservation": {},
"model_performance_preservation": {},
"reconstruction_quality": {},
"shrinkage": {},
"sparsity": {},
"token_stats": {},
}

if eval_config.compute_kl or eval_config.compute_ce_loss:
assert eval_config.n_eval_reconstruction_batches > 0
all_metrics |= get_downstream_reconstruction_metrics(
reconstruction_metrics = get_downstream_reconstruction_metrics(
sae,
model,
activation_store,
Expand All @@ -135,6 +143,21 @@ def run_evals(
ignore_tokens=ignore_tokens,
verbose=verbose,
)

if eval_config.compute_kl:
all_metrics["model_behavior_preservation"].update({
"kl_div_score": reconstruction_metrics["kl_div_score"],
"kl_div_with_ablation": reconstruction_metrics["kl_div_with_ablation"],
"kl_div_with_sae": reconstruction_metrics["kl_div_with_sae"],
})

if eval_config.compute_ce_loss:
all_metrics["model_performance_preservation"].update({
"ce_loss_score": reconstruction_metrics["ce_loss_score"],
"ce_loss_with_ablation": reconstruction_metrics["ce_loss_with_ablation"],
"ce_loss_with_sae": reconstruction_metrics["ce_loss_with_sae"],
"ce_loss_without_sae": reconstruction_metrics["ce_loss_without_sae"],
})

activation_store.reset_input_dataset()

Expand All @@ -144,7 +167,7 @@ def run_evals(
or eval_config.compute_variance_metrics
):
assert eval_config.n_eval_sparsity_variance_batches > 0
scalar_metrics, feature_metrics = get_sparsity_and_variance_metrics(
sparsity_variance_metrics, feature_metrics = get_sparsity_and_variance_metrics(
sae,
model,
activation_store,
Expand All @@ -158,14 +181,32 @@ def run_evals(
ignore_tokens=ignore_tokens,
verbose=verbose,
)
all_metrics |= scalar_metrics

if eval_config.compute_l2_norms:
all_metrics["shrinkage"].update({
"l2_norm_in": sparsity_variance_metrics["l2_norm_in"],
"l2_norm_out": sparsity_variance_metrics["l2_norm_out"],
"l2_ratio": sparsity_variance_metrics["l2_ratio"],
"relative_reconstruction_bias": sparsity_variance_metrics["relative_reconstruction_bias"],
})

if eval_config.compute_sparsity_metrics:
all_metrics["sparsity"].update({
"l0": sparsity_variance_metrics["l0"],
"l1": sparsity_variance_metrics["l1"],
})

if eval_config.compute_variance_metrics:
all_metrics["reconstruction_quality"].update({
"explained_variance": sparsity_variance_metrics["explained_variance"],
"mse": sparsity_variance_metrics["mse"],
"cossim": sparsity_variance_metrics["cossim"],
})
else:
feature_metrics = {}

if eval_config.compute_featurewise_weight_based_metrics:
feature_metrics |= get_featurewise_weight_based_metrics(
sae,
)
feature_metrics |= get_featurewise_weight_based_metrics(sae)

if len(all_metrics) == 0:
raise ValueError(
Expand All @@ -191,12 +232,13 @@ def run_evals(
* actual_batch_size
)

all_metrics["total_tokens_eval_reconstruction"] = (
total_tokens_evaluated_eval_reconstruction
)
all_metrics["total_tokens_eval_sparsity_variance"] = (
total_tokens_evaluated_eval_sparsity_variance
)
all_metrics["token_stats"] = {
"total_tokens_eval_reconstruction": total_tokens_evaluated_eval_reconstruction,
"total_tokens_eval_sparsity_variance": total_tokens_evaluated_eval_sparsity_variance,
}

# Remove empty metric groups
all_metrics = {k: v for k, v in all_metrics.items() if v}

return all_metrics, feature_metrics

Expand Down Expand Up @@ -781,12 +823,26 @@ def run_evaluations(args: argparse.Namespace) -> list[defaultdict[Any, Any]]:
return eval_results


def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str):
def replace_nans_with_negative_one(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: replace_nans_with_negative_one(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [replace_nans_with_negative_one(item) for item in obj]
elif isinstance(obj, float) and math.isnan(obj):
return -1
else:
return obj


def process_results(eval_results: List[Dict[str, Any]], output_dir: str) -> Dict[str, Union[List[Path], Path]]:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)

# Replace NaNs with -1 in each result
cleaned_results = [replace_nans_with_negative_one(result) for result in eval_results]

# Save individual JSON files
for result in eval_results:
for result in cleaned_results:
json_filename = f"{result['unique_id']}_{result['eval_cfg']['context_size']}_{result['eval_cfg']['dataset']}.json".replace(
"/", "_"
)
Expand All @@ -796,10 +852,10 @@ def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str):

# Save all results in a single JSON file
with open(output_path / "all_eval_results.json", "w") as f:
json.dump(eval_results, f, indent=2)
json.dump(cleaned_results, f, indent=2)

# Convert to DataFrame and save as CSV
df = pd.json_normalize(eval_results) # type: ignore
df = pd.json_normalize(cleaned_results)
df.to_csv(output_path / "all_eval_results.csv", index=False)

return {
Expand Down
57 changes: 27 additions & 30 deletions tutorials/evaluating_saes_with_sae_lens_evals.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,6 @@
"- `n_eval_reconstruction_batches 20` Number of prompt batches to use for reconstruction evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -214,36 +207,42 @@
"source": [
"### Understanding SAE Metrics\n",
"\n",
"The SAE metrics we compute can be grouped into five distinct categories. Let's explore each category and its associated metrics in detail. We'll then compare our three SAEs on each metric.\n",
"The SAE metrics we compute can be grouped into six distinct categories. Let's explore each category and its associated metrics in detail. We'll then compare our three SAEs on each metric.\n",
"\n",
"\n",
"#### Model Performance Preservation\n",
"\n",
"These metrics indicate how much the SAE affects the underlying model's performance and output:\n",
"\n",
"- **ce_loss_with_sae**: Cross-entropy loss of the model's output after applying the SAE.\n",
"- **ce_loss_without_sae**: Baseline cross-entropy loss of the original model without the SAE.\n",
"- **ce_loss_with_ablation**: Cross-entropy loss when the relevant activations are set to zero.\n",
"- **ce_loss_score**: A derived metric comparing the cross-entropy losses with and without the SAE.\n",
"\n",
"#### Model Behavior Preservation\n",
"\n",
"These metrics indicate differences between the distributions of logit predictions with and without the SAE.\n",
"\n",
"- **kl_div_score**: A derived metric comparing the KL divergence with SAE to the KL divergence with ablation.\n",
"- **kl_div_with_sae**: Kullback-Leibler divergence between the original model's output distribution and the distribution after applying the SAE. Lower values indicate better preservation of the model's behavior.\n",
"- **kl_div_with_ablation**: KL divergence between the original model's output and the output when the relevant activations are set to zero. This serves as a baseline for comparison.\n",
"\n",
"#### Reconstruction Quality\n",
"\n",
"These metrics assess how well the SAE can reconstruct the original input:\n",
"These metrics assess how well the SAE can reconstruct the original input activations at the target layer:\n",
"\n",
"- **l2_ratio**: The ratio of the L2 norm of the SAE output to the L2 norm of the SAE input. A value close to 1 indicates good preservation of the input's magnitude.\n",
"- **relative_reconstruction_bias**: Measures the bias in reconstruction. Values closer to 1 indicate less bias.\n",
"- **explained_variance**: The proportion of variance in the input that is explained by the SAE's reconstruction. Higher values indicate better reconstruction quality.\n",
"- **mse**: Mean Squared Error between the input and the reconstruction. Lower values indicate better reconstruction accuracy.\n",
"- **explained_variance**: The proportion of variance in the input that is explained by the SAE's reconstruction. Higher values indicate better reconstruction quality.\n",
"- **cossim**: Cosine similarity between the input and the reconstruction. Values closer to 1 indicate better preservation of the input's direction.\n",
"\n",
"#### Magnitude Preservation\n",
"#### Shrinkage\n",
"\n",
"These metrics show how well the SAE preserves the overall magnitude of the input:\n",
"\n",
"- **l2_norm_in**: The L2 norm (Euclidean norm) of the input activations.\n",
"- **l2_norm_out**: The L2 norm of the output activations after passing through the SAE.\n",
"\n",
"#### Model Behavior Preservation\n",
"\n",
"These metrics indicate how much the SAE affects the underlying model's performance and output distributions:\n",
"\n",
"- **kl_div_with_sae**: Kullback-Leibler divergence between the original model's output distribution and the distribution after applying the SAE. Lower values indicate better preservation of the model's behavior.\n",
"- **kl_div_with_ablation**: KL divergence between the original model's output and the output when the relevant activations are set to zero. This serves as a baseline for comparison.\n",
"- **kl_div_score**: A derived metric comparing the KL divergence with SAE to the KL divergence with ablation.\n",
"- **ce_loss_with_sae**: Cross-entropy loss of the model's output after applying the SAE.\n",
"- **ce_loss_without_sae**: Baseline cross-entropy loss of the original model without the SAE.\n",
"- **ce_loss_with_ablation**: Cross-entropy loss when the relevant activations are set to zero.\n",
"- **ce_loss_score**: A derived metric comparing the cross-entropy losses with and without the SAE.\n",
"- **l2_ratio**: The ratio of the L2 norm of the SAE output to the L2 norm of the SAE input. A value close to 1 indicates good preservation of the input's magnitude.\n",
"- **relative_reconstruction_bias**: Measures the bias in reconstruction. Values closer to 1 indicate less bias.\n",
"\n",
"#### Sparsity\n",
"\n",
Expand All @@ -252,7 +251,7 @@
"- **l0**: The L0 \"norm\" of the SAE's activations, which is the number of non-zero elements. It measures how many features are active.\n",
"- **l1**: The L1 norm of the SAE's activations, which is the sum of the absolute values. It's another measure of sparsity.\n",
"\n",
"#### Evaluation Scale\n",
"#### Token Statistics\n",
"\n",
"These metrics provide context about the amount of data used in the evaluation:\n",
"\n",
Expand Down Expand Up @@ -574,9 +573,7 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"S"
]
"source": []
}
],
"metadata": {
Expand All @@ -595,7 +592,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down

0 comments on commit f6be1a6

Please sign in to comment.