diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 1878fc5d..34fdbd8b 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -156,6 +156,8 @@ def get_batch_tokens(self): def get_activations(self, batch_tokens: torch.Tensor): """ Returns activations of shape (batches, context, num_layers, d_in) + + d_in may result from a concatenated head dimension. """ layers = ( self.cfg.hook_point_layer @@ -174,6 +176,11 @@ def get_activations(self, batch_tokens: torch.Tensor): activations_list = [ act[:, :, self.cfg.hook_point_head_index] for act in activations_list ] + elif activations_list[0].ndim > 3: # if we have a head dimension + # flatten the head dimension + activations_list = [ + act.view(act.shape[0], act.shape[1], -1) for act in activations_list + ] # Stack along a new dimension to keep separate layers distinct stacked_activations = torch.stack(activations_list, dim=2) diff --git a/sae_training/evals.py b/sae_training/evals.py index 770ac689..e5fb278c 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -48,19 +48,19 @@ def run_evals( ) # get act - if sparse_autoencoder.cfg.hook_point_head_index is not None: - original_act = cache[sparse_autoencoder.cfg.hook_point][ - :, :, sparse_autoencoder.cfg.hook_point_head_index - ] + if hook_point_head_index is not None: + original_act = cache[hook_point][:, :, hook_point_head_index] + elif "attn" in hook_point: + original_act = cache[hook_point].flatten(-2, -1) else: - original_act = cache[sparse_autoencoder.cfg.hook_point] - - sae_out, _feature_acts, _, _, _, _ = sparse_autoencoder(original_act) - patterns_original = ( - cache[get_act_name("pattern", hook_point_layer)][:, hook_point_head_index] - .detach() - .cpu() - ) + original_act = cache[hook_point] + + sae_out, _, _, _, _, _ = sparse_autoencoder(original_act) + # patterns_original = ( + # cache[get_act_name("pattern", hook_point_layer)][:, hook_point_head_index] + # .detach() + # .cpu() + # ) del cache if "cuda" in str(model.cfg.device): @@ -84,71 +84,83 @@ def run_evals( step=n_training_steps, ) - head_index = sparse_autoencoder.cfg.hook_point_head_index - - def standard_replacement_hook(activations: torch.Tensor, hook: Any): - activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) - return activations - - def head_replacement_hook(activations: torch.Tensor, hook: Any): - new_actions = sparse_autoencoder.forward(activations[:, :, head_index])[0].to( - activations.dtype - ) - activations[:, :, head_index] = new_actions - return activations - - head_index = sparse_autoencoder.cfg.hook_point_head_index - replacement_hook = ( - standard_replacement_hook if head_index is None else head_replacement_hook - ) + # head_index = sparse_autoencoder.cfg.hook_point_head_index + + # def standard_replacement_hook(activations: torch.Tensor, hook: Any): + # activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) + # return activations + + # def all_head_replacement_hook(activations: torch.Tensor, hook: Any): + # new_activations = sparse_autoencoder.forward(activations)[0].to( + # activations.dtype + # ) + # activations = new_activations.reshape( + # activations.shape + # ) # reshape to match original shape + # return activations + + # def single_head_replacement_hook(activations: torch.Tensor, hook: Any): + # new_activations = sparse_autoencoder.forward(activations[:, :, head_index])[ + # 0 + # ].to(activations.dtype) + # activations[:, :, head_index] = new_activations + # return activations + + # if "attn" in hook_point: + # if hook_point_head_index is None: + # replacement_hook = all_head_replacement_hook + # else: + # replacement_hook = single_head_replacement_hook + # else: + # replacement_hook = standard_replacement_hook # get attn when using reconstructed activations - with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]): - _, new_cache = model.run_with_cache( - eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)] - ) - patterns_reconstructed = ( - new_cache[get_act_name("pattern", hook_point_layer)][ - :, hook_point_head_index - ] - .detach() - .cpu() - ) - del new_cache - - # get attn when using reconstructed activations - with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]): - _, zero_ablation_cache = model.run_with_cache( - eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)] - ) - patterns_ablation = ( - zero_ablation_cache[get_act_name("pattern", hook_point_layer)][ - :, hook_point_head_index - ] - .detach() - .cpu() - ) - del zero_ablation_cache - - if sparse_autoencoder.cfg.hook_point_head_index: - kl_result_reconstructed = kl_divergence_attention( - patterns_original, patterns_reconstructed - ) - kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy() - - kl_result_ablation = kl_divergence_attention( - patterns_original, patterns_ablation - ) - kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy() - - if wandb.run is not None: - wandb.log( - { - f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(), - f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(), - }, - step=n_training_steps, - ) + # with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]): + # _, new_cache = model.run_with_cache( + # eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)] + # ) + # patterns_reconstructed = ( + # new_cache[get_act_name("pattern", hook_point_layer)][ + # :, hook_point_head_index + # ] + # .detach() + # .cpu() + # ) + # del new_cache + + # # get attn when using reconstructed activations + # with model.hooks(fwd_hooks=[(hook_point, partial(zero_ablate_hook))]): + # _, zero_ablation_cache = model.run_with_cache( + # eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)] + # ) + # patterns_ablation = ( + # zero_ablation_cache[get_act_name("pattern", hook_point_layer)][ + # :, hook_point_head_index + # ] + # .detach() + # .cpu() + # ) + # del zero_ablation_cache + + # if sparse_autoencoder.cfg.hook_point_head_index: + # kl_result_reconstructed = kl_divergence_attention( + # patterns_original, patterns_reconstructed + # ) + # kl_result_reconstructed = kl_result_reconstructed.sum(dim=-1).numpy() + + # kl_result_ablation = kl_divergence_attention( + # patterns_original, patterns_ablation + # ) + # kl_result_ablation = kl_result_ablation.sum(dim=-1).numpy() + + # if wandb.run is not None: + # wandb.log( + # { + # f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(), + # f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(), + # }, + # step=n_training_steps, + # ) def recons_loss_batched( @@ -187,22 +199,36 @@ def get_recons_loss( ): hook_point = sparse_autoencoder.cfg.hook_point loss = model(batch_tokens, return_type="loss") - head_index = sparse_autoencoder.cfg.hook_point_head_index + hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index def standard_replacement_hook(activations: torch.Tensor, hook: Any): activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype) return activations - def head_replacement_hook(activations: torch.Tensor, hook: Any): - new_activations = sparse_autoencoder.forward(activations[:, :, head_index])[ - 0 - ].to(activations.dtype) - activations[:, :, head_index] = new_activations + def all_head_replacement_hook(activations: torch.Tensor, hook: Any): + new_activations = sparse_autoencoder.forward(activations.flatten(-2, -1))[0].to( + activations.dtype + ) + new_activations = new_activations.reshape( + activations.shape + ) # reshape to match original shape + return new_activations + + def single_head_replacement_hook(activations: torch.Tensor, hook: Any): + new_activations = sparse_autoencoder.forward( + activations[:, :, hook_point_head_index] + )[0].to(activations.dtype) + activations[:, :, hook_point_head_index] = new_activations return activations - replacement_hook = ( - standard_replacement_hook if head_index is None else head_replacement_hook - ) + if "attn" in hook_point: + if hook_point_head_index is None: + replacement_hook = all_head_replacement_hook + else: + replacement_hook = single_head_replacement_hook + else: + replacement_hook = standard_replacement_hook + recons_loss = model.run_with_hooks( batch_tokens, return_type="loss", diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 8892a8ae..316a0dc2 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -16,9 +16,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: mps\n" + ] + } + ], "source": [ "import torch\n", "import os\n", @@ -218,9 +226,7 @@ " total_training_tokens=1_000_000 * 800,\n", " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method=\"anthropic\",\n", " feature_sampling_window=2000, # Doesn't currently matter.\n", - " feature_reinit_scale=0.2,\n", " dead_feature_window=40000,\n", " dead_feature_threshold=1e-8,\n", " # WANDB\n", @@ -315,86 +321,142 @@ "# Tiny Stories" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MLP Out" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_training.config import LanguageModelSAERunnerConfig\n", + "from sae_training.lm_runner import language_model_sae_runner\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "if device == \"cpu\" and torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distibuion)\n", + " model_name=\"tiny-stories-1M\",\n", + " hook_point=\"blocks.1.mlp.hook_post\",\n", + " hook_point_layer=1,\n", + " d_in=256,\n", + " # dataset_path=\"roneneldan/TinyStories\",\n", + " # is_dataset_tokenized=False,\n", + " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " # SAE Parameters\n", + " expansion_factor=16,\n", + " # Training Parameters\n", + " lr=1e-4,\n", + " lp_norm=1.0,\n", + " l1_coefficient=2e-4,\n", + " train_batch_size=4096,\n", + " context_size=128,\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=128,\n", + " total_training_tokens=1_000_000 * 20,\n", + " store_batch_size=32,\n", + " feature_sampling_window=500, # So we see the histograms. \n", + " dead_feature_window=250,\n", + " # WANDB\n", + " log_to_wandb=True,\n", + " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " wandb_log_frequency=10,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "sparse_autoencoder = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hook Z\n", + "\n" + ] + }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Run name: 4096-L1-[0.0002, 0.0003, 0.0006]-LR-0.0001-Tokens-1.000e+07\n", + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", "n_tokens_per_buffer (millions): 0.524288\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", "n_tokens_per_feature_sampling_window (millions): 262.144\n", "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", + "We will reset the sparsity calculation 9 times.\n", "Number tokens in sparsity calculation window: 2.05e+06\n", "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", - "Moving model to device: mps\n", - "Run name: 4096-L1-0.0002-LR-0.0001-Tokens-1.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 4096-L1-0.0002-LR-0.0001-Tokens-1.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 4096-L1-0.0003-LR-0.0001-Tokens-1.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 4096-L1-0.0003-LR-0.0001-Tokens-1.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", - "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 4096-L1-0.0006-LR-0.0001-Tokens-1.000e+07\n", + "Moving model to device: mps\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", "n_tokens_per_buffer (millions): 0.524288\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", "n_tokens_per_feature_sampling_window (millions): 262.144\n", "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", + "We will reset the sparsity calculation 9 times.\n", "Number tokens in sparsity calculation window: 2.05e+06\n", - "Run name: 4096-L1-0.0006-LR-0.0001-Tokens-1.000e+07\n", + "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", "n_tokens_per_buffer (millions): 0.524288\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 2441\n", - "Total wandb updates: 244\n", + "Total training steps: 4882\n", + "Total wandb updates: 488\n", "n_tokens_per_feature_sampling_window (millions): 262.144\n", "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 4 times.\n", + "We will reset the sparsity calculation 9 times.\n", "Number tokens in sparsity calculation window: 2.05e+06\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjbloom\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, { "data": { "text/html": [ - "wandb version 0.16.4 is available! To upgrade, please run:\n", + "wandb version 0.16.5 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ @@ -419,7 +481,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240325_154839-voo84o5b" + "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" ], "text/plain": [ "" @@ -431,7 +493,7 @@ { "data": { "text/html": [ - "Syncing run 4096-L1-[0.0002, 0.0003, 0.0006]-LR-0.0001-Tokens-1.000e+07 to Weights & Biases (docs)
" + "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -455,7 +517,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/voo84o5b" + " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" ], "text/plain": [ "" @@ -468,106 +530,76 @@ "name": "stderr", "output_type": "stream", "text": [ - "Objective value: 339329.6875: 4%|▍ | 4/100 [00:00<00:04, 22.66it/s]\n", + "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", - "/Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (qy1ho0vw) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", - " lambda data: self._console_raw_callback(\"stderr\", data),\n", - "2442| MSE Loss 0.000 | L1 0.001: : 10002432it [12:04, 13803.94it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.12it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.96it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.22it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.21it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.12it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.12it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.20it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.23it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.71it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.96it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.93it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.23it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.22it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.96it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.13it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.14it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]\n", - "'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 7e8a4e24-d581-41e5-a4c6-efe32048b72f)')' thrown while requesting GET https://huggingface.co/datasets/apollo-research/roneneldan-TinyStories-tokenizer-gpt2/resolve/bc8db71bbc792977b43d430bddeeb9906e193f8d/data/train-00000-of-00004.parquet\n", - "Retrying in 1s [Retry 1/5].\n", - "'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: d09e549c-a5e6-45c9-aed2-3cc66d288c0b)')' thrown while requesting GET https://huggingface.co/datasets/apollo-research/roneneldan-TinyStories-tokenizer-gpt2/resolve/bc8db71bbc792977b43d430bddeeb9906e193f8d/data/train-00000-of-00004.parquet\n", - "Retrying in 2s [Retry 2/5].\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.79it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.84it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.16it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.20it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.94it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.18it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.18it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.15it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.74it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.69it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.75it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.82it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.12it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.14it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.99it/s]\n", - "100%|██████████| 10/10 [00:05<00:00, 1.90it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.88it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.71it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.87it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.70it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.96it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.83it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.91it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.72it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.11it/s]\n" + "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", + "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", + "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", + "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", + "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", + "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Saved model to checkpoints/dchpw62o/final_sae_group_tiny-stories-1M_blocks.1.mlp.hook_post_4096.pt\n" + "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2f1e8cf1010b46f9b0fb5d3cd23e5bd3", + "model_id": "1dffd84a387d4cf48100fbe143287481", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Label(value='0.038 MB of 24.149 MB uploaded\\r'), FloatProgress(value=0.0015833850643276347, max…" + "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" ] }, "metadata": {}, @@ -588,7 +620,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


details/current_learning_rate_coeff0.0002▁▂▃▄▅▅▆▇████████████████████████████████
details/current_learning_rate_coeff0.0003▁▂▃▄▅▅▆▇████████████████████████████████
details/current_learning_rate_coeff0.0006▁▂▃▄▅▅▆▇████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss_coeff0.0002▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss_coeff0.0003▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/ghost_grad_loss_coeff0.0006▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss_coeff0.0002██▇▆▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss_coeff0.0003██▇▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss_coeff0.0006█▇▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss_coeff0.0002█▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss_coeff0.0003█▅▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss_coeff0.0006█▆▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss_coeff0.0002█▆▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss_coeff0.0003█▇▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss_coeff0.0006█▇▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score_coeff0.0002▁▄▅▅▆▆▇▇▇▇▇█████████████
metrics/CE_loss_score_coeff0.0003▁▄▄▄▅▅▆▆▇▇▇▇▇▇▇█████████
metrics/CE_loss_score_coeff0.0006▁▄▂▂▃▄▄▅▅▅▆▆▆▇▇▇▇▇██████
metrics/ce_loss_with_ablation_coeff0.0002▂▃▂▅▃▃▄▇▆▇█▄▅▃▄▇▆▂▁▄▄▄█▂
metrics/ce_loss_with_ablation_coeff0.0003▆▄▁▅█▅▅█▄▄▆▅▇█▇▆▄▇▅▅▅▇▆▅
metrics/ce_loss_with_ablation_coeff0.0006▄▆▅▆▃▄▅▆▇▆▆█▄▅▅▅▆▆▆▅▅▅▁▄
metrics/ce_loss_with_sae_coeff0.0002█▅▄▄▃▃▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁
metrics/ce_loss_with_sae_coeff0.0003█▅▅▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
metrics/ce_loss_with_sae_coeff0.0006█▆▇▇▆▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁
metrics/ce_loss_without_sae_coeff0.0002▃▅▄▃▃▃▅▅▃▅█▂▄▁█▇▄▄▁▂▃▄▅▇
metrics/ce_loss_without_sae_coeff0.0003▃▄▆▅▅▅▆▆▄▁▄▄▅▅█▆▇▃▅▅▆▇▅▆
metrics/ce_loss_without_sae_coeff0.0006▃▃▄▃▁▂▃▃▄█▃▂▅▅▅▃▆▅▃▄▅▄▂▂
metrics/explained_variance_coeff0.0002▁▄▆▇▇▇▇█████████████████████████████████
metrics/explained_variance_coeff0.0003▁▄▆▇▇▇▇▇▇▇██████████████████████████████
metrics/explained_variance_coeff0.0006▁▃▆▆▆▆▆▆▇▇▇▇▇▇▇█████████████████████████
metrics/explained_variance_std_coeff0.0002▇██▆▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance_std_coeff0.0003▆██▆▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/explained_variance_std_coeff0.0006▅▇█▆▄▄▅▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
metrics/l0_coeff0.0002███▇▆▅▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0_coeff0.0003███▇▆▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0_coeff0.0006██▇▇▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm_coeff0.0002▂▂▁▁▂▄▄▅▅▆▆▆▇▇▇▇▇▇█▇████
metrics/l2_norm_coeff0.0003▄▃▁▁▃▄▄▅▆▆▆▆▇▇▇▇▇▇▇█████
metrics/l2_norm_coeff0.0006▆▄▁▁▂▃▄▅▅▆▆▆▇▇▇▇▇▇██████
metrics/l2_ratio_coeff0.0002▂▂▁▁▂▄▄▅▅▆▆▆▇▇▇▇▇▇██████
metrics/l2_ratio_coeff0.0003▄▃▁▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇██████
metrics/l2_ratio_coeff0.0006▆▄▁▁▂▃▄▅▅▆▆▆▇▇▇▇▇███████
metrics/mean_log10_feature_sparsity_coeff0.0002█▄▂▁
metrics/mean_log10_feature_sparsity_coeff0.0003█▄▂▁
metrics/mean_log10_feature_sparsity_coeff0.0006█▄▂▁
sparsity/below_1e-5_coeff0.0002▁▁▁█
sparsity/below_1e-5_coeff0.0003▁▁▂█
sparsity/below_1e-5_coeff0.0006▁▁▁█
sparsity/below_1e-6_coeff0.0002▁▁▁▁
sparsity/below_1e-6_coeff0.0003▁▁▁▁
sparsity/below_1e-6_coeff0.0006▁▁▁▁
sparsity/dead_features_coeff0.0002▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/dead_features_coeff0.0003▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁██▅▁▁▁▁▅
sparsity/dead_features_coeff0.0006▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▄▅█▆▆▇
sparsity/mean_passes_since_fired_coeff0.0002▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▅▅▄▇▆▅▇▇██
sparsity/mean_passes_since_fired_coeff0.0003▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▅▆▆▇▆▆▇██
sparsity/mean_passes_since_fired_coeff0.0006▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▅▅▆▆▆▇██

Run summary:


details/current_learning_rate_coeff0.00020.0001
details/current_learning_rate_coeff0.00030.0001
details/current_learning_rate_coeff0.00060.0001
details/n_training_tokens9994240
losses/ghost_grad_loss_coeff0.00020.0
losses/ghost_grad_loss_coeff0.00030.0
losses/ghost_grad_loss_coeff0.00060.0
losses/l1_loss_coeff0.00023.09463
losses/l1_loss_coeff0.00032.58137
losses/l1_loss_coeff0.00061.83998
losses/mse_loss_coeff0.00020.00019
losses/mse_loss_coeff0.00030.00026
losses/mse_loss_coeff0.00060.00053
losses/overall_loss_coeff0.00020.00081
losses/overall_loss_coeff0.00030.00104
losses/overall_loss_coeff0.00060.00163
metrics/CE_loss_score_coeff0.00020.92625
metrics/CE_loss_score_coeff0.00030.88257
metrics/CE_loss_score_coeff0.00060.74837
metrics/ce_loss_with_ablation_coeff0.00027.74356
metrics/ce_loss_with_ablation_coeff0.00037.77858
metrics/ce_loss_with_ablation_coeff0.00067.76302
metrics/ce_loss_with_sae_coeff0.00023.10251
metrics/ce_loss_with_sae_coeff0.00033.29611
metrics/ce_loss_with_sae_coeff0.00063.90444
metrics/ce_loss_without_sae_coeff0.00022.73316
metrics/ce_loss_without_sae_coeff0.00032.70042
metrics/ce_loss_without_sae_coeff0.00062.60843
metrics/explained_variance_coeff0.00020.96407
metrics/explained_variance_coeff0.00030.9494
metrics/explained_variance_coeff0.00060.89856
metrics/explained_variance_std_coeff0.00020.03024
metrics/explained_variance_std_coeff0.00030.04194
metrics/explained_variance_std_coeff0.00060.08105
metrics/l0_coeff0.0002119.72095
metrics/l0_coeff0.000377.5647
metrics/l0_coeff0.000633.18384
metrics/l2_norm_coeff0.00021.39449
metrics/l2_norm_coeff0.00031.36607
metrics/l2_norm_coeff0.00061.28269
metrics/l2_ratio_coeff0.00020.93444
metrics/l2_ratio_coeff0.00030.91607
metrics/l2_ratio_coeff0.00060.86204
metrics/mean_log10_feature_sparsity_coeff0.0002-1.81471
metrics/mean_log10_feature_sparsity_coeff0.0003-2.2457
metrics/mean_log10_feature_sparsity_coeff0.0006-3.13876
sparsity/below_1e-5_coeff0.00021
sparsity/below_1e-5_coeff0.00036
sparsity/below_1e-5_coeff0.000627
sparsity/below_1e-6_coeff0.00020
sparsity/below_1e-6_coeff0.00030
sparsity/below_1e-6_coeff0.00060
sparsity/dead_features_coeff0.00020
sparsity/dead_features_coeff0.00031
sparsity/dead_features_coeff0.00067
sparsity/mean_passes_since_fired_coeff0.00020.23755
sparsity/mean_passes_since_fired_coeff0.00031.10229
sparsity/mean_passes_since_fired_coeff0.00065.02368

" + "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" ], "text/plain": [ "" @@ -600,7 +632,7 @@ { "data": { "text/html": [ - " View run 4096-L1-[0.0002, 0.0003, 0.0006]-LR-0.0001-Tokens-1.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/voo84o5b
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)" + " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" @@ -612,7 +644,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240325_154839-voo84o5b/logs" + "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" ], "text/plain": [ "" @@ -620,6 +652,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n" + ] } ], "source": [ @@ -637,9 +677,9 @@ "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.mlp.hook_post\",\n", + " hook_point=\"blocks.1.attn.hook_z\",\n", " hook_point_layer=1,\n", - " d_in=256,\n", + " d_in=64,\n", " # dataset_path=\"roneneldan/TinyStories\",\n", " # is_dataset_tokenized=False,\n", " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", @@ -649,12 +689,13 @@ " expansion_factor=16,\n", " # Training Parameters\n", " lr=1e-4,\n", - " l1_coefficient=[2e-4,3e-4,6e-4],\n", + " lp_norm=1.0,\n", + " l1_coefficient=2e-4,\n", " train_batch_size=4096,\n", " context_size=128,\n", " # Activation Store Parameters\n", " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 100,\n", + " total_training_tokens=1_000_000 * 20,\n", " store_batch_size=32,\n", " feature_sampling_window=500, # So we see the histograms. \n", " dead_feature_window=250,\n", diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py index f3ffd68c..044fb88a 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/test_activations_store.py @@ -28,6 +28,14 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: "hook_point_layer": 1, "d_in": 64, }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "tokenized": False, + "hook_point": "blocks.1.attn.hook_z", + "hook_point_layer": 1, + "d_in": 64, + }, { "model_name": "gelu-2l", "dataset_path": "NeelNanda/c4-tokenized-2b", @@ -53,7 +61,13 @@ def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: "d_in": 768, }, ], - ids=["tiny-stories-1M", "gelu-2l-tokenized", "gpt2-tokenized", "gpt2"], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-attn-out", + "gelu-2l-tokenized", + "gpt2-tokenized", + "gpt2", + ], ) def cfg(request: pytest.FixtureRequest) -> SimpleNamespace: # This function will be called with each parameter set diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py index 3b1ab860..9ca02b4f 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/test_sparse_autoencoder.py @@ -10,25 +10,56 @@ from sae_training.activations_store import ActivationsStore from sae_training.sparse_autoencoder import SparseAutoencoder -TEST_MODEL = "tiny-stories-1M" -TEST_DATASET = "roneneldan/TinyStories" - -@pytest.fixture -def cfg(): +# Define a new fixture for different configurations +@pytest.fixture( + params=[ + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "tokenized": False, + "hook_point": "blocks.1.hook_resid_pre", + "hook_point_layer": 1, + "d_in": 64, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "tokenized": False, + "hook_point": "blocks.1.hook_resid_pre", + "hook_point_layer": 1, + "d_in": 64, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "tokenized": False, + "hook_point": "blocks.1.attn.hook_z", + "hook_point_layer": 1, + "d_in": 64, + }, + ], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-resid-pre-pretokenized", + "tiny-stories-1M-attn-out", + ], +) +def cfg(request: pytest.FixtureRequest) -> SimpleNamespace: """ Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. """ + params = request.param # Create a mock object with the necessary attributes mock_config = SimpleNamespace() - mock_config.model_name = TEST_MODEL - mock_config.hook_point = "blocks.0.hook_mlp_out" - mock_config.hook_point_layer = 0 + mock_config.model_name = params["model_name"] + mock_config.dataset_path = params["dataset_path"] + mock_config.is_dataset_tokenized = params["tokenized"] + mock_config.hook_point = params["hook_point"] + mock_config.hook_point_layer = params["hook_point_layer"] + mock_config.d_in = params["d_in"] mock_config.hook_point_head_index = None - mock_config.dataset_path = TEST_DATASET - mock_config.is_dataset_tokenized = False mock_config.use_cached_activations = False - mock_config.d_in = 64 mock_config.use_ghost_grads = False mock_config.expansion_factor = 2 mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor @@ -66,12 +97,12 @@ def sparse_autoencoder(cfg: Any): @pytest.fixture -def model(): - return HookedTransformer.from_pretrained(TEST_MODEL) +def model(cfg: SimpleNamespace): + return HookedTransformer.from_pretrained(cfg.model_name, device="cpu") @pytest.fixture -def activation_store(cfg: Any, model: HookedTransformer): +def activation_store(cfg: SimpleNamespace, model: HookedTransformer): return ActivationsStore(cfg, model)