-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
379 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,379 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Hack Day Analysis\n", | ||
"\n", | ||
"- Try loading our models from last night\n", | ||
"- Try out tiny stories models. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get our model from last night" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", | ||
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjbloom\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Tracking run with wandb version 0.16.0" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Run data is saved locally in <code>/home/paperspace/mats_sae_training/wandb/run-20231202_095753-nh2ujgd7</code>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"Syncing run <strong><a href='https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7' target=\"_blank\">prime-snowflake-2</a></strong> to <a href='https://wandb.ai/jbloom/mats_sae_training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
" View project at <a href='https://wandb.ai/jbloom/mats_sae_training' target=\"_blank\">https://wandb.ai/jbloom/mats_sae_training</a>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/html": [ | ||
" View run at <a href='https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7' target=\"_blank\">https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7</a>" | ||
], | ||
"text/plain": [ | ||
"<IPython.core.display.HTML object>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact sae:v1, 128.13MB. 1 files... \n", | ||
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", | ||
"Done. 0:0:0.4\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import wandb\n", | ||
"run = wandb.init()\n", | ||
"artifact = run.use_artifact('jbloom/mats_sae_training_language_models/sae:v1', type='model')\n", | ||
"artifact_dir = artifact.download()\n", | ||
"wandb.finish()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"torch.Size([512, 32768])" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"weights = torch.load(artifact_dir + '/sae.pt')\n", | ||
"weights[\"W_enc\"].shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<All keys matched successfully>" | ||
] | ||
}, | ||
"execution_count": 40, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import torch \n", | ||
"from dataclasses import dataclass\n", | ||
"from sae_training.SAE import SAE\n", | ||
"\n", | ||
"@dataclass\n", | ||
"class SAEConfig():\n", | ||
" d_in = weights[\"W_enc\"].shape[0]\n", | ||
" d_sae = weights[\"W_enc\"].shape[1]\n", | ||
" device = \"cuda\"\n", | ||
" dtype = torch.float32\n", | ||
" l1_coefficient = 0.0001\n", | ||
"\n", | ||
"sae_cfg = SAEConfig()\n", | ||
"SAE = SAE(sae_cfg)\n", | ||
"SAE.load_state_dict(weights)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Loaded pretrained model gelu-2l into HookedTransformer\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from transformer_lens import HookedTransformer\n", | ||
"\n", | ||
"model = HookedTransformer.from_pretrained(\"gelu-2l\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "72e610439495413395ee5026cbc3169e", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Resolving data files: 0%| | 0/23 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# load some data\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"from sae_training.activations_buffer import DataLoaderBuffer\n", | ||
"\n", | ||
"@dataclass\n", | ||
"class DataLoaderBufferConfig():\n", | ||
" d_in: int\n", | ||
" d_sae: int\n", | ||
" device: str\n", | ||
" dtype: torch.dtype\n", | ||
" batch_size: int\n", | ||
" n_workers: int\n", | ||
" context_size: int = 128\n", | ||
" store_batch_size: int = 16\n", | ||
" n_batches_in_buffer: int = 10\n", | ||
" hook_point: str = \"blocks.0.hook_mlp_out\"\n", | ||
" \n", | ||
"\n", | ||
"\n", | ||
"cfg = DataLoaderBufferConfig(\n", | ||
" d_in=weights[\"W_enc\"].shape[0],\n", | ||
" d_sae=weights[\"W_enc\"].shape[1],\n", | ||
" device=\"cuda\",\n", | ||
" dtype=torch.float32,\n", | ||
" batch_size=128,\n", | ||
" n_workers=0,\n", | ||
")\n", | ||
"\n", | ||
"activations_buffer = DataLoaderBuffer(\n", | ||
" cfg, model, data_path=\"NeelNanda/c4-tokenized-2b\",\n", | ||
")\n", | ||
"\n", | ||
"def get_new_dataloader(data_loader_buffer, batch_size):\n", | ||
" buffer = data_loader_buffer.get_buffer()\n", | ||
" dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))\n", | ||
" n_remaining_batches_in_buffer = len(dataloader)\n", | ||
" return dataloader, n_remaining_batches_in_buffer\n", | ||
"\n", | ||
"dataloader, n_remaining_batches_in_buffer = get_new_dataloader(activations_buffer, 128)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 36, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokens = next(dataloader)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 57, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor(61.6924, device='cuda:0')" | ||
] | ||
}, | ||
"execution_count": 57, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"activations = activations_buffer.get_buffer()\n", | ||
"sae_out, feature_acts, loss, mse_loss, l1_loss = SAE(activations)\n", | ||
"(feature_acts > 0).float().sum(0).mean()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor(11.5401, device='cuda:0', grad_fn=<MeanBackward1>)" | ||
] | ||
}, | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"mse_loss" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 51, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor(100.7457, device='cuda:0')" | ||
] | ||
}, | ||
"execution_count": 51, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"(feature_acts > 0).float().sum(1).mean()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 53, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor(62.9661, device='cuda:0')" | ||
] | ||
}, | ||
"execution_count": 53, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mats_sae_training", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |