Skip to content

Commit

Permalink
Fix SAE failing to upload to wandb due to artifact name. (#224)
Browse files Browse the repository at this point in the history
* Fix SAE artifact name.

* format

---------

Co-authored-by: Joseph Bloom <[email protected]>
  • Loading branch information
robertzk and jbloomAus authored Jul 10, 2024
1 parent 4e2eb94 commit 6ae4849
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,12 @@ def save_checkpoint(
save_file(log_feature_sparsities, log_feature_sparsity_path)

if trainer.cfg.log_to_wandb and os.path.exists(log_feature_sparsity_path):
# Avoid wandb saving errors such as:
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
sae_name = self.sae.get_name().replace("/", "__")

model_artifact = wandb.Artifact(
f"{self.sae.get_name()}",
sae_name,
type="model",
metadata=dict(trainer.cfg.__dict__),
)
Expand All @@ -219,7 +223,7 @@ def save_checkpoint(
wandb.log_artifact(model_artifact, aliases=wandb_aliases)

sparsity_artifact = wandb.Artifact(
f"{self.sae.get_name()}_log_feature_sparsity",
f"{sae_name}_log_feature_sparsity",
type="log_feature_sparsity",
metadata=dict(trainer.cfg.__dict__),
)
Expand Down

0 comments on commit 6ae4849

Please sign in to comment.