Skip to content

Commit

Permalink
still hadn't fixed the issue, now fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Feb 1, 2024
1 parent b4546db commit a36ee21
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ def forward(self, x, dead_neuron_mask = None):
mse_rescaling_factor = (mse_loss / mse_loss_ghost_resid).detach()
mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid

mse_loss_ghost_resid = mse_loss_ghost_resid.mean()
mse_loss = mse_loss.mean()
sparsity = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
l1_loss = self.l1_coefficient * sparsity
loss = mse_loss + l1_loss + mse_loss_ghost_resid.mean()
loss = mse_loss + l1_loss + mse_loss_ghost_resid

return sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid

Expand Down

0 comments on commit a36ee21

Please sign in to comment.