Skip to content

Commit

Permalink
notebook_for_keith
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 2, 2023
1 parent 2b43980 commit d06e09b
Showing 1 changed file with 379 additions and 0 deletions.
379 changes: 379 additions & 0 deletions hack_day_analysis.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Hack Day Analysis\n",
"\n",
"- Try loading our models from last night\n",
"- Try out tiny stories models. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get our model from last night"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mjbloom\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.16.0"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/home/paperspace/mats_sae_training/wandb/run-20231202_095753-nh2ujgd7</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7' target=\"_blank\">prime-snowflake-2</a></strong> to <a href='https://wandb.ai/jbloom/mats_sae_training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/jbloom/mats_sae_training' target=\"_blank\">https://wandb.ai/jbloom/mats_sae_training</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7' target=\"_blank\">https://wandb.ai/jbloom/mats_sae_training/runs/nh2ujgd7</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact sae:v1, 128.13MB. 1 files... \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n",
"Done. 0:0:0.4\n"
]
}
],
"source": [
"import wandb\n",
"run = wandb.init()\n",
"artifact = run.use_artifact('jbloom/mats_sae_training_language_models/sae:v1', type='model')\n",
"artifact_dir = artifact.download()\n",
"wandb.finish()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([512, 32768])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights = torch.load(artifact_dir + '/sae.pt')\n",
"weights[\"W_enc\"].shape"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch \n",
"from dataclasses import dataclass\n",
"from sae_training.SAE import SAE\n",
"\n",
"@dataclass\n",
"class SAEConfig():\n",
" d_in = weights[\"W_enc\"].shape[0]\n",
" d_sae = weights[\"W_enc\"].shape[1]\n",
" device = \"cuda\"\n",
" dtype = torch.float32\n",
" l1_coefficient = 0.0001\n",
"\n",
"sae_cfg = SAEConfig()\n",
"SAE = SAE(sae_cfg)\n",
"SAE.load_state_dict(weights)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model gelu-2l into HookedTransformer\n"
]
}
],
"source": [
"from transformer_lens import HookedTransformer\n",
"\n",
"model = HookedTransformer.from_pretrained(\"gelu-2l\")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "72e610439495413395ee5026cbc3169e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/23 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# load some data\n",
"from torch.utils.data import DataLoader\n",
"from sae_training.activations_buffer import DataLoaderBuffer\n",
"\n",
"@dataclass\n",
"class DataLoaderBufferConfig():\n",
" d_in: int\n",
" d_sae: int\n",
" device: str\n",
" dtype: torch.dtype\n",
" batch_size: int\n",
" n_workers: int\n",
" context_size: int = 128\n",
" store_batch_size: int = 16\n",
" n_batches_in_buffer: int = 10\n",
" hook_point: str = \"blocks.0.hook_mlp_out\"\n",
" \n",
"\n",
"\n",
"cfg = DataLoaderBufferConfig(\n",
" d_in=weights[\"W_enc\"].shape[0],\n",
" d_sae=weights[\"W_enc\"].shape[1],\n",
" device=\"cuda\",\n",
" dtype=torch.float32,\n",
" batch_size=128,\n",
" n_workers=0,\n",
")\n",
"\n",
"activations_buffer = DataLoaderBuffer(\n",
" cfg, model, data_path=\"NeelNanda/c4-tokenized-2b\",\n",
")\n",
"\n",
"def get_new_dataloader(data_loader_buffer, batch_size):\n",
" buffer = data_loader_buffer.get_buffer()\n",
" dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))\n",
" n_remaining_batches_in_buffer = len(dataloader)\n",
" return dataloader, n_remaining_batches_in_buffer\n",
"\n",
"dataloader, n_remaining_batches_in_buffer = get_new_dataloader(activations_buffer, 128)\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"tokens = next(dataloader)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(61.6924, device='cuda:0')"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"activations = activations_buffer.get_buffer()\n",
"sae_out, feature_acts, loss, mse_loss, l1_loss = SAE(activations)\n",
"(feature_acts > 0).float().sum(0).mean()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(11.5401, device='cuda:0', grad_fn=<MeanBackward1>)"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mse_loss"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(100.7457, device='cuda:0')"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(feature_acts > 0).float().sum(1).mean()"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(62.9661, device='cuda:0')"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mats_sae_training",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit d06e09b

Please sign in to comment.