Skip to content

Commit

Permalink
fix: Fix issues with resumption testing (#144)
Browse files Browse the repository at this point in the history
* fix always-true comparison in train context testing

* set default warmup steps to zero

* remove unused type attribute from L1Scheduler

* update training tests to use real context builder

* add docstring for build_train_ctx
  • Loading branch information
tomMcGrath authored May 13, 2024
1 parent 6fadcfd commit 085d04f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 38 deletions.
2 changes: 1 addition & 1 deletion sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class LanguageModelSAERunnerConfig:
lr_scheduler_name: str | list[str] = (
"constant" # constant, cosineannealing, cosineannealingwarmrestarts
)
lr_warm_up_steps: int | list[int] = 500
lr_warm_up_steps: int | list[int] = 0
lr_end: float | list[float] | None = (
None # only used for cosine annealing, default is lr / 10
)
Expand Down
1 change: 0 additions & 1 deletion sae_lens/training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
sparse_autoencoder: SparseAutoencoder,
):

self.type = type
self.l1_warmup_steps = l1_warm_up_steps
self.total_steps = total_steps
self.sparse_autoencoder = sparse_autoencoder
Expand Down
53 changes: 17 additions & 36 deletions tests/unit/training/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from sae_lens import __version__
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
from sae_lens.training.optim import L1Scheduler
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from sae_lens.training.sparse_autoencoder import (
SAE_CFG_PATH,
Expand All @@ -31,6 +31,7 @@
SAETrainContext,
SAETrainingRunState,
TrainStepOutput,
_build_train_context,
_build_train_step_log_dict,
_log_feature_sparsity,
_save_checkpoint,
Expand All @@ -42,46 +43,26 @@
from tests.unit.helpers import build_sae_cfg, load_model_cached


# TODO: Address why we have this code here rather than importing it.
def build_train_ctx(
sae: SparseAutoencoder,
act_freq_scores: Tensor | None = None,
n_forward_passes_since_fired: Tensor | None = None,
n_frac_active_tokens: int = 0,
) -> SAETrainContext:
"""
Factory helper to build a default SAETrainContext object.
"""
assert sae.cfg.d_sae is not None
assert not isinstance(sae.cfg.lr, list)
optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr)
return SAETrainContext(
act_freq_scores=(
torch.zeros(sae.cfg.d_sae) if act_freq_scores is None else act_freq_scores
),
n_forward_passes_since_fired=(
torch.zeros(sae.cfg.d_sae)
if n_forward_passes_since_fired is None
else n_forward_passes_since_fired
),
n_frac_active_tokens=n_frac_active_tokens,
optimizer=optimizer,
lr_scheduler=get_lr_scheduler(
"constant",
lr=sae.cfg.lr,
optimizer=optimizer,
training_steps=1000,
lr_end=0,
warm_up_steps=0,
decay_steps=0,
num_cycles=1,
),
l1_scheduler=L1Scheduler(
l1_warm_up_steps=0,
total_steps=sae.cfg.training_tokens,
sparse_autoencoder=sae,
),
)
"""Builds a training context. We need to have this version so we can override some attributes."""
# Build train context
ctx = _build_train_context(sae, sae.cfg.training_tokens)
# Override attributes if required for testing
ctx.n_frac_active_tokens = n_frac_active_tokens
if n_forward_passes_since_fired is not None:
ctx.n_forward_passes_since_fired = n_forward_passes_since_fired
else:
ctx.n_forward_passes_since_fired = torch.zeros(sae.cfg.d_sae) # type: ignore
if act_freq_scores is not None:
ctx.act_freq_scores = act_freq_scores
else:
ctx.act_freq_scores = torch.zeros(sae.cfg.d_sae) # type: ignore
return ctx


def modify_sae_output(
Expand Down Expand Up @@ -411,7 +392,7 @@ def assert_close(sd1: Any, sd2: Any):
assert train_contexts_2.keys() == train_contexts.keys()
for k in train_contexts.keys():
ctx1 = train_contexts[k]
ctx2 = train_contexts[k]
ctx2 = train_contexts_2[k]
sd1 = ctx1.state_dict()
sd2 = ctx2.state_dict()
assert_close(sd1, sd2)
Expand Down

0 comments on commit 085d04f

Please sign in to comment.