Skip to content

Commit

Permalink
Merge pull request huggingface#4 from microsoft/raviskolli/ort
Browse files Browse the repository at this point in the history
Remove data based dependencies in T5 for ORT
  • Loading branch information
raviskolli authored Apr 14, 2021
2 parents 61b2ef2 + b947293 commit bce3290
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
1 change: 1 addition & 0 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def main():
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
ort=True if training_args.ort else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ class PretrainedConfig(object):
- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use
BFloat16 scalars (only used by some TensorFlow models).
Onnxruntime specific parameters
- **ort** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should use ORT.
"""
model_type: str = ""
is_composition: bool = False
Expand All @@ -186,6 +190,7 @@ def __init__(self, **kwargs):
self.output_attentions = kwargs.pop("output_attentions", False)
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.ort = kwargs.pop("ort", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
self.tie_word_embeddings = kwargs.pop(
"tie_word_embeddings", True
Expand Down
40 changes: 31 additions & 9 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.ort = config.ort
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
Expand Down Expand Up @@ -643,9 +644,16 @@ def forward(
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
Expand All @@ -670,9 +678,16 @@ def forward(
hidden_states = cross_attention_outputs[0]

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

# Combine self attn and cross attn key value states
if present_key_value_state is not None:
Expand All @@ -685,9 +700,16 @@ def forward(
hidden_states = self.layer[-1](hidden_states)

# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
if self.ort:
# Remove data-based control flow for static graph
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
else:
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

outputs = (hidden_states,)

Expand Down

0 comments on commit bce3290

Please sign in to comment.