diff --git a/src/transformers/pipelines/fill_mask.py b/src/transformers/pipelines/fill_mask.py index 251c7f09732f..acc9ed59c0a6 100644 --- a/src/transformers/pipelines/fill_mask.py +++ b/src/transformers/pipelines/fill_mask.py @@ -1,3 +1,4 @@ +import itertools from typing import TYPE_CHECKING, Optional, Union import numpy as np @@ -75,15 +76,9 @@ def __init__( self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) self.top_k = top_k - def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): - numel = np.prod(masked_index.shape) - if numel > 1: - raise PipelineException( - "fill-mask", - self.model.base_model_prefix, - f"More than one mask_token ({self.tokenizer.mask_token}) is not supported", - ) - elif numel < 1: + def ensure_atleast_one_mask_token(self, masked_indices: np.ndarray): + numel = np.prod(masked_indices.shape) + if numel < 1: raise PipelineException( "fill-mask", self.model.base_model_prefix, @@ -141,12 +136,12 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): result = [] if self.framework == "tf": - masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() + masked_indices = tf.where(input_ids == self.tokenizer.mask_token_id).numpy() # Fill mask pipeline supports only one ${mask_token} per sample - self.ensure_exactly_one_mask_token(masked_index) + self.ensure_atleast_one_mask_token(masked_indices) - logits = outputs[i, masked_index.item(), :] + logits = outputs[i, masked_indices.item(), :] probs = tf.nn.softmax(logits) if targets is None: topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k) @@ -157,34 +152,62 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy() predictions = target_inds[sort_inds.numpy()] else: - masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) + masked_indices = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False) + # Fill mask pipeline supports at least one ${mask_token} per sample + self.ensure_atleast_one_mask_token(masked_indices.numpy()) - # Fill mask pipeline supports only one ${mask_token} per sample - self.ensure_exactly_one_mask_token(masked_index.numpy()) + logits_multiple = [outputs[i, index.item(), :] for index in masked_indices] + + probs_multiple = [logits.softmax(dim=0) for logits in logits_multiple] - logits = outputs[i, masked_index.item(), :] - probs = logits.softmax(dim=0) if targets is None: - values, predictions = probs.topk(top_k if top_k is not None else self.top_k) + values_all = [] + predictions_all = [] + for probs in probs_multiple: + values, predictions = probs.topk(top_k if top_k is not None else self.top_k) + values_all.append(values) + predictions_all.append(predictions) else: + # pending for when the target tokens are specifically input to the model. values = probs[..., target_inds] sort_inds = list(reversed(values.argsort(dim=-1))) values = values[..., sort_inds] predictions = target_inds[sort_inds] - for v, p in zip(values.tolist(), predictions.tolist()): + values_indices = [[i for i in range(value.size()[0])] for value in values_all] + values_combinatorial_val = list(itertools.product(*values_all)) + values_combinatorial_ind = list(itertools.product(*values_indices)) + values_combinatorial = [] + for values_comb_val, values_comb_ind in zip(values_combinatorial_val, values_combinatorial_ind): + values_combinatorial.append([np.prod(values_comb_val), list(values_comb_ind)]) + values_combinatorial = sorted(values_combinatorial, key=lambda x: x[0], reverse=True)[0 : self.top_k] + + for value_combinatorial in values_combinatorial: tokens = input_ids.numpy() - tokens[masked_index] = p - # Filter padding out: + tokens_collated = [] + for mask_iter, element_index in enumerate(value_combinatorial[1]): + masked_index = masked_indices.tolist()[mask_iter][0] + tokens[masked_index] = predictions_all[mask_iter].tolist()[element_index] + tokens_collated.append(predictions_all[mask_iter].tolist()[element_index]) tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] - result.append( - { - "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), - "score": v, - "token": p, - "token_str": self.tokenizer.decode(p), - } - ) + if len(tokens_collated) == 1: + result.append( + { + "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), + "score": value_combinatorial[0], + "token": tokens_collated[0], + "token_str": self.tokenizer.decode(tokens_collated[0]) + } + ) + else: + result.append( + { + "sequence": self.tokenizer.decode(tokens, skip_special_tokens=True), + "score": value_combinatorial[0], + "tokens": tokens_collated, + "tokens_strs": [self.tokenizer.decode(token) for token in tokens_collated], + } + ) # Append results += [result] @@ -192,3 +215,6 @@ def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): if len(results) == 1: return results[0] return results + + +# values = [loc1, loc2]