Skip to content

Commit

Permalink
feat: Hooked toy model (#134)
Browse files Browse the repository at this point in the history
* adds initial re-implementation of toy models

* removes instance dimension from toy models

* fixing up minor nits and adding more tests

---------

Co-authored-by: David Chanin <[email protected]>
  • Loading branch information
evanhanders and chanind authored May 11, 2024
1 parent 1a3bedb commit 03aa25c
Show file tree
Hide file tree
Showing 5 changed files with 411 additions and 233 deletions.
53 changes: 53 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,59 @@ def __post_init__(self):
)


@dataclass
class ToyModelSAERunnerConfig:
# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
n_correlated_pairs: int = 0
n_anticorrelated_pairs: int = 0
feature_probability: float = 0.025
model_training_steps: int = 10_000

# SAE Parameters
d_sae: int = 5

# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 1024
b_dec_init_method: str = "geometric_median"

# Sparsity / Dead Feature Handling
use_ghost_grads: bool = (
False # not currently implemented, but SAE class expects it.
)
feature_sampling_window: int = 100
dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8

# Activation Store Parameters
total_training_tokens: int = 25_000

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str | None = None
wandb_log_frequency: int = 50

# Misc
device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: str | torch.dtype = "float32"

def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE

if isinstance(self.dtype, str) and self.dtype not in DTYPE_MAP:
raise ValueError(
f"dtype must be one of {list(DTYPE_MAP.keys())}. Got {self.dtype}"
)
elif isinstance(self.dtype, str):
self.dtype = DTYPE_MAP[self.dtype]


def _default_cached_activations_path(
dataset_path: str,
model_name: str,
Expand Down
61 changes: 7 additions & 54 deletions sae_lens/training/toy_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,33 @@
from dataclasses import dataclass
from typing import Any, cast

import einops
import torch
import wandb

from sae_lens.training.config import ToyModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.toy_models import Config as ToyConfig
from sae_lens.training.toy_models import Model as ToyModel
from sae_lens.training.toy_models import ReluOutputModel as ToyModel
from sae_lens.training.toy_models import ToyConfig
from sae_lens.training.train_sae_on_toy_model import train_toy_sae


@dataclass
class SAEToyModelRunnerConfig:
# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
n_correlated_pairs: int = 0
n_anticorrelated_pairs: int = 0
feature_probability: float = 0.025
model_training_steps: int = 10_000

# SAE Parameters
d_sae: int = 5

# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
train_batch_size: int = 1024
b_dec_init_method: str = "geometric_median"

# Sparsity / Dead Feature Handling
use_ghost_grads: bool = (
False # not currently implemented, but SAE class expects it.
)
feature_sampling_window: int = 100
dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8

# Activation Store Parameters
total_training_tokens: int = 25_000

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_toy_model"
wandb_entity: str | None = None
wandb_log_frequency: int = 50

# Misc
device: str = "cpu"
seed: int = 42
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = torch.float32

def __post_init__(self):
self.d_in = self.n_hidden # hidden for the ReLu model is the input for the SAE


def toy_model_sae_runner(cfg: SAEToyModelRunnerConfig):
def toy_model_sae_runner(cfg: ToyModelSAERunnerConfig):
"""
A runner for training an SAE on a toy model.
"""
# Toy Model Config
toy_model_cfg = ToyConfig(
n_instances=1, # Not set up to train > 1 SAE so shouldn't do > 1 model.
n_features=cfg.n_features,
n_hidden=cfg.n_hidden,
n_correlated_pairs=cfg.n_correlated_pairs,
n_anticorrelated_pairs=cfg.n_anticorrelated_pairs,
feature_probability=cfg.feature_probability,
)

# Initialize Toy Model
model = ToyModel(
cfg=toy_model_cfg,
device=cfg.device,
feature_probability=cfg.feature_probability,
device=torch.device(cfg.device),
)

# Train the Toy Model
Expand All @@ -85,7 +38,7 @@ def toy_model_sae_runner(cfg: SAEToyModelRunnerConfig):
hidden = einops.einsum(
batch,
model.W,
"batch_size instances features, instances hidden features -> batch_size instances hidden",
"batch_size features, hidden features -> batch_size hidden",
)

sparse_autoencoder = SparseAutoencoder(
Expand Down
Loading

0 comments on commit 03aa25c

Please sign in to comment.