From 0a57e473b63f79043d69172592ec3ee00b852339 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 09:13:51 +0100 Subject: [PATCH 1/5] support seqpos slicing --- sae_lens/config.py | 3 +++ sae_lens/training/activations_store.py | 27 +++++++++++++------------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index a74b7565..94bf1517 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). + seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5). device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -151,6 +152,7 @@ class LanguageModelSAERunnerConfig: normalize_activations: str = ( "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) ) + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -453,6 +455,7 @@ class CacheActivationsRunnerConfig: store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 normalize_activations: str = "none" # should always be none for activation caching + seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 3a068cd2..d68aadf3 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -87,6 +87,7 @@ def from_config( model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, + seqpos_slice=cfg.seqpos_slice, ) @classmethod @@ -146,6 +147,7 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, + seqpos_slice: tuple[int | None, ...] = (None,) ): self.model = model if model_kwargs is None: @@ -187,6 +189,7 @@ def __init__( self.dtype = DTYPE_MAP[dtype] self.cached_activations_path = cached_activations_path self.autocast_lm = autocast_lm + self.seqpos_slice = seqpos_slice self.n_dataset_processed = 0 @@ -428,7 +431,7 @@ def get_activations(self, batch_tokens: torch.Tensor): autocast_if_enabled = contextlib.nullcontext() with autocast_if_enabled: - layerwise_activations = self.model.run_with_cache( + layerwise_activations_cache = self.model.run_with_cache( batch_tokens, names_filter=[self.hook_name], stop_at_layer=self.hook_layer + 1, @@ -436,29 +439,26 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - n_batches, n_context = batch_tokens.shape + layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] + n_batches, n_context = layerwise_activations.shape[:2] stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][ + stacked_activations[:, :, 0] = layerwise_activations[ :, :, self.hook_head_index ] elif ( - layerwise_activations[self.hook_name].ndim > 3 + layerwise_activations.ndim > 3 ): # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations[ - self.hook_name - ].reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) else: - stacked_activations[:, :, 0] = layerwise_activations[self.hook_name] + stacked_activations[:, :, 0] = layerwise_activations return stacked_activations @@ -474,6 +474,7 @@ def get_buffer( If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react. """ context_size = self.context_size + training_context_size = len(range(context_size)[slice(*self.seqpos_slice)]) batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer @@ -481,7 +482,7 @@ def get_buffer( if self.cached_activations_path is not None: # Load the activations from disk - buffer_size = total_size * context_size + buffer_size = total_size * training_context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), @@ -535,7 +536,7 @@ def get_buffer( refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( - (total_size, context_size, num_layers, d_in), + (total_size, training_context_size, num_layers, d_in), dtype=self.dtype, # type: ignore device=self.device, ) From 89714c8417034bdfed68dd27742692077ea828c1 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 12:04:33 +0100 Subject: [PATCH 2/5] fix forward functions for gated --- sae_lens/sae.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index b347990a..cbf5d10d 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -427,6 +427,13 @@ def forward( with torch.no_grad(): x = x.to(self.dtype) sae_in = self.reshape_fn_in(x) # type: ignore + + # handle run time activation normalization if needed + sae_in = self.run_time_activation_norm_fn_in(sae_in) + + # apply b_dec_to_input if using that method. + sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) + gating_pre_activation = sae_in @ self.W_enc + self.b_gate active_features = (gating_pre_activation > 0).float() @@ -455,10 +462,10 @@ def forward( sae_in = self.reshape_fn_in(x) # type: ignore # handle run time activation normalization if needed - x = self.run_time_activation_norm_fn_in(x) + sae_in = self.run_time_activation_norm_fn_in(sae_in) # apply b_dec_to_input if using that method. - sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input) + sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = sae_in @ self.W_enc + self.b_enc @@ -495,11 +502,11 @@ def encode_gated( magnitude_pre_activation = self.hook_sae_acts_pre( sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag ) - feature_magnitudes = self.hook_sae_acts_post( - self.activation_fn(magnitude_pre_activation) - ) + feature_magnitudes = self.activation_fn(magnitude_pre_activation) - return active_features * feature_magnitudes + feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes) + + return feature_acts def encode_jumprelu( self, x: Float[torch.Tensor, "... d_in"] From d38622e715980fb88de9565e8a87aa223c801b87 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 12:09:03 +0100 Subject: [PATCH 3/5] remove seqpos changes --- sae_lens/config.py | 12 +----------- sae_lens/training/activations_store.py | 27 +++++++++++++------------- 2 files changed, 14 insertions(+), 25 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 94bf1517..291117da 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -6,7 +6,6 @@ import torch import wandb from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict - from sae_lens import __version__ DTYPE_MAP = { @@ -60,7 +59,6 @@ class LanguageModelSAERunnerConfig: store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations. train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop. normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output). - seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5). device (str): The device to use. Usually cuda. act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram. seed (int): The seed to use. @@ -149,10 +147,7 @@ class LanguageModelSAERunnerConfig: finetuning_tokens: int = 0 store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 - normalize_activations: str = ( - "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) - ) - seqpos_slice: tuple[int | None, ...] = (None,) + normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) # Misc device: str = "cpu" @@ -233,7 +228,6 @@ class LanguageModelSAERunnerConfig: sae_lens_training_version: str = field(default_factory=lambda: __version__) def __post_init__(self): - if self.resume: raise ValueError( "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path." @@ -398,7 +392,6 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]: } def to_dict(self) -> dict[str, Any]: - cfg_dict = { **self.__dict__, # some args may not be serializable by default @@ -410,7 +403,6 @@ def to_dict(self) -> dict[str, Any]: return cfg_dict def to_json(self, path: str) -> None: - if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) @@ -455,7 +447,6 @@ class CacheActivationsRunnerConfig: store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 normalize_activations: str = "none" # should always be none for activation caching - seqpos_slice: tuple[int | None, ...] = (None,) # Misc device: str = "cpu" @@ -489,7 +480,6 @@ def __post_init__(self): @dataclass class ToyModelSAERunnerConfig: - architecture: Literal["standard", "gated"] = "standard" # ReLu Model Parameters diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index d68aadf3..3a068cd2 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -87,7 +87,6 @@ def from_config( model_kwargs=cfg.model_kwargs, autocast_lm=cfg.autocast_lm, dataset_trust_remote_code=cfg.dataset_trust_remote_code, - seqpos_slice=cfg.seqpos_slice, ) @classmethod @@ -147,7 +146,6 @@ def __init__( model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, dataset_trust_remote_code: bool | None = None, - seqpos_slice: tuple[int | None, ...] = (None,) ): self.model = model if model_kwargs is None: @@ -189,7 +187,6 @@ def __init__( self.dtype = DTYPE_MAP[dtype] self.cached_activations_path = cached_activations_path self.autocast_lm = autocast_lm - self.seqpos_slice = seqpos_slice self.n_dataset_processed = 0 @@ -431,7 +428,7 @@ def get_activations(self, batch_tokens: torch.Tensor): autocast_if_enabled = contextlib.nullcontext() with autocast_if_enabled: - layerwise_activations_cache = self.model.run_with_cache( + layerwise_activations = self.model.run_with_cache( batch_tokens, names_filter=[self.hook_name], stop_at_layer=self.hook_layer + 1, @@ -439,26 +436,29 @@ def get_activations(self, batch_tokens: torch.Tensor): **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][:, slice(*self.seqpos_slice)] - n_batches, n_context = layerwise_activations.shape[:2] + n_batches, n_context = batch_tokens.shape stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[ + stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][ :, :, self.hook_head_index ] elif ( - layerwise_activations.ndim > 3 + layerwise_activations[self.hook_name].ndim > 3 ): # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations.view(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations[ + self.hook_name + ].view(n_batches, n_context, -1) except RuntimeError as e: print(f"Error during view operation: {e}") print("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations.reshape(n_batches, n_context, -1) + stacked_activations[:, :, 0] = layerwise_activations[ + self.hook_name + ].reshape(n_batches, n_context, -1) else: - stacked_activations[:, :, 0] = layerwise_activations + stacked_activations[:, :, 0] = layerwise_activations[self.hook_name] return stacked_activations @@ -474,7 +474,6 @@ def get_buffer( If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react. """ context_size = self.context_size - training_context_size = len(range(context_size)[slice(*self.seqpos_slice)]) batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer @@ -482,7 +481,7 @@ def get_buffer( if self.cached_activations_path is not None: # Load the activations from disk - buffer_size = total_size * training_context_size + buffer_size = total_size * context_size # Initialize an empty tensor with an additional dimension for layers new_buffer = torch.zeros( (buffer_size, num_layers, d_in), @@ -536,7 +535,7 @@ def get_buffer( refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers new_buffer = torch.zeros( - (total_size, training_context_size, num_layers, d_in), + (total_size, context_size, num_layers, d_in), dtype=self.dtype, # type: ignore device=self.device, ) From d9ea96a2bf83adcea6f916b5f01b887233fd48f0 Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Wed, 18 Sep 2024 12:09:59 +0100 Subject: [PATCH 4/5] fix formatting (remove my changes) --- sae_lens/config.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 291117da..383698c9 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -6,6 +6,7 @@ import torch import wandb from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict + from sae_lens import __version__ DTYPE_MAP = { @@ -147,7 +148,9 @@ class LanguageModelSAERunnerConfig: finetuning_tokens: int = 0 store_batch_size_prompts: int = 32 train_batch_size_tokens: int = 4096 - normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) + normalize_activations: str = ( + "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update) + ) # Misc device: str = "cpu" @@ -228,6 +231,7 @@ class LanguageModelSAERunnerConfig: sae_lens_training_version: str = field(default_factory=lambda: __version__) def __post_init__(self): + if self.resume: raise ValueError( "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path." @@ -392,6 +396,7 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]: } def to_dict(self) -> dict[str, Any]: + cfg_dict = { **self.__dict__, # some args may not be serializable by default @@ -403,6 +408,7 @@ def to_dict(self) -> dict[str, Any]: return cfg_dict def to_json(self, path: str) -> None: + if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) @@ -480,6 +486,7 @@ def __post_init__(self): @dataclass class ToyModelSAERunnerConfig: + architecture: Literal["standard", "gated"] = "standard" # ReLu Model Parameters @@ -589,4 +596,4 @@ class PretokenizeRunnerConfig: hf_repo_id: str | None = None hf_num_shards: int = 64 hf_revision: str = "main" - hf_is_private_repo: bool = False + hf_is_private_repo: bool = False \ No newline at end of file From c16366c53572ea3b5f8829f4def41862d06d3c59 Mon Sep 17 00:00:00 2001 From: jbloomAus Date: Fri, 20 Sep 2024 10:13:19 +0100 Subject: [PATCH 5/5] format --- sae_lens/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 383698c9..a74b7565 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -596,4 +596,4 @@ class PretokenizeRunnerConfig: hf_repo_id: str | None = None hf_num_shards: int = 64 hf_revision: str = "main" - hf_is_private_repo: bool = False \ No newline at end of file + hf_is_private_repo: bool = False