Skip to content

Commit

Permalink
Added use_error_term to hooked sae transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
curt-tigges committed Jul 16, 2024
1 parent 8d38d96 commit d172e79
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion sae_lens/analysis/hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
super().__init__(*model_args, **model_kwargs)
self.acts_to_saes: Dict[str, SAE] = {}

def add_sae(self, sae: SAE):
def add_sae(self, sae: SAE, use_error_term: Optional[bool] = None):
"""Attaches an SAE to the model
WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
Expand All @@ -90,6 +90,9 @@ def add_sae(self, sae: SAE):
)
return

if use_error_term is not None:
sae.use_error_term = True

self.acts_to_saes[act_name] = sae
set_deep_attr(self, act_name, sae)
self.setup()
Expand Down

0 comments on commit d172e79

Please sign in to comment.