Skip to content

Commit

Permalink
Fix hook z training reshape bug (#165)
Browse files Browse the repository at this point in the history
* remove file duplicate

* fix: hook-z evals working, and reshaping mode more explicit
  • Loading branch information
jbloomAus authored May 29, 2024
1 parent 21ac24d commit 0550ae3
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 468 deletions.
21 changes: 19 additions & 2 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,20 @@ def run_evals(
eval_batch_size_prompts: int | None = None,
model_kwargs: Mapping[str, Any] = {},
) -> Mapping[str, Any]:

hook_name = sae.cfg.hook_name
hook_head_index = sae.cfg.hook_head_index
### Evals
eval_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)

# TODO: Come up with a cleaner long term strategy here for SAEs that do reshaping.
# turn off hook_z reshaping mode if it's on, and restore it after evals
if "hook_z" in hook_name:
previous_hook_z_reshaping_mode = sae.hook_z_reshaping_mode
sae.turn_off_forward_pass_hook_z_reshaping()
else:
previous_hook_z_reshaping_mode = None

# Get Reconstruction Score
losses_df = recons_loss_batched(
sae,
Expand All @@ -47,7 +56,7 @@ def run_evals(

# we would include hook z, except that we now have base SAE's
# which will do their own reshaping for hook z.
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v"]
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"]
if hook_head_index is not None:
original_act = cache[hook_name][:, :, hook_head_index]
elif any(substring in hook_name for substring in has_head_dim_key_substrings):
Expand Down Expand Up @@ -81,6 +90,13 @@ def run_evals(
"metrics/ce_loss_with_ablation": zero_abl_loss,
}

# restore previous hook z reshaping mode if necessary
if "hook_z" in hook_name:
if previous_hook_z_reshaping_mode and not sae.hook_z_reshaping_mode:
sae.turn_on_forward_pass_hook_z_reshaping()
elif not previous_hook_z_reshaping_mode and sae.hook_z_reshaping_mode:
sae.turn_off_forward_pass_hook_z_reshaping()

return metrics


Expand Down Expand Up @@ -160,6 +176,7 @@ def all_head_replacement_hook(activations: torch.Tensor, hook: Any):
# Unscale if activations were scaled prior to going into the SAE
if activation_store.normalize_activations:
new_activations = activation_store.unscale(new_activations)

return new_activations

def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
Expand All @@ -183,7 +200,7 @@ def zero_ablate_hook(activations: torch.Tensor, hook: Any):

# we would include hook z, except that we now have base SAE's
# which will do their own reshaping for hook z.
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v"]
has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"]
if any(substring in hook_name for substring in has_head_dim_key_substrings):
if head_index is None:
replacement_hook = all_head_replacement_hook
Expand Down
47 changes: 30 additions & 17 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,24 +138,11 @@ def __init__(
# the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
# into a separate encode and decode function.
# this will cause errors if we call decode before encode.
self.reshape_fn_in = lambda x: x
self.reshape_fn_out = lambda x, d_head: x
self.d_head = None
if self.cfg.hook_name.endswith("_z"):

def reshape_fn_in(x: torch.Tensor):
self.d_head = x.shape[-1] # type: ignore
self.reshape_fn_in = lambda x: einops.rearrange(
x, "... n_heads d_head -> ... (n_heads d_head)"
)
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

self.reshape_fn_in = reshape_fn_in

if self.cfg.hook_name.endswith("_z"):
self.reshape_fn_out = lambda x, d_head: einops.rearrange(
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
)
self.turn_on_forward_pass_hook_z_reshaping()
else:
# need to default the reshape fns
self.turn_off_forward_pass_hook_z_reshaping()

self.setup() # Required for `HookedRootModule`s

Expand Down Expand Up @@ -404,6 +391,32 @@ def get_name(self):
def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
return cls(SAEConfig.from_dict(config_dict))

def turn_on_forward_pass_hook_z_reshaping(self):

assert self.cfg.hook_name.endswith(
"_z"
), "This method should only be called for hook_z SAEs."

def reshape_fn_in(x: torch.Tensor):
self.d_head = x.shape[-1] # type: ignore
self.reshape_fn_in = lambda x: einops.rearrange(
x, "... n_heads d_head -> ... (n_heads d_head)"
)
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")

self.reshape_fn_in = reshape_fn_in

self.reshape_fn_out = lambda x, d_head: einops.rearrange(
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
)
self.hook_z_reshaping_mode = True

def turn_off_forward_pass_hook_z_reshaping(self):
self.reshape_fn_in = lambda x: x
self.reshape_fn_out = lambda x, d_head: x
self.d_head = None
self.hook_z_reshaping_mode = False


def get_activation_fn(activation_fn: str) -> Callable[[torch.Tensor], torch.Tensor]:
if activation_fn == "relu":
Expand Down
Loading

0 comments on commit 0550ae3

Please sign in to comment.