Skip to content

Commit

Permalink
anthropic sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 22, 2023
1 parent ca74543 commit 048d267
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 26 deletions.
2 changes: 2 additions & 0 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def collect_anthropic_resampling_losses(self, model, activation_store):
# del normal_logits

normal_activations = normal_activations_cache[self.cfg.hook_point]
if self.cfg.hook_point_head_index is not None:
normal_activations = normal_activations[:,:,self.cfg.hook_point_head_index]

# calculate the difference in loss
changes_in_loss = ce_loss_with_recons - ce_loss_without_recons
Expand Down
76 changes: 50 additions & 26 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,28 @@ def train_sae_on_language_model(

feature_sparsity = act_freq_scores / n_frac_active_tokens
dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]
sparse_autoencoder.resample_neurons_anthropic(
dead_neuron_indices,
model,
optimizer,
activation_store
)

if use_wandb:
wandb.log(
{
"metrics/n_resampled_neurons": len(dead_neuron_indices),
},
step=n_training_steps,
if len(dead_neuron_indices) > 0:
sparse_autoencoder.resample_neurons_anthropic(
dead_neuron_indices,
model,
optimizer,
activation_store
)

# for now, we'll hardcode this.
current_lr = scheduler.get_last_lr()[0]
reduced_lr = current_lr * 0.1
increment = (current_lr - reduced_lr) / 1000
optimizer.param_groups[0]['lr'] = reduced_lr
steps_before_reset = 1000

if use_wandb:
wandb.log(
{
"metrics/n_resampled_neurons": len(dead_neuron_indices),
},
step=n_training_steps,
)

# for now, we'll hardcode this.
current_lr = scheduler.get_last_lr()[0]
reduced_lr = current_lr * 0.1
increment = (current_lr - reduced_lr) / 1000
optimizer.param_groups[0]['lr'] = reduced_lr
steps_before_reset = 1000


# Resample dead neurons
Expand Down Expand Up @@ -284,9 +285,23 @@ def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: Activatio
},
step=n_training_steps,
)

head_index = sparse_autoencoder.cfg.hook_point_head_index

def standard_replacement_hook(activations, hook):
activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
return activations

def head_replacement_hook(activations, hook):
new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
activations[:,:,head_index] = new_actions
return activations

head_index = sparse_autoencoder.cfg.hook_point_head_index
replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook

# get attn when using reconstructed activations
with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook, encoder=sparse_autoencoder))]):
with model.hooks(fwd_hooks=[(hook_point, partial(replacement_hook))]):
_, new_cache = model.run_with_cache(eval_tokens, names_filter=[get_act_name("pattern", hook_point_layer)])
patterns_reconstructed = new_cache[get_act_name("pattern", hook_point_layer)][:,hook_point_head_index].detach().cpu()
del new_cache
Expand Down Expand Up @@ -366,14 +381,26 @@ def run_evals(sparse_autoencoder: SparseAutoencoder, activation_store: Activatio
)

@torch.no_grad()
def get_recons_loss(sparse_autoencder, model, activation_store, batch_tokens):
def get_recons_loss(sparse_autoencoder, model, activation_store, batch_tokens):
hook_point = activation_store.cfg.hook_point
loss = model(batch_tokens, return_type="loss")

head_index = sparse_autoencoder.cfg.hook_point_head_index

def standard_replacement_hook(activations, hook):
activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
return activations

def head_replacement_hook(activations, hook):
new_actions = sparse_autoencoder.forward(activations[:,:,head_index])[0].to(activations.dtype)
activations[:,:,head_index] = new_actions
return activations

replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook
recons_loss = model.run_with_hooks(
batch_tokens,
return_type="loss",
fwd_hooks=[(hook_point, partial(replacement_hook, encoder=sparse_autoencder))],
fwd_hooks=[(hook_point, partial(replacement_hook))],
)

zero_abl_loss = model.run_with_hooks(
Expand All @@ -385,9 +412,6 @@ def get_recons_loss(sparse_autoencder, model, activation_store, batch_tokens):
return score, loss, recons_loss, zero_abl_loss


def replacement_hook(mlp_post, hook, encoder):
activations = encoder(mlp_post)[0].to(mlp_post.dtype)
return activations


def mean_ablate_hook(mlp_post, hook):
Expand Down

0 comments on commit 048d267

Please sign in to comment.