diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index fff8ace2..5856f7b1 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -207,8 +207,12 @@ def save_checkpoint( save_file(log_feature_sparsities, log_feature_sparsity_path) if trainer.cfg.log_to_wandb and os.path.exists(log_feature_sparsity_path): + # Avoid wandb saving errors such as: + # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc + sae_name = self.sae.get_name().replace("/", "__") + model_artifact = wandb.Artifact( - f"{self.sae.get_name()}", + sae_name, type="model", metadata=dict(trainer.cfg.__dict__), ) @@ -219,7 +223,7 @@ def save_checkpoint( wandb.log_artifact(model_artifact, aliases=wandb_aliases) sparsity_artifact = wandb.Artifact( - f"{self.sae.get_name()}_log_feature_sparsity", + f"{sae_name}_log_feature_sparsity", type="log_feature_sparsity", metadata=dict(trainer.cfg.__dict__), )