Skip to content

Commit

Permalink
further-lm-improvments
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Dec 1, 2023
1 parent eba5f79 commit 63048eb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 22 deletions.
6 changes: 4 additions & 2 deletions sae_training/SAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def forward(self, x):
def resample_neurons(
self,
x: Float[Tensor, "batch_size n_hidden"],
frac_active_in_window: Float[Tensor, "window n_hidden_ae"],
feature_sparsity: Float[Tensor, "n_hidden_ae"],
neuron_resample_scale: float,
) -> None:
'''
Expand All @@ -106,7 +106,7 @@ def resample_neurons(
per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze()

# Find the dead neurons in this instance. If all neurons are alive, continue
is_dead = (frac_active_in_window.sum(0) < 1e-8)
is_dead = (feature_sparsity < 1e-8)
dead_neurons = torch.nonzero(is_dead).squeeze(-1)
alive_neurons = torch.nonzero(~is_dead).squeeze(-1)
n_dead = dead_neurons.numel()
Expand Down Expand Up @@ -135,6 +135,8 @@ def resample_neurons(
# Lastly, set the new weights & biases
self.W_enc.data[:, dead_neurons] = replacement_values.T.squeeze(1)
self.b_enc.data[dead_neurons] = 0.0

return len(dead_neurons)

@torch.no_grad()
def set_decoder_norm_to_unit_norm(self):
Expand Down
19 changes: 18 additions & 1 deletion sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass

import torch
Expand Down Expand Up @@ -67,7 +68,6 @@ def language_model_sae_runner(cfg):
model = HookedTransformer.from_pretrained(cfg.model_name) # any other cfg we should pass in here?

# initialize dataset
dataset = load_dataset(cfg.dataset_path, streaming=True, split="train")
activations_buffer = DataLoaderBuffer(
cfg, model, data_path=cfg.dataset_path
)
Expand All @@ -89,6 +89,23 @@ def language_model_sae_runner(cfg):
wandb_log_frequency = cfg.wandb_log_frequency
)



# save sae to checkpoints folder
unique_id = wandb.util.generate_id()
#make sure directory exists

os.makedirs(f"{cfg.checkpoint_path}/{unique_id}", exist_ok=True)
torch.save(sparse_autoencoder.state_dict(), f"{cfg.checkpoint_path}/{unique_id}/sae.pt")
# upload to wandb
if cfg.log_to_wandb:
model_artifact = wandb.Artifact(
"sae", type="model", metadata=dict(cfg.__dict__)
)
model_artifact.add_file(f"{cfg.checkpoint_path}/{unique_id}/sae.pt")
wandb.log_artifact(model_artifact)


if cfg.log_to_wandb:
wandb.finish()

Expand Down
50 changes: 31 additions & 19 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,25 @@ def train_sae_on_language_model(
# Make sure the W_dec is still zero-norm
sae.set_decoder_norm_to_unit_norm()

# # Resample dead neurons
# if (feature_sampling_window is not None) and ((step + 1) % feature_sampling_window == 0):

# # Get the fraction of neurons active in the previous window
# frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0)

# # run model with cach on inputs and get out hidden
# # _, cache = model(batch, return_cache=True)
# # hidden = cache[hook_point,0]

# # if standard resampling <- do this
# # Resample
# sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale)

# # elif anthropic resampling <- do this
# # sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale)
# Resample dead neurons
if (feature_sampling_window is not None) and ((n_training_steps + 1) % feature_sampling_window == 0):

# Get the fraction of neurons active in the previous window
frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0)
feature_sparsity = frac_active_in_window.sum(0) / (
feature_sampling_window * batch_size
)
# if standard resampling <- do this
n_resampled_neurons = sae.resample_neurons(next(dataloader), feature_sparsity, feature_reinit_scale)
n_remaining_batches_in_buffer -= 1

# elif anthropic resampling <- do this
# run the model and reinit where recons loss is high.
if n_remaining_batches_in_buffer == 0:
dataloader, n_remaining_batches_in_buffer = get_new_dataloader(
data_loader_buffer, n_remaining_batches_in_buffer, batch_size)
else:
n_resampled_neurons = 0

# # Update learning rate here if using scheduler.

Expand All @@ -68,10 +71,10 @@ def train_sae_on_language_model(
n_training_tokens += batch_size
n_remaining_batches_in_buffer -= 1

# Update the buffer if we've run out of batches
if n_remaining_batches_in_buffer == 0:
buffer = data_loader_buffer.get_buffer()
dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))
n_remaining_batches_in_buffer = len(dataloader)
dataloader, n_remaining_batches_in_buffer = get_new_dataloader(
data_loader_buffer, n_remaining_batches_in_buffer, batch_size)

with torch.no_grad():
# Calculate the sparsities, and add it to a list, calculate sparsity metrics
Expand Down Expand Up @@ -118,6 +121,7 @@ def train_sae_on_language_model(
.float()
.mean()
.item(),
"metrics/n_resampled_neurons": n_resampled_neurons,
"details/n_training_tokens": n_training_tokens,
},
step=n_training_steps,
Expand Down Expand Up @@ -160,6 +164,14 @@ def train_sae_on_language_model(
return sae


def get_new_dataloader(data_loader_buffer, n_remaining_batches_in_buffer, batch_size):
buffer = data_loader_buffer.get_buffer()
dataloader = iter(DataLoader(buffer, batch_size=batch_size, shuffle=True))
n_remaining_batches_in_buffer = len(dataloader)
return dataloader, n_remaining_batches_in_buffer



@torch.no_grad()
def get_recons_loss(sae, model, data_loader_buffer, num_batches=5):
hook_point = data_loader_buffer.cfg.hook_point
Expand Down

0 comments on commit 63048eb

Please sign in to comment.