Skip to content

Commit

Permalink
Adding support for multiple mask tokens. (#14716)
Browse files Browse the repository at this point in the history
* Adding support for multiple mask tokens.

- Original implem: #10222

Co-authored-by: njafer <[email protected]>

* In order to accomodate optionally multimodal models like Perceiver

we add information to the tasks to specify tasks where we know for sure
if we need the tokenizer/feature_extractor or not.

* Adding info in the documentation about multi masks.

+ marked as experimental.

* Add a copy() to prevent overriding the same tensor over and over.

* Fixup.

* Adding small test for multi mask with real values..

Co-authored-by: njafer <[email protected]>
  • Loading branch information
Narsil and njafer authored Dec 14, 2021
1 parent 2a606f9 commit e7ed7ff
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 36 deletions.
31 changes: 30 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,21 @@
"tf": (),
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
"type": "audio",
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"tf": (),
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
"type": "multimodal",
},
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
"type": "multimodal",
},
"text-classification": {
"impl": TextClassificationPipeline,
Expand All @@ -148,6 +151,7 @@
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
},
},
"type": "text",
},
"token-classification": {
"impl": TokenClassificationPipeline,
Expand All @@ -159,6 +163,7 @@
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
},
},
"type": "text",
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
Expand All @@ -167,6 +172,7 @@
"default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
},
"type": "text",
},
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
Expand All @@ -179,18 +185,21 @@
"tf": "google/tapas-base-finetuned-wtq",
},
},
"type": "text",
},
"fill-mask": {
"impl": FillMaskPipeline,
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
"type": "text",
},
# This task is a special case as it's parametrized by SRC, TGT languages.
"translation": {
Expand All @@ -202,18 +211,21 @@
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
"type": "text",
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
"type": "text",
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
Expand All @@ -224,33 +236,48 @@
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
},
"type": "text",
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
"type": "text",
},
"image-classification": {
"impl": ImageClassificationPipeline,
"tf": (),
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"tf": (),
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
"type": "image",
},
"object-detection": {
"impl": ObjectDetectionPipeline,
"tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
"type": "image",
},
}

NO_FEATURE_EXTRACTOR_TASKS = set()
NO_TOKENIZER_TASKS = set()
for task, values in SUPPORTED_TASKS.items():
if values["type"] == "text":
NO_FEATURE_EXTRACTOR_TASKS.add(task)
elif values["type"] in {"audio", "image"}:
NO_TOKENIZER_TASKS.add(task)
elif values["type"] != "multimodal":
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")


def get_supported_tasks() -> List[str]:
"""
Expand Down Expand Up @@ -528,12 +555,14 @@ def pipeline(
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None

if task in {"audio-classification", "image-classification"}:
if task in NO_TOKENIZER_TASKS:
# These will never require a tokenizer.
# the model on the other hand might have a tokenizer, but
# the files could be missing from the hub, instead of failing
# on such repos, we just force to not load it.
load_tokenizer = False
if task in NO_FEATURE_EXTRACTOR_TASKS:
load_feature_extractor = False

if load_tokenizer:
# Try to infer tokenizer from model or config name (if provided as str)
Expand Down
67 changes: 35 additions & 32 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ class FillMaskPipeline(Pipeline):
.. note::
This pipeline only works for inputs with exactly one token masked.
This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might
expect joint probabilities (See `discussion <https://github.com/huggingface/transformers/pull/10222>`__).
"""

def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
Expand All @@ -59,13 +61,7 @@ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
masked_index = self.get_masked_index(input_ids)
numel = np.prod(masked_index.shape)
if numel > 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
)
elif numel < 1:
if numel < 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
Expand Down Expand Up @@ -98,46 +94,53 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None):
top_k = target_ids.shape[0]
input_ids = model_outputs["input_ids"][0]
outputs = model_outputs["logits"]
result = []

if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]

# Fill mask pipeline supports only one ${mask_token} per sample
outputs = outputs.numpy()

logits = outputs[0, masked_index.item(), :]
probs = tf.nn.softmax(logits)
logits = outputs[0, masked_index, :]
probs = tf.nn.softmax(logits, axis=-1)
if target_ids is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
probs = tf.expand_dims(probs, 0)

topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
# Fill mask pipeline supports only one ${mask_token} per sample

logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
logits = outputs[0, masked_index, :]
probs = logits.softmax(dim=-1)
if target_ids is not None:
probs = probs[..., target_ids]

values, predictions = probs.topk(top_k)

for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
if target_ids is not None:
p = target_ids[p].tolist()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
}
)
result = []
single_mask = values.shape[0] == 1
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
row = []
for v, p in zip(_values, _predictions):
# Copy is important since we're going to modify this array in place
tokens = input_ids.numpy().copy()
if target_ids is not None:
p = target_ids[p].tolist()

tokens[masked_index[i]] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
# Originally we skip special tokens to give readable output.
# For multi masks though, the other [MASK] would be removed otherwise
# making the output look odd, so we add them back
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
row.append(proposition)
result.append(row)
if single_mask:
return result[0]
return result

def get_target_ids(self, targets, top_k=None):
Expand Down
54 changes: 51 additions & 3 deletions tests/test_pipelines_fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,32 @@ def test_small_model_pt(self):
],
)

outputs = unmasker("My name is <mask> <mask>", top_k=2)

self.assertEqual(
nested_simplify(outputs, decimals=6),
[
[
{
"score": 2.2e-05,
"token": 35676,
"token_str": " Maul",
"sequence": "<s>My name is Maul<mask></s>",
},
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name isELS<mask></s>"},
],
[
{
"score": 2.2e-05,
"token": 35676,
"token_str": " Maul",
"sequence": "<s>My name is<mask> Maul</s>",
},
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name is<mask>ELS</s>"},
],
],
)

@slow
@require_torch
def test_large_model_pt(self):
Expand Down Expand Up @@ -231,9 +257,6 @@ def run_pipeline_test(self, fill_masker, examples):

with self.assertRaises(ValueError):
fill_masker([None])
# Multiple masks
with self.assertRaises(PipelineException):
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
# No mask_token is not supported
with self.assertRaises(PipelineException):
fill_masker("This is")
Expand All @@ -242,6 +265,7 @@ def run_pipeline_test(self, fill_masker, examples):
self.run_test_targets(model, tokenizer)
self.run_test_top_k_targets(model, tokenizer)
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
self.fill_mask_with_multiple_masks(model, tokenizer)

def run_test_targets(self, model, tokenizer):
vocab = tokenizer.get_vocab()
Expand Down Expand Up @@ -340,3 +364,27 @@ def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer):
# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)

def fill_mask_with_multiple_masks(self, model, tokenizer):
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)

outputs = fill_masker(
f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
)
self.assertEqual(
outputs,
[
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
],
)

0 comments on commit e7ed7ff

Please sign in to comment.