Skip to content

Commit

Permalink
feat: Replace print with controllable logging (#388)
Browse files Browse the repository at this point in the history
* replaces in pretrained_sae_loaders.py

* replaces in load_model.py

* replaces in neuronpedia_integration.py

* replaces in tsea.py

* replaces in pretrained_saes.py

* replaces in cache_activations_runner.py

* replaces in activations_store.py

* replaces in training_sae.py

* replaces in upload_saes_to_huggingface.py

* replaces in sae_training_runner.py

* replaces in config.py

* fixes error for CI

---------

Co-authored-by: David Chanin <[email protected]>
  • Loading branch information
anthonyduong9 and chanind authored Nov 30, 2024
1 parent 6a54dd6 commit 2bcd646
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 60 deletions.
3 changes: 3 additions & 0 deletions sae_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
__version__ = "5.0.0"

import logging

logger = logging.getLogger(__name__)

from .analysis.hooked_sae_transformer import HookedSAETransformer
from .cache_activations_runner import CacheActivationsRunner
Expand Down
50 changes: 27 additions & 23 deletions sae_lens/analysis/neuronpedia_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from tenacity import retry, stop_after_attempt, wait_random_exponential

from sae_lens import SAE
from sae_lens import SAE, logger

NEURONPEDIA_DOMAIN = "https://neuronpedia.org"

Expand Down Expand Up @@ -62,7 +62,7 @@ def NanAndInfReplacer(value: str):
def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
sae_id = sae.cfg.neuronpedia_id
if sae_id is None:
print(
logger.warning(
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
)
else:
Expand All @@ -78,7 +78,7 @@ def get_neuronpedia_quick_list(

sae_id = sae.cfg.neuronpedia_id
if sae_id is None:
print(
logger.warning(
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
)
assert sae_id is not None
Expand Down Expand Up @@ -258,7 +258,7 @@ async def autointerp_neuronpedia_features( # noqa: C901
Returns:
None
"""
print("\n\n")
logger.info("\n\n")

if os.getenv("OPENAI_API_KEY") is None:
if openai_api_key is None:
Expand Down Expand Up @@ -286,7 +286,7 @@ async def autointerp_neuronpedia_features( # noqa: C901
if not skip_neuronpedia_api_key_test:
test_key(neuronpedia_api_key)

print("\n\n=== Step 1) Fetching features from Neuronpedia")
logger.info("\n\n=== Step 1) Fetching features from Neuronpedia")
for feature in features:
feature_data = get_neuronpedia_feature(
feature=feature.feature,
Expand Down Expand Up @@ -326,10 +326,10 @@ async def autointerp_neuronpedia_features( # noqa: C901
for iteration_num, feature in enumerate(features):
start_time = datetime.now()

print(
logger.info(
f"\n========== Feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature} ({iteration_num + 1} of {len(features)} Features) =========="
)
print(
logger.info(
f"\n=== Step 2) Explaining feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
)

Expand Down Expand Up @@ -364,25 +364,27 @@ async def autointerp_neuronpedia_features( # noqa: C901
num_samples=1,
)
except Exception as e:
print(f"ERROR, RETRYING: {e}")
logger.error(f"ERROR, RETRYING: {e}")
else:
break
else:
print(
logger.error(
f"ERROR: Failed to explain feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
)

assert len(explanations) == 1
explanation = explanations[0].rstrip(".")
print(f"===== {autointerp_explainer_model_name}'s explanation: {explanation}")
logger.info(
f"===== {autointerp_explainer_model_name}'s explanation: {explanation}"
)
feature.autointerp_explanation = explanation

scored_simulation = None
if do_score and autointerp_scorer_model_name:
print(
logger.info(
f"\n=== Step 3) Scoring feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}"
)
print("=== This can take up to 30 seconds.")
logger.info("=== This can take up to 30 seconds.")

temp_activation_records = [
ActivationRecord(
Expand Down Expand Up @@ -417,7 +419,7 @@ async def autointerp_neuronpedia_features( # noqa: C901
)
score = scored_simulation.get_preferred_score()
except Exception as e:
print(f"ERROR, RETRYING: {e}")
logger.error(f"ERROR, RETRYING: {e}")
else:
break

Expand All @@ -427,15 +429,17 @@ async def autointerp_neuronpedia_features( # noqa: C901
or len(scored_simulation.scored_sequence_simulations)
!= num_activations_to_use
):
print(
logger.error(
f"ERROR: Failed to score feature {feature.modelId}@{feature.layer}-{feature.dataset}:{feature.feature}. Skipping it."
)
continue
feature.autointerp_explanation_score = score
print(f"===== {autointerp_scorer_model_name}'s score: {(score * 100):.0f}")
logger.info(
f"===== {autointerp_scorer_model_name}'s score: {(score * 100):.0f}"
)

else:
print("=== Step 3) Skipping scoring as instructed.")
logger.info("=== Step 3) Skipping scoring as instructed.")

feature_data = {
"modelId": feature.modelId,
Expand All @@ -455,15 +459,15 @@ async def autointerp_neuronpedia_features( # noqa: C901
if save_to_disk:
output_file = f"{output_dir}/{feature.modelId}-{feature.layer}-{feature.dataset}_feature-{feature.feature}_time-{datetime.now().strftime('%Y%m%d-%H%M%S')}.jsonl"
os.makedirs(output_dir, exist_ok=True)
print(f"\n=== Step 4) Saving feature to {output_file}")
logger.info(f"\n=== Step 4) Saving feature to {output_file}")
with open(output_file, "a") as f:
f.write(feature_data_str)
f.write("\n")
else:
print("\n=== Step 4) Skipping saving to disk.")
logger.info("\n=== Step 4) Skipping saving to disk.")

if upload_to_neuronpedia:
print("\n=== Step 5) Uploading feature to Neuronpedia")
logger.info("\n=== Step 5) Uploading feature to Neuronpedia")
upload_data = json.dumps(
{
"feature": feature_data,
Expand All @@ -476,15 +480,15 @@ async def autointerp_neuronpedia_features( # noqa: C901
url, json=upload_data_json, headers={"x-api-key": neuronpedia_api_key}
)
if response.status_code != 200:
print(
logger.error(
f"ERROR: Couldn't upload explanation to Neuronpedia: {response.text}"
)
else:
print(
logger.info(
f"===== Uploaded to Neuronpedia: {NEURONPEDIA_DOMAIN}/{feature.modelId}/{feature.layer}-{feature.dataset}/{feature.feature}"
)

end_time = datetime.now()
print(f"\n========== Time Spent for Feature: {end_time - start_time}\n")
logger.info(f"\n========== Time Spent for Feature: {end_time - start_time}\n")

print("\n\n========== Generation and Upload Complete ==========\n\n")
logger.info("\n\n========== Generation and Upload Complete ==========\n\n")
10 changes: 6 additions & 4 deletions sae_lens/analysis/tsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from babe import UsNames
from transformer_lens import HookedTransformer

from sae_lens import logger


def get_enrichment_df(
projections: torch.Tensor,
Expand Down Expand Up @@ -180,10 +182,10 @@ def plot_top_k_feature_projections_by_token_and_category(

# scores = enrichment_scores[category][features]
scores = enrichment_scores[category].loc[features]
print(scores)
logger.debug(scores)
tokens_list = [model.to_single_str_token(i) for i in list(range(model.cfg.d_vocab))]

print(features)
logger.debug(features)
feature_logit_scores = pd.DataFrame(
dec_projection_onto_W_U[features].numpy(), index=features # type: ignore
).T
Expand All @@ -193,9 +195,9 @@ def plot_top_k_feature_projections_by_token_and_category(
]

# display(feature_)
print(category)
logger.debug(category)
for feature, score in zip(features, scores): # type: ignore
print(feature)
logger.debug(feature)
score = -1 * np.log(1 - score) # convert to enrichment score
fig = px.histogram(
feature_logit_scores,
Expand Down
9 changes: 5 additions & 4 deletions sae_lens/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm import tqdm
from transformer_lens.HookedTransformer import HookedRootModule

from sae_lens import logger
from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
from sae_lens.load_model import load_model
from sae_lens.training.activations_store import ActivationsStore
Expand Down Expand Up @@ -241,7 +242,7 @@ def run(self) -> Dataset:

### Create temporary sharded datasets

print(f"Started caching activations for {self.cfg.dataset_path}")
logger.info(f"Started caching activations for {self.cfg.dataset_path}")

for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
try:
Expand All @@ -255,7 +256,7 @@ def run(self) -> Dataset:
del buffer, shard

except StopIteration:
print(
logger.warning(
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
)
break
Expand All @@ -267,11 +268,11 @@ def run(self) -> Dataset:
)

if self.cfg.shuffle:
print("Shuffling...")
logger.info("Shuffling...")
dataset = dataset.shuffle(seed=self.cfg.seed)

if self.cfg.hf_repo_id:
print("Pushing to Huggingface Hub...")
logger.info("Pushing to Huggingface Hub...")
dataset.push_to_hub(
repo_id=self.cfg.hf_repo_id,
num_shards=self.cfg.hf_num_shards,
Expand Down
26 changes: 14 additions & 12 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
load_dataset,
)

from sae_lens import __version__
from sae_lens import __version__, logger

DTYPE_MAP = {
"float32": torch.float32,
Expand Down Expand Up @@ -342,7 +342,7 @@ def __post_init__(self):
self.checkpoint_path = f"{self.checkpoint_path}/{unique_id}"

if self.verbose:
print(
logger.info(
f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
)
# Print out some useful info:
Expand All @@ -351,43 +351,45 @@ def __post_init__(self):
* self.context_size
* self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 ** 6}")
logger.info(
f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 ** 6}"
)
n_contexts_per_buffer = (
self.store_batch_size_prompts * self.n_batches_in_buffer
)
print(
logger.info(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 ** 6}"
)

total_training_steps = (
self.training_tokens + self.finetuning_tokens
) // self.train_batch_size_tokens
print(f"Total training steps: {total_training_steps}")
logger.info(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
print(f"Total wandb updates: {total_wandb_updates}")
logger.info(f"Total wandb updates: {total_wandb_updates}")

# how many times will we sample dead neurons?
# assert self.dead_feature_window <= self.feature_sampling_window, "dead_feature_window must be smaller than feature_sampling_window"
n_feature_window_samples = (
total_training_steps // self.feature_sampling_window
)
print(
logger.info(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
)
print(
logger.info(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
)
print(
logger.info(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
print(
# logger.info("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
logger.info(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
)

if self.use_ghost_grads:
print("Using Ghost Grads.")
logger.info("Using Ghost Grads.")

if self.context_size < 0:
raise ValueError(
Expand Down
10 changes: 6 additions & 4 deletions sae_lens/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from sae_lens import logger


def load_model(
model_class_name: str,
Expand All @@ -23,11 +25,11 @@ def load_model(
if "n_devices" in model_from_pretrained_kwargs:
n_devices = model_from_pretrained_kwargs["n_devices"]
if n_devices > 1:
print("MODEL LOADING:")
print("Setting model device to cuda for d_devices")
print(f"Will use cuda:0 to cuda:{n_devices-1}")
logger.info("MODEL LOADING:")
logger.info("Setting model device to cuda for d_devices")
logger.info(f"Will use cuda:0 to cuda:{n_devices-1}")
device = "cuda"
print("-------------")
logger.info("-------------")

if model_class_name == "HookedTransformer":
return HookedTransformer.from_pretrained_no_processing(
Expand Down
10 changes: 5 additions & 5 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
import os
import signal
import sys
Expand All @@ -11,6 +10,7 @@
from simple_parsing import ArgumentParser
from transformer_lens.hook_points import HookedRootModule

from sae_lens import logger
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
from sae_lens.load_model import load_model
from sae_lens.sae import SAE_CFG_PATH, SAE_WEIGHTS_PATH, SPARSITY_PATH
Expand Down Expand Up @@ -46,11 +46,11 @@ def __init__(
override_sae: TrainingSAE | None = None,
):
if override_dataset is not None:
logging.warning(
logger.warning(
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
)
if override_model is not None:
logging.warning(
logger.warning(
f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
)

Expand Down Expand Up @@ -156,10 +156,10 @@ def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
sae = trainer.fit()

except (KeyboardInterrupt, InterruptedException):
print("interrupted, saving progress")
logger.warning("interrupted, saving progress")
checkpoint_name = trainer.n_training_tokens
self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
print("done saving")
logger.info("done saving")
raise

return sae
Expand Down
3 changes: 2 additions & 1 deletion sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from safetensors import safe_open
from safetensors.torch import load_file

from sae_lens import logger
from sae_lens.config import DTYPE_MAP
from sae_lens.toolkit.pretrained_saes_directory import (
PretrainedSAELookup,
Expand Down Expand Up @@ -401,7 +402,7 @@ def gemma_2_sae_loader(

# if it is an embedding SAE, then we need to adjust for the scale of d_model because of how they trained it
if "embedding" in folder_name:
print("Adjusting for d_model in embedding SAE")
logger.debug("Adjusting for d_model in embedding SAE")
state_dict["W_enc"].data = state_dict["W_enc"].data / np.sqrt(cfg_dict["d_in"])
state_dict["W_dec"].data = state_dict["W_dec"].data * np.sqrt(cfg_dict["d_in"])

Expand Down
Loading

0 comments on commit 2bcd646

Please sign in to comment.