Skip to content

Commit

Permalink
Always cast softmax inputs to float32 when in training mode.
Browse files Browse the repository at this point in the history
While we don't need this for accurate results in b/float16, this is a
safety precaution to make sure that training accuracy does not
regress.

Signed-off-by: Daniel Galvez <[email protected]>
  • Loading branch information
galv committed Jun 5, 2024
1 parent a6ffea9 commit 2a6f156
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,10 @@ def _compute_out_global_to_all(
global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len)

# compute global attn probs
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1)
if self.training:
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1, dtype=torch.float32)
else:
global_attn_probs_float = nn.functional.softmax(global_attn_scores, dim=-1)

global_attn_probs = self.dropout(global_attn_probs_float)

Expand Down

0 comments on commit 2a6f156

Please sign in to comment.