Skip to content

Commit

Permalink
Merge pull request #76 from jbloomAus/faster-ci
Browse files Browse the repository at this point in the history
perf: improving CI speed
  • Loading branch information
jbloomAus authored Apr 11, 2024
2 parents 8784c74 + 392f982 commit 8b00000
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./wandb/*, ./research/wandb/*
exclude = ./wandb/*, ./research/wandb/*, .venv/*
34 changes: 30 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches:
- main
- clean_up_repo
pull_request:
branches:
- main
Expand All @@ -25,14 +24,41 @@ jobs:
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache Huggingface assets
uses: actions/cache@v4
with:
key: huggingface-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
path: ~/.cache/huggingface
restore-keys: |
huggingface-${{ runner.os }}-${{ matrix.python-version }}-
- name: Load cached Poetry installation
id: cached-poetry
uses: actions/cache@v4
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-0 # increment to reset cache
- name: Install Poetry
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
venv-${{ runner.os }}-${{ matrix.python-version }}-
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Lint with flake8
run: poetry run flake8 .
Expand All @@ -49,7 +75,7 @@ jobs:
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: jbloomAus/mats_sae_training
slug: jbloomAus/SAELens

release:

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ flake8 = "^7.0.0"
isort = "^5.13.2"
pyright = "^1.1.351"


[tool.isort]
profile = "black"

Expand All @@ -55,6 +56,7 @@ reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
24 changes: 15 additions & 9 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ class ActivationsStore:
cached_activations_path: str | None
tokens_column: Literal["tokens", "input_ids", "text"]
hook_point_head_index: int | None
_dataloader: Iterator[Any] | None = None
_storage_buffer: torch.Tensor | None = None

@classmethod
def from_config(
cls,
model: HookedTransformer,
cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
) -> "ActivationsStore":
cached_activations_path = cfg.cached_activations_path
# set cached_activations_path to None if we're not using cached activations
Expand All @@ -65,7 +66,6 @@ def from_config(
device=cfg.device,
dtype=cfg.dtype,
cached_activations_path=cached_activations_path,
create_dataloader=create_dataloader,
)

def __init__(
Expand All @@ -85,7 +85,6 @@ def __init__(
device: str | torch.device,
dtype: str | torch.dtype,
cached_activations_path: str | None = None,
create_dataloader: bool = True,
):
self.model = model
self.dataset = (
Expand Down Expand Up @@ -151,10 +150,17 @@ def __init__(

# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()
@property
def storage_buffer(self) -> torch.Tensor:
if self._storage_buffer is None:
self._storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
return self._storage_buffer

@property
def dataloader(self) -> Iterator[Any]:
if self._dataloader is None:
self._dataloader = self.get_data_loader()
return self._dataloader

def get_batch_tokens(self):
"""
Expand Down Expand Up @@ -363,7 +369,7 @@ def get_data_loader(
mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]

# 2. put 50 % in storage
self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]

# 3. put other 50 % in a dataloader
dataloader = iter(
Expand All @@ -387,7 +393,7 @@ def next_batch(self):
return next(self.dataloader)
except StopIteration:
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
self._dataloader = self.get_data_loader()
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
Expand Down
1 change: 0 additions & 1 deletion sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
activations_store = ActivationsStore.from_config(
model,
cfg,
create_dataloader=False,
)

# if the activations directory exists and has files in it, raise an exception
Expand Down
28 changes: 13 additions & 15 deletions sae_lens/training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
from itertools import product
from types import SimpleNamespace
from typing import Any, Iterator
from typing import Iterator

import torch

Expand Down Expand Up @@ -71,11 +71,8 @@ def to(self, device: torch.device | str):
for ae in self.autoencoders.values():
ae.to(device)

# old pickled SAEs load as a dict
@classmethod
def load_from_pretrained_legacy(
cls, path: str
) -> "SparseAutoencoderDictionary" | dict[str, Any]:
def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary":
"""
Load function for the model. Loads the model's state_dict and the config used to train it.
This method can be called directly on the class, without needing an instance.
Expand Down Expand Up @@ -129,18 +126,19 @@ def load_from_pretrained_legacy(
f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz"
)

return group
# # # Ensure the loaded state contains both 'cfg' and 'state_dict'
# # if "cfg" not in state_dict or "state_dict" not in state_dict:
# # raise ValueError(
# # "The loaded state dictionary must contain 'cfg' and 'state_dict' keys"
# # )
# handle loading old autoencoders where before SAEGroup existed, where we just save a dict
if isinstance(group, dict):
cfg = group["cfg"]
sparse_autoencoder = SparseAutoencoder(cfg=cfg)
sparse_autoencoder.load_state_dict(group["state_dict"])
group = cls(cfg)
for key in group.autoencoders:
group.autoencoders[key] = sparse_autoencoder

# # Create an instance of the class using the loaded configuration
# instance = cls(cfg=state_dict["cfg"])
# instance.load_state_dict(state_dict["state_dict"])
if not isinstance(group, cls):
raise ValueError("The loaded object is not a valid SAEGroup")

# return instance
return group

@classmethod
def load_from_pretrained(
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from transformer_lens import HookedTransformer

from tests.unit.helpers import TINYSTORIES_MODEL
from tests.unit.helpers import TINYSTORIES_MODEL, load_model_cached


@pytest.fixture
def ts_model():
return HookedTransformer.from_pretrained(TINYSTORIES_MODEL, device="cpu")
return load_model_cached(TINYSTORIES_MODEL)
22 changes: 19 additions & 3 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from transformer_lens import HookedTransformer

from sae_lens.training.config import LanguageModelSAERunnerConfig

Expand All @@ -26,11 +27,11 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
l1_coefficient=2e-3,
lp_norm=1,
lr=2e-4,
train_batch_size=2048,
context_size=64,
train_batch_size=4,
context_size=6,
feature_sampling_window=50,
dead_feature_threshold=1e-7,
n_batches_in_buffer=4,
n_batches_in_buffer=2,
total_training_tokens=1_000_000,
store_batch_size=4,
log_to_wandb=False,
Expand All @@ -48,3 +49,18 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
setattr(mock_config, key, val)

return mock_config


MODEL_CACHE: dict[str, HookedTransformer] = {}


def load_model_cached(model_name: str) -> HookedTransformer:
"""
helper to avoid unnecessarily loading the same model multiple times.
NOTE: if the model gets modified in tests this will not work.
"""
if model_name not in MODEL_CACHE:
MODEL_CACHE[model_name] = HookedTransformer.from_pretrained(
model_name, device="cpu"
)
return MODEL_CACHE[model_name]
Loading

0 comments on commit 8b00000

Please sign in to comment.