Skip to content

Commit

Permalink
keep track of tokens used seperately
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus authored and Curt Tigges committed Oct 18, 2024
1 parent 87601ba commit c168c2b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class EvalConfig:
library_version: str = field(default_factory=get_library_version)
git_hash: str = field(default_factory=get_git_hash)


def get_eval_everything_config(
batch_size_prompts: int | None = None,
n_eval_reconstruction_batches: int = 10,
Expand Down Expand Up @@ -160,12 +159,24 @@ def run_evals(
elif not previous_hook_z_reshaping_mode and sae.hook_z_reshaping_mode:
sae.turn_off_forward_pass_hook_z_reshaping()

total_tokens_evaluated = (
total_tokens_evaluated_eval_reconstruction = (
activation_store.context_size
* eval_config.n_eval_reconstruction_batches
* actual_batch_size
)
metrics["total_tokens_evaluated"] = total_tokens_evaluated

total_tokens_evaluated_eval_sparsity_variance = (
activation_store.context_size
* eval_config.n_eval_sparsity_variance_batches
* actual_batch_size
)

metrics["total_tokens_eval_reconstruction"] = (
total_tokens_evaluated_eval_reconstruction
)
metrics["total_tokens_eval_sparsity_variance"] = (
total_tokens_evaluated_eval_sparsity_variance
)

return metrics

Expand Down
2 changes: 1 addition & 1 deletion tests/benchmark/test_eval_all_loadable_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_eval_all_loadable_saes(

eval_config = get_eval_everything_config(
batch_size_prompts=8,
n_eval_reconstruction_batches=10,
n_eval_reconstruction_batches=3,
n_eval_sparsity_variance_batches=10,
)

Expand Down

0 comments on commit c168c2b

Please sign in to comment.