Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the change from single mask to multi mask support for pytorch #10222

Conversation

naveenjafer
Copy link

@naveenjafer naveenjafer commented Feb 16, 2021

What does this PR do?

A draft PR for the Feature request to change from single mask to multi mask support for the fill mask pipeline.

As discussed this is one a draft PR to discuss the changes that need to be made to the output format to jointly support multiple and single mask in one pipeline call. The PR implements the change for Pytorch and code has not been pushed in yet for when the keyword argument is called.

The pipeline tests are expected to fail since the output format changed.
#10158

Example code that tests this feature is below.

import json
from transformers import pipeline
unmasker = pipeline('fill-mask', model='bert-base-uncased')
t = unmasker("hi [MASK] morning I'm a [MASK] model.")
print(json.dumps(t, indent=4))

@LysandreJik

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR, keep up the good work.

Main comment is I think critical for the proposed change.

values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]

for v, p in zip(values.tolist(), predictions.tolist()):
for i, item in enumerate(values_all[0]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to avoid all indices if possible in your loops.

It's usually a code smell (not necessarily but it's most likely avoidable here).
It will also likely lead to indexing errors at some point.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yeah I think I can get rid of this by zipping.

"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
"scores": [v.tolist()[i] for v in values_all],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to keep this backward compatible if possible. You can simply change the return types based on number of masks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but then the issue would be, that if say I was to iteratively call the pipeline with a corpus with a varying number of masks, the return output types would be different when it is a single mask vs multiple mask, I wonder if that will result in a suboptimal APIish experience for someone looking to try it out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a core maintainer can pitch in on this @LysandreJik to confirm, but if your change is not backward compatible, it will need a major release to be included. So if you want to maximize your chances of getting it merged, I think making it backward compatible is important.

It's definitely something that can be taken care of later in the PR. We might even directly take care of it I'm not sure. Just letting you know ;)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or the other option would be to keep them as 2 separate pipeline calls, fill-mask and fill-mask-multiple. But not as elegant as I would have liked to handle it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik What would you suggest we do here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @Narsil that we should handle backwards compatibility as much as possible here. Having a different return according to the number of masks isn't super clean but it's definitely way better than breaking the workflow of existing fill-mask pipeline users.

Pipelines are, by design, already a bit "magical", in the sense that it returns different outputs according to your inputs. If you pass the pipeline a string or a list of strings, the resulting type will be different:

>>> from transformers import pipeline
>>> mask_filler = pipeline("fill-mask")
>>> mask_filler("Where do <mask> live?")
[{'sequence': '<s>Where do you live?</s>', 'score': 0.575425386428833, 'token': 47, 'token_str': 'Ġyou'}, {'sequence': '<s>Where do YOU live?</s>', 'score': 0.1382409781217575, 'token': 10540, 'token_str': 'ĠYOU'}, {'sequence': '<s>Where do they live?</s>', 'score': 0.044609859585762024, 'token': 51, 'token_str': 'Ġthey'}, {'sequence': '<s>Where do we live?</s>', 'score': 0.0327814482152462, 'token': 52, 'token_str': 'Ġwe'}, {'sequence': '<s>Where do millennials live?</s>', 'score': 0.02294538915157318, 'token': 15100, 'token_str': 'Ġmillennials'}]
>>> mask_filler(["Where do <mask> live?", "Where <mask> I live?"])
[[{'sequence': '<s>Where do you live?</s>', 'score': 0.5754269361495972, 'token': 47, 'token_str': 'Ġyou'}, {'sequence': '<s>Where do YOU live?</s>', 'score': 0.1382400244474411, 'token': 10540, 'token_str': 'ĠYOU'}, {'sequence': '<s>Where do they live?</s>', 'score': 0.044610150158405304, 'token': 51, 'token_str': 'Ġthey'}, {'sequence': '<s>Where do we live?</s>', 'score': 0.032781537622213364, 'token': 52, 'token_str': 'Ġwe'}, {'sequence': '<s>Where do millennials live?</s>', 'score': 0.022945256903767586, 'token': 15100, 'token_str': 'Ġmillennials'}], [{'sequence': '<s>Where do I live?</s>', 'score': 0.6903642416000366, 'token': 109, 'token_str': 'Ġdo'}, {'sequence': '<s>Where should I live?</s>', 'score': 0.15854182839393616, 'token': 197, 'token_str': 'Ġshould'}, {'sequence': '<s>Where did I live?</s>', 'score': 0.04364638775587082, 'token': 222, 'token_str': 'Ġdid'}, {'sequence': '<s>Where am I live?</s>', 'score': 0.030600957572460175, 'token': 524, 'token_str': 'Ġam'}, {'sequence': '<s>Where shall I live?</s>', 'score': 0.029068272560834885, 'token': 5658, 'token_str': 'Ġshall'}]]

It wouldn't be opposed to that design to return an array of values when there are several masks in the pipeline.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik Okay, in that case, the return type for single mask will stay the same -

{  
   "sequence" : "the final sequence with the mask added",  
   "score" :  "the softmax score",  
   "token" : "the token ID used in filling the MASK",  
   "token_str" : "the token string used in filling the MASK"  
}  

and for the multiple masks case it would be (notice that the key and type for token and token_str changes)

{  
   "sequence" : "the final sequence with all the masks added",  
   "score" :  "the combinatorial score of each individual mask's softmax output",  
   "tokens" : ["the token ID used in filling MASK 1 ", "the token ID used in filling MASK 2", ... ],  
   "tokens_str" : ["the token string used in filling the MASK 1", "the token string used in filling the MASK 2", ...]
}  

Is this agreeable?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other functions in transfomers have parameters controlling what goes into the return dictionary, e.g. output_hidden_states for forward(). How about using multiple_mask_output=True to always use the new format (False by default)?

values_all = []
predictions_all = []
for probs in probs_multiple:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is debattable.

Are the proposition single tokens for mask tokens, or are they tuples of answers. Consider the following:

This <mask> is to <mask> what rice is to sushi.

Here are the top-3 proposition for the 3 masks:
[apple, rhubarb, Paris]
[pie, France, biking]

With your code, you are going to propose IIUC
(apple, pie)
(rhubarb, France)
(Paris, biking)

It's possible (not necessarely though) that the propositions you want to make are more like:

(apple, pie)
(rhubarb, pie)
(Paris, France).

My suggestion at this point it to look at joint probabilities for the tuple suggestion instead of token per token.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Narsil This is correct, I have been a little worried about how BERT's masking for multiple masks and how one obtains the joint prob instead of a single token specific probability. Since it is simultaenously making the prediction for all the masks, it tends to make more mistakes(both gramatically) and knowledge wise too. I would be grateful if you could help me understand how one retrieves a join probability in this case.

This issue gets worse when the masks are situated closer to each other, with BERT typically predicting the same word for both the mask slots.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can do correct joint probabilities.
The output is by design the sum of all joint probabilities at every locus.
What I meant is Bert cannot outptut {token1:"either", token2:"or", score:50%}, {token1:"or", token2:"either", score:50%}. It has to output {token1: ["either", 50%], ["or", 50%]} {token2: ["either", 50%], ["or", 50%]}. So you have no way of recovering the first proposed solution and your best guess can only be (either either, 25), (either, or, 25), (or, either, 25), (or, or, 25)

What I was suggesting, as a better guess, was simply treating them like they were:

  • Softmax all mask locus independantly
  • create all joint probabilities (lazily because it's a combinatorial)
    • p1_1 x p2_1
    • p1_1 x p2_2
    • ...
    • p1_2 x p2_1
    • ....
    • px_y where x is the location of the max token, and y is the rank of the proposed token
  • Softmax that joint probabilities list (just so that output scores are correctly scaled, could be ignored because of combinatorial)
  • Order them

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me, awesome, I will get on this.

Copy link

@jowagner jowagner Apr 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will rank ('apple', 'France') and ('Paris', 'pie') higher than ('Paris', 'France'). We need some measure how happy the transformer is with each candidate sequence. I think we need additional forward passes to measure the effect of each combination. If there is some way of measuring the model's happiness with a candidate sequence one pass per candidate sequence will suffice. If not, I'd suggest to run

This apple is to <mask> what rice is to sushi.
This rhubarb is to <mask> what rice is to sushi.
This Paris is to <mask> what rice is to sushi.
This <mask> is to pie what rice is to sushi.
This <mask> is to France what rice is to sushi.
This <mask> is to biking what rice is to sushi.

and then multiple the probabilities. We will need some kind of beam search to limit the combinations tested as the number of forward passes needed will otherwise explode for more masked tokens or large top_k.

Edit: Actually, depending on the setting, this may run with fewer passes than trying all combinations, e.g. here 6 vs. 9.

@naveenjafer
Copy link
Author

@Narsil @LysandreJik How do you suggest we go about with the targets param? At the moment, targets can either be a list of strings or a string. In case of multiple masks, there are 2 ways to go about with it.

  1. Provide a way for the user to define targets for each mask.
  2. One single target list that can be uniformly applied across all the positions.

The first method would be best implemented by expecting a dict as argument in the keyword param. Something like
{ "0" : "str or list of strings" , "2" : "str or list of strings" ... }

This way the user can decide to skip explicitly defining candidate keywords in some of the mask positions if needed ( skipped mask 1 in the example above).

@LysandreJik
Copy link
Member

Tough question indeed regarding targets! Switching to a dict sounds a bit non intuitive to me, but I don't see any other choice. I guess eventually the API would be the following:

Given a single input string, with a single mask:

  • A candidate as a string returns the candidate score for the mask
  • A candidate list of strings returns the candidate scores for the mask

Given a single input string, with multiple masks:

  • A candidate as a string returns the candidate scores for all masks
  • A candidate list of strings returns the candidate scores for all masks, on all candidates
  • A candidate dict of strings returns the candidate scores for the masks which are concerned by the dictionary keys. Their candidates is the dictionary value linked to that dictionary key.
  • A candidate dict of list of strings returns the candidate scores for the masks which are concerned by the dictionary keys. Their candidates are the dictionary values linked to that dictionary key.

Then there are also lists of input strings, with single masks, and lists of input strings, with multiple masks. This results in a very large amount of possibilities, with different returns, which sounds overwhelming. I'm not too sure that's the best way to handle the issue, I'll give it a bit more thought.

@naveenjafer
Copy link
Author

@LysandreJik I had a question. From what I can understand, one can only define a single set of targets at the moment irrespective of how many input texts are given right? For both the case of a single input text and multiple input texts for even the base case of a single mask, we can only define a single target or a list of targets that applies across them all right? Essentially, it is a many to one relation for the input texts to the target. If that is the case, targets functionality is currently not designed in a useful manner right?

@LysandreJik
Copy link
Member

Hi, sorry for getting back to you so late on this. I agree with you that we can improve the targets. I'm pinging @joeddav as he's the author of the PR that added them.

@joeddav your input on this PR would be more than welcome! Thank you.

@joeddav
Copy link
Contributor

joeddav commented Mar 10, 2021

Personally, I think the simplest solution would be best: only support targets in the single-mask case. If targets is passed and there are multiple mask tokens, raise a ValueError. It's a pretty narrow use case to need to pass a string with multiple masked tokens while also needing to evaluate possible target tokens for each. In my opinion, that's a complicated and rare use case and we don't need to muddle pipelines code by attempting to support it. It can always be accomplished by using the core modules instead of a pipeline.

@naveenjafer
Copy link
Author

@joeddav That does make sense to me! The objective of a pipeline should only be to accommodate for some quick use test cases. Making it cumbersome misses the point altogether. @LysandreJik What do you think?

@LysandreJik
Copy link
Member

Yes, I agree with @joeddav as well!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@jowagner
Copy link

[...] If you think this still needs to be addressed please comment on this thread.

This feature would have many applications and would enable comparison of MLMs in gloze tests beyond the restricted setting of targeting words in the intersection of the vocabularies of the models to be compared. There are some open questions how top_k predictions should be made, see issue #3609, so I think it would be good to wait a few more weeks to give everybody time to read the linked paper and discuss ideas.

@naveenjafer
Copy link
Author

@jowagner Just to clarify it for others who might be following, the paper you are referring to is this one https://arxiv.org/abs/2002.03079 right?

@jowagner
Copy link

@jowagner Just to clarify it for others who might be following, the paper you are referring to is this one https://arxiv.org/abs/2002.03079 right?

Yes. I hope to read it soon and get a more clear picture what is needed here. I tend to think that producing top_k predictions for multiple masked tokens is outside the scope of the BERT model and really needs an extra model on top of it, e.g. a model that predicts a ranked list of best crystallisation points and can then be used to perform a beam search, fixing on subword unit at a time and producing a k-best list of best crystallisation processes.

@naveenjafer
Copy link
Author

@jowagner I have a doubt in that case coming back to the basics of BERT. when some of the words are masked and a prediction is to be made on multiple masks during pre-training step in BERT, does BERT not face the same issue? Or are the masks predicted one mask at a time in each training sentence fed to BERT?

@jowagner
Copy link

jowagner commented Apr 30, 2021

Looking at Devlin et al 2018 again, I don't see the pre-training objective stated but certainly they try to push as much probability mass as possible to the one completion attested in the training data. BERT is trained to get the top prediction right. Good secondary predictions for individual tokens are only a by-product. Nothing pushes the model to make the k-th predictions consistent across multiple masked subword units for k > 1.

Yes, making predictions can be expected to be harder when there are multiple masked subword units but that also happens in pre-training and BERT therefore learns to do this. Maybe BERT does this in steps, crystallising only a few decisions in each layer. A way to find out would be to fix the BERT layers, add MLM heads to each layer, tune these heads and then see how the predictions (and probabilities) change from layer to layer. (This would make a nice paper, or maybe somebody has done this already.)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Jun 3, 2021
@naveenjafer
Copy link
Author

Do we have a final verdict yet on the approach to be followed? @mitramir55 had suggested a code proposal I believe in #3609

@LysandreJik LysandreJik reopened this Jun 7, 2021
@github-actions
Copy link

github-actions bot commented Jul 1, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@naveenjafer
Copy link
Author

naveenjafer commented Jul 6, 2021

@LysandreJik Shall i replace this with the implementation suggested earlier in #3609 and raise a PR? Though I dont quite think we have discussed on what scoring would be ideal for the beam search used to sort the predictions.

@LysandreJik
Copy link
Member

@Narsil had good insights about your previous implementation - @Narsil could you let us know what you think of the solution proposed here #3609 (comment)?

@Narsil
Copy link
Contributor

Narsil commented Jul 7, 2021

The design in #3609 (comment) seems very interesting !

Main comments:

  • I would be curious to see (and probably it would need to become a test) to prove that doing n inference instead of 1 will produce better results (because it should be close to the real joint probabilities) that's the main interest of this proposed approach.

  • I think it should output the same tokens as fill-mask pipeline in the degenerate case (when there's only 1 mask).
    I don't think it's correct right now (see below what I tried)

  • Because we iteratively do topk for each mask it's a bit of an exponential if I understand correctly. I would probably add some kind of cleanup to limit the number of "beams" to topk (I may have overlooked but it seems to be currently missing)

  • the proposed code could probably be refactored a bit for clarity and avoid integer indexing and deep nesting.

import torch                                                                     
from transformers import AutoModelForMaskedLM, AutoTokenizer, pipeline           
import random                                                                    
                                                                                 
                                                                                    
def predict_seqs_dict(sequence, model, tokenizer, top_k=5, order="right-to-left"):  
                                                                                 
    ids_main = tokenizer.encode(sequence, return_tensors="pt", add_special_tokens=False)  
                                                                                 
    ids_ = ids_main.detach().clone()                                             
    position = torch.where(ids_main == tokenizer.mask_token_id)                  
                                                                                 
    positions_list = position[1].numpy().tolist()                                
                                                                                 
    if order == "left-to-right":                                                 
        positions_list.reverse()                                                 
                                                                                 
    elif order == "random":                                                      
        random.shuffle(positions_list)                                           
                                                                                 
    # print(positions_list)                                                      
    predictions_ids = {}                                                         
    predictions_detokenized_sents = {}                                           
                                                                                 
    for i in range(len(positions_list)):                                          
        predictions_ids[i] = []                                                   
        predictions_detokenized_sents[i] = []                                     
                                                                                  
        # if it was the first prediction,                                         
        # just go on and predict the first predictions                            
                                                                                  
        if i == 0:                                                                
            model_logits = model(ids_main)["logits"][0][positions_list[0]]           
            top_k_tokens = torch.topk(model_logits, top_k, dim=0).indices.tolist()  
                                                                                  
            for j in range(len(top_k_tokens)):                                    
                # print(j)                                                        
                ids_t_ = ids_.detach().clone()                                    
                ids_t_[0][positions_list[0]] = top_k_tokens[j]                    
                predictions_ids[i].append(ids_t_)                                 
                                                                                  
                pred = tokenizer.decode(ids_t_[0])                                
                predictions_detokenized_sents[i].append(pred)                     
                                                                                  
                # append the sentences and ids of this masked token               
                                                                                      
        # if we already have some predictions, go on and fill the rest of the masks  
        # by continuing the previous predictions                                                                                                                                                                                                                                          
        if i != 0:                                                                                                                                                                                                                                                                        
            for pred_ids in predictions_ids[i - 1]:                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                          
                # get the logits                                                                                                                                                                                                                                                          
                model_logits = model(pred_ids)["logits"][0][positions_list[i]]                                                                                                                                                                                                            
                # get the top 5 of this prediction and masked token                                                                                                                                                                                                                       
                top_k_tokens = torch.topk(model_logits, top_k, dim=0).indices.tolist()                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                          
                for top_id in top_k_tokens:                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                          
                    ids_t_i = pred_ids.detach().clone()                                                                                                                                                                                                                                   
                    ids_t_i[0][positions_list[i]] = top_id                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                          
                    pred = tokenizer.decode(ids_t_i[0])                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                          
                    # append the sentences and ids of this masked token                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                          
                    predictions_ids[i].append(ids_t_i)                                                                                                                                                                                                                                    
                    predictions_detokenized_sents[i].append(pred)                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                          
    return predictions_detokenized_sents                                         
                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                          
sequence = "This is some super neat [MASK] !"                                                                                                                                                                                                                                             
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")                                                                                                                                                                                                                            
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")                                                                                                                                                                                                                         
                                                                                 
pipe = pipeline(task="fill-mask", tokenizer=tokenizer, model=model)        
print(predict_seqs_dict(sequence, model, tokenizer))                           
print(pipe(sequence))    

@jowagner
Copy link

jowagner commented Aug 3, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Yes, we need more time or help from somebody with time to review the discussion and make recommendations.

My thoughts re-reading a few comments, including some of my own:

  • Producing k-best predictions for multiple masked tokens requires choices, i.e. a model, separate from the underlying transformer model. This is where the PR is stalled. A quick way forward would be to support only k=1 when there are multiple masked tokens for the time being. For k=1, it is undisputed that the prediction should be the transformer's top prediction for each token.
  • This PR/feature does not directly allow comparison of cloze test predictions of models with different vocabularies. Users would have to probe with continuous sequences of masked tokens of varying length and somehow decide between the candidate predictions.

@LysandreJik LysandreJik added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Aug 4, 2021
@huggingface huggingface deleted a comment from github-actions bot Aug 4, 2021
Narsil added a commit to Narsil/transformers that referenced this pull request Dec 14, 2021
Narsil added a commit that referenced this pull request Dec 14, 2021
* Adding support for multiple mask tokens.

- Original implem: #10222

Co-authored-by: njafer <[email protected]>

* In order to accomodate optionally multimodal models like Perceiver

we add information to the tasks to specify tasks where we know for sure
if we need the tokenizer/feature_extractor or not.

* Adding info in the documentation about multi masks.

+ marked as experimental.

* Add a copy() to prevent overriding the same tensor over and over.

* Fixup.

* Adding small test for multi mask with real values..

Co-authored-by: njafer <[email protected]>
@breandan
Copy link

After reading this thread and skimming through #14716, I must confess I still a little unsure how the scores for multi-masked prompts are computed. Based on my understanding, for a prompt with k-masks, it seems like you want to do a beam search over over the Cartesian product mask_1_targets x mask_2_targets x ... x mask_k_targets and return the top-n most likely tuples maximizing P(mask_1=token_i_k, mask_2=token_i_2, ... m_k=token_i_k), i.e.:

{
   T_1=[(token_1_1, ..., token_1_k), score_t_1],
   T_2=[(token_2_1, ..., token_2_k), score_t_2],
   ...
   T_n=[(token_n_1, ..., token_n_k), score_t_n]
}

Is this accurate? Perhaps you could try to clarify the design intent and limitations of the current API in the documentation somewhere. If you intend to eventually support computing the joint probability, I think would be beneficial to provide a way for consumers to supply a set of per-mask targets and configure the beam search parameters, e.g. beam width. Thanks!

@Narsil
Copy link
Contributor

Narsil commented Feb 16, 2022

After reading this thread and skimming through #14716, I must confess I still a little unsure how the scores for multi-masked prompts are computed. Based on my understanding, for a prompt with k-masks, it seems like you want to do a beam search over over the Cartesian product mask_1_targets x mask_2_targets x ... x mask_k_targets and return the top-n most likely tuples maximizing P(mask_1=token_i_k, mask_2=token_i_2, ... m_k=token_i_k), i.e.:

{
   T_1=[(token_1_1, ..., token_1_k), score_t_1],
   T_2=[(token_2_1, ..., token_2_k), score_t_2],
   ...
   T_n=[(token_n_1, ..., token_n_k), score_t_n]
}

Is this accurate?

Actually no, this was the intent of this PR which never got merged. Instead of trying to make educated guess about mask combinations, #14716 added what seems the most appropriate, which is what the models really answers, which is various tokens at mask locations, without ANY information about correlations.

This is how the model is built, and as such, we return it raw.

from transformers import pipeline


pipe = pipeline(model="bert-base-uncased")

print(pipe("This is a [MASK] and a [MASK]", top_k=3))
[[{'score': 0.5048776268959045,
   'sequence': '[CLS] this is a. and a [MASK] [SEP]',
   'token': 1012,
   'token_str': '.'},
  {'score': 0.07435218244791031,
   'sequence': '[CLS] this is a ; and a [MASK] [SEP]',
   'token': 1025,
   'token_str': ';'},
  {'score': 0.05109349265694618,
   'sequence': '[CLS] this is a, and a [MASK] [SEP]',
   'token': 1010,
   'token_str': ','}],
 [{'score': 0.8665121793746948,
   'sequence': '[CLS] this is a [MASK] and a. [SEP]',
   'token': 1012,
   'token_str': '.'},
  {'score': 0.05160374939441681,
   'sequence': '[CLS] this is a [MASK] and a | [SEP]',
   'token': 1064,
   'token_str': '|'},
  {'score': 0.046446096152067184,
   'sequence': '[CLS] this is a [MASK] and a ; [SEP]',
   'token': 1025,
   'token_str': ';'}]]

You are then free to do all the complex attempts to make the suggestions combined. But we don't attempt to hide it since, the model really doesn't model that.

@kaisugi
Copy link
Contributor

kaisugi commented Apr 12, 2022

I appreciate this implementation for the support of multiple [MASK] tokens in the input.
However, I cannot figure out why the pipeline output is kept nested only in those cases. It forces me to do some additional coding to make it unnested.
Is there any specific reason for this?

if single_mask:
return result[0]
return result

@Narsil
Copy link
Contributor

Narsil commented Apr 12, 2022

Is there any specific reason for this?

Backward compatibility, the first pipeline wasn't built with that option in mind making it harder to support multi mask seamlessly like you would expect. The removal of such quirks might happen in 5.0 though. We know it's not convenient as it is, but breaking user code is even less convenient.

@ArthurZucker
Copy link
Collaborator

Closing as this PR is super old and not relevant anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants