Skip to content

Commit

Permalink
add_analysis_files_for_post
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 10, 2024
1 parent 3949a46 commit e75323c
Show file tree
Hide file tree
Showing 4 changed files with 1,594 additions and 0 deletions.
102 changes: 102 additions & 0 deletions sae_analysis/feature_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pandas as pd
import torch
from tqdm import tqdm
from transformer_lens import HookedTransformer

from sae_training.sparse_autoencoder import SparseAutoencoder


@torch.no_grad()
def get_feature_property_df(
sparse_autoencoder: SparseAutoencoder, feature_sparsity: torch.Tensor
):
"""
feature_property_df = get_feature_property_df(sparse_autoencoder, log_feature_density.cpu())
"""

W_dec_normalized = (
sparse_autoencoder.W_dec.cpu()
) # / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True)
W_enc_normalized = (
sparse_autoencoder.W_enc.cpu()
/ sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True)
)
d_e_projection = W_dec_normalized @ W_enc_normalized.T
b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T

temp_df = pd.DataFrame(
{
"log_feature_sparsity": feature_sparsity + 1e-10,
"d_e_projection": d_e_projection,
# "d_e_projection_normalized": d_e_projection_normalized,
"b_enc": sparse_autoencoder.b_enc.detach().cpu(),
"b_dec_projection": b_dec_projection,
"feature": list(range(sparse_autoencoder.cfg.d_sae)),
"dead_neuron": (feature_sparsity < -9).cpu(),
}
)

return temp_df


@torch.no_grad()
def get_stats_df(projection: torch.Tensor):
"""
Returns a dataframe with the mean, std, skewness and kurtosis of the projection
"""
mean = projection.mean(dim=1, keepdim=True)
diffs = projection - mean
var = (diffs**2).mean(dim=1, keepdim=True)
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0), dim=1)
kurtosis = torch.mean(torch.pow(zscores, 4.0), dim=1)

stats_df = pd.DataFrame(
{
"feature": range(len(skews)),
"mean": mean.numpy().squeeze(),
"std": std.numpy().squeeze(),
"skewness": skews.numpy(),
"kurtosis": kurtosis.numpy(),
}
)

return stats_df


@torch.no_grad()
def get_all_stats_dfs(
gpt2_small_sparse_autoencoders: dict[str, SparseAutoencoder], # [hook_point, sae]
gpt2_small_sae_sparsities: dict[str, torch.Tensor], # [hook_point, sae]
model: HookedTransformer,
cosine_sim: bool = False,
):
stats_dfs = []
pbar = tqdm(gpt2_small_sparse_autoencoders.keys())
for key in pbar:
layer = int(key.split(".")[1])
sparse_autoencoder = gpt2_small_sparse_autoencoders[key]
pbar.set_description(f"Processing layer {sparse_autoencoder.cfg.hook_point}")
W_U_stats_df_dec, _ = get_W_U_W_dec_stats_df(
sparse_autoencoder.W_dec.cpu(), model, cosine_sim
)
log_feature_sparsity = gpt2_small_sae_sparsities[key].detach().cpu()
W_U_stats_df_dec["log_feature_sparsity"] = log_feature_sparsity
W_U_stats_df_dec["layer"] = layer + (1 if "post" in key else 0)
stats_dfs.append(W_U_stats_df_dec)

W_U_stats_df_dec_all_layers = pd.concat(stats_dfs, axis=0)
return W_U_stats_df_dec_all_layers


@torch.no_grad()
def get_W_U_W_dec_stats_df(
W_dec: torch.Tensor, model: HookedTransformer, cosine_sim: bool = False
) -> tuple[pd.DataFrame, torch.Tensor]:
W_U = model.W_U.detach().cpu()
if cosine_sim:
W_U = W_U / W_U.norm(dim=0, keepdim=True)
dec_projection_onto_W_U = W_dec @ W_U
W_U_stats_df = get_stats_df(dec_projection_onto_W_U)
return W_U_stats_df, dec_projection_onto_W_U
51 changes: 51 additions & 0 deletions sae_analysis/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import webbrowser

import torch
from huggingface_hub import hf_hub_download

from sae_training.sparse_autoencoder import SparseAutoencoder


def get_all_gpt2_small_saes() -> (
tuple[dict[str, SparseAutoencoder], dict[str, torch.Tensor]]
):

REPO_ID = "jbloom/GPT2-Small-SAEs"
gpt2_small_sparse_autoencoders = {}
gpt2_small_saes_log_feature_sparsities = {}
for layer in range(12):
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = SparseAutoencoder.load_from_pretrained(f"{path}")
sae.cfg.use_ghost_grads = False
gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_density = torch.load(path, map_location=sae.cfg.device)
gpt2_small_saes_log_feature_sparsities[sae.cfg.hook_point] = log_feature_density

# get the final one
layer = 11
FILENAME = (
f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_post_24576.pt"
)
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
sae = SparseAutoencoder.load_from_pretrained(f"{path}")
sae.cfg.use_ghost_grads = False
gpt2_small_sparse_autoencoders[sae.cfg.hook_point] = sae

FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_post_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_density = torch.load(path, map_location=sae.cfg.device)
gpt2_small_saes_log_feature_sparsities[sae.cfg.hook_point] = log_feature_density

return gpt2_small_sparse_autoencoders, gpt2_small_saes_log_feature_sparsities


def open_neuronpedia(feature_id: int, layer: int = 0):

path_to_html = f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feature_id}"

print(f"Feature {feature_id}")
webbrowser.open_new_tab(path_to_html)
Loading

0 comments on commit e75323c

Please sign in to comment.