Skip to content

Commit

Permalink
feat: allow models to be passed in as overrides (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lewington-pitsos authored Jul 4, 2024
1 parent 1db84b5 commit dd95996
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,28 @@ def __init__(
self,
cfg: LanguageModelSAERunnerConfig,
override_dataset: HfDataset | None = None,
override_model: HookedRootModule | None = None,
):
if override_dataset is not None:
logging.warning(
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducable via configuration alone."
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
)
if override_model is not None:
logging.warning(
f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
)

self.cfg = cfg

self.model = load_model(
self.cfg.model_class_name,
self.cfg.model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)
if override_model is None:
self.model = load_model(
self.cfg.model_class_name,
self.cfg.model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)
else:
self.model = override_model

self.activations_store = ActivationsStore.from_config(
self.model,
Expand Down

0 comments on commit dd95996

Please sign in to comment.