Skip to content

Commit

Permalink
Pretokenize runner (#148)
Browse files Browse the repository at this point in the history
* feat: adding a pretokenize runner

* rewriting pretokenization based on feedback
  • Loading branch information
chanind authored May 15, 2024
1 parent 9ce0fe4 commit f864178
Show file tree
Hide file tree
Showing 7 changed files with 895 additions and 2 deletions.
9 changes: 8 additions & 1 deletion sae_lens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

from .training.activations_store import ActivationsStore
from .training.cache_activations_runner import CacheActivationsRunner
from .training.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig
from .training.config import (
CacheActivationsRunnerConfig,
LanguageModelSAERunnerConfig,
PretokenizeRunnerConfig,
)
from .training.evals import run_evals
from .training.lm_runner import language_model_sae_runner
from .training.pretokenize_runner import pretokenize_runner
from .training.sae_group import SparseAutoencoderDictionary
from .training.session_loader import LMSparseAutoencoderSessionloader
from .training.sparse_autoencoder import SparseAutoencoder
Expand All @@ -14,10 +19,12 @@
"LanguageModelSAERunnerConfig",
"CacheActivationsRunnerConfig",
"LMSparseAutoencoderSessionloader",
"PretokenizeRunnerConfig",
"SparseAutoencoder",
"SparseAutoencoderDictionary",
"run_evals",
"language_model_sae_runner",
"pretokenize_runner",
"CacheActivationsRunner",
"ActivationsStore",
"train_sae_group_on_language_model",
Expand Down
92 changes: 92 additions & 0 deletions sae_lens/training/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Generator, Iterator

import torch


def _add_tokens_to_batch(
batch: torch.Tensor | None,
tokens: torch.Tensor,
context_size: int,
is_start_of_sequence: bool,
begin_batch_token_id: int | None = None,
begin_sequence_token_id: int | None = None,
sequence_separator_token_id: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
original_tokens = tokens
# prepend the start of sequence token if needed
if is_start_of_sequence and begin_sequence_token_id is not None:
begin_sequence_token_id_tensor = torch.tensor(
[begin_sequence_token_id], dtype=torch.long, device=tokens.device
)
if tokens[0] != begin_sequence_token_id_tensor:
tokens = torch.concat([begin_sequence_token_id_tensor, tokens], dim=0)
# We're at the start of a new batch
if batch is None:
# add the BOS token to the start if needed
if begin_batch_token_id is not None:
begin_batch_token_id_tensor = torch.tensor(
[begin_batch_token_id], dtype=torch.long, device=tokens.device
)
if tokens[0] != begin_batch_token_id_tensor:
tokens = torch.concat([begin_batch_token_id_tensor, tokens], dim=0)
batch = tokens[:context_size]
return tokens[:context_size], tokens[context_size:]
# if we're concatting batches, add the separator token as needed
if sequence_separator_token_id is not None:
sequence_separator_token_id_tensor = torch.tensor(
[sequence_separator_token_id], dtype=torch.long, device=tokens.device
)
if tokens[0] != sequence_separator_token_id_tensor:
tokens = torch.concat([sequence_separator_token_id_tensor, tokens], dim=0)
tokens_needed = context_size - batch.shape[0]
batch = torch.concat([batch, tokens[:tokens_needed]])

remaining_tokens = tokens[tokens_needed:]
# it's possible we've prepending 2 tokens to original_tokens, but only removed 1
# if so, we should only return the original tokens
if len(remaining_tokens) > len(original_tokens):
remaining_tokens = original_tokens
return batch, remaining_tokens


@torch.no_grad()
def concat_and_batch_sequences(
tokens_iterator: Iterator[torch.Tensor],
context_size: int,
begin_batch_token_id: int | None = None,
begin_sequence_token_id: int | None = None,
sequence_separator_token_id: int | None = None,
) -> Generator[torch.Tensor, None, None]:
"""
Generator to concat token sequences together from the tokens_interator, yielding
batches of size `context_size`.
Args:
tokens_iterator: An iterator which returns a 1D tensors of tokens
context_size: Each batch will have this many tokens
begin_batch_token_id: If provided, this token will be at position 0 of each batch
begin_sequence_token_id: If provided, this token will be the first token of each sequence
sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
max_batches: If not provided, the iterator will be run to completion.
"""
batch: torch.Tensor | None = None
for tokens in tokens_iterator:
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
remaining_tokens = tokens
is_start_of_sequence = True
while len(remaining_tokens) > 0:
batch, remaining_tokens = _add_tokens_to_batch(
batch=batch,
tokens=remaining_tokens,
context_size=context_size,
is_start_of_sequence=is_start_of_sequence,
begin_batch_token_id=begin_batch_token_id,
begin_sequence_token_id=begin_sequence_token_id,
sequence_separator_token_id=sequence_separator_token_id,
)
is_start_of_sequence = False
if batch.shape[0] == context_size:
yield batch
batch = None
31 changes: 30 additions & 1 deletion sae_lens/training/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from dataclasses import dataclass, field
from typing import Any, Optional, cast
from typing import Any, Literal, Optional, cast

import torch
import wandb
Expand Down Expand Up @@ -409,3 +409,32 @@ def _default_cached_activations_path(
if hook_point_head_index is not None:
path += f"_{hook_point_head_index}"
return path


@dataclass
class PretokenizeRunnerConfig:
tokenizer_name: str = "gpt2"
dataset_path: str = "NeelNanda/c4-10k"
split: str | None = "train"
data_files: list[str] | None = None
data_dir: str | None = None
num_proc: int = 4
context_size: int = 128
column_name: str = "text"
shuffle: bool = True
seed: int | None = None
streaming: bool = False

# special tokens
begin_batch_token: int | Literal["bos", "eos", "sep"] | None = "bos"
begin_sequence_token: int | Literal["bos", "eos", "sep"] | None = None
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "eos"

# if saving locally, set save_path
save_path: str | None = None

# if saving to huggingface, set hf_repo_id
hf_repo_id: str | None = None
hf_num_shards: int = 64
hf_revision: str = "main"
hf_is_private_repo: bool = False
169 changes: 169 additions & 0 deletions sae_lens/training/pretokenize_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import io
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, Literal, cast

import torch
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from sae_lens import __version__
from sae_lens.training.batching import concat_and_batch_sequences
from sae_lens.training.config import PretokenizeRunnerConfig


@dataclass
class PretokenizedDatasetMetadata:
"""
This metadata will be saved along with the pretokenized dataset as a JSON file.
"""

sae_lens_version: str
tokenizer_name: str
original_dataset: str
original_split: str | None
original_data_files: list[str] | None
context_size: int
shuffled: bool
seed: int | None
begin_batch_token: int | Literal["bos", "eos", "sep"] | None
begin_sequence_token: int | Literal["bos", "eos", "sep"] | None
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None


def metadata_from_config(cfg: PretokenizeRunnerConfig) -> PretokenizedDatasetMetadata:
return PretokenizedDatasetMetadata(
sae_lens_version=__version__,
tokenizer_name=cfg.tokenizer_name,
original_dataset=cfg.dataset_path,
original_split=cfg.split,
original_data_files=cfg.data_files,
context_size=cfg.context_size,
shuffled=cfg.shuffle,
seed=cfg.seed,
begin_batch_token=cfg.begin_batch_token,
begin_sequence_token=cfg.begin_sequence_token,
sequence_separator_token=cfg.sequence_separator_token,
)


def get_special_token_from_cfg(
cfg_token: int | Literal["bos", "eos", "sep"] | None,
tokenizer: PreTrainedTokenizerBase,
) -> int | None:
if cfg_token is None:
return None
if isinstance(cfg_token, int):
return cfg_token
if cfg_token == "bos":
return tokenizer.bos_token_id
if cfg_token == "eos":
return tokenizer.eos_token_id
if cfg_token == "sep":
return tokenizer.sep_token_id
raise ValueError(f"Invalid token type: {cfg_token}")


def pretokenize_dataset(
dataset: Dataset,
tokenizer: PreTrainedTokenizerBase,
cfg: PretokenizeRunnerConfig,
):
def process_examples(examples: dict[str, list[str]]):
tokens_iterator = cast(
Iterator[torch.Tensor],
(
tokenizer.encode(text, return_tensors="pt")[0]
for text in examples[cfg.column_name]
),
)
return {
"input_ids": list(
concat_and_batch_sequences(
tokens_iterator=tokens_iterator,
context_size=cfg.context_size,
begin_batch_token_id=get_special_token_from_cfg(
cfg.begin_batch_token, tokenizer
),
begin_sequence_token_id=get_special_token_from_cfg(
cfg.begin_sequence_token, tokenizer
),
sequence_separator_token_id=get_special_token_from_cfg(
cfg.sequence_separator_token, tokenizer
),
)
)
}

tokenized_dataset = dataset.map(
process_examples,
batched=True,
num_proc=cfg.num_proc,
remove_columns=dataset.column_names,
)
if cfg.shuffle:
tokenized_dataset = tokenized_dataset.shuffle(seed=cfg.seed)
tokenized_dataset.set_format(type="torch", columns=["input_ids"])
return tokenized_dataset


def push_to_hugging_face_hub(
dataset: Dataset,
cfg: PretokenizeRunnerConfig,
):
assert cfg.hf_repo_id is not None
dataset.push_to_hub(
repo_id=cfg.hf_repo_id,
num_shards=cfg.hf_num_shards,
private=cfg.hf_is_private_repo,
revision=cfg.hf_revision,
)
# also upload metadata file
metadata = metadata_from_config(cfg)
meta_io = io.BytesIO()
meta_contents = json.dumps(metadata.__dict__, indent=2, ensure_ascii=False).encode(
"utf-8"
)
meta_io.write(meta_contents)
meta_io.seek(0)

api = HfApi()
api.upload_file(
path_or_fileobj=meta_io,
path_in_repo="sae_lens.json",
repo_id=cfg.hf_repo_id,
repo_type="dataset",
commit_message="Add sae_lens metadata",
)


def pretokenize_runner(
cfg: PretokenizeRunnerConfig,
):
dataset = load_dataset(
cfg.dataset_path,
data_dir=cfg.data_dir,
data_files=cfg.data_files,
split=cfg.split,
streaming=cfg.streaming,
)
if isinstance(dataset, DatasetDict):
raise ValueError("Dataset has multiple splits. Must provide a 'split' param.")
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)
tokenizer.model_max_length = sys.maxsize
tokenized_dataset = pretokenize_dataset(cast(Dataset, dataset), tokenizer, cfg)

if cfg.save_path is not None:
tokenized_dataset.save_to_disk(cfg.save_path)
metadata = metadata_from_config(cfg)
metadata_path = Path(cfg.save_path) / "sae_lens.json"
with open(metadata_path, "w") as f:
json.dump(metadata.__dict__, f, indent=2, ensure_ascii=False)

if cfg.hf_repo_id is not None:
push_to_hugging_face_hub(tokenized_dataset, cfg)

return tokenized_dataset
Loading

0 comments on commit f864178

Please sign in to comment.