Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 28, 2024
1 parent 8aadcd3 commit 67dfb46
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
5 changes: 4 additions & 1 deletion sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
device=device,
requires_grad=False,
)
if not self.cfg.prepend_bos and tokens[0] == self.model.tokenizer.bos_token_id:
if (
not self.cfg.prepend_bos
and tokens[0] == self.model.tokenizer.bos_token_id
):
tokens = tokens[1:]
return tokens
6 changes: 4 additions & 2 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,17 @@ def __post_init__(self):

# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = total_training_steps // self.feature_sampling_window
n_feature_window_samples = (
total_training_steps // self.feature_sampling_window
)
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(
Expand Down
6 changes: 3 additions & 3 deletions scripts/run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
" dataset_path=\"NeelNanda/c4-tokenized-2b\",\n",
" is_dataset_tokenized=True,\n",
" # SAE Parameters\n",
" expansion_factor=[16,32,64],\n",
" expansion_factor=[16, 32, 64],\n",
" b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n",
" # Training Parameters\n",
" lr=0.0012,\n",
Expand Down Expand Up @@ -368,7 +368,7 @@
" n_batches_in_buffer=128,\n",
" total_training_tokens=1_000_000 * 20,\n",
" store_batch_size=32,\n",
" feature_sampling_window=500, # So we see the histograms. \n",
" feature_sampling_window=500, # So we see the histograms.\n",
" dead_feature_window=250,\n",
" # WANDB\n",
" log_to_wandb=True,\n",
Expand Down Expand Up @@ -697,7 +697,7 @@
" n_batches_in_buffer=128,\n",
" total_training_tokens=1_000_000 * 20,\n",
" store_batch_size=32,\n",
" feature_sampling_window=500, # So we see the histograms. \n",
" feature_sampling_window=500, # So we see the histograms.\n",
" dead_feature_window=250,\n",
" # WANDB\n",
" log_to_wandb=True,\n",
Expand Down

0 comments on commit 67dfb46

Please sign in to comment.