Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the change from single mask to multi mask support for pytorch #10222

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 55 additions & 29 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
Expand Down Expand Up @@ -75,15 +76,9 @@ def __init__(
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
self.top_k = top_k

def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
numel = np.prod(masked_index.shape)
if numel > 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
)
elif numel < 1:
def ensure_atleast_one_mask_token(self, masked_indices: np.ndarray):
numel = np.prod(masked_indices.shape)
if numel < 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
Expand Down Expand Up @@ -141,12 +136,12 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
result = []

if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
masked_indices = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()

# Fill mask pipeline supports only one ${mask_token} per sample
self.ensure_exactly_one_mask_token(masked_index)
self.ensure_atleast_one_mask_token(masked_indices)

logits = outputs[i, masked_index.item(), :]
logits = outputs[i, masked_indices.item(), :]
probs = tf.nn.softmax(logits)
if targets is None:
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
Expand All @@ -157,38 +152,69 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
masked_indices = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
# Fill mask pipeline supports at least one ${mask_token} per sample
self.ensure_atleast_one_mask_token(masked_indices.numpy())

# Fill mask pipeline supports only one ${mask_token} per sample
self.ensure_exactly_one_mask_token(masked_index.numpy())
logits_multiple = [outputs[i, index.item(), :] for index in masked_indices]

probs_multiple = [logits.softmax(dim=0) for logits in logits_multiple]

logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is None:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
values_all = []
predictions_all = []
for probs in probs_multiple:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is debattable.

Are the proposition single tokens for mask tokens, or are they tuples of answers. Consider the following:

This <mask> is to <mask> what rice is to sushi.

Here are the top-3 proposition for the 3 masks:
[apple, rhubarb, Paris]
[pie, France, biking]

With your code, you are going to propose IIUC
(apple, pie)
(rhubarb, France)
(Paris, biking)

It's possible (not necessarely though) that the propositions you want to make are more like:

(apple, pie)
(rhubarb, pie)
(Paris, France).

My suggestion at this point it to look at joint probabilities for the tuple suggestion instead of token per token.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil This is correct, I have been a little worried about how BERT's masking for multiple masks and how one obtains the joint prob instead of a single token specific probability. Since it is simultaenously making the prediction for all the masks, it tends to make more mistakes(both gramatically) and knowledge wise too. I would be grateful if you could help me understand how one retrieves a join probability in this case.

This issue gets worse when the masks are situated closer to each other, with BERT typically predicting the same word for both the mask slots.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can do correct joint probabilities.
The output is by design the sum of all joint probabilities at every locus.
What I meant is Bert cannot outptut {token1:"either", token2:"or", score:50%}, {token1:"or", token2:"either", score:50%}. It has to output {token1: ["either", 50%], ["or", 50%]} {token2: ["either", 50%], ["or", 50%]}. So you have no way of recovering the first proposed solution and your best guess can only be (either either, 25), (either, or, 25), (or, either, 25), (or, or, 25)

What I was suggesting, as a better guess, was simply treating them like they were:

  • Softmax all mask locus independantly
  • create all joint probabilities (lazily because it's a combinatorial)
    • p1_1 x p2_1
    • p1_1 x p2_2
    • ...
    • p1_2 x p2_1
    • ....
    • px_y where x is the location of the max token, and y is the rank of the proposed token
  • Softmax that joint probabilities list (just so that output scores are correctly scaled, could be ignored because of combinatorial)
  • Order them

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, awesome, I will get on this.

Copy link

@jowagner jowagner Apr 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will rank ('apple', 'France') and ('Paris', 'pie') higher than ('Paris', 'France'). We need some measure how happy the transformer is with each candidate sequence. I think we need additional forward passes to measure the effect of each combination. If there is some way of measuring the model's happiness with a candidate sequence one pass per candidate sequence will suffice. If not, I'd suggest to run

This apple is to <mask> what rice is to sushi.
This rhubarb is to <mask> what rice is to sushi.
This Paris is to <mask> what rice is to sushi.
This <mask> is to pie what rice is to sushi.
This <mask> is to France what rice is to sushi.
This <mask> is to biking what rice is to sushi.

and then multiple the probabilities. We will need some kind of beam search to limit the combinations tested as the number of forward passes needed will otherwise explode for more masked tokens or large top_k.

Edit: Actually, depending on the setting, this may run with fewer passes than trying all combinations, e.g. here 6 vs. 9.

values_all.append(values)
predictions_all.append(predictions)
else:
# pending for when the target tokens are specifically input to the model.
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]

for v, p in zip(values.tolist(), predictions.tolist()):
values_indices = [[i for i in range(value.size()[0])] for value in values_all]
values_combinatorial_val = list(itertools.product(*values_all))
values_combinatorial_ind = list(itertools.product(*values_indices))
values_combinatorial = []
for values_comb_val, values_comb_ind in zip(values_combinatorial_val, values_combinatorial_ind):
values_combinatorial.append([np.prod(values_comb_val), list(values_comb_ind)])
values_combinatorial = sorted(values_combinatorial, key=lambda x: x[0], reverse=True)[0 : self.top_k]

for value_combinatorial in values_combinatorial:
tokens = input_ids.numpy()
tokens[masked_index] = p
# Filter padding out:
tokens_collated = []
for mask_iter, element_index in enumerate(value_combinatorial[1]):
masked_index = masked_indices.tolist()[mask_iter][0]
tokens[masked_index] = predictions_all[mask_iter].tolist()[element_index]
tokens_collated.append(predictions_all[mask_iter].tolist()[element_index])
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
}
)
if len(tokens_collated) == 1:
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": value_combinatorial[0],
"token": tokens_collated[0],
"token_str": self.tokenizer.decode(tokens_collated[0])
}
)
else:
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": value_combinatorial[0],
"tokens": tokens_collated,
"tokens_strs": [self.tokenizer.decode(token) for token in tokens_collated],
}
)

# Append
results += [result]

if len(results) == 1:
return results[0]
return results


# values = [loc1, loc2]