Skip to content

Commit

Permalink
Adding support for multiple mask tokens.
Browse files Browse the repository at this point in the history
- Original implem: huggingface#10222

Co-authored-by: njafer <[email protected]>
  • Loading branch information
Narsil and njafer committed Dec 14, 2021
1 parent 2a606f9 commit da96fd0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 34 deletions.
68 changes: 37 additions & 31 deletions src/transformers/pipelines/fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
masked_index = self.get_masked_index(input_ids)
numel = np.prod(masked_index.shape)
if numel > 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
)
elif numel < 1:
# if numel > 1:
# raise PipelineException(
# "fill-mask",
# self.model.base_model_prefix,
# f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
# )
if numel < 1:
raise PipelineException(
"fill-mask",
self.model.base_model_prefix,
Expand Down Expand Up @@ -98,46 +98,52 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None):
top_k = target_ids.shape[0]
input_ids = model_outputs["input_ids"][0]
outputs = model_outputs["logits"]
result = []

if self.framework == "tf":
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]

# Fill mask pipeline supports only one ${mask_token} per sample
outputs = outputs.numpy()

logits = outputs[0, masked_index.item(), :]
probs = tf.nn.softmax(logits)
logits = outputs[0, masked_index, :]
probs = tf.nn.softmax(logits, axis=-1)
if target_ids is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
probs = tf.expand_dims(probs, 0)

topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
# Fill mask pipeline supports only one ${mask_token} per sample

logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
logits = outputs[0, masked_index, :]
probs = logits.softmax(dim=-1)
if target_ids is not None:
probs = probs[..., target_ids]

values, predictions = probs.topk(top_k)

for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
if target_ids is not None:
p = target_ids[p].tolist()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
result.append(
{
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
"score": v,
"token": p,
"token_str": self.tokenizer.decode(p),
}
)
result = []
single_mask = values.shape[0] == 1
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
row = []
for v, p in zip(_values, _predictions):
tokens = input_ids.numpy()
if target_ids is not None:
p = target_ids[p].tolist()

tokens[masked_index[i]] = p
# Filter padding out:
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
# Originally we skip special tokens to give readable output.
# For multi masks though, the other [MASK] would be removed otherwise
# making the output look odd, so we add them back
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
row.append(proposition)
result.append(row)
if single_mask:
return result[0]
return result

def get_target_ids(self, targets, top_k=None):
Expand Down
28 changes: 25 additions & 3 deletions tests/test_pipelines_fill_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,6 @@ def run_pipeline_test(self, fill_masker, examples):

with self.assertRaises(ValueError):
fill_masker([None])
# Multiple masks
with self.assertRaises(PipelineException):
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
# No mask_token is not supported
with self.assertRaises(PipelineException):
fill_masker("This is")
Expand All @@ -242,6 +239,7 @@ def run_pipeline_test(self, fill_masker, examples):
self.run_test_targets(model, tokenizer)
self.run_test_top_k_targets(model, tokenizer)
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
self.fill_mask_with_multiple_masks(model, tokenizer)

def run_test_targets(self, model, tokenizer):
vocab = tokenizer.get_vocab()
Expand Down Expand Up @@ -340,3 +338,27 @@ def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer):
# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)

def fill_mask_with_multiple_masks(self, model, tokenizer):
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)

outputs = fill_masker(
f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
)
self.assertEqual(
outputs,
[
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
[
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
],
],
)

0 comments on commit da96fd0

Please sign in to comment.