Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change unpacking of TF layoutlm inputs to use decorator #16112

Merged
merged 3 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 45 additions & 129 deletions src/transformers/models/layoutlm/modeling_tf_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
Expand Down Expand Up @@ -691,6 +691,7 @@ class PreTrainedModel
"""
raise NotImplementedError

@unpack_inputs
def call(
self,
input_ids: Optional[TFModelInputType] = None,
Expand All @@ -708,55 +709,39 @@ def call(
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.fill(dims=input_shape, value=1)
if attention_mask is None:
attention_mask = tf.fill(dims=input_shape, value=1)

if inputs["token_type_ids"] is None:
inputs["token_type_ids"] = tf.fill(dims=input_shape, value=0)
if inputs["bbox"] is None:
inputs["bbox"] = tf.fill(dims=input_shape + [4], value=0)
if token_type_ids is None:
token_type_ids = tf.fill(dims=input_shape, value=0)
if bbox is None:
bbox = tf.fill(dims=input_shape + [4], value=0)

embedding_output = self.embeddings(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
position_ids=inputs["position_ids"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
training=inputs["training"],
input_ids=input_ids,
bbox=bbox,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
training=training,
)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1]))
extended_attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1]))

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
Expand All @@ -773,30 +758,30 @@ def call(
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
if head_mask is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
head_mask = [None] * self.config.num_hidden_layers

encoder_outputs = self.encoder(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
head_mask=head_mask,
# Need to pass these required positional arguments to `Encoder`
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)

sequence_output = encoder_outputs[0]
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None

if not inputs["return_dict"]:
if not return_dict:
return (
sequence_output,
pooled_output,
Expand Down Expand Up @@ -924,6 +909,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):

self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
Expand Down Expand Up @@ -979,36 +965,18 @@ def call(

>>> last_hidden_states = outputs.last_hidden_state
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs
Expand Down Expand Up @@ -1064,6 +1032,7 @@ def get_prefix_bias_name(self) -> str:
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1127,9 +1096,7 @@ def call(

>>> loss = outputs.loss
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1140,32 +1107,13 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output=sequence_output, training=inputs["training"])
loss = (
None
if inputs["labels"] is None
else self.hf_compute_loss(labels=inputs["labels"], logits=prediction_scores)
)
prediction_scores = self.mlm(sequence_output=sequence_output, training=training)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=prediction_scores)

if not inputs["return_dict"]:
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down Expand Up @@ -1208,6 +1156,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1271,9 +1220,7 @@ def call(
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1284,29 +1231,14 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
pooled_output = outputs[1]
pooled_output = self.dropout(inputs=pooled_output, training=inputs["training"])
pooled_output = self.dropout(inputs=pooled_output, training=training)
logits = self.classifier(inputs=pooled_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down Expand Up @@ -1355,6 +1287,7 @@ def __init__(self, config: LayoutLMConfig, *inputs, **kwargs):
name="classifier",
)

@unpack_inputs
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
Expand Down Expand Up @@ -1416,9 +1349,7 @@ def call(
>>> loss = outputs.loss
>>> logits = outputs.logits
```"""
inputs = input_processing(
func=self.call,
config=self.config,
outputs = self.layoutlm(
input_ids=input_ids,
bbox=bbox,
attention_mask=attention_mask,
Expand All @@ -1429,29 +1360,14 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
outputs = self.layoutlm(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
sequence_output = self.dropout(inputs=sequence_output, training=training)
logits = self.classifier(inputs=sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(labels=inputs["labels"], logits=logits)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)

if not inputs["return_dict"]:
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

Expand Down
Loading