Skip to content

Commit

Permalink
get decoder fine tuning working
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus authored and jbloom-md committed Apr 18, 2024
1 parent 1666a68 commit 11a71e1
Show file tree
Hide file tree
Showing 11 changed files with 488 additions and 634 deletions.
2 changes: 1 addition & 1 deletion sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def from_config(
context_size=cfg.context_size,
d_in=cfg.d_in,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.total_training_tokens,
total_training_tokens=cfg.training_tokens,
store_batch_size=cfg.store_batch_size,
train_batch_size=cfg.train_batch_size,
prepend_bos=cfg.prepend_bos,
Expand Down
4 changes: 2 additions & 2 deletions sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
else:
os.makedirs(activations_store.cached_activations_path)

print(f"Started caching {cfg.total_training_tokens} activations")
print(f"Started caching {cfg.training_tokens} activations")
tokens_per_buffer = (
cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
)
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer)
n_buffers = math.ceil(cfg.training_tokens / tokens_per_buffer)
# for i in tqdm(range(n_buffers), desc="Caching activations"):
for i in range(n_buffers):
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer)
Expand Down
30 changes: 24 additions & 6 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class LanguageModelSAERunnerConfig:

# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
training_tokens: int = 2_000_000
finetuning_tokens: int = 0
store_batch_size: int = 32
train_batch_size: int = 4096

Expand All @@ -56,11 +57,20 @@ class LanguageModelSAERunnerConfig:
prepend_bos: bool = True

# Training Parameters

## Batch size
train_batch_size: int = 4096

## Adam
adam_beta1: float | list[float] = 0
adam_beta2: float | list[float] = 0.999

## Loss Function
mse_loss_normalization: Optional[str] = None
l1_coefficient: float | list[float] = 1e-3
lp_norm: float | list[float] = 1

## Learning Rate Schedule
lr: float | list[float] = 3e-4
lr_scheduler_name: str | list[str] = (
"constant" # constant, cosineannealing, cosineannealingwarmrestarts
Expand All @@ -71,7 +81,9 @@ class LanguageModelSAERunnerConfig:
)
lr_decay_steps: int | list[int] = 0
n_restart_cycles: int | list[int] = 1 # used only for cosineannealingwarmrestarts
train_batch_size: int = 4096

## FineTuning
finetuning_method: Optional[str] = None # scale, decoder or unrotated_decoder

# Resampling protocol args
use_ghost_grads: bool | list[bool] = (
Expand Down Expand Up @@ -111,7 +123,7 @@ def __post_init__(self):
)

if self.run_name is None:
self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"

if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
raise ValueError(
Expand All @@ -129,6 +141,12 @@ def __post_init__(self):
elif isinstance(self.dtype, str):
self.dtype: torch.dtype = DTYPE_MAP[self.dtype]

# if we use decoder fine tuning, we can't be applying b_dec to the input
if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
raise ValueError(
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
)

self.device: str | torch.device = torch.device(self.device)

if self.lr_end is None:
Expand All @@ -144,7 +162,7 @@ def __post_init__(self):

if self.verbose:
print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = (
Expand All @@ -156,7 +174,7 @@ def __post_init__(self):
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)

total_training_steps = self.total_training_tokens // self.train_batch_size
total_training_steps = self.training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
Expand Down Expand Up @@ -209,7 +227,7 @@ class CacheActivationsRunnerConfig:

# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
training_tokens: int = 2_000_000
store_batch_size: int = 32
train_batch_size: int = 4096

Expand Down
4 changes: 4 additions & 0 deletions sae_lens/training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary"
if not hasattr(cfg, "model_kwargs"):
cfg.model_kwargs = {}
sparse_autoencoder = SparseAutoencoder(cfg=cfg)
# add dummy scaling factor to the state dict
group["state_dict"]["scaling_factor"] = torch.ones(
cfg.d_sae, dtype=cfg.dtype, device=cfg.device
)
sparse_autoencoder.load_state_dict(group["state_dict"])
group = cls(cfg)
for key in group.autoencoders:
Expand Down
16 changes: 15 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def __init__(
torch.zeros(self.d_in, dtype=self.dtype, device=self.device)
)

# scaling factor for fine-tuning (not to be used in initial training)
self.scaling_factor = nn.Parameter(
torch.ones(self.d_sae, dtype=self.dtype, device=self.device)
)

self.hook_sae_in = HookPoint()
self.hook_hidden_pre = HookPoint()
self.hook_hidden_post = HookPoint()
Expand All @@ -124,7 +129,8 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)

sae_out = self.hook_sae_out(
einops.einsum(
feature_acts,
feature_acts
* self.scaling_factor, # need to make sure this handled when loading old models.
self.W_dec,
"... d_sae, d_sae d_in -> ... d_in",
)
Expand Down Expand Up @@ -330,6 +336,14 @@ def load_from_pretrained(cls, path: str, device: str = "cpu"):
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
for k in f.keys():
tensors[k] = f.get_tensor(k)

# old saves may not have scaling factors.
if "scaling_factor" not in tensors:
assert isinstance(config.d_sae, int)
tensors["scaling_factor"] = torch.ones(
config.d_sae, dtype=config.dtype, device=config.device
)

sae.load_state_dict(tensors)

return sae
Expand Down
44 changes: 43 additions & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

# used to map between parameters which are updated during finetuning and the config str.
FINETUNING_PARAMETERS = {
"scale": ["scaling_factor"],
"decoder": ["scaling_factor", "W_dec", "b_dec"],
"unrotated_decoder": ["scaling_factor", "b_dec"],
}


def _log_feature_sparsity(
feature_sparsity: torch.Tensor, eps: float = 1e-10
Expand All @@ -35,6 +42,7 @@ class SAETrainContext:
n_frac_active_tokens: int
optimizer: Optimizer
scheduler: LRScheduler
finetuning: bool = False

@property
def feature_sparsity(self) -> torch.Tensor:
Expand All @@ -44,6 +52,21 @@ def feature_sparsity(self) -> torch.Tensor:
def log_feature_sparsity(self) -> torch.Tensor:
return _log_feature_sparsity(self.feature_sparsity)

def begin_finetuning(self, sae: SparseAutoencoder):

# finetuning method should be set in the config
# if not, then we don't finetune
if not isinstance(sae.cfg.finetuning_method, str):
return

for name, param in sae.named_parameters():
if name in FINETUNING_PARAMETERS[sae.cfg.finetuning_method]:
param.requires_grad = True
else:
param.requires_grad = False

self.finetuning = True


@dataclass
class TrainSAEGroupOutput:
Expand Down Expand Up @@ -88,10 +111,13 @@ def train_sae_group_on_language_model(
use_wandb: bool = False,
wandb_log_frequency: int = 50,
) -> TrainSAEGroupOutput:
total_training_tokens = sae_group.cfg.total_training_tokens
total_training_tokens = (
sae_group.cfg.training_tokens + sae_group.cfg.finetuning_tokens
)
total_training_steps = total_training_tokens // batch_size
n_training_steps = 0
n_training_tokens = 0
started_fine_tuning = False

checkpoint_thresholds = []
if n_checkpoints > 0:
Expand Down Expand Up @@ -180,6 +206,16 @@ def train_sae_group_on_language_model(
)
pbar.update(batch_size)

### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
if (not started_fine_tuning) and (
n_training_tokens > sae_group.cfg.training_tokens
):
started_fine_tuning = True
for name, sparse_autoencoder in sae_group.autoencoders.items():
ctx = train_contexts[name]
# this should turn grads on for the scaling factor and other parameters.
ctx.begin_finetuning(sae_group.autoencoders[name])

# save final sae group to checkpoints folder
final_checkpoint = _save_checkpoint(
sae_group,
Expand Down Expand Up @@ -248,6 +284,12 @@ def _build_train_context(
)
n_frac_active_tokens = 0

# we don't train the scaling factor (initially)
# set requires grad to false for the scaling factor
for name, param in sae.named_parameters():
if "scaling_factor" in name:
param.requires_grad = False

optimizer = Adam(
sae.parameters(),
lr=sae.cfg.lr,
Expand Down
Loading

0 comments on commit 11a71e1

Please sign in to comment.