Skip to content

Commit

Permalink
feat: validate that pretokenized dataset tokenizer matches model toke…
Browse files Browse the repository at this point in the history
…nizer (#215)

Co-authored-by: Joseph Bloom <[email protected]>
  • Loading branch information
chanind and jbloomAus authored Jul 18, 2024
1 parent 22a0841 commit c73b811
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
42 changes: 42 additions & 0 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

import contextlib
import json
import os
from typing import Any, Generator, Iterator, Literal, cast

import numpy as np
import torch
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
from huggingface_hub import hf_hub_download
from huggingface_hub.utils._errors import HfHubHTTPError
from requests import HTTPError
from safetensors import safe_open
from safetensors.torch import save_file
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens.hook_points import HookedRootModule
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from sae_lens.config import (
DTYPE_MAP,
Expand Down Expand Up @@ -216,6 +221,15 @@ def __init__(
# TODO: investigate if this can work for iterable datasets, or if this is even worthwhile as a perf improvement
if hasattr(self.dataset, "set_format"):
self.dataset.set_format(type="torch", columns=[self.tokens_column]) # type: ignore

if (
isinstance(dataset, str)
and hasattr(model, "tokenizer")
and model.tokenizer is not None
):
validate_pretokenized_dataset_tokenizer(
dataset_path=dataset, model_tokenizer=model.tokenizer
)
else:
print(
"Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
Expand Down Expand Up @@ -639,3 +653,31 @@ def state_dict(self) -> dict[str, torch.Tensor]:

def save(self, file_path: str):
save_file(self.state_dict(), file_path)


def validate_pretokenized_dataset_tokenizer(
dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
) -> None:
"""
Helper to validate that the tokenizer used to pretokenize the dataset matches the model tokenizer.
"""
try:
tokenization_cfg_path = hf_hub_download(
dataset_path, "sae_lens.json", repo_type="dataset"
)
except HfHubHTTPError:
return
if tokenization_cfg_path is None:
return
with open(tokenization_cfg_path, "r") as f:
tokenization_cfg = json.load(f)
tokenizer_name = tokenization_cfg["tokenizer_name"]
try:
ds_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# if we can't download the specified tokenizer to verify, just continue
except HTTPError:
return
if ds_tokenizer.get_vocab() != model_tokenizer.get_vocab():
raise ValueError(
f"Dataset tokenizer {tokenizer_name} does not match model tokenizer {model_tokenizer}."
)
40 changes: 39 additions & 1 deletion tests/unit/training/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from sae_lens.config import LanguageModelSAERunnerConfig, PretokenizeRunnerConfig
from sae_lens.pretokenize_runner import pretokenize_dataset
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.activations_store import (
ActivationsStore,
validate_pretokenized_dataset_tokenizer,
)
from tests.unit.helpers import build_sae_cfg, load_model_cached


Expand Down Expand Up @@ -424,3 +427,38 @@ def test_activation_store__errors_if_neither_dataset_nor_dataset_path(

with pytest.raises(ValueError):
ActivationsStore.from_config(ts_model, cfg, override_dataset=None)


def test_validate_pretokenized_dataset_tokenizer_errors_if_the_tokenizer_doesnt_match_the_model():
ds_path = "chanind/openwebtext-gpt2"
model_tokenizer = HookedTransformer.from_pretrained("opt-125m").tokenizer
assert model_tokenizer is not None
with pytest.raises(ValueError):
validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer)


def test_validate_pretokenized_dataset_tokenizer_runs_successfully_if_tokenizers_match(
ts_model: HookedTransformer,
):
ds_path = "chanind/openwebtext-gpt2"
model_tokenizer = ts_model.tokenizer
assert model_tokenizer is not None
validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer)


def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_is_not_created_by_sae_lens(
ts_model: HookedTransformer,
):
ds_path = "apollo-research/monology-pile-uncopyrighted-tokenizer-gpt2"
model_tokenizer = ts_model.tokenizer
assert model_tokenizer is not None
validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer)


def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_path_doesnt_exist(
ts_model: HookedTransformer,
):
ds_path = "blah/nonsense-1234"
model_tokenizer = ts_model.tokenizer
assert model_tokenizer is not None
validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer)

0 comments on commit c73b811

Please sign in to comment.