Skip to content

Commit

Permalink
Fix attention mask handling in the Hybrid Engine Bloom flow (#5101)
Browse files Browse the repository at this point in the history
The Bloom flow in Hybrid Engine applies the same transformation of the
input mask which is already performed earlier by the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in Deepspeed
Chat on different accelerators, including CUDA and HPU.

The fix removes redundant mask transformation and application, producing
correct convergence.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent 2989cf7 commit d9e12d3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def __init__(self, **kwargs):

# All model specific things should be defined here instead of the base class.
self.bigscience_bloom = True
self.triangular_masking = False

def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config

self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
self.module.config.scale_attention = self.scale_attention
self.module.config.invert_mask = False
return self.module

def attention_qkv_mp(self, mp_replace, reversed_dim=False):
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
invert_mask: If True, the attention mask is inverted when passed to attention block.
"""

def __init__(self,
Expand Down Expand Up @@ -80,7 +81,8 @@ def __init__(self,
use_triton=False,
triton_autotune=False,
num_kv=-1,
rope_theta=10000):
rope_theta=10000,
invert_mask=True):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(self,
self.triton_autotune = triton_autotune
self.num_kv = num_kv
self.rope_theta = rope_theta
self.invert_mask = invert_mask

@classmethod
def from_dict(cls, json_object):
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/ops/transformer/inference/ds_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,12 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if input_mask.dtype == torch.bool:
input_mask = input_mask.long()

# Invert input_mask per transformer implementation (eg, in BLOOM, it's already inverted)
if self.config.invert_mask:
input_mask = 1 - input_mask

attention_probs = self.softmax_func(attn_scores=attention_scores,
attn_mask=((1 - input_mask).to(target_dtype) * minus_inf),
attn_mask=input_mask.to(target_dtype) * minus_inf,
alibi=alibi,
triangular=(self.config.triangular_masking
and (attention_scores.shape[-2] > 1)),
Expand Down

0 comments on commit d9e12d3

Please sign in to comment.