Skip to content

Commit

Permalink
Fix gated forward functions (#295)
Browse files Browse the repository at this point in the history
* support seqpos slicing

* fix forward functions for gated

* remove seqpos changes

* fix formatting (remove my changes)

* format

---------

Co-authored-by: jbloomAus <[email protected]>
  • Loading branch information
callummcdougall and jbloomAus authored Sep 20, 2024
1 parent 7e7d02e commit a708220
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ def forward(
with torch.no_grad():
x = x.to(self.dtype)
sae_in = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
sae_in = self.run_time_activation_norm_fn_in(sae_in)

# apply b_dec_to_input if using that method.
sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

gating_pre_activation = sae_in @ self.W_enc + self.b_gate
active_features = (gating_pre_activation > 0).float()

Expand Down Expand Up @@ -455,10 +462,10 @@ def forward(
sae_in = self.reshape_fn_in(x) # type: ignore

# handle run time activation normalization if needed
x = self.run_time_activation_norm_fn_in(x)
sae_in = self.run_time_activation_norm_fn_in(sae_in)

# apply b_dec_to_input if using that method.
sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input)
sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)

# "... d_in, d_in d_sae -> ... d_sae",
hidden_pre = sae_in @ self.W_enc + self.b_enc
Expand Down Expand Up @@ -495,11 +502,11 @@ def encode_gated(
magnitude_pre_activation = self.hook_sae_acts_pre(
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
)
feature_magnitudes = self.hook_sae_acts_post(
self.activation_fn(magnitude_pre_activation)
)
feature_magnitudes = self.activation_fn(magnitude_pre_activation)

return active_features * feature_magnitudes
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)

return feature_acts

def encode_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
Expand Down

0 comments on commit a708220

Please sign in to comment.