diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9bf200de5bd3..a2de67c54b45 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -56,7 +56,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) + mask = torch.full((target_length, target_length), -torch.inf) mask_cond = torch.arange(mask.size(-1)) intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) mask.masked_fill_(intermediate_mask, 0) @@ -79,7 +79,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -torch.inf) def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: @@ -303,7 +303,9 @@ def forward( # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype attn_weights = (attention_scores * self.layer_number) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.clip( + attn_weights, torch.finfo(attn_weights.dtype).min, torch.finfo(attn_weights.dtype).max + ) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) attention_probs = attention_probs * (~attention_mask.bool()) # [batch_size, num_heads, q_length, k_length] @@ -599,7 +601,6 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke combined_attention_mask = ( expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask ) - return combined_attention_mask def set_input_embeddings(self, new_embeddings):