Skip to content

Commit

Permalink
fix: session loader wasn't working
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 15, 2024
1 parent ac606a3 commit a928d7e
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions sae_lens/training/session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


class LMSparseAutoencoderSessionloader:
Expand Down Expand Up @@ -46,17 +45,16 @@ def load_pretrained_sae(
"""

# load the SAE
sparse_autoencoder = SparseAutoencoder.load_from_pretrained(path)
sparse_autoencoder.to(device)
sparse_autoencoder.cfg.device = device
sparse_autoencoders = SparseAutoencoderDictionary.load_from_pretrained(
path, device
)
first_sparse_autoencoder_cfg = next(iter(sparse_autoencoders))[1].cfg

# load the model, SAE and activations loader with it.
session_loader = cls(sparse_autoencoder.cfg)
model, sae_group, activations_loader = (
session_loader.load_sae_training_group_session()
)
session_loader = cls(first_sparse_autoencoder_cfg)
model, _, activations_loader = session_loader.load_sae_training_group_session()

return model, sae_group, activations_loader
return model, sparse_autoencoders, activations_loader

def get_model(self, model_name: str) -> HookedTransformer:
"""
Expand Down

0 comments on commit a928d7e

Please sign in to comment.