Skip to content

Commit

Permalink
Add notebook to transfer W&B models to HF (#154)
Browse files Browse the repository at this point in the history
hard to check this works quickly but assuming it does.
  • Loading branch information
tomMcGrath authored May 20, 2024
1 parent f445fdf commit 91239c1
Showing 1 changed file with 306 additions and 0 deletions.
306 changes: 306 additions & 0 deletions scripts/wandb_to_hf.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Upload SAEs from Weights & Biases to HuggingFace"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This script uploads trained SAEs from Weights and Biases run artifacts to a HuggingFace model repo. Each run will get its own folder containing the weights, a config JSON and a tensor of sparsity values. The models are first downloaded to the local machine in the `scripts/artifacts` folder, then uploaded to HuggingFace, before the artifacts downloaded in this run are then deleted from `scripts/artifacts` (pre-existing artifacts in this folder will be left unchanged).\n",
"\n",
"To run this script uou'll need to:\n",
"- Create a HuggingFace model repo to upload to (or use an existing one).\n",
"- Get a HuggingFace [write access token](https://huggingface.co/docs/hub/en/security-tokens) for the repo.\n",
"- Put these variables in the cell below, along with the W&B project name you want to transfer.\n",
"\n",
"Note: only finished runs will be transferred. You can change this or add extra filtering in the cell beginning with the comment \"more filtering on W&B runs can be done here\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Script variables"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wandb_project_name = \"YOUR-WANDB-PROJECT\"\n",
"hf_repo_id = \"YOUR-HF-REPO\"\n",
"hf_token = \"YOUR-HF-TOKEN\" # do not upload to github!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## W&B downloads"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import wandb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Get runs from W&B"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"api = wandb.Api()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# more filtering on W&B runs can be done here\n",
"runs = api.runs(wandb_project_name)\n",
"completed_runs = [run for run in runs if run.state == \"finished\"]\n",
"sorted_by_layer = sorted(completed_runs, key=lambda x: x.config[\"hook_point_layer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### W&B helper functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_model_endstr(run):\n",
" d_sae = int(run.config[\"d_in\"]) * int(run.config[\"expansion_factor\"])\n",
" hook_point = run.config[\"hook_point\"]\n",
" return \"_\".join([hook_point, str(d_sae)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_model(artifact, model_endstr):\n",
" return artifact.name.split(\":\")[0].endswith(model_endstr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_sparsity_endstr(run):\n",
" model_endstr = get_model_endstr(run)\n",
" return model_endstr + \"_log_feature_sparsity\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_sparsity(artifact, sparsity_endstr):\n",
" return artifact.name.split(\":\")[0].endswith(sparsity_endstr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_model_artifact(run):\n",
" model_endstr = get_model_endstr(run)\n",
" for a in run.logged_artifacts():\n",
" if is_model(a, model_endstr):\n",
" return a\n",
" \n",
" else:\n",
" continue\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_sparsity_artifact(run):\n",
" sparsity_endstr = get_sparsity_endstr(run)\n",
" for a in run.logged_artifacts():\n",
" if is_sparsity(a, sparsity_endstr):\n",
" return a\n",
" \n",
" else:\n",
" continue"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def download_model_and_sparsity(run):\n",
" print(f\"Downloading model & sparsity for {run.config['hook_point']}\")\n",
" model_artifact = get_model_artifact(run)\n",
" model_path = model_artifact.download()\n",
" sparsity_artifact = get_sparsity_artifact(run)\n",
" sparsity_artifact.download(root=model_path)\n",
" return model_path"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download from W&B (quite slow)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_paths = [download_model_and_sparsity(run) for run in sorted_by_layer]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## HF uploads"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import HfApi\n",
"api = HfApi()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def upload_to_hf(model_path):\n",
" repo_path = model_path.split('/')[-1]\n",
" api.upload_folder(\n",
" folder_path=model_path,\n",
" repo_id=hf_repo_id,\n",
" path_in_repo=repo_path,\n",
" token=hf_token,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for model_path in model_paths:\n",
" print(f\"Uploading {model_path}\")\n",
" upload_to_hf(model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleanup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for model_path in model_paths:\n",
" ## first remove directory contents\n",
" for f in os.scandir(model_path):\n",
" os.remove(f)\n",
"\n",
" ## next remove the (now empty) dir\n",
" os.rmdir(model_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "sae-lens-QXoybXMy-py3.10",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 91239c1

Please sign in to comment.