Skip to content

Commit

Permalink
feat: Support Gated-SAEs (#188)
Browse files Browse the repository at this point in the history
* Initial draft of encoder

* Second draft of Gated SAE implementation

* Added SFN loss implementation

* Latest modification of SFN loss training setup

* fix missing config use

* dont have special sfn loss

* add hooks and reshape

* sae error term not working, WIP

* make tests  pass

* add benchmark for gated

---------

Co-authored-by: Joseph Bloom <[email protected]>
  • Loading branch information
curt-tigges and jbloomAus authored Jun 25, 2024
1 parent 623a1eb commit 232c39c
Show file tree
Hide file tree
Showing 14 changed files with 1,669 additions and 54 deletions.
7 changes: 7 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class LanguageModelSAERunnerConfig:
)

# SAE Parameters
architecture: Literal["standard", "gated"] = "standard"
d_in: int = 512
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
Expand Down Expand Up @@ -349,6 +350,8 @@ def total_training_steps(self) -> int:

def get_base_sae_cfg_dict(self) -> dict[str, Any]:
return {
# TEMP
"architecture": self.architecture,
"d_in": self.d_in,
"d_sae": self.d_sae,
"dtype": self.dtype,
Expand Down Expand Up @@ -474,6 +477,9 @@ def __post_init__(self):

@dataclass
class ToyModelSAERunnerConfig:

architecture: Literal["standard", "gated"] = "standard"

# ReLu Model Parameters
n_features: int = 5
n_hidden: int = 2
Expand Down Expand Up @@ -527,6 +533,7 @@ def __post_init__(self):
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
# TO DO: Have the same hyperparameters as in the main sae runner.
return {
"architecture": self.architecture,
"d_in": self.d_in,
"d_sae": self.d_sae,
"dtype": self.dtype,
Expand Down
107 changes: 100 additions & 7 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Literal, Optional, Tuple

import einops
import torch
Expand All @@ -28,6 +28,8 @@

@dataclass
class SAEConfig:
# architecture details
architecture: Literal["standard", "gated"]

# forward pass details.
d_in: int
Expand Down Expand Up @@ -76,6 +78,7 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":

def to_dict(self) -> dict[str, Any]:
return {
"architecture": self.architecture,
"d_in": self.d_in,
"d_sae": self.d_sae,
"dtype": self.dtype,
Expand Down Expand Up @@ -121,7 +124,12 @@ def __init__(
self.device = torch.device(cfg.device)
self.use_error_term = use_error_term

self.initialize_weights_basic()
if self.cfg.architecture == "standard":
self.initialize_weights_basic()
self.encode_fn = self.encode
elif self.cfg.architecture == "gated":
self.initialize_weights_gated()
self.encode_fn = self.encode_gated

# handle presence / absence of scaling factor.
if self.cfg.finetuning_scaling_factor:
Expand Down Expand Up @@ -209,16 +217,50 @@ def initialize_weights_basic(self):
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

def initialize_weights_gated(self):
# Initialize the weights and biases for the gated encoder
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)

self.b_gate = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.r_mag = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.b_mag = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)

self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)

# Basic Forward Pass Functionality.
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:

feature_acts = self.encode(x)
feature_acts = self.encode_fn(x)
sae_out = self.decode(feature_acts)

if self.use_error_term:
# TEMP
if self.use_error_term and self.cfg.architecture != "gated":
with torch.no_grad():
# Recompute everything without hooks to get true error term
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
Expand Down Expand Up @@ -248,16 +290,62 @@ def forward(

sae_out = self.run_time_activation_norm_fn_out(sae_out)
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)

# TODO: Add tests
elif self.use_error_term and self.cfg.architecture == "gated":
with torch.no_grad():
x = x.to(self.dtype)
sae_in = self.reshape_fn_in(x) # type: ignore
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
active_features = (gating_pre_activation > 0).float()

# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
)
feature_magnitudes = self.hook_sae_acts_post(
self.activation_fn(magnitude_pre_activation)
)
feature_acts_clean = active_features * feature_magnitudes
x_reconstruct_clean = self.reshape_fn_out(
self.apply_finetuning_scaling_factor(feature_acts_clean)
@ self.W_dec
+ self.b_dec,
d_head=self.d_head,
)

sae_error = self.hook_sae_error(x - x_reconstruct_clean)
return self.hook_sae_output(sae_out + sae_error)

return self.hook_sae_output(sae_out)

def encode_gated(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
x = x.to(self.dtype)
x = self.reshape_fn_in(x)
sae_in = self.hook_sae_input(x - self.b_dec * self.cfg.apply_b_dec_to_input)

# Gating path
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
active_features = (gating_pre_activation > 0).float()

# Magnitude path with weight sharing
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
)
feature_magnitudes = self.hook_sae_acts_post(
self.activation_fn(magnitude_pre_activation)
)

return active_features * feature_magnitudes

def encode(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
"""
Calcuate SAE features from inputs
Calculate SAE features from inputs
"""

# move x to correct dtype
Expand Down Expand Up @@ -301,7 +389,12 @@ def fold_W_dec_norm(self):
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
if self.cfg.architecture == "gated":
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
else:
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()

@torch.no_grad()
def fold_activation_norm_scaling_factor(
Expand Down
5 changes: 5 additions & 0 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def connor_rob_hook_z_loader(
# }

cfg_dict = {
"architecture": "standard",
"d_in": old_cfg_dict["act_size"],
"d_sae": old_cfg_dict["dict_size"],
"dtype": "float32",
Expand Down Expand Up @@ -169,6 +170,10 @@ def load_pretrained_sae_lens_sae_components(
with open(cfg_path, "r") as f:
cfg_dict = json.load(f)

cfg_dict["architecture"] = (
"standard" # TODO: modify this when we add support for loading more architectures
)

# filter config for varnames
cfg_dict["device"] = device
cfg_dict["dtype"] = dtype
Expand Down
25 changes: 12 additions & 13 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def _build_train_step_log_dict(
"losses/mse_loss": mse_loss,
"losses/l1_loss": l1_loss
/ self.current_l1_coefficient, # normalize by l1 coefficient
"losses/auxiliary_reconstruction_loss": output.auxiliary_reconstruction_loss,
"losses/ghost_grad_loss": ghost_grad_loss,
"losses/overall_loss": loss,
# variance explained
Expand Down Expand Up @@ -317,19 +318,17 @@ def _run_and_log_evals(self):
model_kwargs=self.cfg.model_kwargs,
)

W_dec_norm_dist = self.sae.W_dec.norm(dim=1).detach().float().cpu().numpy()
b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()

# More detail on loss.

# add weight histograms
eval_metrics = {
**eval_metrics,
**{
"weights/W_dec_norms": wandb.Histogram(W_dec_norm_dist),
"weights/b_e": wandb.Histogram(b_e_dist),
},
}
W_dec_norm_dist = self.sae.W_dec.norm(dim=1).detach().cpu().numpy()
eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore

if self.sae.cfg.architecture == "standard":
b_e_dist = self.sae.b_enc.detach().cpu().numpy()
eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
elif self.sae.cfg.architecture == "gated":
b_gate_dist = self.sae.b_gate.detach().cpu().numpy()
eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
b_mag_dist = self.sae.b_mag.detach().cpu().numpy()
eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore

wandb.log(
eval_metrics,
Expand Down
Loading

0 comments on commit 232c39c

Please sign in to comment.