Skip to content

Commit

Permalink
Merge pull request #17 from Benw8888/sae_group_pr
Browse files Browse the repository at this point in the history
SAE Group for sweeps PR
  • Loading branch information
jbloomAus authored Feb 29, 2024
2 parents ad84706 + dd24413 commit 3e78bce
Show file tree
Hide file tree
Showing 17 changed files with 906 additions and 6,650 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./sae_training/geom_median/
exclude = ./sae_training/geom_median/, ./wandb/*, ./research/wandb/*
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ repos:
- id: check-added-large-files
args: [--maxkb=250000]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--config=.flake8']
additional_dependencies: [
'flake8-blind-except',
'flake8-docstrings',
# 'flake8-docstrings',
'flake8-bugbear',
'flake8-comprehensions',
'flake8-docstrings',
'flake8-implicit-str-concat',
'pydocstyle>=5.0.0',
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ipykernel = "^6.29.2"
matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"


[tool.poetry.group.dev.dependencies]
Expand All @@ -33,4 +34,4 @@ profile = "black"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
build-backend = "poetry.core.masonry.api"
111 changes: 66 additions & 45 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def __init__(
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.iterable_dataset = iter(self.dataset)

# check if it's tokenized
if "tokens" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = True
print("Dataset is tokenized! Updating config.")
elif "text" in next(self.iterable_dataset).keys():
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")
# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)
self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys()
print(
f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations:
if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cfg.cached_activations_path
Expand Down Expand Up @@ -144,39 +144,65 @@ def get_batch_tokens(self):
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens, get_loss=False):
act_name = self.cfg.hook_point
hook_point_layer = self.cfg.hook_point_layer
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
layers = (
self.cfg.hook_point_layer
if isinstance(self.cfg.hook_point_layer, list)
else [self.cfg.hook_point_layer]
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
if self.cfg.hook_point_head_index is not None:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name][:, :, self.cfg.hook_point_head_index]
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
activations = self.model.run_with_cache(
batch_tokens, names_filter=act_name, stop_at_layer=hook_point_layer + 1
)[1][act_name]
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name] for act_name in act_names
]

# Stack along a new dimension to keep separate layers distinct
stacked_activations = torch.stack(activations_list, dim=2)

return activations
return stacked_activations

def get_buffer(self, n_batches_in_buffer):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = (
len(self.cfg.hook_point_layer)
if isinstance(self.cfg.hook_point_layer, list)
else 1
) # Number of hook points or layers

if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor (flattened along all dims except d_in)
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, d_in), dtype=self.cfg.dtype, device=self.cfg.device
(buffer_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)
n_tokens_filled = 0

# The activations may be split across multiple files,
# Or we might only want a subset of one file (depending on the sizes)
# Assume activations for different layers are stored separately and need to be combined
while n_tokens_filled < buffer_size:
# Load the next file
# Make sure it exists
if not os.path.exists(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
):
Expand All @@ -192,55 +218,49 @@ def get_buffer(self, n_batches_in_buffer):
)
print(f"Returning a buffer of size {n_tokens_filled} instead.")
print("\n\n")
new_buffer = new_buffer[:n_tokens_filled]
break
new_buffer = new_buffer[:n_tokens_filled, ...]
return new_buffer

activations = torch.load(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
)

# If we only want a subset of the file, take it
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
activations = activations[: buffer_size - n_tokens_filled]
activations = activations[: buffer_size - n_tokens_filled, ...]
taking_subset_of_file = True

# Add it to the buffer
new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = (
activations
)
new_buffer[
n_tokens_filled : n_tokens_filled + activations.shape[0], ...
] = activations

# Update counters
n_tokens_filled += activations.shape[0]
if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
else:
self.next_cache_idx += 1
self.next_idx_within_buffer = 0

n_tokens_filled += activations.shape[0]

return new_buffer

refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# refill_iterator = tqdm(refill_iterator, desc="generate activations")

# Initialize empty tensor buffer of the maximum required size
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, d_in),
(total_size, context_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
)

# Insert activations directly into pre-allocated buffer
# pbar = tqdm(total=n_batches_in_buffer, desc="Filling buffer")
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[refill_batch_idx_start : refill_batch_idx_start + batch_size] = (
refill_activations
)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
] = refill_activations

# pbar.update(1)

new_buffer = new_buffer.reshape(-1, d_in)
new_buffer = new_buffer.reshape(-1, num_layers, d_in)
new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])]

return new_buffer
Expand All @@ -260,7 +280,8 @@ def get_data_loader(

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer]
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer],
dim=0,
)

mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
Expand Down
3 changes: 2 additions & 1 deletion sae_training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
)
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer)
for i in tqdm(range(n_buffers), desc="Caching activations"):
# for i in tqdm(range(n_buffers), desc="Caching activations"):
for i in range(n_buffers):
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer)
torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt")
del buffer
Expand Down
7 changes: 5 additions & 2 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RunnerConfig(ABC):

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point: str = "blocks.{layer}.hook_mlp_out"
hook_point_layer: int = 0
hook_point_head_index: Optional[int] = None
dataset_path: str = "NeelNanda/c4-tokenized-2b"
Expand Down Expand Up @@ -56,9 +56,11 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
b_dec_init_method: str = "geometric_median"
expansion_factor: int = 4
from_pretrained_path: Optional[str] = None
d_sae: Optional[int] = None

# Training Parameters
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
Expand Down Expand Up @@ -86,7 +88,8 @@ class LanguageModelSAERunnerConfig(RunnerConfig):

def __post_init__(self):
super().__post_init__()
self.d_sae = self.d_in * self.expansion_factor
if not isinstance(self.expansion_factor, list):
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = (
self.train_batch_size * self.context_size * self.n_batches_in_buffer
)
Expand Down
17 changes: 9 additions & 8 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run_evals(
activation_store: ActivationsStore,
model: HookedTransformer,
n_training_steps: int,
suffix: str = "",
):
hook_point = sparse_autoencoder.cfg.hook_point
hook_point_layer = sparse_autoencoder.cfg.hook_point_layer
Expand Down Expand Up @@ -71,13 +72,13 @@ def run_evals(
wandb.log(
{
# l2 norms
"metrics/l2_norm": l2_norm_out.mean().item(),
"metrics/l2_ratio": l2_norm_ratio.mean().item(),
f"metrics/l2_norm{suffix}": l2_norm_out.mean().item(),
f"metrics/l2_ratio{suffix}": l2_norm_ratio.mean().item(),
# CE Loss
"metrics/CE_loss_score": recons_score,
"metrics/ce_loss_without_sae": ntp_loss,
"metrics/ce_loss_with_sae": recons_loss,
"metrics/ce_loss_with_ablation": zero_abl_loss,
f"metrics/CE_loss_score{suffix}": recons_score,
f"metrics/ce_loss_without_sae{suffix}": ntp_loss,
f"metrics/ce_loss_with_sae{suffix}": recons_loss,
f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss,
},
step=n_training_steps,
)
Expand Down Expand Up @@ -142,8 +143,8 @@ def head_replacement_hook(activations, hook):
if wandb.run is not None:
wandb.log(
{
"metrics/kldiv_reconstructed": kl_result_reconstructed.mean().item(),
"metrics/kldiv_ablation": kl_result_ablation.mean().item(),
f"metrics/kldiv_reconstructed{suffix}": kl_result_reconstructed.mean().item(),
f"metrics/kldiv_ablation{suffix}": kl_result_ablation.mean().item(),
},
step=n_training_steps,
)
Expand Down
Loading

0 comments on commit 3e78bce

Please sign in to comment.