diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 55cd03ed..8998d82f 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -433,6 +433,13 @@ def forward( with torch.no_grad(): x = x.to(self.dtype) sae_in = self.reshape_fn_in(x) # type: ignore + + # handle run time activation normalization if needed + sae_in = self.run_time_activation_norm_fn_in(sae_in) + + # apply b_dec_to_input if using that method. + sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) + gating_pre_activation = sae_in @ self.W_enc + self.b_gate active_features = (gating_pre_activation > 0).float() @@ -461,10 +468,10 @@ def forward( sae_in = self.reshape_fn_in(x) # type: ignore # handle run time activation normalization if needed - x = self.run_time_activation_norm_fn_in(x) + sae_in = self.run_time_activation_norm_fn_in(sae_in) # apply b_dec_to_input if using that method. - sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input) + sae_in = sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) # "... d_in, d_in d_sae -> ... d_sae", hidden_pre = sae_in @ self.W_enc + self.b_enc @@ -501,11 +508,11 @@ def encode_gated( magnitude_pre_activation = self.hook_sae_acts_pre( sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag ) - feature_magnitudes = self.hook_sae_acts_post( - self.activation_fn(magnitude_pre_activation) - ) + feature_magnitudes = self.activation_fn(magnitude_pre_activation) - return active_features * feature_magnitudes + feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes) + + return feature_acts def encode_jumprelu( self, x: Float[torch.Tensor, "... d_in"]