From 6870de6d4b8f41e2274e401e6486d8de909b8e5b Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 18 Jul 2022 15:31:44 +0200 Subject: [PATCH 1/4] fix causal_mask - avoid having both `-inf` and `dtype.min` in causal mask due to addition --- src/transformers/models/bloom/modeling_bloom.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9bf200de5bd3..9b9c66f167ff 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -595,11 +595,13 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + if combined_attention_mask is None: + combined_attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + else: + combined_attention_mask.masked_fill_( + ~attention_mask[:, None, None, :].to(torch.bool).expand_as(combined_attention_mask), + -torch.inf, ) - return combined_attention_mask def set_input_embeddings(self, new_embeddings): From 69e72e1c628476dbca8d95c811f84f24afe47856 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 18 Jul 2022 15:33:43 +0200 Subject: [PATCH 2/4] fix softmax for float16 - clip values between dtype max and min to avoid infs (not liked by softmax) Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/bloom/modeling_bloom.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9b9c66f167ff..8ecd367b7016 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -56,7 +56,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) + mask = torch.full((target_length, target_length), -torch.inf) mask_cond = torch.arange(mask.size(-1)) intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) mask.masked_fill_(intermediate_mask, 0) @@ -79,7 +79,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -torch.inf) def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: @@ -303,7 +303,7 @@ def forward( # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype attn_weights = (attention_scores * self.layer_number) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attn_weights = torch.clip(attn_weights, torch.finfo(attn_weights.dtype).min, torch.finfo(attn_weights.dtype).max) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) attention_probs = attention_probs * (~attention_mask.bool()) # [batch_size, num_heads, q_length, k_length] @@ -601,7 +601,7 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke combined_attention_mask.masked_fill_( ~attention_mask[:, None, None, :].to(torch.bool).expand_as(combined_attention_mask), -torch.inf, - ) + ) return combined_attention_mask def set_input_embeddings(self, new_embeddings): From 68156643d82538a66eb4fd9b2605e4c717286a6c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 18 Jul 2022 15:36:24 +0200 Subject: [PATCH 3/4] styling --- src/transformers/models/bloom/modeling_bloom.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 8ecd367b7016..abbcab5f2176 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -303,7 +303,9 @@ def forward( # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype attn_weights = (attention_scores * self.layer_number) + attention_mask - attn_weights = torch.clip(attn_weights, torch.finfo(attn_weights.dtype).min, torch.finfo(attn_weights.dtype).max) + attn_weights = torch.clip( + attn_weights, torch.finfo(attn_weights.dtype).min, torch.finfo(attn_weights.dtype).max + ) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) attention_probs = attention_probs * (~attention_mask.bool()) # [batch_size, num_heads, q_length, k_length] From 93eaaff498b40a40a9f3a3bfee1a8acb9c5f43ff Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 18 Jul 2022 16:11:03 +0200 Subject: [PATCH 4/4] revert back to using addition for some speed - it's okay to use addition since we're using `-inf` again --- src/transformers/models/bloom/modeling_bloom.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index abbcab5f2176..a2de67c54b45 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -597,13 +597,10 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - if combined_attention_mask is None: - combined_attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - else: - combined_attention_mask.masked_fill_( - ~attention_mask[:, None, None, :].to(torch.bool).expand_as(combined_attention_mask), - -torch.inf, - ) + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) return combined_attention_mask def set_input_embeddings(self, new_embeddings):