diff --git a/sae_lens/config.py b/sae_lens/config.py index 2175df33..085ee36c 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -124,6 +124,7 @@ class LanguageModelSAERunnerConfig: ) # SAE Parameters + architecture: Literal["standard", "gated"] = "standard" d_in: int = 512 d_sae: Optional[int] = None b_dec_init_method: str = "geometric_median" @@ -349,6 +350,8 @@ def total_training_steps(self) -> int: def get_base_sae_cfg_dict(self) -> dict[str, Any]: return { + # TEMP + "architecture": self.architecture, "d_in": self.d_in, "d_sae": self.d_sae, "dtype": self.dtype, @@ -474,6 +477,9 @@ def __post_init__(self): @dataclass class ToyModelSAERunnerConfig: + + architecture: Literal["standard", "gated"] = "standard" + # ReLu Model Parameters n_features: int = 5 n_hidden: int = 2 @@ -527,6 +533,7 @@ def __post_init__(self): def get_base_sae_cfg_dict(self) -> dict[str, Any]: # TO DO: Have the same hyperparameters as in the main sae runner. return { + "architecture": self.architecture, "d_in": self.d_in, "d_sae": self.d_sae, "dtype": self.dtype, diff --git a/sae_lens/sae.py b/sae_lens/sae.py index dd51d7e1..bd0f91b2 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -5,7 +5,7 @@ import json import os from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Literal, Optional, Tuple import einops import torch @@ -28,6 +28,8 @@ @dataclass class SAEConfig: + # architecture details + architecture: Literal["standard", "gated"] # forward pass details. d_in: int @@ -76,6 +78,7 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": def to_dict(self) -> dict[str, Any]: return { + "architecture": self.architecture, "d_in": self.d_in, "d_sae": self.d_sae, "dtype": self.dtype, @@ -121,7 +124,12 @@ def __init__( self.device = torch.device(cfg.device) self.use_error_term = use_error_term - self.initialize_weights_basic() + if self.cfg.architecture == "standard": + self.initialize_weights_basic() + self.encode_fn = self.encode + elif self.cfg.architecture == "gated": + self.initialize_weights_gated() + self.encode_fn = self.encode_gated # handle presence / absence of scaling factor. if self.cfg.finetuning_scaling_factor: @@ -209,16 +217,50 @@ def initialize_weights_basic(self): torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device) ) + def initialize_weights_gated(self): + # Initialize the weights and biases for the gated encoder + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty( + self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device + ) + ) + ) + + self.b_gate = nn.Parameter( + torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) + ) + + self.r_mag = nn.Parameter( + torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) + ) + + self.b_mag = nn.Parameter( + torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) + ) + + self.W_dec = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty( + self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + ) + ) + ) + + self.b_dec = nn.Parameter( + torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) + ) + # Basic Forward Pass Functionality. def forward( self, x: torch.Tensor, ) -> torch.Tensor: - - feature_acts = self.encode(x) + feature_acts = self.encode_fn(x) sae_out = self.decode(feature_acts) - if self.use_error_term: + # TEMP + if self.use_error_term and self.cfg.architecture != "gated": with torch.no_grad(): # Recompute everything without hooks to get true error term # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct @@ -248,16 +290,62 @@ def forward( sae_out = self.run_time_activation_norm_fn_out(sae_out) sae_error = self.hook_sae_error(x - x_reconstruct_clean) + return self.hook_sae_output(sae_out + sae_error) + + # TODO: Add tests + elif self.use_error_term and self.cfg.architecture == "gated": + with torch.no_grad(): + x = x.to(self.dtype) + sae_in = self.reshape_fn_in(x) # type: ignore + gating_pre_activation = sae_in @ self.W_enc + self.b_gate + active_features = (gating_pre_activation > 0).float() + + # Magnitude path with weight sharing + magnitude_pre_activation = self.hook_sae_acts_pre( + sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag + ) + feature_magnitudes = self.hook_sae_acts_post( + self.activation_fn(magnitude_pre_activation) + ) + feature_acts_clean = active_features * feature_magnitudes + x_reconstruct_clean = self.reshape_fn_out( + self.apply_finetuning_scaling_factor(feature_acts_clean) + @ self.W_dec + + self.b_dec, + d_head=self.d_head, + ) + sae_error = self.hook_sae_error(x - x_reconstruct_clean) return self.hook_sae_output(sae_out + sae_error) return self.hook_sae_output(sae_out) + def encode_gated( + self, x: Float[torch.Tensor, "... d_in"] + ) -> Float[torch.Tensor, "... d_sae"]: + x = x.to(self.dtype) + x = self.reshape_fn_in(x) + sae_in = self.hook_sae_input(x - self.b_dec * self.cfg.apply_b_dec_to_input) + + # Gating path + gating_pre_activation = sae_in @ self.W_enc + self.b_gate + active_features = (gating_pre_activation > 0).float() + + # Magnitude path with weight sharing + magnitude_pre_activation = self.hook_sae_acts_pre( + sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag + ) + feature_magnitudes = self.hook_sae_acts_post( + self.activation_fn(magnitude_pre_activation) + ) + + return active_features * feature_magnitudes + def encode( self, x: Float[torch.Tensor, "... d_in"] ) -> Float[torch.Tensor, "... d_sae"]: """ - Calcuate SAE features from inputs + Calculate SAE features from inputs """ # move x to correct dtype @@ -301,7 +389,12 @@ def fold_W_dec_norm(self): W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1) self.W_dec.data = self.W_dec.data / W_dec_norms self.W_enc.data = self.W_enc.data * W_dec_norms.T - self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + if self.cfg.architecture == "gated": + self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze() + self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() + self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() + else: + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() @torch.no_grad() def fold_activation_norm_scaling_factor( diff --git a/sae_lens/toolkit/pretrained_sae_loaders.py b/sae_lens/toolkit/pretrained_sae_loaders.py index e5e9eb1d..3bc72a4d 100644 --- a/sae_lens/toolkit/pretrained_sae_loaders.py +++ b/sae_lens/toolkit/pretrained_sae_loaders.py @@ -124,6 +124,7 @@ def connor_rob_hook_z_loader( # } cfg_dict = { + "architecture": "standard", "d_in": old_cfg_dict["act_size"], "d_sae": old_cfg_dict["dict_size"], "dtype": "float32", @@ -169,6 +170,10 @@ def load_pretrained_sae_lens_sae_components( with open(cfg_path, "r") as f: cfg_dict = json.load(f) + cfg_dict["architecture"] = ( + "standard" # TODO: modify this when we add support for loading more architectures + ) + # filter config for varnames cfg_dict["device"] = device cfg_dict["dtype"] = dtype diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 87ec8c27..34c7c0d5 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -287,6 +287,7 @@ def _build_train_step_log_dict( "losses/mse_loss": mse_loss, "losses/l1_loss": l1_loss / self.current_l1_coefficient, # normalize by l1 coefficient + "losses/auxiliary_reconstruction_loss": output.auxiliary_reconstruction_loss, "losses/ghost_grad_loss": ghost_grad_loss, "losses/overall_loss": loss, # variance explained @@ -317,19 +318,17 @@ def _run_and_log_evals(self): model_kwargs=self.cfg.model_kwargs, ) - W_dec_norm_dist = self.sae.W_dec.norm(dim=1).detach().float().cpu().numpy() - b_e_dist = self.sae.b_enc.detach().float().cpu().numpy() - - # More detail on loss. - - # add weight histograms - eval_metrics = { - **eval_metrics, - **{ - "weights/W_dec_norms": wandb.Histogram(W_dec_norm_dist), - "weights/b_e": wandb.Histogram(b_e_dist), - }, - } + W_dec_norm_dist = self.sae.W_dec.norm(dim=1).detach().cpu().numpy() + eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore + + if self.sae.cfg.architecture == "standard": + b_e_dist = self.sae.b_enc.detach().cpu().numpy() + eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore + elif self.sae.cfg.architecture == "gated": + b_gate_dist = self.sae.b_gate.detach().cpu().numpy() + eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore + b_mag_dist = self.sae.b_mag.detach().cpu().numpy() + eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore wandb.log( eval_metrics, diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 7f99d1d2..e21acff3 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -27,6 +27,7 @@ class TrainStepOutput: mse_loss: float l1_loss: float ghost_grad_loss: float + auxiliary_reconstruction_loss: float = 0.0 @dataclass @@ -50,7 +51,8 @@ def from_sae_runner_config( ) -> "TrainingSAEConfig": return cls( - # base confg + # base config + architecture=cfg.architecture, d_in=cfg.d_in, d_sae=cfg.d_sae, # type: ignore dtype=cfg.dtype, @@ -104,6 +106,7 @@ def to_dict(self) -> dict[str, Any]: # parameters. Maybe there's a cleaner way to do this def get_base_sae_cfg_dict(self) -> dict[str, Any]: return { + "architecture": self.architecture, "d_in": self.d_in, "d_sae": self.d_sae, "activation_fn_str": self.activation_fn_str, @@ -140,6 +143,15 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) super().__init__(base_sae_cfg) self.cfg = cfg # type: ignore + + self.encode_with_hidden_pre_fn = ( + self.encode_with_hidden_pre + if cfg.architecture != "gated" + else self.encode_with_hidden_pre_gated + ) + + self.check_cfg_compatibility() + self.use_error_term = use_error_term self.initialize_weights_complex() @@ -154,13 +166,20 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) + def check_cfg_compatibility(self): + if self.cfg.architecture == "gated": + assert ( + self.cfg.use_ghost_grads is False + ), "Gated SAEs do not support ghost grads" + assert self.use_error_term is False, "Gated SAEs do not support error terms" + def encode( self, x: Float[torch.Tensor, "... d_in"] ) -> Float[torch.Tensor, "... d_sae"]: """ Calcuate SAE features from inputs """ - feature_acts, _ = self.encode_with_hidden_pre(x) + feature_acts, _ = self.encode_with_hidden_pre_fn(x) return feature_acts def encode_with_hidden_pre( @@ -188,12 +207,44 @@ def encode_with_hidden_pre( return feature_acts, hidden_pre_noised + def encode_with_hidden_pre_gated( + self, x: Float[torch.Tensor, "... d_in"] + ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: + + # move x to correct dtype + x = x.to(self.dtype) + + # handle hook z reshaping if needed. + x = self.reshape_fn_in(x) # type: ignore + + # apply b_dec_to_input if using that method. + sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) + + # Gating path with Heaviside step function + gating_pre_activation = sae_in @ self.W_enc + self.b_gate + active_features = (gating_pre_activation > 0).float() + + # Magnitude path with weight sharing + magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag + # magnitude_pre_activation_noised = magnitude_pre_activation + ( + # torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training + # ) + feature_magnitudes = self.activation_fn( + magnitude_pre_activation + ) # magnitude_pre_activation_noised) + + # Return both the gated feature activations and the magnitude pre-activations + return ( + active_features * feature_magnitudes, + magnitude_pre_activation, + ) # magnitude_pre_activation_noised + def forward( self, x: Float[torch.Tensor, "... d_in"], ) -> Float[torch.Tensor, "... d_in"]: - feature_acts, _ = self.encode_with_hidden_pre(x) + feature_acts, _ = self.encode_with_hidden_pre_fn(x) sae_out = self.decode(feature_acts) return sae_out @@ -207,7 +258,7 @@ def training_forward_pass( # do a forward pass to get SAE out, but we also need the # hidden pre. - feature_acts, _ = self.encode_with_hidden_pre(sae_in) + feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in) sae_out = self.decode(feature_acts) # MSE LOSS @@ -218,7 +269,7 @@ def training_forward_pass( if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: # first half of second forward pass - _, hidden_pre = self.encode_with_hidden_pre(sae_in) + _, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) ghost_grad_loss = self.calculate_ghost_grad_loss( x=sae_in, sae_out=sae_out, @@ -229,17 +280,40 @@ def training_forward_pass( else: ghost_grad_loss = 0.0 - # SPARSITY LOSS - # either the W_dec norms are 1 and this won't do anything or they are not 1 - # and we're using their norm in the loss function. - weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1) - sparsity = weighted_feature_acts.norm( - p=self.cfg.lp_norm, dim=-1 - ) # sum over the feature dimension + if self.cfg.architecture == "gated": + # Gated SAE Loss Calculation + + # Shared variables + sae_in_centered = ( + self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input + ) + pi_gate = sae_in_centered @ self.W_enc + self.b_gate + pi_gate_act = torch.relu(pi_gate) + + # SFN sparsity loss - summed over the feature dimension and averaged over the batch + l1_loss = ( + current_l1_coefficient + * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean() + ) + + # Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch + via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec + aux_reconstruction_loss = torch.sum( + (via_gate_reconstruction - sae_in) ** 2, dim=-1 + ).mean() + + loss = mse_loss + l1_loss + aux_reconstruction_loss + else: + # default SAE sparsity loss + weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1) + sparsity = weighted_feature_acts.norm( + p=self.cfg.lp_norm, dim=-1 + ) # sum over the feature dimension - l1_loss = (current_l1_coefficient * sparsity).mean() + l1_loss = (current_l1_coefficient * sparsity).mean() + loss = mse_loss + l1_loss + ghost_grad_loss - loss = mse_loss + l1_loss + ghost_grad_loss + aux_reconstruction_loss = torch.tensor(0.0) return TrainStepOutput( sae_in=sae_in, @@ -253,6 +327,7 @@ def training_forward_pass( if isinstance(ghost_grad_loss, torch.Tensor) else ghost_grad_loss ), + auxiliary_reconstruction_loss=aux_reconstruction_loss.item(), ) def calculate_ghost_grad_loss( @@ -354,7 +429,7 @@ def initialize_weights_complex(self): elif self.cfg.normalize_sae_decoder: self.set_decoder_norm_to_unit_norm() - # Then we intialize the encoder weights (either as the transpose of decoder or not) + # Then we initialize the encoder weights (either as the transpose of decoder or not) if self.cfg.init_encoder_as_decoder_transpose: self.W_enc.data = self.W_dec.data.T.clone().contiguous() else: diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index d5eb3e29..55293b04 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -103,3 +103,101 @@ def test_language_model_sae_runner(): assert sae is not None # know whether or not this works by looking at the dashboard! + + +def test_language_model_sae_runner_gated(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # total_training_steps = 20_000 + total_training_steps = 500 + batch_size = 4096 + total_training_tokens = total_training_steps * batch_size + print(f"Total Training Tokens: {total_training_tokens}") + + lr_warm_up_steps = 0 + lr_decay_steps = 40_000 + print(f"lr_decay_steps: {lr_decay_steps}") + l1_warmup_steps = 10_000 + print(f"l1_warmup_steps: {l1_warmup_steps}") + + cfg = LanguageModelSAERunnerConfig( + # Pick a tiny model to make this easier. + model_name="gelu-1l", + architecture="gated", + ## MLP Layer 0 ## + hook_name="blocks.0.hook_mlp_out", + hook_layer=0, + d_in=512, + dataset_path="NeelNanda/c4-tokenized-2b", + context_size=256, + is_dataset_tokenized=True, + prepend_bos=True, # I used to train GPT2 SAEs with a prepended-bos but no longer think we should do this. + # How big do we want our SAE to be? + expansion_factor=16, + # Dataset / Activation Store + # When we do a proper test + # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) + # For now. + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, + # Loss Function + ## Reconstruction Coefficient. + mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. + ## Anthropic does not mention using an Lp norm other than L1. + l1_coefficient=5, + lp_norm=1.0, + # Instead, they multiply the L1 loss contribution + # from each feature of the activations by the decoder norm of the corresponding feature. + scale_sparsity_penalty_by_decoder_norm=True, + # Learning Rate + lr_scheduler_name="constant", # we set this independently of warmup and decay steps. + l1_warm_up_steps=l1_warmup_steps, + lr_warm_up_steps=lr_warm_up_steps, + lr_decay_steps=lr_warm_up_steps, + ## No ghost grad term. + use_ghost_grads=False, + # Initialization / Architecture + apply_b_dec_to_input=False, + # encoder bias zero's. (I'm not sure what it is by default now) + # decoder bias zero's. + b_dec_init_method="zeros", + normalize_sae_decoder=False, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # Optimizer + lr=4e-5, + ## adam optimizer has no weight decay by default so worry about this. + adam_beta1=0.9, + adam_beta2=0.999, + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=64, + store_batch_size_prompts=16, + normalize_activations="none", + # Feature Store + feature_sampling_window=1000, + dead_feature_window=1000, + dead_feature_threshold=1e-4, + # performance enhancement: + compile_sae=False, + # WANDB + log_to_wandb=True, # always use wandb unless you are just testing code. + wandb_project="benchmark", + wandb_log_frequency=100, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="float32", + ) + + # look at the next cell to see some instruction for what to do while this is running. + sae = SAETrainingRunner(cfg).run() + + assert sae is not None + # know whether or not this works by looking at the dashboard! diff --git a/tests/unit/analysis/test_hooked_sae.py b/tests/unit/analysis/test_hooked_sae.py index 6d37fac2..61f09a82 100644 --- a/tests/unit/analysis/test_hooked_sae.py +++ b/tests/unit/analysis/test_hooked_sae.py @@ -42,6 +42,7 @@ def get_hooked_sae(model: HookedTransformer, act_name: str) -> SAE: d_in = site_to_size[site] sae_cfg = SAEConfig( + architecture="standard", d_in=d_in, d_sae=d_in * 2, dtype="float32", diff --git a/tests/unit/analysis/test_hooked_sae_transformer.py b/tests/unit/analysis/test_hooked_sae_transformer.py index 100a3c38..2cc32e97 100644 --- a/tests/unit/analysis/test_hooked_sae_transformer.py +++ b/tests/unit/analysis/test_hooked_sae_transformer.py @@ -42,6 +42,7 @@ def get_hooked_sae(model: HookedTransformer, act_name: str) -> SAE: d_in = site_to_size[site] sae_cfg = SAEConfig( + architecture="standard", d_in=d_in, d_sae=d_in * 2, dtype="float32", diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 2de701f4..81899610 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -45,6 +45,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): cfg = LanguageModelSAERunnerConfig() expected_config = { + "architecture": "standard", "d_in": 512, "d_sae": 2048, "activation_fn_str": "relu", diff --git a/tests/unit/training/test_gated_sae.py b/tests/unit/training/test_gated_sae.py new file mode 100644 index 00000000..1afa894d --- /dev/null +++ b/tests/unit/training/test_gated_sae.py @@ -0,0 +1,137 @@ +import pytest +import torch + +from sae_lens.training.training_sae import TrainingSAE +from tests.unit.helpers import build_sae_cfg + + +def test_gated_sae_initialization(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + assert sae.W_enc.shape == (cfg.d_in, cfg.d_sae) + assert sae.W_dec.shape == (cfg.d_sae, cfg.d_in) + # assert sae.b_enc.shape == (cfg.d_sae,) + assert sae.b_mag.shape == (cfg.d_sae,) + assert sae.b_gate.shape == (cfg.d_sae,) + assert sae.r_mag.shape == (cfg.d_sae,) + assert sae.b_dec.shape == (cfg.d_in,) + assert isinstance(sae.activation_fn, torch.nn.ReLU) + assert sae.device == torch.device("cpu") + assert sae.dtype == torch.float32 + + # biases + assert torch.allclose(sae.b_dec, torch.zeros_like(sae.b_dec), atol=1e-6) + assert torch.allclose(sae.b_mag, torch.zeros_like(sae.b_mag), atol=1e-6) + assert torch.allclose(sae.b_gate, torch.zeros_like(sae.b_gate), atol=1e-6) + + # check if the decoder weight norm is 1 by default + assert torch.allclose( + sae.W_dec.norm(dim=1), torch.ones_like(sae.W_dec.norm(dim=1)), atol=1e-6 + ) + + +def test_gated_sae_encoding(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + batch_size = 32 + d_in = sae.cfg.d_in + d_sae = sae.cfg.d_sae + + x = torch.randn(batch_size, d_in) + feature_acts, hidden_pre = sae.encode_with_hidden_pre_gated(x) + + assert feature_acts.shape == (batch_size, d_sae) + assert hidden_pre.shape == (batch_size, d_sae) + + # Check the gating mechanism + gating_pre_activation = x @ sae.W_enc + sae.b_gate + active_features = (gating_pre_activation > 0).float() + magnitude_pre_activation = x @ (sae.W_enc * sae.r_mag.exp()) + sae.b_mag + feature_magnitudes = sae.activation_fn(magnitude_pre_activation) + + expected_feature_acts = active_features * feature_magnitudes + assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) + + +def test_gated_sae_loss(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + batch_size = 32 + d_in = sae.cfg.d_in + x = torch.randn(batch_size, d_in) + + train_step_output = sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, d_in) + assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae) + + sae_in_centered = x - sae.b_dec + via_gate_feature_magnitudes = torch.relu(sae_in_centered @ sae.W_enc + sae.b_gate) + preactivation_l1_loss = ( + sae.cfg.l1_coefficient * torch.sum(via_gate_feature_magnitudes, dim=-1).mean() + ) + + via_gate_reconstruction = ( + via_gate_feature_magnitudes @ sae.W_dec.detach() + sae.b_dec.detach() + ) + aux_reconstruction_loss = torch.sum( + (via_gate_reconstruction - x) ** 2, dim=-1 + ).mean() + + expected_loss = ( + train_step_output.mse_loss + preactivation_l1_loss + aux_reconstruction_loss + ) + assert ( + pytest.approx(train_step_output.loss.item(), rel=1e-3) == expected_loss.item() + ) + + +def test_gated_sae_forward_pass(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + batch_size = 32 + d_in = sae.cfg.d_in + + x = torch.randn(batch_size, d_in) + sae_out = sae(x) + + assert sae_out.shape == (batch_size, d_in) + + +def test_gated_sae_training_forward_pass(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + batch_size = 32 + d_in = sae.cfg.d_in + + x = torch.randn(batch_size, d_in) + train_step_output = sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, d_in) + assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae) + + # Detach the loss tensor and convert to numpy for comparison + detached_loss = train_step_output.loss.detach().cpu().numpy() + expected_loss = ( + train_step_output.mse_loss + + train_step_output.l1_loss + + train_step_output.auxiliary_reconstruction_loss + ) + + assert pytest.approx(detached_loss, rel=1e-3) == expected_loss diff --git a/tests/unit/training/test_sae_initialization.py b/tests/unit/training/test_sae_initialization.py index 9c8eaade..fe73ee23 100644 --- a/tests/unit/training/test_sae_initialization.py +++ b/tests/unit/training/test_sae_initialization.py @@ -33,6 +33,37 @@ def test_SparseAutoencoder_initialization_standard(): assert not torch.allclose(unit_normed_W_enc, unit_normed_W_dec, atol=1e-6) +def test_SparseAutoencoder_initialization_gated(): + cfg = build_sae_cfg() + setattr(cfg, "architecture", "gated") + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + + assert sae.W_enc.shape == (cfg.d_in, cfg.d_sae) + assert sae.W_dec.shape == (cfg.d_sae, cfg.d_in) + assert sae.b_mag.shape == (cfg.d_sae,) + assert sae.b_gate.shape == (cfg.d_sae,) + assert sae.r_mag.shape == (cfg.d_sae,) + assert sae.b_dec.shape == (cfg.d_in,) + assert isinstance(sae.activation_fn, torch.nn.ReLU) + assert sae.device == torch.device("cpu") + assert sae.dtype == torch.float32 + + # biases + assert torch.allclose(sae.b_dec, torch.zeros_like(sae.b_dec), atol=1e-6) + assert torch.allclose(sae.b_mag, torch.zeros_like(sae.b_mag), atol=1e-6) + assert torch.allclose(sae.b_gate, torch.zeros_like(sae.b_gate), atol=1e-6) + + # check if the decoder weight norm is 1 by default + assert torch.allclose( + sae.W_dec.norm(dim=1), torch.ones_like(sae.W_dec.norm(dim=1)), atol=1e-6 + ) + + # Default currently shouldn't be tranpose initialization + unit_normed_W_enc = sae.W_enc / torch.norm(sae.W_enc, dim=0) + unit_normed_W_dec = sae.W_dec.T + assert not torch.allclose(unit_normed_W_enc, unit_normed_W_dec, atol=1e-6) + + def test_SparseAutoencoder_initialization_orthogonal_enc_dec(): cfg = build_sae_cfg(decoder_orthogonal_init=True) diff --git a/tests/unit/training/test_sae_trainer.py b/tests/unit/training/test_sae_trainer.py index adb3a8fe..2573140e 100644 --- a/tests/unit/training/test_sae_trainer.py +++ b/tests/unit/training/test_sae_trainer.py @@ -185,6 +185,7 @@ def test_build_train_step_log_dict(trainer: SAETrainer) -> None: # l1 loss is scaled by l1_coefficient "losses/l1_loss": train_output.l1_loss / trainer.cfg.l1_coefficient, "losses/ghost_grad_loss": pytest.approx(0.15), + "losses/auxiliary_reconstruction_loss": 0.0, "losses/overall_loss": 0.5, "metrics/explained_variance": 0.75, "metrics/explained_variance_std": 0.25, diff --git a/tutorials/training_a_gated_sae.ipynb b/tutorials/training_a_gated_sae.ipynb new file mode 100644 index 00000000..ff319f9f --- /dev/null +++ b/tutorials/training_a_gated_sae.ipynb @@ -0,0 +1,700 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "5O8tQblzOVHu" + }, + "source": [ + "# A Very Basic Gated SAE Training Run" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "shAFb9-lOVHu" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "LeRi_tw2dhae" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", + "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", + "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", + "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", + "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", + "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", + "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", + "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", + "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", + "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", + "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", + "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", + "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", + "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", + "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", + "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", + "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", + "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", + "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", + "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", + "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", + "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", + "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", + "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", + "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", + "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", + "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", + "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", + "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", + "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", + "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", + "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", + "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", + "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", + "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", + "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", + "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", + "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", + "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", + "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", + "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", + "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", + "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", + "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", + "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", + "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", + "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", + "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", + "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", + "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", + "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", + "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", + "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", + "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", + "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", + "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", + "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", + "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", + "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", + "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", + "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", + "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", + "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", + "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", + "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", + "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", + "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", + "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", + "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", + "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", + "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "try:\n", + " #import google.colab # type: ignore\n", + " #from google.colab import output\n", + " %pip install sae-lens transformer-lens circuitsvis\n", + "except:\n", + " from IPython import get_ipython # type: ignore\n", + " ipython = get_ipython(); assert ipython is not None\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uy-b3CcSOVHu", + "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", + "\n", + "if torch.cuda.is_available():\n", + " device = \"cuda\"\n", + "elif torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cpu\"\n", + "\n", + "print(\"Using device:\", device)\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCHtPycOOVHw" + }, + "source": [ + "## Training on MLP Out" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "oAsZCAdJOVHw" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240611_143204-n7cy5v24" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/n7cy5v24" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 30.13it/s]\n", + "5500| MSE Loss 208.944 | L1 167.607: 0%| | 225280/122880000 [08:05<71:26:53, 476.86it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interrupted, saving progress\n", + "done saving\n" + ] + }, + { + "ename": "InterruptedException", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mInterruptedException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 63\u001b[0m\n\u001b[1;32m 9\u001b[0m cfg \u001b[38;5;241m=\u001b[39m LanguageModelSAERunnerConfig(\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# Data Generating Function (Model + Training Distribution)\u001b[39;00m\n\u001b[1;32m 11\u001b[0m variant\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbaseline\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;66;03m# we'll use the gated variant.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 60\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 61\u001b[0m )\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# look at the next cell to see some instruction for what to do while this is running.\u001b[39;00m\n\u001b[0;32m---> 63\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:87\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 78\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 79\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 80\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 83\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 84\u001b[0m )\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m---> 87\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 90\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:130\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 127\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 129\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:162\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 159\u001b[0m layer_acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mnext_batch()[:, \u001b[38;5;241m0\u001b[39m, :]\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[0;32m--> 162\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_train_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_acts\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_log_train_step(step_output)\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/sae_trainer.py:216\u001b[0m, in \u001b[0;36mSAETrainer._train_step\u001b[0;34m(self, sae, sae_in)\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[38;5;66;03m# for documentation on autocasting see:\u001b[39;00m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html\u001b[39;00m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast_if_enabled:\n\u001b[0;32m--> 216\u001b[0m train_step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msae\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_forward_pass\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[43m \u001b[49m\u001b[43msae_in\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msae_in\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 218\u001b[0m \u001b[43m \u001b[49m\u001b[43mdead_neuron_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdead_neurons\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 219\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_l1_coefficient\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 223\u001b[0m did_fire \u001b[38;5;241m=\u001b[39m (train_step_output\u001b[38;5;241m.\u001b[39mfeature_acts \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/training/training_sae.py:303\u001b[0m, in \u001b[0;36mTrainingSAE.training_forward_pass\u001b[0;34m(self, sae_in, current_l1_coefficient, dead_neuron_mask)\u001b[0m\n\u001b[1;32m 295\u001b[0m l1_loss \u001b[38;5;241m=\u001b[39m (current_l1_coefficient \u001b[38;5;241m*\u001b[39m sparsity)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 296\u001b[0m loss \u001b[38;5;241m=\u001b[39m mse_loss \u001b[38;5;241m+\u001b[39m l1_loss \u001b[38;5;241m+\u001b[39m ghost_grad_loss\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m TrainStepOutput(\n\u001b[1;32m 299\u001b[0m sae_in\u001b[38;5;241m=\u001b[39msae_in,\n\u001b[1;32m 300\u001b[0m sae_out\u001b[38;5;241m=\u001b[39msae_out,\n\u001b[1;32m 301\u001b[0m feature_acts\u001b[38;5;241m=\u001b[39mfeature_acts,\n\u001b[1;32m 302\u001b[0m loss\u001b[38;5;241m=\u001b[39mloss,\n\u001b[0;32m--> 303\u001b[0m mse_loss\u001b[38;5;241m=\u001b[39m\u001b[43mmse_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 304\u001b[0m l1_loss\u001b[38;5;241m=\u001b[39ml1_loss\u001b[38;5;241m.\u001b[39mitem(),\n\u001b[1;32m 305\u001b[0m ghost_grad_loss\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 306\u001b[0m ghost_grad_loss\u001b[38;5;241m.\u001b[39mitem()\n\u001b[1;32m 307\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(ghost_grad_loss, torch\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m ghost_grad_loss\n\u001b[1;32m 309\u001b[0m ),\n\u001b[1;32m 310\u001b[0m )\n", + "File \u001b[0;32m~/projects/SAELens/sae_lens/sae_training_runner.py:25\u001b[0m, in \u001b[0;36minterrupt_callback\u001b[0;34m(sig_num, stack_frame)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minterrupt_callback\u001b[39m(sig_num: Any, stack_frame: Any):\n\u001b[0;32m---> 25\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InterruptedException()\n", + "\u001b[0;31mInterruptedException\u001b[0m: " + ] + } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distribution)\n", + " architecture=\"baseline\", # we'll use the gated variant.\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=True,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=True,\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=5, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\"\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = SAETrainingRunner(cfg).run()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-20-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: W&B API key is configured. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240616_143959-ch6e0a5s" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/gated_sae_testing" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "30000| MSE Loss 143.062 | L1 0.000: 1%| | 1228800/122880000 [1:04:38<106:39:53, 316.81it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_l1_coefficient▁▂▂▃▃▄▄▅▅▆▆▇████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/auxiliary_reconstruction_loss▁▃▃▄▄▅▅▅▆▆▆▇▇███████████████████████████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▁▂▃▄▄▅▅▅▆▆▆▆▆▆▇▆▆▆▆▆▆▇▆▆▇▆▆▆▇▆▇▆▆▆▆▇▆▇▆█
losses/overall_loss▁▃▄▄▅▆▆▆▇▇▇▇██████████████▇████▇▇█▇█▇█▇█
losses/sfn_sparsity_loss▂▃▅▆▆▇████▇▆▅▄▃▃▃▃▃▃▃▃▄▃▃▃▂▃▃▃▃▃▃▂▂▂▁▁▁▁
metrics/CE_loss_score██▇▇▆▆▆▆▆▆▆▆▆▅▆▆▅▆▅▃▅▁▆▅▆▆▅▅▃▆▅▅▆▅▆▆▅▅▃▄
metrics/ce_loss_with_ablation▅▃▁▆▄▃▅▆▅▄▄▁▅▃▄▃▃▄▁▃▄▆▄▄▄▃▆▄▃█▄▄▁▄▄▅█▄▄▃
metrics/ce_loss_with_sae▁▁▂▂▃▂▃▃▃▃▃▃▃▃▃▃▄▃▄▆▄█▃▄▃▃▄▄▆▃▃▄▃▄▃▃▄▄▆▅
metrics/ce_loss_without_sae▄▂▃█▃▁▃▃▄▅▃▄▂▃▄▃▄▃▃▃▃▃▄▁▂▂▆▃▄▅▃▄▃▅▇▅▃▃▃▂
metrics/explained_variance█▇▆▅▅▄▄▄▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▃▃▂▃▂▃▁
metrics/explained_variance_std▁▂▃▃▃▃▃▃▃▃▃▃▃▃▄▃▃▄▄▄▃▆▄▄▄▄▄▄▅▄▄▄▄▄▄▅▄▅▄█
metrics/l0█▅▅▂▃▅▆▄▅█▃▅▃▅▇▄▆▄▆▄▄▁▄▅▅▄▁▃▇▄▄▅▃▆▃▄▄▄▃▁
metrics/l2_norm▇▆▇▁▄▃▆▃▃▂▂▄▂▂▃▂▄▃▂▃▄▅▅▃▃▃▄▄▃▃▂▅▃▃▃▃▃▆▆█
metrics/l2_norm_in▃▃▇▂▄▂▅▅▅▅▄▆▃▃▁▃▄▅▃▂▅▂▅▄▃▄▃▅▂▅▄▃▃▄▁▄▂█▃▅
metrics/l2_ratio█▆▆▁▃▃▅▂▃▁▂▃▂▂▃▂▄▂▂▄▄▅▅▃▃▂▅▃▃▃▂▅▃▂▃▃▃▅▆█
metrics/mean_log10_feature_sparsity█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▅▄▄▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▂▃▆██
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▃▃▄▇██
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▄▅▇█

Run summary:


details/current_l1_coefficient20
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/auxiliary_reconstruction_loss227.78122
losses/ghost_grad_loss0.0
losses/l1_loss0.0
losses/mse_loss143.06226
losses/overall_loss434.46942
losses/sfn_sparsity_loss63.62593
metrics/CE_loss_score0.59248
metrics/ce_loss_with_ablation8.29373
metrics/ce_loss_with_sae4.50411
metrics/ce_loss_without_sae1.8969
metrics/explained_variance0.15973
metrics/explained_variance_std0.24142
metrics/l07705.52734
metrics/l2_norm14.99578
metrics/l2_norm_in17.58649
metrics/l2_ratio0.8463
metrics/mean_log10_feature_sparsity-0.74933
sparsity/below_1e-5681
sparsity/below_1e-6681
sparsity/dead_features681
sparsity/mean_passes_since_fired138.74988

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 16384-L1-20-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/gated_sae_testing/runs/ch6e0a5s
View project at: https://wandb.ai/curt-tigges/gated_sae_testing
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240616_143959-ch6e0a5s/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "total_training_steps = 30_000 # probably we should do more\n", + "batch_size = 4096\n", + "total_training_tokens = total_training_steps * batch_size\n", + "\n", + "lr_warm_up_steps = 0\n", + "lr_decay_steps = total_training_steps // 5 # 20% of training\n", + "l1_warm_up_steps = 10_000 #total_training_steps // 20 # 5% of training\n", + "\n", + "cfg = LanguageModelSAERunnerConfig(\n", + " # Data Generating Function (Model + Training Distribution)\n", + " architecture=\"gated\", # we'll use the gated variant.\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_name=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", + " is_dataset_tokenized=True,\n", + " streaming=True, # we could pre-download the token dataset if it was small.\n", + " # SAE Parameters\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=True, # We won't apply the decoder weights to the input.\n", + " normalize_sae_decoder=False,\n", + " scale_sparsity_penalty_by_decoder_norm=False,\n", + " decoder_heuristic_init=True,\n", + " init_encoder_as_decoder_transpose=True,\n", + " normalize_activations=False,\n", + " # Training Parameters\n", + " lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", + " adam_beta2=0.999,\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", + " l1_coefficient=20, # will control how sparse the feature activations are\n", + " l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", + " train_batch_size_tokens=batch_size,\n", + " context_size=256, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", + " # Activation Store Parameters\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " store_batch_size_prompts=16,\n", + " # Resampling protocol\n", + " use_ghost_grads=False, # we don't use ghost grads anymore.\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", + " # WANDB\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"gated_sae_testing\",\n", + " wandb_log_frequency=30,\n", + " eval_every_n_wandb_logs=20,\n", + " # Misc\n", + " device=device,\n", + " seed=42,\n", + " n_checkpoints=0,\n", + " checkpoint_path=\"checkpoints\",\n", + " dtype=\"float32\"\n", + ")\n", + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder = SAETrainingRunner(cfg).run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorials/training_a_sparse_autoencoder.ipynb b/tutorials/training_a_sparse_autoencoder.ipynb index 006b880c..31a25850 100644 --- a/tutorials/training_a_sparse_autoencoder.ipynb +++ b/tutorials/training_a_sparse_autoencoder.ipynb @@ -22,15 +22,162 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "id": "LeRi_tw2dhae" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sae-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (3.3.0)\n", + "Requirement already satisfied: transformer-lens in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.19.0)\n", + "Requirement already satisfied: circuitsvis in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (1.43.2)\n", + "Requirement already satisfied: automated-interpretability<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.3)\n", + "Requirement already satisfied: babe<0.0.8,>=0.0.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.0.7)\n", + "Requirement already satisfied: datasets<3.0.0,>=2.17.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (2.19.2)\n", + "Requirement already satisfied: matplotlib<4.0.0,>=3.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.9.0)\n", + "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.6 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.1.7)\n", + "Requirement already satisfied: nltk<4.0.0,>=3.8.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (3.8.1)\n", + "Requirement already satisfied: plotly<6.0.0,>=5.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (5.22.0)\n", + "Requirement already satisfied: plotly-express<0.5.0,>=0.4.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.1)\n", + "Requirement already satisfied: pytest-profiling<2.0.0,>=1.7.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.7.0)\n", + "Requirement already satisfied: python-dotenv<2.0.0,>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (1.0.1)\n", + "Requirement already satisfied: pyyaml<7.0.0,>=6.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (6.0.1)\n", + "Requirement already satisfied: pyzmq==26.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (26.0.0)\n", + "Requirement already satisfied: sae-vis<0.3.0,>=0.2.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.2.18)\n", + "Requirement already satisfied: safetensors<0.5.0,>=0.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.4.3)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.38.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (4.41.2)\n", + "Requirement already satisfied: typer<0.13.0,>=0.12.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-lens) (0.12.3)\n", + "Requirement already satisfied: accelerate>=0.23.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.31.0)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.14.1)\n", + "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: einops>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.7.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.0.3)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.29)\n", + "Requirement already satisfied: numpy>=1.24 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (1.26.4)\n", + "Requirement already satisfied: pandas>=1.1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.2.2)\n", + "Requirement already satisfied: rich>=12.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (13.7.1)\n", + "Requirement already satisfied: sentencepiece in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.2.0)\n", + "Requirement already satisfied: torch>=1.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (2.1.2)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.66.4)\n", + "Requirement already satisfied: typing-extensions in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (4.12.2)\n", + "Requirement already satisfied: wandb>=0.13.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformer-lens) (0.17.1)\n", + "Requirement already satisfied: importlib-metadata>=5.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (7.1.0)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from circuitsvis) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->circuitsvis) (12.5.40)\n", + "Requirement already satisfied: filelock in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from triton==2.1.0->circuitsvis) (3.14.0)\n", + "Requirement already satisfied: packaging>=20.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (24.0)\n", + "Requirement already satisfied: psutil in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (5.9.0)\n", + "Requirement already satisfied: huggingface-hub in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from accelerate>=0.23.0->transformer-lens) (0.23.3)\n", + "Requirement already satisfied: blobfile<3.0.0,>=2.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.1.1)\n", + "Requirement already satisfied: boostedblob<0.16.0,>=0.15.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.15.3)\n", + "Requirement already satisfied: httpx<0.28.0,>=0.27.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.27.0)\n", + "Requirement already satisfied: orjson<4.0.0,>=3.10.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.10.4)\n", + "Requirement already satisfied: pytest<9.0.0,>=8.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (8.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0.0,>=1.4.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: tiktoken<0.7.0,>=0.6.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.6.0)\n", + "Requirement already satisfied: py2store in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.20)\n", + "Requirement already satisfied: graze in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (16.1.0)\n", + "Requirement already satisfied: pyarrow-hotfix in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.3.8)\n", + "Requirement already satisfied: requests>=2.32.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (2.32.3)\n", + "Requirement already satisfied: xxhash in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets<3.0.0,>=2.17.1->sae-lens) (2024.3.1)\n", + "Requirement already satisfied: aiohttp in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from datasets<3.0.0,>=2.17.1->sae-lens) (3.9.5)\n", + "Requirement already satisfied: zipp>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.19.2)\n", + "Requirement already satisfied: typeguard==2.13.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jaxtyping>=0.2.11->transformer-lens) (2.13.3)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (1.4.5)\n", + "Requirement already satisfied: pillow>=8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (10.3.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib<4.0.0,>=3.8.3->sae-lens) (2.9.0)\n", + "Requirement already satisfied: traitlets in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from matplotlib-inline<0.2.0,>=0.1.6->sae-lens) (5.14.3)\n", + "Requirement already satisfied: click in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (8.1.7)\n", + "Requirement already satisfied: joblib in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from nltk<4.0.0,>=3.8.1->sae-lens) (2024.5.15)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pandas>=1.1.5->transformer-lens) (2024.1)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly<6.0.0,>=5.19.0->sae-lens) (8.3.0)\n", + "Requirement already satisfied: statsmodels>=0.9.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.14.2)\n", + "Requirement already satisfied: scipy>=0.18 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (1.13.1)\n", + "Requirement already satisfied: patsy>=0.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from plotly-express<0.5.0,>=0.4.1->sae-lens) (0.5.6)\n", + "Requirement already satisfied: six in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (1.16.0)\n", + "Requirement already satisfied: gprof2dot in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest-profiling<2.0.0,>=1.7.0->sae-lens) (2024.6.6)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from rich>=12.6.0->transformer-lens) (2.18.0)\n", + "Requirement already satisfied: dataclasses-json<0.7.0,>=0.6.4 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.6.7)\n", + "Requirement already satisfied: eindex-callum<0.2.0,>=0.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sae-vis<0.3.0,>=0.2.18->sae-lens) (0.1.1)\n", + "Requirement already satisfied: sympy in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (1.12.1)\n", + "Requirement already satisfied: networkx in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.3)\n", + "Requirement already satisfied: jinja2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from torch>=1.10->transformer-lens) (3.1.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from transformers<5.0.0,>=4.38.1->sae-lens) (0.19.1)\n", + "Requirement already satisfied: shellingham>=1.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typer<0.13.0,>=0.12.3->sae-lens) (1.5.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (0.4.0)\n", + "Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (3.1.43)\n", + "Requirement already satisfied: platformdirs in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (4.2.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,<6,>=3.19.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (5.27.1)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (2.5.1)\n", + "Requirement already satisfied: setproctitle in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (1.3.3)\n", + "Requirement already satisfied: setuptools in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from wandb>=0.13.5->transformer-lens) (70.0.0)\n", + "Requirement already satisfied: pycryptodomex~=3.8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.20.0)\n", + "Requirement already satisfied: urllib3<3,>=1.25.3 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.2.1)\n", + "Requirement already satisfied: lxml~=4.9 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from blobfile<3.0.0,>=2.1.1->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.9.4)\n", + "Requirement already satisfied: uvloop>=0.16.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from boostedblob<0.16.0,>=0.15.3->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.19.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from aiohttp->datasets<3.0.0,>=2.17.1->sae-lens) (4.0.3)\n", + "Requirement already satisfied: marshmallow<4.0.0,>=3.18.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (3.21.3)\n", + "Requirement already satisfied: typing-inspect<1,>=0.4.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (0.9.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (4.0.11)\n", + "Requirement already satisfied: anyio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (4.4.0)\n", + "Requirement already satisfied: certifi in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2024.6.2)\n", + "Requirement already satisfied: httpcore==1.* in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.0.5)\n", + "Requirement already satisfied: idna in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.7)\n", + "Requirement already satisfied: sniffio in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.3.1)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.0->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (0.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer-lens) (0.1.2)\n", + "Requirement already satisfied: iniconfig in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.0)\n", + "Requirement already satisfied: pluggy<2.0,>=1.5 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.5.0)\n", + "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (1.2.0)\n", + "Requirement already satisfied: tomli>=1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from pytest<9.0.0,>=8.1.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (2.0.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from requests>=2.32.1->datasets<3.0.0,>=2.17.1->sae-lens) (3.3.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from scikit-learn<2.0.0,>=1.4.2->automated-interpretability<0.0.4,>=0.0.3->sae-lens) (3.5.0)\n", + "Requirement already satisfied: dol in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from graze->babe<0.0.8,>=0.0.7->sae-lens) (0.2.47)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from jinja2->torch>=1.10->transformer-lens) (2.1.5)\n", + "Requirement already satisfied: config2py in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.33)\n", + "Requirement already satisfied: importlib-resources in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from py2store->babe<0.0.8,>=0.0.7->sae-lens) (6.4.0)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from sympy->torch>=1.10->transformer-lens) (1.3.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens) (5.0.1)\n", + "Requirement already satisfied: mypy-extensions>=0.3.0 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7.0,>=0.6.4->sae-vis<0.3.0,>=0.2.18->sae-lens) (1.0.0)\n", + "Requirement already satisfied: i2 in /home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages (from config2py->py2store->babe<0.0.8,>=0.0.7->sae-lens) (0.1.17)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "try:\n", - " import google.colab # type: ignore\n", - " from google.colab import output\n", + " #import google.colab # type: ignore\n", + " #from google.colab import output\n", " %pip install sae-lens transformer-lens circuitsvis\n", "except:\n", " from IPython import get_ipython # type: ignore\n", @@ -41,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -49,7 +196,23 @@ "id": "uy-b3CcSOVHu", "outputId": "58ce28d0-f91f-436d-cf87-76bb26e2ecaf" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n" + ] + } + ], "source": [ "import torch\n", "import os\n", @@ -85,11 +248,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "hFz6JUMuOVHv" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n", + "/home/curttigges/miniconda3/envs/saelens/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " return self.fget.__get__(instance, owner)()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + } + ], "source": [ "from transformer_lens import HookedTransformer\n", "\n", @@ -118,11 +299,57 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "id": "G4ad4Zz1OVHv" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'Once upon a time, Bobby was hungry and needed something to do. He went to the subway but was far away.\\n\\nThe man wanted to get the hat, so the people wanted it. He found it hard to be a big, powerful bird. It wanted'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a trunk. The trunk was very rich, and it was a very special trunk. All the animals came across the trunk, and was very colorful. They took turns to fill it up when they bumped into a dragon. \\n\\n'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a young man. He was three years old. He wanted to learn how to keep the match safe. So he kept checking it every day.\\n\\nOne day a 3 year old girl wanted to learn about the fun trunk of the'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a little girl named Sally. She liked to play with her toys. One sunny day, Sally found a butterfly. She was so happy! She wanted to play with something new. So, she called her friend Tom.\\n\\nTom'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'Once upon a time, there was a little girl named Lola. She really loved playing with her pet cat, Tom and show them his appreciation.\\n\\nOne day, Lola licked the couch closer and soon found herself in a magical land! As soon as'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "# here we use generate to get 10 completeions with temperature 1. Feel free to play with the prompt to make it more interesting.\n", "for i in range(5):\n", @@ -157,11 +384,64 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "id": "TpmPoj7uOVHv" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a', ' time', ',', ' there', ' was', ' a', ' little', ' girl', ' named', ' Lily', '.', ' She', ' lived', ' in', ' a', ' big', ',', ' happy', ' little', ' girl', '.', ' On', ' her', ' big', ' adventure', ',']\n", + "Tokenized answer: [' Lily']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+              "Rank: 1        Logit: 18.81 Prob: 13.46% Token: | Lily|\n",
+              "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.81\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m13.46\u001b[0m\u001b[1m% Token: | Lily|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 20.48 Prob: 71.06% Token: | she|\n", + "Top 1th token. Logit: 18.81 Prob: 13.46% Token: | Lily|\n", + "Top 2th token. Logit: 17.35 Prob: 3.11% Token: | the|\n", + "Top 3th token. Logit: 17.26 Prob: 2.86% Token: | her|\n", + "Top 4th token. Logit: 16.74 Prob: 1.70% Token: | there|\n", + "Top 5th token. Logit: 16.43 Prob: 1.25% Token: | they|\n", + "Top 6th token. Logit: 15.80 Prob: 0.66% Token: | all|\n", + "Top 7th token. Logit: 15.64 Prob: 0.56% Token: | things|\n", + "Top 8th token. Logit: 15.28 Prob: 0.39% Token: | one|\n", + "Top 9th token. Logit: 15.24 Prob: 0.38% Token: | lived|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' Lily', 1)]\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Lily'\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from transformer_lens.utils import test_prompt\n", "\n", @@ -203,11 +483,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "id": "Tic0RCUpOVHw" }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import circuitsvis as cv # optional dep, install with pip install circuitsvis\n", "\n", @@ -238,11 +540,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "id": "Nikp2ASlOVHw" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 200/200 [00:01<00:00, 103.20it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "example_prompt = model.generate(\n", " \"Once upon a time\",\n", @@ -295,11 +626,146 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "id": "oAsZCAdJOVHw" }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-5-LR-5e-05-Tokens-1.229e+08\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.001024\n", + "Total training steps: 30000\n", + "Total wandb updates: 1000\n", + "n_tokens_per_feature_sampling_window (millions): 1048.576\n", + "n_tokens_per_dead_feature_window (millions): 1048.576\n", + "We will reset the sparsity calculation 30 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading readme: 100%|██████████| 415/415 [00:00<00:00, 4.30MB/s]\n", + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mcurt-tigges\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.1" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/curttigges/projects/SAELens/tutorials/wandb/run-20240610_114538-yr1gvjdc" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/curt-tigges/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Estimating norm scaling factor: 100%|██████████| 1000/1000 [00:33<00:00, 29.79it/s]\n", + "30000| MSE Loss 187.703 | L1 156.883: 1%| | 1228800/122880000 [42:59<70:56:28, 476.34it/s]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_l1_coefficient▁▅██████████████████████████████████████
details/current_learning_rate████████████████████████████████▇▇▅▅▄▃▂▁
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss▄▅█▇▆▆▅▄▄▄▄▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁
losses/overall_loss▁▅█▇▇▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/CE_loss_score█▂▁▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▅▆▅▅▅▅▆▅▆▆▆▆▆▆▆▅
metrics/ce_loss_with_ablation█▆▇▃▄▄▃▆▆▅▃▅▅▅▅▃▄▅▅▆▄▅▃▅▃▄▅▁▅▅▆▅▆▄▆▄▆▄▅▄
metrics/ce_loss_with_sae▂▇█▆▅▅▅▅▄▄▅▆▄▄▃▄▄▄▁▃▃▃▁▄▄▂▃▂▄▂▁▃▄▃▄▂▄▂▃▄
metrics/ce_loss_without_sae▇▇▇▅▄▄▆▆▅▄▇█▆▆▄▅▆▇▁▄▄▅▁▇▆▄▅▄▆▄▂▅▇▆▇▄▇▄▅█
metrics/explained_variance▅▄▁▂▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇█▇▇▇███████
metrics/explained_variance_std▁▅███▇▇▇▇▇▇▆▆▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▆▆▆▅▆
metrics/l0█▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▆▅▁▃▃▂▂▃▁▅▃▃▅▆▂█▄▃▄▆▇▃▆▃▃▁▂▅▅▆▇▅▄▅▃██▃▃▃
metrics/l2_norm_in▅▆▅▆▆▄▅▅▃▆▄▄▆▇▄█▄▃▅▆▇▃▆▃▃▁▂▅▅▆▇▄▅▆▃▇█▄▄▃
metrics/l2_ratio▇▅▁▃▃▃▃▄▃▆▅▅▆▆▄▇▆▆▅▆▇▅▇▆▅▅▅▆▆▇▇▇▆▆▆██▆▆▆
metrics/mean_log10_feature_sparsity█▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▅▆███▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅
sparsity/below_1e-6▁▁▁▁▁▃▆▆▆█▆█████▆▆▆▆▆▆▆▆▆▆▆▆▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▁▁▅█▁▁▁▅▅█▅█▅▅████▅▅▅████
sparsity/mean_passes_since_fired▂▁▁▁▁▁▁▁▂▂▁▁▂▂▂▂▂▂▄▂▂▂▃▃▄▄▄▃▃▃▄▅▆▄▄▅▆▇▇█

Run summary:


details/current_l1_coefficient5
details/current_learning_rate0.0
details/n_training_tokens122880000
losses/ghost_grad_loss0.0
losses/l1_loss31.37665
losses/mse_loss187.70346
losses/overall_loss344.5867
metrics/CE_loss_score0.90369
metrics/ce_loss_with_ablation8.30545
metrics/ce_loss_with_sae2.62867
metrics/ce_loss_without_sae2.02306
metrics/explained_variance0.66377
metrics/explained_variance_std0.13242
metrics/l0192.95703
metrics/l2_norm24.64933
metrics/l2_norm_in31.38967
metrics/l2_ratio0.77361
metrics/mean_log10_feature_sparsity-2.66941
sparsity/below_1e-52
sparsity/below_1e-62
sparsity/dead_features2
sparsity/mean_passes_since_fired0.86823

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 16384-L1-5-LR-5e-05-Tokens-1.229e+08 at: https://wandb.ai/curt-tigges/sae_lens_tutorial/runs/yr1gvjdc
View project at: https://wandb.ai/curt-tigges/sae_lens_tutorial
Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240610_114538-yr1gvjdc/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "total_training_steps = 30_000 # probably we should do more\n", "batch_size = 4096\n", @@ -425,7 +891,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.14" } }, "nbformat": 4,