Skip to content

Commit

Permalink
add arg for dead neuron calc
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Jan 18, 2024
1 parent 0319d89 commit ffb75fb
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 39 deletions.
101 changes: 66 additions & 35 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class RunnerConfig(ABC):
"""
The config that's shared across all runners.
"""

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
Expand All @@ -21,22 +22,23 @@ class RunnerConfig(ABC):
is_dataset_tokenized: bool = True
context_size: int = 128
use_cached_activations: bool = False
cached_activations_path: Optional[str] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"

cached_activations_path: Optional[
str
] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"

# SAE Parameters
d_in: int = 512


# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
store_batch_size: int = 1024

# Misc
device: str = "cpu"
seed: int = 42
dtype: torch.dtype = torch.float32

def __post_init__(self):
# Autofill cached_activations_path unless the user overrode it
if self.cached_activations_path is None:
Expand All @@ -55,79 +57,107 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
b_dec_init_method: str = "geometric_median"
expansion_factor: int = 4
from_pretrained_path: Optional[str] = None

# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_warm_up_steps: int = 500
train_batch_size: int = 4096

# Resampling protocol args
feature_sampling_window: int = 2000
feature_sampling_method: str = "Anthropic" # None or Anthropic
feature_sampling_method: str = "Anthropic" # None or Anthropic
resample_batches: int = 32
feature_reinit_scale: float = 0.2
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
dead_feature_estimation_method: str = "no_fire"
dead_feature_threshold: float = 1e-8

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_language_model"
wandb_entity: str = None
wandb_log_frequency: int = 10

# Misc
n_checkpoints: int = 0
checkpoint_path: str = "checkpoints"

def __post_init__(self):
super().__post_init__()
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = self.train_batch_size * self.context_size * self.n_batches_in_buffer

self.tokens_per_buffer = (
self.train_batch_size * self.context_size * self.n_batches_in_buffer
)

self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"

if self.feature_sampling_method not in [None, "l2", "anthropic"]:
raise ValueError(f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}")

raise ValueError(
f"feature_sampling_method must be None, l2, or anthropic. Got {self.feature_sampling_method}"
)

if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
raise ValueError(f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}")
raise ValueError(
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
)
if self.b_dec_init_method == "zeros":
print("Warning: We are initializing b_dec to zeros. This is probably not what you want.")

print(
"Warning: We are initializing b_dec to zeros. This is probably not what you want."
)

self.device = torch.device(self.device)
unique_id = wandb.util.generate_id()

unique_id = wandb.util.generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

print(f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}")
print(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}"
)
# Print out some useful info:
n_tokens_per_buffer = self.store_batch_size * self.context_size * self.n_batches_in_buffer
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
print(f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}")

print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}"
)

total_training_steps = self.total_training_tokens // self.train_batch_size
print(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")

# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_dead_feature_samples = total_training_steps // self.dead_feature_window
n_feature_window_samples = total_training_steps // self.feature_sampling_window
n_dead_feature_samples = total_training_steps // self.dead_feature_window
n_feature_window_samples = total_training_steps // self.feature_sampling_window
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 **6}"
)
print(f"We will reset neurons {n_dead_feature_samples} times.")
print(f"We will reset the sparsity calculation {n_feature_window_samples} times.")
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
print(f"Number of tokens when resampling: {self.resample_batches * self.store_batch_size}")
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
print(f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}")
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
)


@dataclass
class CacheActivationsRunnerConfig(RunnerConfig):
"""
Configuration for caching activations of an LLM.
"""

# Activation caching stuff
shuffle_every_n_buffers: int = 10
n_shuffles_with_last_section: int = 10
Expand All @@ -138,5 +168,6 @@ def __post_init__(self):
super().__post_init__()
if self.use_cached_activations:
# this is a dummy property in this context; only here to avoid class compatibility headaches
raise ValueError("use_cached_activations should be False when running cache_activations_runner")

raise ValueError(
"use_cached_activations should be False when running cache_activations_runner"
)
14 changes: 10 additions & 4 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def train_sae_on_language_model(
feature_sparsity = act_freq_scores / n_frac_active_tokens

# if reset criterion is frequency in window, then then use that to generate indices.
# dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]

# if reset criterion is has_fired, then use that to generate indices.
dead_neuron_indices = (act_freq_scores == 0).nonzero(as_tuple=False)[:, 0]
if sparse_autoencoder.cfg.dead_feature_estimation_method == "no_fire":
dead_neuron_indices = (act_freq_scores == 0).nonzero(as_tuple=False)[:, 0]
elif sparse_autoencoder.cfg.dead_feature_estimation_method == "frequency":
dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]

if len(dead_neuron_indices) > 0:

if len(dead_neuron_indices) > sparse_autoencoder.cfg.resample_batches * sparse_autoencoder.cfg.store_batch_size:
print("Warning: more dead neurons than number of tokens. Consider sampling more tokens when resampling.")

sparse_autoencoder.resample_neurons_anthropic(
dead_neuron_indices,
model,
Expand All @@ -97,6 +101,8 @@ def train_sae_on_language_model(
increment = (current_lr - reduced_lr) / 10_000
optimizer.param_groups[0]['lr'] = reduced_lr
steps_before_reset = 10_000
else:
print("No dead neurons, skipping resampling")

# Resample dead neurons
if (feature_sampling_method == "l2") and ((n_training_steps + 1) % dead_feature_window == 0):
Expand Down

0 comments on commit ffb75fb

Please sign in to comment.