Skip to content

Commit

Permalink
set_up_lm_runner
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Nov 30, 2023
1 parent b407aab commit d1095af
Show file tree
Hide file tree
Showing 10 changed files with 449 additions and 75 deletions.
180 changes: 178 additions & 2 deletions dev.ipynb

Large diffs are not rendered by default.

48 changes: 34 additions & 14 deletions sae_training/SAE.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

#%%
"""Most of this is just copied over from Arthur's code and slightly simplified:
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
"""
Expand All @@ -13,7 +12,6 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint


#%%
# TODO make sure that W_dec stays unit norm during training
class SAE(HookedRootModule):
def __init__(
Expand All @@ -28,6 +26,7 @@ def __init__(
f"d_in must be an int but was {self.d_in=}; {type(self.d_in)=}"
)
self.d_sae = cfg.d_sae
self.l1_coefficient = cfg.l1_coefficient
self.dtype = cfg.dtype
self.device = cfg.device

Expand Down Expand Up @@ -62,7 +61,7 @@ def __init__(

self.setup() # Required for `HookedRootModule`s

def forward(self, x, return_mode: Literal["sae_out", "hidden_post", "both"]="both"):
def forward(self, x):
sae_in = self.hook_sae_in(
x - self.b_dec
) # Remove encoder bias as per Anthropic
Expand All @@ -75,25 +74,23 @@ def forward(self, x, return_mode: Literal["sae_out", "hidden_post", "both"]="bot
)
+ self.b_enc
)
hidden_post = self.hook_hidden_post(torch.nn.functional.relu(hidden_pre))
feature_acts = self.hook_hidden_post(torch.nn.functional.relu(hidden_pre))

sae_out = self.hook_sae_out(
einops.einsum(
hidden_post,
feature_acts,
self.W_dec,
"... d_sae, d_sae d_in -> ... d_in",
)
+ self.b_dec
)

mse_loss = ((sae_out - x)**2).mean()
l1_loss = torch.abs(feature_acts).sum()
loss = mse_loss + self.l1_coefficient * l1_loss

return sae_out, feature_acts, loss, mse_loss, l1_loss

if return_mode == "sae_out":
return sae_out
elif return_mode == "hidden_post":
return hidden_post
elif return_mode == "both":
return sae_out, hidden_post
else:
raise ValueError(f"Unexpected {return_mode=}")

@torch.no_grad()
def resample_neurons(
Expand All @@ -105,7 +102,7 @@ def resample_neurons(
'''
Resamples neurons that have been dead for `dead_neuron_window` steps, according to `frac_active`.
'''
sae_out = self.forward(x, return_mode="sae_out")
sae_out, _, _, _, _ = self.forward(x)
per_token_l2_loss = (sae_out - x).pow(2).sum(dim=-1).squeeze()

# Find the dead neurons in this instance. If all neurons are alive, continue
Expand Down Expand Up @@ -138,3 +135,26 @@ def resample_neurons(
# Lastly, set the new weights & biases
self.W_enc.data[:, dead_neurons] = replacement_values.T.squeeze(1)
self.b_enc.data[dead_neurons] = 0.0

@torch.no_grad()
def set_decoder_norm_to_unit_norm(self):
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

@torch.no_grad()
def remove_gradient_parallel_to_decoder_directions(self):
'''
Update grads so that they remove the parallel component
(d_sae, d_in) shape
'''

parallel_component = einops.einsum(
self.W_dec.grad,
self.W_dec.data,
"d_sae d_in, d_sae d_in -> d_sae",
)

self.W_dec.grad -= einops.einsum(
parallel_component,
self.W_dec.data,
"d_sae, d_sae d_in -> d_sae d_in",
)
11 changes: 8 additions & 3 deletions sae_training/lm_datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datasets import load_dataset


# To do: preprocess_tokenized_dataset, preprocess_text_dataset, preprocess other dataset
def preprocess_tokenized_dataset(source_batch: dict, context_size: int) -> dict:
tokenized_prompts = source_batch["tokens"]
Expand All @@ -20,14 +21,16 @@ def preprocess_tokenized_dataset(source_batch: dict, context_size: int) -> dict:

def get_mapped_dataset(cfg):
# Load the dataset
context_size = cfg["context_size"]
dataset_path = cfg["dataset_path"]
context_size = cfg.context_size
dataset_path = cfg.dataset_path
dataset_split = "train"
buffer_size: int = 1000,
preprocess_batch_size: int = 1000,

dataset = load_dataset(dataset_path, streaming=True, split=dataset_split) # type: ignore

ids = dataset.to_iterable_dataset() # try out shards here
# ids = ids.filter(filter_fn).map(process_fn)

# Setup preprocessing
existing_columns = list(next(iter(dataset)).keys())
mapped_dataset = dataset.map(
Expand All @@ -44,3 +47,5 @@ def get_mapped_dataset(cfg):
dataset = mapped_dataset.shuffle(buffer_size=buffer_size)
return dataset



61 changes: 45 additions & 16 deletions sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from dataclasses import dataclass

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer

from sae_training.activation_store import ActivationStore
from sae_training.lm_datasets import get_mapped_dataset
from sae_training.lm_datasets import preprocess_tokenized_dataset

# from sae_training.activation_store import ActivationStore
from sae_training.SAE import SAE
from sae_training.train_sae import train_sae
from sae_training.train_sae_on_language_model import train_sae_on_language_model


@dataclass
class SAERunnerConfig:
class LanguageModelSAERunnerConfig:

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
Expand All @@ -19,6 +22,7 @@ class SAERunnerConfig:
dataset_path: str = "NeelNanda/c4-tokenized-2b"

# SAE Parameters
d_in: int = 768
expansion_factor: int = 4

# Training Parameters
Expand All @@ -27,7 +31,14 @@ class SAERunnerConfig:
train_batch_size: int = 4096
context_size: int = 128

# Resampling protocol args
feature_sampling_window: int = 100
feature_reinit_scale: float = 0.2
dead_feature_threshold: float = 1e-8


# Activation Store Parameters
shuffle_buffer_size: int = 10_000
# max_store_size: int = 384 * 4096 * 2
# max_activations: int = 2_000_000_000
# resample_frequency: int = 122_880_000
Expand All @@ -36,32 +47,50 @@ class SAERunnerConfig:

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training"
wandb_project: str = "mats_sae_training_language_model"
wandb_entity: str = None
wandb_log_frequency: int = 50

# Misc
device: str = "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = torch.float32

def __post_init__(self):
self.d_sae = self.d_in * self.expansion_factor

def sae_runner(cfg):
def language_model_sae_runner(cfg):


model = HookedTransformer.from_pretrained("gelu-2l") # any other cfg we should pass in here?
# get the model
model = HookedTransformer.from_pretrained(cfg.model_name) # any other cfg we should pass in here?

# initialize dataset
dataset = get_mapped_dataset(cfg)
activation_store = ActivationStore(cfg, dataset)

dataset = load_dataset(cfg.dataset_path, streaming=True, split="train")
existing_columns = list(next(iter(dataset)).keys())
mapped_dataset = dataset.map(
preprocess_tokenized_dataset, # preprocess is what differentiates different datasets
batched=True,
batch_size=cfg.train_batch_size,
fn_kwargs={"context_size": cfg.context_size},
remove_columns=existing_columns,
)
dataset = mapped_dataset.shuffle(buffer_size=cfg.shuffle_buffer_size)
dataloader = DataLoader(dataset, batch_size=cfg.train_batch_size)

# initialize the SAE
sparse_autoencoder = SAE(cfg)

# train SAE
sparse_autoencoder = train_sae(
model,
activation_store,
sparse_autoencoder,
cfg)
sparse_autoencoder = train_sae_on_language_model(
model, sparse_autoencoder, dataloader,
batch_size = cfg.train_batch_size,
feature_sampling_window = cfg.feature_sampling_window,
feature_reinit_scale = cfg.feature_reinit_scale,
dead_feature_threshold = cfg.feature_reinit_scale,
use_wandb = cfg.log_to_wandb,
wandb_log_frequency = cfg.wandb_log_frequency
)

return trained_sae
return sparse_autoencoder
5 changes: 2 additions & 3 deletions sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sae_training.SAE import SAE
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
from sae_training.train_sae import train_sae
from sae_training.train_sae_on_toy_model import train_toy_sae


@dataclass
Expand Down Expand Up @@ -84,12 +84,11 @@ def toy_model_sae_runner(cfg):
if cfg.log_to_wandb:
wandb.init(project="sae-training-test", config=cfg)

sae = train_sae(
sae = train_toy_sae(
model, # need model so we can do evals for neuron resampling
sae,
hidden.detach().squeeze(),
use_wandb=cfg.log_to_wandb,
l1_coeff=cfg.l1_coefficient,
batch_size=cfg.train_batch_size,
n_epochs=cfg.train_epochs,
feature_sampling_window=cfg.feature_sampling_window,
Expand Down
118 changes: 118 additions & 0 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from functools import partial

import einops
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_training.SAE import SAE


def train_sae_on_language_model(
model: HookedTransformer,
sae: SAE,
dataloader: DataLoader,
batch_size: int = 1024,
feature_sampling_window: int = 100, # how many training steps between resampling the features / considiring neurons dead
feature_reinit_scale: float = 0.2, # how much to scale the resampled features by
dead_feature_threshold: float = 1e-8, # how infrequently a feature has to be active to be considered dead
use_wandb: bool = False,
wandb_log_frequency: int = 50,):

optimizer = torch.optim.Adam(sae.parameters())
frac_active_list = [] # track active features

sae.train()
n_training_steps = 0
pbar = tqdm(dataloader)
for step, batch in enumerate(pbar):

# Make sure the W_dec is still zero-norm
sae.set_decoder_norm_to_unit_norm()

# Resample dead neurons
if (feature_sampling_window is not None) and ((step + 1) % feature_sampling_window == 0):

# Get the fraction of neurons active in the previous window
frac_active_in_window = torch.stack(frac_active_list[-feature_sampling_window:], dim=0)

# run model with cach on inputs and get out hidden
# _, cache = model(batch, return_cache=True)
# hidden = cache[hook_point,0]

# if standard resampling <- do this
# Resample
sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale)

# elif anthropic resampling <- do this
# sae.resample_neurons(hidden, frac_active_in_window, feature_reinit_scale)

# Update learning rate here if using scheduler.

# Generate Activations
activations = list()
def hook_store_activation(input, activations):
activations.append(input)
return input

activations = list()
def hook_store_activation(x, activations):
activations.append(x)
return x

hook_func = partial(hook_store_activation, activations=activations)
hook_func(torch.Tensor([1,2,3]))
_ = model.run_with_hooks(
x , fwd_hooks=
[(hook_point, hook_func)]
)

# Forward and Backward Passes
optimizer.zero_grad()
_, feature_acts, loss, mse_loss, l1_loss = sae(activations.pop())
# loss = reconstruction MSE + L1 regularization

with torch.no_grad():

# Calculate the sparsities, and add it to a list
frac_active = einops.reduce(
(feature_acts.abs() > dead_feature_threshold).float(),
"batch_size hidden_ae -> hidden_ae", "mean")
frac_active_list.append(frac_active)

batch_size = batch.shape[0]
log_frac_feature_activation = torch.log(frac_active + 1e-8)
n_dead_features = (frac_active < dead_feature_threshold).sum()

l0 = (feature_acts > 0).float().mean()
l2_norm = torch.norm(feature_acts, dim=1).mean()


if use_wandb and ((step + 1) % wandb_log_frequency == 0):
wandb.log({
"losses/mse_loss": mse_loss.item(),
"losses/l1_loss": batch_size*l1_loss.item(),
"losses/overall_loss": loss.item(),
"metrics/l0": l0.item(),
"metrics/l2": l2_norm.item(),
# "metrics/feature_density_histogram": wandb.Histogram(log_frac_feature_activation.tolist()),
"metrics/n_dead_features": n_dead_features,
"metrics/n_alive_features": sae.d_sae - n_dead_features,
}, step=n_training_steps)

pbar.set_description(f"{step}| MSE Loss {mse_loss.item():.3f} | L0 {l0.item():.3f} | n_dead_features {n_dead_features}")

loss.backward()

# Taken from Artur's code https://github.com/ArthurConmy/sae/blob/3f8c314d9c008ec40de57828762ec5c9159e4092/sae/utils.py#L91
# TODO do we actually need this?
# Update grads so that they remove the parallel component
# (d_sae, d_in) shape
sae.remove_gradient_parallel_to_decoder_directions()
optimizer.step()

n_training_steps += 1

return sae
Loading

0 comments on commit d1095af

Please sign in to comment.