Skip to content

Commit

Permalink
Merge pull request #1 from jbloomAus/activations_on_disk
Browse files Browse the repository at this point in the history
Activations on disk
  • Loading branch information
jbloomAus authored Dec 14, 2023
2 parents b5344a3 + 94ed3e6 commit e5f198e
Show file tree
Hide file tree
Showing 7 changed files with 621 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ cython_debug/
wandb/
checkpoints/
data/
artifacts/
artifacts/
activations/
400 changes: 397 additions & 3 deletions research/run.ipynb

Large diffs are not rendered by default.

74 changes: 69 additions & 5 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer


class ActivationsStore:
"""
Class for streaming tokens and generating and storing activations
while training SAEs.
"""
def __init__(
self, cfg, model: HookedTransformer,
self, cfg, model: HookedTransformer, create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
Expand All @@ -26,9 +26,30 @@ def __init__(
self.cfg.is_dataset_tokenized = False
print("Dataset is not tokenized! Updating config.")

# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()
if self.cfg.use_cached_activations:
# Sanity check: does the cache directory exist?
assert os.path.exists(self.cfg.cached_activations_path), \
f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

self.next_cache_idx = 0 # which file to open next
self.next_idx_within_buffer = 0 # where to start reading from in that file

# Check that we have enough data on disk
first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt")
buffer_size_on_disk = first_buffer.shape[0]
n_buffers_on_disk = len(os.listdir(self.cfg.cached_activations_path))
# Note: we're assuming all files have the same number of tokens
# (which seems reasonable imo since that's what our script does)
n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk
assert n_activations_on_disk > self.cfg.total_training_tokens, \
f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M."

# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()

def get_batch_tokens(self):
"""
Expand Down Expand Up @@ -137,6 +158,49 @@ def get_buffer(self, n_batches_in_buffer):
d_in = self.cfg.d_in
total_size = batch_size * n_batches_in_buffer

if self.cfg.use_cached_activations:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor (flattened along all dims except d_in)
new_buffer = torch.zeros((buffer_size, d_in), dtype=self.cfg.dtype,
device=self.cfg.device)
n_tokens_filled = 0

# The activations may be split across multiple files,
# Or we might only want a subset of one file (depending on the sizes)
while n_tokens_filled < buffer_size:
# Load the next file
# Make sure it exists
if not os.path.exists(f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"):
print("\n\nWarning: Ran out of cached activation files earlier than expected.")
print(f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}.")
if buffer_size % self.cfg.total_training_tokens != 0:
print("This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens")
print(f"Returning a buffer of size {n_tokens_filled} instead.")
print("\n\n")
new_buffer = new_buffer[:n_tokens_filled]
break
activations = torch.load(f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt")

# If we only want a subset of the file, take it
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
activations = activations[:buffer_size - n_tokens_filled]
taking_subset_of_file = True

# Add it to the buffer
new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = activations

# Update counters
n_tokens_filled += activations.shape[0]
if taking_subset_of_file:
self.next_idx_within_buffer = activations.shape[0]
else:
self.next_cache_idx += 1
self.next_idx_within_buffer = 0

return new_buffer

refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# refill_iterator = tqdm(refill_iterator, desc="generate activations")

Expand Down
49 changes: 49 additions & 0 deletions sae_training/cache_activations_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import math
import os

import torch
from transformer_lens import HookedTransformer
from tqdm import tqdm

from sae_training.activations_store import ActivationsStore
from sae_training.config import CacheActivationsRunnerConfig
from sae_training.utils import shuffle_activations_pairwise


def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
model = HookedTransformer.from_pretrained(cfg.model_name)
model.to(cfg.device)
activations_store = ActivationsStore(cfg, model, create_dataloader=False)

# if the activations directory exists and has files in it, raise an exception
if os.path.exists(activations_store.cfg.cached_activations_path):
if len(os.listdir(activations_store.cfg.cached_activations_path)) > 0:
raise Exception(f"Activations directory ({activations_store.cfg.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files.")
else:
os.makedirs(activations_store.cfg.cached_activations_path)

print(f"Started caching {cfg.total_training_tokens} activations")
tokens_per_buffer = cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer)
for i in tqdm(range(n_buffers), desc="Caching activations"):
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer)
torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt")
del buffer

if i % cfg.shuffle_every_n_buffers == 0 and i > 0:
# Shuffle the buffers on disk

# Do random pairwise shuffling between the last shuffle_every_n_buffers buffers
for _ in range(cfg.n_shuffles_with_last_section):
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
buffer_idx_range=(i - cfg.shuffle_every_n_buffers, i))

# Do more random pairwise shuffling between all the buffers
for _ in range(cfg.n_shuffles_in_entire_dir):
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
buffer_idx_range=(0, i))

# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"):
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path,
buffer_idx_range=(0, n_buffers))
68 changes: 52 additions & 16 deletions sae_training/config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,55 @@

from abc import ABC
from dataclasses import dataclass
from typing import Optional

import torch

import wandb


@dataclass
class LanguageModelSAERunnerConfig:
class RunnerConfig(ABC):
"""
Configuration for training a sparse autoencoder on a language model.
The config that's shared across all runners.
"""

# Data Generating Function (Model + Training Distibuion)
model_name: str = "gelu-2l"
hook_point: str = "blocks.0.hook_mlp_out"
hook_point_layer: int = 0
hook_point_head_index: Optional[int] = None
dataset_path: str = "NeelNanda/c4-tokenized-2b"
is_dataset_tokenized: bool = True
context_size: int = 128
use_cached_activations: bool = False
cached_activations_path: Optional[str] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"

# SAE Parameters
d_in: int = 512

# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
store_batch_size: int = 1024

# Misc
device: str = "cpu"
seed: int = 42
dtype: torch.dtype = torch.float32

def __post_init__(self):
# Autofill cached_activations_path unless the user overrode it
if self.cached_activations_path is None:
self.cached_activations_path = f"activations/{self.dataset_path.replace('/', '_')}/{self.model_name.replace('/', '_')}/{self.hook_point}"
if self.hook_point_head_index is not None:
self.cached_activations_path += f"_{self.hook_point_head_index}"


@dataclass
class LanguageModelSAERunnerConfig(RunnerConfig):
"""
Configuration for training a sparse autoencoder on a language model.
"""

# SAE Parameters
expansion_factor: int = 4

# Training Parameters
Expand All @@ -31,7 +58,6 @@ class LanguageModelSAERunnerConfig:
lr_scheduler_name: str = "constant" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_warm_up_steps: int = 500
train_batch_size: int = 4096
context_size: int = 128

# Resampling protocol args
feature_sampling_method: str = "l2" # None or l2
Expand All @@ -40,25 +66,18 @@ class LanguageModelSAERunnerConfig:
dead_feature_window: int = 100 # unless this window is larger feature sampling,
dead_feature_threshold: float = 1e-8

# Activation Store Parameters
n_batches_in_buffer: int = 20
total_training_tokens: int = 2_000_000
store_batch_size: int = 1024

# WANDB
log_to_wandb: bool = True
wandb_project: str = "mats_sae_training_language_model"
wandb_entity: str = None
wandb_log_frequency: int = 10

# Misc
device: str = "cpu"
seed: int = 42
n_checkpoints: int = 0
checkpoint_path: str = "checkpoints"
dtype: torch.dtype = torch.float32

def __post_init__(self):
super().__post_init__()
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = self.train_batch_size * self.context_size * self.n_batches_in_buffer

Expand All @@ -69,7 +88,7 @@ def __post_init__(self):

unique_id = wandb.util.generate_id()
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

# Print out some useful info:
n_tokens_per_buffer = self.store_batch_size * self.context_size * self.n_batches_in_buffer
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 **6}")
Expand All @@ -82,4 +101,21 @@ def __post_init__(self):
# how many times will we sample dead neurons?
n_dead_feature_samples = total_training_steps // self.dead_feature_window - 1
print(f"n_dead_feature_samples: {n_dead_feature_samples}")


@dataclass
class CacheActivationsRunnerConfig(RunnerConfig):
"""
Configuration for caching activations of an LLM.
"""
# Activation caching stuff
shuffle_every_n_buffers: int = 10
n_shuffles_with_last_section: int = 10
n_shuffles_in_entire_dir: int = 10
n_shuffles_final: int = 100

def __post_init__(self):
super().__post_init__()
if self.use_cached_activations:
# this is a dummy property in this context; only here to avoid class compatibility headaches
raise ValueError("use_cached_activations should be False when running cache_activations_runner")

25 changes: 25 additions & 0 deletions sae_training/timeit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
This is a util to time the execution of a function.
(Has to be a separate file, if you put it in utils.py you get circular imports; need to find a permanent home for it)
"""

from functools import wraps
import time

def timeit(func):
"""
Decorator to time a function.
Taken from https://dev.to/kcdchennai/python-decorator-to-measure-execution-time-54hk
"""
@wraps(func)
def timeit_wrapper(*args, **kwargs):
start_time = time.perf_counter()
result = func(*args, **kwargs)
end_time = time.perf_counter()
total_time = end_time - start_time
print(f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds')
return result
return timeit_wrapper

30 changes: 27 additions & 3 deletions sae_training/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Tuple

from typing import Tuple
import torch
from transformer_lens import HookedTransformer

Expand Down Expand Up @@ -75,4 +74,29 @@ def get_activations_loader(self, cfg: LanguageModelSAERunnerConfig, model: Hooke
cfg, model,
)

return activations_loader
return activations_loader

def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int]):
"""
Shuffles two buffers on disk.
"""
assert buffer_idx_range[0] < buffer_idx_range[1], \
"buffer_idx_range[0] must be smaller than buffer_idx_range[1]"

buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
while buffer_idx1 == buffer_idx2: # Make sure they're not the same
buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()

buffer1 = torch.load(f"{datapath}/{buffer_idx1}.pt")
buffer2 = torch.load(f"{datapath}/{buffer_idx2}.pt")
joint_buffer = torch.cat([buffer1, buffer2])

# Shuffle them
joint_buffer = joint_buffer[torch.randperm(joint_buffer.shape[0])]
shuffled_buffer1 = joint_buffer[:buffer1.shape[0]]
shuffled_buffer2 = joint_buffer[buffer1.shape[0]:]

# Save them back
torch.save(shuffled_buffer1, f"{datapath}/{buffer_idx1}.pt")
torch.save(shuffled_buffer2, f"{datapath}/{buffer_idx2}.pt")

0 comments on commit e5f198e

Please sign in to comment.