Skip to content

Commit

Permalink
Make feature sparsity an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Apr 15, 2024
1 parent dde2481 commit 8230570
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 77 deletions.
6 changes: 4 additions & 2 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"bg_color_map", ["white", "darkorange"]
)

SPARSITY_THRESHOLD = -5
DEFAULT_SPARSITY_THRESHOLD = -5

HTML_ANOMALIES = {
"âĢĶ": "—",
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
sae_path: str,
model_id: str,
sae_id: str,
sparsity_threshold: int = DEFAULT_SPARSITY_THRESHOLD,
neuronpedia_outputs_folder: str = "../../neuronpedia_outputs",
init_session: bool = True,
# token pars
Expand All @@ -89,6 +90,7 @@ def __init__(
self.model_id = model_id
self.layer = self.sparse_autoencoder.cfg.hook_point_layer
self.sae_id = sae_id
self.sparsity_threshold = sparsity_threshold
self.n_features_at_a_time = n_features_at_a_time
self.n_batches_to_sample_from = n_batches_to_sample_from
self.n_prompts_to_select = n_prompts_to_select
Expand Down Expand Up @@ -171,7 +173,7 @@ def run(self):
sparsity = load_sparsity(self.sae_path)
sparsity = sparsity.to(self.device)
self.target_feature_indexes = (
(sparsity > SPARSITY_THRESHOLD).nonzero(as_tuple=True)[0].tolist()
(sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist()
)

# divide into batches
Expand Down
62 changes: 4 additions & 58 deletions tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -44,64 +44,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/johnnylin/.cache/huggingface/hub/models--jbloom--GPT2-Small-SAEs-Reformatted/snapshots/5bd69d8ccac6b19d91934c5aeed4866f8b6e50c7/blocks.0.hook_resid_pre\n",
"Loaded pretrained model gpt2-small into HookedTransformer\n",
"Moving model to device: mps\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/johnnylin/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/datasets/load.py:1461: FutureWarning: The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext\n",
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"==== Starting at batch: 1\n",
"==== Ending at batch: 1\n",
"Total features to run: 19321\n",
"Total skipped: 5255\n",
"Total batches: 806\n",
"Hook Point Layer: 0\n",
"Hook Point: blocks.0.hook_resid_pre\n",
"Writing files to: ../../neuronpedia_outputs/gpt2-small_res-jb_blocks.0.hook_resid_pre\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 84%|████████▍ | 3435/4096 [02:41<00:31, 21.29it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 19\u001b[0m\n\u001b[1;32m 4\u001b[0m NP_OUTPUT_FOLDER \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../../neuronpedia_outputs\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 5\u001b[0m runner \u001b[38;5;241m=\u001b[39m NeuronpediaRunner(\n\u001b[1;32m 6\u001b[0m sae_path\u001b[38;5;241m=\u001b[39mSAE_PATH,\n\u001b[1;32m 7\u001b[0m model_id\u001b[38;5;241m=\u001b[39mMODEL_ID,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m end_batch_inclusive\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0m \u001b[43mrunner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:219\u001b[0m, in \u001b[0;36mNeuronpediaRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# get tokens:\u001b[39;00m\n\u001b[1;32m 218\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 219\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tokens\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_batches_to_sample_from\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_prompts_to_select\u001b[49m\n\u001b[1;32m 221\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTime to get tokens: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mend\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39mstart\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:123\u001b[0m, in \u001b[0;36mNeuronpediaRunner.get_tokens\u001b[0;34m(self, n_batches_to_sample_from, n_prompts_to_select)\u001b[0m\n\u001b[1;32m 121\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28mrange\u001b[39m(n_batches_to_sample_from))\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m pbar:\n\u001b[0;32m--> 123\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m batch_tokens[torch\u001b[38;5;241m.\u001b[39mrandperm(batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])][\n\u001b[1;32m 125\u001b[0m : batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 126\u001b[0m ]\n\u001b[1;32m 127\u001b[0m all_tokens_list\u001b[38;5;241m.\u001b[39mappend(batch_tokens)\n",
"File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/training/activations_store.py:227\u001b[0m, in \u001b[0;36mActivationsStore.get_batch_tokens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m current_length \u001b[38;5;241m==\u001b[39m context_size:\n\u001b[1;32m 226\u001b[0m full_batch \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(current_batch, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 227\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfull_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 230\u001b[0m current_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 231\u001b[0m current_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"outputs": [],
"source": [
"from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n",
"\n",
Expand All @@ -111,6 +56,7 @@
" sae_path=SAE_PATH,\n",
" model_id=MODEL_ID,\n",
" sae_id=SAE_ID,\n",
" sparsity_threshold=-5,\n",
" neuronpedia_outputs_folder=NP_OUTPUT_FOLDER,\n",
" init_session=True,\n",
" n_batches_to_sample_from=2**12,\n",
Expand Down
12 changes: 7 additions & 5 deletions tutorials/neuronpedia/make_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
SAE_PATH = sys.argv[1]
MODEL_ID = sys.argv[2]
SAE_ID = sys.argv[3]
N_BATCHES_SAMPLE = int(sys.argv[4])
N_PROMPTS_SELECT = int(sys.argv[5])
FEATURES_AT_A_TIME = int(sys.argv[6])
START_BATCH_INCLUSIVE = int(sys.argv[7])
END_BATCH_INCLUSIVE = int(sys.argv[8])
SPARSITY_THRESHOLD = int(sys.argv[4])
N_BATCHES_SAMPLE = int(sys.argv[5])
N_PROMPTS_SELECT = int(sys.argv[6])
FEATURES_AT_A_TIME = int(sys.argv[7])
START_BATCH_INCLUSIVE = int(sys.argv[8])
END_BATCH_INCLUSIVE = int(sys.argv[9])

NP_OUTPUT_FOLDER = "../../neuronpedia_outputs"

runner = NeuronpediaRunner(
sae_path=SAE_PATH,
model_id=MODEL_ID,
sae_id=SAE_ID,
sparsity_threshold=SPARSITY_THRESHOLD,
neuronpedia_outputs_folder=NP_OUTPUT_FOLDER,
init_session=True,
n_batches_to_sample_from=N_BATCHES_SAMPLE,
Expand Down
29 changes: 17 additions & 12 deletions tutorials/neuronpedia/make_features.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,47 @@

echo "===== This will start a batch job that generates features to upload to Neuronpedia."
echo "===== This takes input of one SAE directory at a time."
echo "===== Features will be output into ./neuronpedia_outputs/{model}_{hook_point}_{d_sae}/batch-{batch_num}.json"
echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/"

echo ""
echo "(Step 1 of 8)"
echo "(Step 1 of 9)"
echo "What is the absolute, full local file path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)?"
read saepath
# TODO: support huggingface directories

echo ""
echo "(Step 2 of 8)"
echo "(Step 2 of 9)"
echo "What's the model ID? This must exactly match (including casing) the model ID you created on Neuronpedia."
read modelid

echo ""
echo "(Step 3 of 8)"
echo "(Step 3 of 9)"
echo "What's the SAE ID?"
echo "This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). It's in the format [abbrev hook name]-[abbrev author name], like res-jb."
read saeid

echo ""
echo "(Step 4 of 8)"
echo "(Step 4 of 9)"
echo "How many features are in this SAE?"
read numfeatures

echo ""
echo "(Step 5 of 8)"
echo "(Step 5 of 9)"
read -p "What's your feature sparsity threshold? (default: -5): " sparsity
[ -z "${sparsity}" ] && sparsity='-5'

echo ""
echo "(Step 6 of 9)"
read -p "How many features do you want generate per batch file? More requires more RAM. (default: 128): " perbatch
[ -z "${perbatch}" ] && perbatch='128'

echo ""
echo "(Step 6 of 8)"
echo "(Step 7 of 9)"
read -p "Enter number of batches to sample from (default: 4096): " batches
[ -z "${batches}" ] && batches='4096'

echo ""
echo "(Step 7 of 8)"
echo "(Step 8 of 9)"
read -p "Enter number of prompts to select from (default: 24576): " prompts
[ -z "${prompts}" ] && prompts='24576'

Expand All @@ -49,23 +54,23 @@ numbatches=$(expr $numfeatures / $perbatch)
echo "===== INFO: We'll generate $numbatches batches of $perbatch features per batch = $numfeatures total features"

echo ""
echo "(Step 8 of 8)"
echo "(Step 9 of 9)"
read -p "Do you want to resume from a specific batch number? Enter 1 to start from the beginning (default: 1): " startbatch
[ -z "${startbatch}" ] && startbatch='1'

endbatch=$(expr $numbatches)


echo ""
echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/batch-{batch_num}.json"
echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/"
read -p "===== Hit ENTER to start!" start

for j in $(seq $startbatch $endbatch)
do
echo ""
echo "===== BATCH: $j"
echo "RUNNING: python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j"
python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j
echo "RUNNING: python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j"
python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j
done

echo ""
Expand Down

0 comments on commit 8230570

Please sign in to comment.