Skip to content

Commit

Permalink
Change unpacking of TF inputs: layoutlm, mpnet, rag, and roformer (#1…
Browse files Browse the repository at this point in the history
…6112)

Co-authored-by: ChienVM <[email protected]>
  • Loading branch information
vumichien and ChienVM authored Mar 15, 2022
1 parent 0d7322c commit 611d3a0
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 626 deletions.
174 changes: 45 additions & 129 deletions src/transformers/models/layoutlm/modeling_tf_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
Expand Down Expand Up @@ -691,6 +691,7 @@ class PreTrainedModel
"""
raise NotImplementedError

@unpack_inputs
def call(
self,
input_ids: Optional[TFModelInputType] = None,
Expand All @@ -708,55 +709,39 @@ def call(
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
if attention_mask is None:
attention_mask = tf.fill(dims=input_shape, value=1)

if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
if inputs["bbox"] is None:
inputs["bbox"] = tf.fill(dims=input_shape + [4], value=0)
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
if bbox is None:
bbox = tf.fill(dims=input_shape + [4], value=0)

embedding_output = self.embeddings(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
training=inputs["training"],
input_ids=input_ids,
bbox=bbox,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
training=training,
)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
Expand All @@ -773,30 +758,30 @@ def call(
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers

encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
head_mask=head_mask,
# Need to pass these required positional arguments to `Encoder`
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None

if not inputs["return_dict"]:
if not return_dict:
return (
sequence_output,
pooled_output,
Expand Down Expand Up @@ -924,6 +909,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):

self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
Expand Down Expand Up @@ -979,36 +965,18 @@ def call(
>>> last_hidden_states = outputs.last_hidden_state
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs
Expand Down Expand Up @@ -1064,6 +1032,7 @@ def get_prefix_bias_name(self) -> str:
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1127,9 +1096,7 @@ def call(
>>> loss = outputs.loss
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1140,32 +1107,13 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
loss = (
None
if inputs["labels"] is None
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
)
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)

if not inputs["return_dict"]:
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down Expand Up @@ -1208,6 +1156,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1271,9 +1220,7 @@ def call(
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1284,29 +1231,14 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down Expand Up @@ -1355,6 +1287,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1416,9 +1349,7 @@ def call(
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1429,29 +1360,14 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down
Loading

0 comments on commit 611d3a0

Please sign in to comment.