Skip to content

Commit

Permalink
[split_special_tokens] Add support for split_special_tokens argum…
Browse files Browse the repository at this point in the history
…ent to encode (huggingface#25081)

* draft changes

* update and add tests

* styling for no

* move test

* path to usable model

* update test

* small update

* update bertbased tokenizers

* don'tuse kwargs for _tokenize

* don'tuse kwargs for _tokenize

* fix copies

* update

* update test for special tokenizers

* fixup

* skip two tests

* remove pdb breakpiont()

* wowo

* rewrite custom tests

* nits

* revert chang in target keys

* fix markup lm

* update documentation of the argument
  • Loading branch information
ArthurZucker authored and blbadger committed Nov 8, 2023
1 parent fa5b15d commit f0ac2d7
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 24 deletions.
6 changes: 4 additions & 2 deletions src/transformers/models/bert/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/convbert/tokenization_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/distilbert/tokenization_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/electra/tokenization_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/funnel/tokenization_funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/layoutlm/tokenization_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/lxmert/tokenization_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mobilebert/tokenization_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/roc_bert/tokenization_roc_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,12 @@ def vocab_size(self):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)

def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
all_special_tokens_extended = {
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
}
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)

text, kwargs = self.prepare_for_tokenization(text, **kwargs)

Expand All @@ -513,8 +514,14 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)

no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
# split_special_tokens: empty `no_split_token`
if split_special_tokens:
no_split_token = []
tokens = [text]
else:
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)

# ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens):
if token in no_split_token:
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,11 @@ def all_special_ids(self) -> List[int]:
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
's', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""


Expand Down Expand Up @@ -1546,6 +1551,9 @@ def __init__(self, **kwargs):
# By default, cleaning tokenization spaces for both fast and slow tokenizers
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)

# By default, do not split special tokens for both fast and slow tokenizers
self.split_special_tokens = kwargs.pop("split_special_tokens", False)

self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv2/test_tokenization_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def test_encode_decode_with_spaces(self):
def test_right_and_left_truncation(self):
pass

@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass

def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/layoutlmv3/test_tokenization_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def test_encode_decode_with_spaces(self):
def test_right_and_left_truncation(self):
pass

@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass

def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down
13 changes: 13 additions & 0 deletions tests/models/layoutxlm/test_tokenization_layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,19 @@ def test_save_sentencepiece_tokenizer(self) -> None:
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)

def test_split_special_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
_, _, boxes = self.get_question_words_and_boxes()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)

encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes
)
self.assertTrue(len(encoded_split_special_token) > 1)

@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
Expand Down
13 changes: 13 additions & 0 deletions tests/models/markuplm/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,19 @@ def test_special_tokens_initialization(self):
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)

def test_split_special_tokens(self):
# TODO this is only possible for slow currently
tokenizer = self.get_tokenizer()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)

encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True
)
self.assertTrue(len(encoded_split_special_token) > 1)

def test_training_new_tokenizer(self):
# This feature only exists for fast tokenizers
if not self.test_rust_tokenizer:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3909,6 +3909,7 @@ def test_save_slow_from_fast_and_reload_fast(self):
# Should not raise an error
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)

# TODO This is ran for all models but only tests bert...
def test_clean_up_tokenization_spaces(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
assert tokenizer.clean_up_tokenization_spaces is True
Expand Down Expand Up @@ -3953,3 +3954,29 @@ def test_clean_up_tokenization_spaces(self):
tokenizer.clean_up_tokenization_spaces = True
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"

def test_split_special_tokens(self):
if not self.test_slow_tokenizer:
return

for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)

if not tokenizer.is_fast:
# bloom, gptneox etc only have a fast
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)

encoded_split_special_token = tokenizer.encode(
special_token, add_special_tokens=False, split_special_tokens=True
)
if len(encoded_split_special_token) == 1:
# if we have subword tokenization or special vocab
self.assertTrue(
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
)
else:
self.assertTrue(len(encoded_split_special_token) > 1)

0 comments on commit f0ac2d7

Please sign in to comment.