Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jbloomAus/mats_sae_training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyfarnik committed Jan 26, 2024
2 parents 74d4fb8 + 4c5fed8 commit a22d856
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import numpy as np
import plotly_express as px
import torch
import wandb
from torch.optim import Adam
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.optim import get_scheduler
from sae_training.sparse_autoencoder import SparseAutoencoder
Expand All @@ -29,6 +29,8 @@ def train_sae_on_language_model(
wandb_log_frequency: int = 50,
):

if feature_sampling_method is not None:
feature_sampling_method = feature_sampling_method.lower()

total_training_tokens = sparse_autoencoder.cfg.total_training_tokens
total_training_steps = total_training_tokens // batch_size
Expand Down Expand Up @@ -64,7 +66,7 @@ def train_sae_on_language_model(
sparse_autoencoder.set_decoder_norm_to_unit_norm()


if (feature_sampling_method.lower()=="anthropic") and ((n_training_steps + 1) % dead_feature_window == 0):
if (feature_sampling_method=="anthropic") and ((n_training_steps + 1) % dead_feature_window == 0):

feature_sparsity = act_freq_scores / n_frac_active_tokens

Expand Down

0 comments on commit a22d856

Please sign in to comment.