Skip to content

Commit

Permalink
fix_evals_bad_rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Feb 14, 2024
1 parent 736c40e commit 22e415d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ def run_evals(
eval_tokens = activation_store.get_batch_tokens()

# Get Reconstruction Score
recons_score, ntp_loss, recons_loss, zero_abl_loss = get_recons_loss(
sparse_autoencoder, model, activation_store
losses_df = recons_loss_batched(
sparse_autoencoder, model, activation_store, n_batches = 10,
)

recons_score = losses_df["score"].mean()
ntp_loss = losses_df["loss"].mean()
recons_loss = losses_df["recons_loss"].mean()
zero_abl_loss = losses_df["zero_abl_loss"].mean()

# get cache
_, cache = model.run_with_cache(
eval_tokens,
Expand Down Expand Up @@ -144,8 +149,9 @@ def head_replacement_hook(activations, hook):
def recons_loss_batched(sparse_autoencoder, model, activation_store, n_batches=100):
losses = []
for _ in tqdm(range(n_batches)):
batch_tokens = activation_store.get_batch_tokens()
score, loss, recons_loss, zero_abl_loss = get_recons_loss(
sparse_autoencoder, model, activation_store
sparse_autoencoder, model, batch_tokens
)
losses.append(
(
Expand Down

0 comments on commit 22e415d

Please sign in to comment.