-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add notebook to transfer W&B models to HF (#154)
hard to check this works quickly but assuming it does.
- Loading branch information
1 parent
f445fdf
commit 91239c1
Showing
1 changed file
with
306 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Upload SAEs from Weights & Biases to HuggingFace" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This script uploads trained SAEs from Weights and Biases run artifacts to a HuggingFace model repo. Each run will get its own folder containing the weights, a config JSON and a tensor of sparsity values. The models are first downloaded to the local machine in the `scripts/artifacts` folder, then uploaded to HuggingFace, before the artifacts downloaded in this run are then deleted from `scripts/artifacts` (pre-existing artifacts in this folder will be left unchanged).\n", | ||
"\n", | ||
"To run this script uou'll need to:\n", | ||
"- Create a HuggingFace model repo to upload to (or use an existing one).\n", | ||
"- Get a HuggingFace [write access token](https://huggingface.co/docs/hub/en/security-tokens) for the repo.\n", | ||
"- Put these variables in the cell below, along with the W&B project name you want to transfer.\n", | ||
"\n", | ||
"Note: only finished runs will be transferred. You can change this or add extra filtering in the cell beginning with the comment \"more filtering on W&B runs can be done here\"." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Script variables" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"wandb_project_name = \"YOUR-WANDB-PROJECT\"\n", | ||
"hf_repo_id = \"YOUR-HF-REPO\"\n", | ||
"hf_token = \"YOUR-HF-TOKEN\" # do not upload to github!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## W&B downloads" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import wandb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Get runs from W&B" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"api = wandb.Api()\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# more filtering on W&B runs can be done here\n", | ||
"runs = api.runs(wandb_project_name)\n", | ||
"completed_runs = [run for run in runs if run.state == \"finished\"]\n", | ||
"sorted_by_layer = sorted(completed_runs, key=lambda x: x.config[\"hook_point_layer\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### W&B helper functions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_model_endstr(run):\n", | ||
" d_sae = int(run.config[\"d_in\"]) * int(run.config[\"expansion_factor\"])\n", | ||
" hook_point = run.config[\"hook_point\"]\n", | ||
" return \"_\".join([hook_point, str(d_sae)])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def is_model(artifact, model_endstr):\n", | ||
" return artifact.name.split(\":\")[0].endswith(model_endstr)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_sparsity_endstr(run):\n", | ||
" model_endstr = get_model_endstr(run)\n", | ||
" return model_endstr + \"_log_feature_sparsity\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def is_sparsity(artifact, sparsity_endstr):\n", | ||
" return artifact.name.split(\":\")[0].endswith(sparsity_endstr)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_model_artifact(run):\n", | ||
" model_endstr = get_model_endstr(run)\n", | ||
" for a in run.logged_artifacts():\n", | ||
" if is_model(a, model_endstr):\n", | ||
" return a\n", | ||
" \n", | ||
" else:\n", | ||
" continue\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_sparsity_artifact(run):\n", | ||
" sparsity_endstr = get_sparsity_endstr(run)\n", | ||
" for a in run.logged_artifacts():\n", | ||
" if is_sparsity(a, sparsity_endstr):\n", | ||
" return a\n", | ||
" \n", | ||
" else:\n", | ||
" continue" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def download_model_and_sparsity(run):\n", | ||
" print(f\"Downloading model & sparsity for {run.config['hook_point']}\")\n", | ||
" model_artifact = get_model_artifact(run)\n", | ||
" model_path = model_artifact.download()\n", | ||
" sparsity_artifact = get_sparsity_artifact(run)\n", | ||
" sparsity_artifact.download(root=model_path)\n", | ||
" return model_path" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Download from W&B (quite slow)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_paths = [download_model_and_sparsity(run) for run in sorted_by_layer]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## HF uploads" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from huggingface_hub import HfApi\n", | ||
"api = HfApi()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def upload_to_hf(model_path):\n", | ||
" repo_path = model_path.split('/')[-1]\n", | ||
" api.upload_folder(\n", | ||
" folder_path=model_path,\n", | ||
" repo_id=hf_repo_id,\n", | ||
" path_in_repo=repo_path,\n", | ||
" token=hf_token,\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for model_path in model_paths:\n", | ||
" print(f\"Uploading {model_path}\")\n", | ||
" upload_to_hf(model_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Cleanup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for model_path in model_paths:\n", | ||
" ## first remove directory contents\n", | ||
" for f in os.scandir(model_path):\n", | ||
" os.remove(f)\n", | ||
"\n", | ||
" ## next remove the (now empty) dir\n", | ||
" os.rmdir(model_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "sae-lens-QXoybXMy-py3.10", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |