diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index cfb9ed7c..6944159e 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -93,7 +93,7 @@ class LanguageModelSAERunnerConfig: lr_scheduler_name: str | list[str] = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts ) - lr_warm_up_steps: int | list[int] = 500 + lr_warm_up_steps: int | list[int] = 0 lr_end: float | list[float] | None = ( None # only used for cosine annealing, default is lr / 10 ) diff --git a/sae_lens/training/optim.py b/sae_lens/training/optim.py index 1329f163..ee6d7ade 100644 --- a/sae_lens/training/optim.py +++ b/sae_lens/training/optim.py @@ -109,7 +109,6 @@ def __init__( sparse_autoencoder: SparseAutoencoder, ): - self.type = type self.l1_warmup_steps = l1_warm_up_steps self.total_steps = total_steps self.sparse_autoencoder = sparse_autoencoder diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index 3477ae83..df927bee 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -15,7 +15,7 @@ from sae_lens import __version__ from sae_lens.training.activations_store import ActivationsStore -from sae_lens.training.optim import L1Scheduler, get_lr_scheduler +from sae_lens.training.optim import L1Scheduler from sae_lens.training.sae_group import SparseAutoencoderDictionary from sae_lens.training.sparse_autoencoder import ( SAE_CFG_PATH, @@ -31,6 +31,7 @@ SAETrainContext, SAETrainingRunState, TrainStepOutput, + _build_train_context, _build_train_step_log_dict, _log_feature_sparsity, _save_checkpoint, @@ -42,46 +43,26 @@ from tests.unit.helpers import build_sae_cfg, load_model_cached -# TODO: Address why we have this code here rather than importing it. def build_train_ctx( sae: SparseAutoencoder, act_freq_scores: Tensor | None = None, n_forward_passes_since_fired: Tensor | None = None, n_frac_active_tokens: int = 0, ) -> SAETrainContext: - """ - Factory helper to build a default SAETrainContext object. - """ - assert sae.cfg.d_sae is not None - assert not isinstance(sae.cfg.lr, list) - optimizer = torch.optim.Adam(sae.parameters(), lr=sae.cfg.lr) - return SAETrainContext( - act_freq_scores=( - torch.zeros(sae.cfg.d_sae) if act_freq_scores is None else act_freq_scores - ), - n_forward_passes_since_fired=( - torch.zeros(sae.cfg.d_sae) - if n_forward_passes_since_fired is None - else n_forward_passes_since_fired - ), - n_frac_active_tokens=n_frac_active_tokens, - optimizer=optimizer, - lr_scheduler=get_lr_scheduler( - "constant", - lr=sae.cfg.lr, - optimizer=optimizer, - training_steps=1000, - lr_end=0, - warm_up_steps=0, - decay_steps=0, - num_cycles=1, - ), - l1_scheduler=L1Scheduler( - l1_warm_up_steps=0, - total_steps=sae.cfg.training_tokens, - sparse_autoencoder=sae, - ), - ) + """Builds a training context. We need to have this version so we can override some attributes.""" + # Build train context + ctx = _build_train_context(sae, sae.cfg.training_tokens) + # Override attributes if required for testing + ctx.n_frac_active_tokens = n_frac_active_tokens + if n_forward_passes_since_fired is not None: + ctx.n_forward_passes_since_fired = n_forward_passes_since_fired + else: + ctx.n_forward_passes_since_fired = torch.zeros(sae.cfg.d_sae) # type: ignore + if act_freq_scores is not None: + ctx.act_freq_scores = act_freq_scores + else: + ctx.act_freq_scores = torch.zeros(sae.cfg.d_sae) # type: ignore + return ctx def modify_sae_output( @@ -411,7 +392,7 @@ def assert_close(sd1: Any, sd2: Any): assert train_contexts_2.keys() == train_contexts.keys() for k in train_contexts.keys(): ctx1 = train_contexts[k] - ctx2 = train_contexts[k] + ctx2 = train_contexts_2[k] sd1 = ctx1.state_dict() sd2 = ctx2.state_dict() assert_close(sd1, sd2)