Skip to content

Commit

Permalink
feat: Save estimated norm scaling factor during checkpointing (#395)
Browse files Browse the repository at this point in the history
* refactor saving

* save estimated_norm_scaling_factor

* use new constant names elsewhere

* estimate norm scaling factor in `ActivationsStore` init

* fix tests

* add test

* tweaks

* safetensors path

* remove scaling factor on fold

* test scaling factor value

* format

* format

* undo silly change

* format

* save fn protocol

* make save fn static

* test which checkpoints have estimated norm scaling factor

* fix test

* fmt
  • Loading branch information
oli-clive-griffin authored Dec 6, 2024
1 parent 53180e0 commit 63a15a0
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 77 deletions.
39 changes: 28 additions & 11 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Tuple, TypeVar, Union, overload

import einops
Expand All @@ -28,9 +29,9 @@
get_pretrained_saes_directory,
)

SPARSITY_PATH = "sparsity.safetensors"
SAE_WEIGHTS_PATH = "sae_weights.safetensors"
SAE_CFG_PATH = "cfg.json"
SPARSITY_FILENAME = "sparsity.safetensors"
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
SAE_CFG_FILENAME = "cfg.json"

T = TypeVar("T", bound="SAE")

Expand Down Expand Up @@ -480,24 +481,40 @@ def fold_activation_norm_scaling_factor(
# once we normalize, we shouldn't need to scale activations.
self.cfg.normalize_activations = "none"

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
if not os.path.exists(path):
os.mkdir(path)
@overload
def save_model(self, path: str | Path) -> Tuple[Path, Path]: ...

@overload
def save_model(
self, path: str | Path, sparsity: torch.Tensor
) -> Tuple[Path, Path, Path]: ...

def save_model(self, path: str | Path, sparsity: Optional[torch.Tensor] = None):
path = Path(path)

if not path.exists():
path.mkdir(parents=True)

# generate the weights
state_dict = self.state_dict()
self.process_state_dict_for_saving(state_dict)
save_file(state_dict, f"{path}/{SAE_WEIGHTS_PATH}")
model_weights_path = path / SAE_WEIGHTS_FILENAME
save_file(state_dict, model_weights_path)

# save the config
config = self.cfg.to_dict()

with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
cfg_path = path / SAE_CFG_FILENAME
with open(cfg_path, "w") as f:
json.dump(config, f)

if sparsity is not None:
sparsity_in_dict = {"sparsity": sparsity}
save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}") # type: ignore
sparsity_path = path / SPARSITY_FILENAME
save_file(sparsity_in_dict, sparsity_path)
return model_weights_path, cfg_path, sparsity_path

return model_weights_path, cfg_path

# overwrite this in subclasses to modify the state_dict in-place before saving
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
Expand All @@ -512,15 +529,15 @@ def load_from_pretrained(
cls, path: str, device: str = "cpu", dtype: str | None = None
) -> "SAE":
# get the config
config_path = os.path.join(path, SAE_CFG_PATH)
config_path = os.path.join(path, SAE_CFG_FILENAME)
with open(config_path) as f:
cfg_dict = json.load(f)
cfg_dict = handle_config_defaulting(cfg_dict)
cfg_dict["device"] = device
if dtype is not None:
cfg_dict["dtype"] = dtype

weight_path = os.path.join(path, SAE_WEIGHTS_PATH)
weight_path = os.path.join(path, SAE_WEIGHTS_FILENAME)
cfg_dict, state_dict = read_sae_from_disk(
cfg_dict=cfg_dict,
weight_path=weight_path,
Expand Down
56 changes: 26 additions & 30 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import json
import os
import signal
import sys
from typing import Any, Sequence, cast
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast

import torch
import wandb
from safetensors.torch import save_file
from simple_parsing import ArgumentParser
from transformer_lens.hook_points import HookedRootModule

from sae_lens import logger
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
from sae_lens.load_model import load_model
from sae_lens.sae import SAE_CFG_PATH, SAE_WEIGHTS_PATH, SPARSITY_PATH
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.geometric_median import compute_geometric_median
from sae_lens.training.sae_trainer import SAETrainer
Expand Down Expand Up @@ -153,7 +152,7 @@ def run_trainer_with_interruption_handling(self, trainer: SAETrainer):

except (KeyboardInterrupt, InterruptedException):
logger.warning("interrupted, saving progress")
checkpoint_name = trainer.n_training_tokens
checkpoint_name = str(trainer.n_training_tokens)
self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
logger.info("done saving")
raise
Expand All @@ -180,60 +179,57 @@ def _init_sae_group_b_decs(
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore

@staticmethod
def save_checkpoint(
self,
trainer: SAETrainer,
checkpoint_name: int | str,
checkpoint_name: str,
wandb_aliases: list[str] | None = None,
) -> str:
checkpoint_path = f"{trainer.cfg.checkpoint_path}/{checkpoint_name}"
) -> None:
base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
base_path.mkdir(exist_ok=True, parents=True)

os.makedirs(checkpoint_path, exist_ok=True)
trainer.activations_store.save(
str(base_path / "activations_store_state.safetensors")
)

path = f"{checkpoint_path}"
os.makedirs(path, exist_ok=True)
if trainer.sae.cfg.normalize_sae_decoder:
trainer.sae.set_decoder_norm_to_unit_norm()

if self.sae.cfg.normalize_sae_decoder:
self.sae.set_decoder_norm_to_unit_norm()
self.sae.save_model(path)
weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
str(base_path),
trainer.log_feature_sparsity,
)

# let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
# and should not cause issues but give us more info about SAEs we trained in SAE Lens.
config = trainer.cfg.to_dict()
with open(f"{path}/cfg.json", "w") as f:
with open(cfg_path, "w") as f:
json.dump(config, f)

log_feature_sparsities = {"sparsity": trainer.log_feature_sparsity}

log_feature_sparsity_path = f"{path}/{SPARSITY_PATH}"
save_file(log_feature_sparsities, log_feature_sparsity_path)

if trainer.cfg.log_to_wandb and os.path.exists(log_feature_sparsity_path):
if trainer.cfg.log_to_wandb:
# Avoid wandb saving errors such as:
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
sae_name = self.sae.get_name().replace("/", "__")
sae_name = trainer.sae.get_name().replace("/", "__")

# save model weights and cfg
model_artifact = wandb.Artifact(
sae_name,
type="model",
metadata=dict(trainer.cfg.__dict__),
)

model_artifact.add_file(f"{path}/{SAE_WEIGHTS_PATH}")
model_artifact.add_file(f"{path}/{SAE_CFG_PATH}")

model_artifact.add_file(str(weights_path))
model_artifact.add_file(str(cfg_path))
wandb.log_artifact(model_artifact, aliases=wandb_aliases)

# save log feature sparsity
sparsity_artifact = wandb.Artifact(
f"{sae_name}_log_feature_sparsity",
type="log_feature_sparsity",
metadata=dict(trainer.cfg.__dict__),
)
sparsity_artifact.add_file(log_feature_sparsity_path)
sparsity_artifact.add_file(str(sparsity_path))
wandb.log_artifact(sparsity_artifact)

return checkpoint_path


def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
if len(args) == 0:
Expand Down
25 changes: 23 additions & 2 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import json
import os
import warnings
from typing import Any, Generator, Iterator, Literal, cast
from collections.abc import Generator, Iterator
from typing import Any, Literal, cast

import datasets
import numpy as np
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(

self.n_dataset_processed = 0

self.estimated_norm_scaling_factor = 1.0
self.estimated_norm_scaling_factor = None

# Check if dataset is tokenized
dataset_sample = next(iter(self.dataset))
Expand Down Expand Up @@ -395,10 +396,22 @@ def load_cached_activation_dataset(self) -> Dataset | None:

return activations_dataset

def set_norm_scaling_factor_if_needed(self):
if self.normalize_activations == "expected_average_only_in":
self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()

def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
if self.estimated_norm_scaling_factor is None:
raise ValueError(
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
)
return activations * self.estimated_norm_scaling_factor

def unscale(self, activations: torch.Tensor) -> torch.Tensor:
if self.estimated_norm_scaling_factor is None:
raise ValueError(
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
)
return activations / self.estimated_norm_scaling_factor

def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
Expand All @@ -410,7 +423,10 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e
for _ in tqdm(
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
):
# temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
self.estimated_norm_scaling_factor = 1.0
acts = self.next_batch()
self.estimated_norm_scaling_factor = None
norms_per_batch.append(acts.norm(dim=-1).mean().item())
mean_norm = np.mean(norms_per_batch)
return np.sqrt(self.d_in) / mean_norm
Expand Down Expand Up @@ -701,9 +717,14 @@ def state_dict(self) -> dict[str, torch.Tensor]:
}
if self._storage_buffer is not None: # first time might be None
result["storage_buffer"] = self._storage_buffer
if self.estimated_norm_scaling_factor is not None:
result["estimated_norm_scaling_factor"] = torch.tensor(
self.estimated_norm_scaling_factor
)
return result

def save(self, file_path: str):
"""save the state dict to a file in safetensors format"""
save_file(self.state_dict(), file_path)


Expand Down
39 changes: 21 additions & 18 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, Optional, Protocol, cast

import torch
import wandb
Expand Down Expand Up @@ -43,6 +43,15 @@ class TrainSAEOutput:
log_feature_sparsities: torch.Tensor


class SaveCheckpointFn(Protocol):
def __call__(
self,
trainer: "SAETrainer",
checkpoint_name: str,
wandb_aliases: Optional[list[str]] = None,
) -> None: ...


class SAETrainer:
"""
Core SAE class used for inference. For training, see TrainingSAE.
Expand All @@ -53,12 +62,12 @@ def __init__(
model: HookedRootModule,
sae: TrainingSAE,
activation_store: ActivationsStore,
save_checkpoint_fn, # type: ignore
save_checkpoint_fn: SaveCheckpointFn,
cfg: LanguageModelSAERunnerConfig,
) -> None:
self.model = model
self.sae = sae
self.activation_store = activation_store
self.activations_store = activation_store
self.save_checkpoint = save_checkpoint_fn
self.cfg = cfg

Expand Down Expand Up @@ -165,12 +174,14 @@ def dead_neurons(self) -> torch.Tensor:
def fit(self) -> TrainingSAE:
pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")

self._estimate_norm_scaling_factor_if_needed()
self.activations_store.set_norm_scaling_factor_if_needed()

# Train loop
while self.n_training_tokens < self.cfg.total_training_tokens:
# Do a training step.
layer_acts = self.activation_store.next_batch()[:, 0, :].to(self.sae.device)
layer_acts = self.activations_store.next_batch()[:, 0, :].to(
self.sae.device
)
self.n_training_tokens += self.cfg.train_batch_size_tokens

step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
Expand All @@ -187,10 +198,11 @@ def fit(self) -> TrainingSAE:
self._begin_finetuning_if_needed()

# fold the estimated norm scaling factor into the sae weights
if self.activation_store.estimated_norm_scaling_factor is not None:
if self.activations_store.estimated_norm_scaling_factor is not None:
self.sae.fold_activation_norm_scaling_factor(
self.activation_store.estimated_norm_scaling_factor
self.activations_store.estimated_norm_scaling_factor
)
self.activations_store.estimated_norm_scaling_factor = None

# save final sae group to checkpoints folder
self.save_checkpoint(
Expand All @@ -202,15 +214,6 @@ def fit(self) -> TrainingSAE:
pbar.close()
return self.sae

@torch.no_grad()
def _estimate_norm_scaling_factor_if_needed(self) -> None:
if self.cfg.normalize_activations == "expected_average_only_in":
self.activation_store.estimated_norm_scaling_factor = (
self.activation_store.estimate_norm_scaling_factor()
)
else:
self.activation_store.estimated_norm_scaling_factor = 1.0

def _train_step(
self,
sae: TrainingSAE,
Expand Down Expand Up @@ -331,7 +334,7 @@ def _run_and_log_evals(self):
self.sae.eval()
eval_metrics, _ = run_evals(
sae=self.sae,
activation_store=self.activation_store,
activation_store=self.activations_store,
model=self.model,
eval_config=self.trainer_eval_config,
model_kwargs=self.cfg.model_kwargs,
Expand Down Expand Up @@ -392,7 +395,7 @@ def _checkpoint_if_needed(self):
):
self.save_checkpoint(
trainer=self,
checkpoint_name=self.n_training_tokens,
checkpoint_name=str(self.n_training_tokens),
)
self.checkpoint_thresholds.pop(0)

Expand Down
Loading

0 comments on commit 63a15a0

Please sign in to comment.