-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jbloom-md
committed
Nov 30, 2023
1 parent
b407aab
commit d1095af
Showing
10 changed files
with
449 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from functools import partial | ||
|
||
import einops | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
from transformer_lens import HookedTransformer | ||
|
||
import wandb | ||
from sae_training.SAE import SAE | ||
|
||
|
||
def train_sae_on_language_model( | ||
model: HookedTransformer, | ||
sae: SAE, | ||
dataloader: DataLoader, | ||
batch_size: int = 1024, | ||
feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead | ||
feature_reinit_scale: float = 0.2, # how much to scale the resampled features by | ||
dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead | ||
use_wandb: bool = False, | ||
wandb_log_frequency: int = 50,): | ||
|
||
optimizer = torch.optim.Adam(sae.parameters()) | ||
frac_active_list = [] # track active features | ||
|
||
sae.train() | ||
n_training_steps = 0 | ||
pbar = tqdm(dataloader) | ||
for step, batch in enumerate(pbar): | ||
|
||
# Make sure the W_dec is still zero-norm | ||
sae.set_decoder_norm_to_unit_norm() | ||
|
||
# Resample dead neurons | ||
if (feature_sampling_window is not None) and ((step + 1) % feature_sampling_window == 0): | ||
|
||
# Get the fraction of neurons active in the previous window | ||
frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0) | ||
|
||
# run model with cach on inputs and get out hidden | ||
# _, cache = model(batch, return_cache=True) | ||
# hidden = cache[hook_point,0] | ||
|
||
# if standard resampling <- do this | ||
# Resample | ||
sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale) | ||
|
||
# elif anthropic resampling <- do this | ||
# sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale) | ||
|
||
# Update learning rate here if using scheduler. | ||
|
||
# Generate Activations | ||
activations = list() | ||
def hook_store_activation(input, activations): | ||
activations.append(input) | ||
return input | ||
|
||
activations = list() | ||
def hook_store_activation(x, activations): | ||
activations.append(x) | ||
return x | ||
|
||
hook_func = partial(hook_store_activation, activations=activations) | ||
hook_func(torch.Tensor([1,2,3])) | ||
_ = model.run_with_hooks( | ||
x , fwd_hooks= | ||
[(hook_point, hook_func)] | ||
) | ||
|
||
# Forward and Backward Passes | ||
optimizer.zero_grad() | ||
_, feature_acts, loss, mse_loss, l1_loss = sae(activations.pop()) | ||
# loss = reconstruction MSE + L1 regularization | ||
|
||
with torch.no_grad(): | ||
|
||
# Calculate the sparsities, and add it to a list | ||
frac_active = einops.reduce( | ||
(feature_acts.abs() > dead_feature_threshold).float(), | ||
"batch_size hidden_ae -> hidden_ae", "mean") | ||
frac_active_list.append(frac_active) | ||
|
||
batch_size = batch.shape[0] | ||
log_frac_feature_activation = torch.log(frac_active + 1e-8) | ||
n_dead_features = (frac_active < dead_feature_threshold).sum() | ||
|
||
l0 = (feature_acts > 0).float().mean() | ||
l2_norm = torch.norm(feature_acts, dim=1).mean() | ||
|
||
|
||
if use_wandb and ((step + 1) % wandb_log_frequency == 0): | ||
wandb.log({ | ||
"losses/mse_loss": mse_loss.item(), | ||
"losses/l1_loss": batch_size*l1_loss.item(), | ||
"losses/overall_loss": loss.item(), | ||
"metrics/l0": l0.item(), | ||
"metrics/l2": l2_norm.item(), | ||
# "metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()), | ||
"metrics/n_dead_features": n_dead_features, | ||
"metrics/n_alive_features": sae.d_sae - n_dead_features, | ||
}, step=n_training_steps) | ||
|
||
pbar.set_description(f"{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}") | ||
|
||
loss.backward() | ||
|
||
# Taken from Artur's code https://github.com/ArthurConmy/sae/blob/3f8c314d9c008ec40de57828762ec5c9159e4092/sae/utils.py#L91 | ||
# TODO do we actually need this? | ||
# Update grads so that they remove the parallel component | ||
# (d_sae, d_in) shape | ||
sae.remove_gradient_parallel_to_decoder_directions() | ||
optimizer.step() | ||
|
||
n_training_steps += 1 | ||
|
||
return sae |
Oops, something went wrong.