Skip to content

Commit

Permalink
Ensured that even detatched SAEs are returned to former state
Browse files Browse the repository at this point in the history
  • Loading branch information
curt-tigges committed Jul 16, 2024
1 parent 1531c1f commit 90ac661
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions sae_lens/analysis/hooked_sae_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def add_sae(self, sae: SAE, use_error_term: Optional[bool] = None):
Args:
sae: SparseAutoencoderBase. The SAE to attach to the model
use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
"""
act_name = sae.cfg.hook_name
if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
Expand Down Expand Up @@ -176,7 +177,9 @@ def run_with_saes(
use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
**model_kwargs: Keyword arguments for the model forward pass
"""
with self.saes(saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term):
with self.saes(
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
):
return self(*model_args, **model_kwargs)

def run_with_cache_with_saes(
Expand Down Expand Up @@ -206,13 +209,16 @@ def run_with_cache_with_saes(
*model_args: Positional arguments for the model forward pass
saes: (Union[HookedSAE, List[HookedSAE]]) The SAEs to be attached for this forward pass
reset_saes_end: (bool) If True, all SAEs added during this run are removed at the end, and previously attached SAEs are restored to their original state. Default is True.
use_error_term: (Optional[bool]) If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
return_cache_object: (bool) if True, this will return an ActivationCache object, with a bunch of
useful HookedTransformer specific methods, otherwise it will return a dictionary of
activations as in HookedRootModule.
remove_batch_dim: (bool) Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
**kwargs: Keyword arguments for the model forward pass
"""
with self.saes(saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term):
with self.saes(
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
):
return self.run_with_cache( # type: ignore
*model_args,
return_cache_object=return_cache_object, # type: ignore
Expand Down Expand Up @@ -284,39 +290,27 @@ def saes(
Args:
saes (Union[HookedSAE, List[HookedSAE]]): SAEs to be attached.
reset_saes_end (bool): If True, removes all SAEs added by this context manager when the context manager exits, returning previously attached SAEs to their original state.
use_error_term (Optional[bool]): If provided, will set the use_error_term attribute of all SAEs attached during this run to this value. Defaults to None.
"""
# act_names_to_reset = []
# prev_saes = []
# if isinstance(saes, SAE):
# saes = [saes]
# try:
# for sae in saes:
# act_names_to_reset.append(sae.cfg.hook_name)
# prev_saes.append(self.acts_to_saes.get(sae.cfg.hook_name, None))
# self.add_sae(sae, use_error_term=use_error_term)
# yield self
# finally:
# if reset_saes_end:
# self.reset_saes(act_names_to_reset, prev_saes)

act_names_to_reset = []
prev_saes = []
original_use_error_terms = {}
modified_saes = {}
if isinstance(saes, SAE):
saes = [saes]
try:
for sae in saes:
act_names_to_reset.append(sae.cfg.hook_name)
prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
prev_saes.append(prev_sae)
if prev_sae and use_error_term is not None:
original_use_error_terms[sae.cfg.hook_name] = prev_sae.use_error_term
prev_sae.use_error_term = use_error_term
if use_error_term is not None:
if hasattr(sae, "use_error_term"):
modified_saes[sae] = sae.use_error_term
sae.use_error_term = use_error_term

Check warning on line 308 in sae_lens/analysis/hooked_sae_transformer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/analysis/hooked_sae_transformer.py#L307-L308

Added lines #L307 - L308 were not covered by tests
self.add_sae(sae, use_error_term=use_error_term)
yield self
finally:
if reset_saes_end:
self.reset_saes(act_names_to_reset, prev_saes)
for hook_name, original_value in original_use_error_terms.items():
if hook_name in self.acts_to_saes:
self.acts_to_saes[hook_name].use_error_term = original_value
# Restore original use_error_term for all modified SAEs
for sae, original_value in modified_saes.items():
sae.use_error_term = original_value

Check warning on line 316 in sae_lens/analysis/hooked_sae_transformer.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/analysis/hooked_sae_transformer.py#L316

Added line #L316 was not covered by tests

0 comments on commit 90ac661

Please sign in to comment.