Skip to content

Commit

Permalink
add verbose mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus authored and Curt Tigges committed Oct 18, 2024
1 parent bc17fa5 commit 15f1b59
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def run_evals(
eval_config: EvalConfig = EvalConfig(),
model_kwargs: Mapping[str, Any] = {},
ignore_tokens: set[int | None] = set(),
verbose: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:

hook_name = sae.cfg.hook_name
Expand Down Expand Up @@ -132,6 +133,7 @@ def run_evals(
n_batches=eval_config.n_eval_reconstruction_batches,
eval_batch_size_prompts=actual_batch_size,
ignore_tokens=ignore_tokens,
verbose=verbose,
)

activation_store.reset_input_dataset()
Expand All @@ -154,6 +156,7 @@ def run_evals(
eval_batch_size_prompts=actual_batch_size,
model_kwargs=model_kwargs,
ignore_tokens=ignore_tokens,
verbose=verbose,
)
all_metrics |= scalar_metrics
else:
Expand Down Expand Up @@ -207,8 +210,8 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
encoder_bias = sae.b_enc.cpu().tolist()
encoder_decoder_cosine_sim = (
torch.nn.functional.cosine_similarity(
unit_norm_decoder,
unit_norm_encoders,
unit_norm_decoder.T,
unit_norm_encoders.T,
)
.cpu()
.tolist()
Expand All @@ -230,6 +233,7 @@ def get_downstream_reconstruction_metrics(
n_batches: int,
eval_batch_size_prompts: int,
ignore_tokens: set[int | None] = set(),
verbose: bool = False,
):
metrics_dict = {}
if compute_kl:
Expand All @@ -240,7 +244,11 @@ def get_downstream_reconstruction_metrics(
metrics_dict["ce_loss_without_sae"] = []
metrics_dict["ce_loss_with_ablation"] = []

for _ in range(n_batches):
batch_iter = range(n_batches)
if verbose:
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")

for _ in batch_iter:
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
for metric_name, metric_value in get_recons_loss(
sae,
Expand Down Expand Up @@ -296,6 +304,7 @@ def get_sparsity_and_variance_metrics(
eval_batch_size_prompts: int,
model_kwargs: Mapping[str, Any],
ignore_tokens: set[int | None] = set(),
verbose: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:

hook_name = sae.cfg.hook_name
Expand Down Expand Up @@ -324,7 +333,11 @@ def get_sparsity_and_variance_metrics(
total_feature_prompts = torch.zeros(sae.cfg.d_sae, device=sae.device)
total_tokens = 0

for _ in range(n_batches):
batch_iter = range(n_batches)
if verbose:
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")

for _ in batch_iter:
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)

if len(ignore_tokens) > 0:
Expand Down Expand Up @@ -656,6 +669,7 @@ def multiple_evals(
datasets: list[str] = ["Skylion007/openwebtext", "lighteval/MATH"],
ctx_lens: list[int] = [128],
output_dir: str = "eval_results",
verbose: bool = False,
) -> list[defaultdict[Any, Any]]:

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -722,6 +736,7 @@ def multiple_evals(
current_model.tokenizer.eos_token_id, # type: ignore
current_model.tokenizer.bos_token_id, # type: ignore
},
verbose=verbose,
)
eval_metrics["metrics"] = scalar_metrics
eval_metrics["feature_metrics"] = feature_metrics
Expand Down Expand Up @@ -757,6 +772,7 @@ def run_evaluations(args: argparse.Namespace) -> list[defaultdict[Any, Any]]:
datasets=args.datasets,
ctx_lens=args.ctx_lens,
output_dir=args.output_dir,
verbose=args.verbose,
)

return eval_results
Expand Down Expand Up @@ -874,6 +890,11 @@ def process_results(eval_results: list[defaultdict[Any, Any]], output_dir: str):
default="eval_results",
help="Directory to save evaluation results",
)
arg_parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose output with tqdm loaders.",
)

args = arg_parser.parse_args()

Expand Down

0 comments on commit 15f1b59

Please sign in to comment.