diff --git a/src/transformers/models/bert/tokenization_bert.py b/src/transformers/models/bert/tokenization_bert.py index 536eb08640c07c..a24f39564264df 100644 --- a/src/transformers/models/bert/tokenization_bert.py +++ b/src/transformers/models/bert/tokenization_bert.py @@ -238,10 +238,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/convbert/tokenization_convbert.py b/src/transformers/models/convbert/tokenization_convbert.py index 4fbed8fe10fd15..800848caaf1cc7 100644 --- a/src/transformers/models/convbert/tokenization_convbert.py +++ b/src/transformers/models/convbert/tokenization_convbert.py @@ -177,10 +177,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/deprecated/retribert/tokenization_retribert.py b/src/transformers/models/deprecated/retribert/tokenization_retribert.py index 4529e8e9029bba..de50c74b70bd02 100644 --- a/src/transformers/models/deprecated/retribert/tokenization_retribert.py +++ b/src/transformers/models/deprecated/retribert/tokenization_retribert.py @@ -178,10 +178,12 @@ def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/distilbert/tokenization_distilbert.py b/src/transformers/models/distilbert/tokenization_distilbert.py index 02596825863e96..5e96e4972d3fac 100644 --- a/src/transformers/models/distilbert/tokenization_distilbert.py +++ b/src/transformers/models/distilbert/tokenization_distilbert.py @@ -195,10 +195,12 @@ def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/electra/tokenization_electra.py b/src/transformers/models/electra/tokenization_electra.py index e202f773efa857..aabeccba7d630e 100644 --- a/src/transformers/models/electra/tokenization_electra.py +++ b/src/transformers/models/electra/tokenization_electra.py @@ -194,10 +194,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/funnel/tokenization_funnel.py b/src/transformers/models/funnel/tokenization_funnel.py index f085fd7c47762b..37a913d0a01bae 100644 --- a/src/transformers/models/funnel/tokenization_funnel.py +++ b/src/transformers/models/funnel/tokenization_funnel.py @@ -205,10 +205,12 @@ def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/layoutlm/tokenization_layoutlm.py b/src/transformers/models/layoutlm/tokenization_layoutlm.py index 57c29d5870ed6e..b518874224a42c 100644 --- a/src/transformers/models/layoutlm/tokenization_layoutlm.py +++ b/src/transformers/models/layoutlm/tokenization_layoutlm.py @@ -176,10 +176,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/lxmert/tokenization_lxmert.py b/src/transformers/models/lxmert/tokenization_lxmert.py index daa761878d9403..e651b8f4454a11 100644 --- a/src/transformers/models/lxmert/tokenization_lxmert.py +++ b/src/transformers/models/lxmert/tokenization_lxmert.py @@ -168,10 +168,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/mobilebert/tokenization_mobilebert.py b/src/transformers/models/mobilebert/tokenization_mobilebert.py index 63c0ab28a7309d..389e38bce61933 100644 --- a/src/transformers/models/mobilebert/tokenization_mobilebert.py +++ b/src/transformers/models/mobilebert/tokenization_mobilebert.py @@ -166,10 +166,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/roc_bert/tokenization_roc_bert.py b/src/transformers/models/roc_bert/tokenization_roc_bert.py index cee778dc878152..d665b91a0680df 100644 --- a/src/transformers/models/roc_bert/tokenization_roc_bert.py +++ b/src/transformers/models/roc_bert/tokenization_roc_bert.py @@ -210,10 +210,12 @@ def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/models/squeezebert/tokenization_squeezebert.py b/src/transformers/models/squeezebert/tokenization_squeezebert.py index ccce92809e321c..f061a1a53c2577 100644 --- a/src/transformers/models/squeezebert/tokenization_squeezebert.py +++ b/src/transformers/models/squeezebert/tokenization_squeezebert.py @@ -180,10 +180,12 @@ def vocab_size(self): def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) - def _tokenize(self, text): + def _tokenize(self, text, split_special_tokens=False): split_tokens = [] if self.do_basic_tokenize: - for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): # If the token is part of the never_split set if token in self.basic_tokenizer.never_split: split_tokens.append(token) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index c1dd9c329a1cc2..e26c0c6d52898e 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -498,6 +498,7 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]: all_special_tokens_extended = { str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken) } + split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens) text, kwargs = self.prepare_for_tokenization(text, **kwargs) @@ -513,8 +514,14 @@ def tokenize(self, text: TextInput, **kwargs) -> List[str]: pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) - no_split_token = set(self.unique_no_split_tokens) - tokens = self.tokens_trie.split(text) + # split_special_tokens: empty `no_split_token` + if split_special_tokens: + no_split_token = [] + tokens = [text] + else: + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + # ["This is something", "", " else"] for i, token in enumerate(tokens): if token in no_split_token: diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c3d2c4eb8999b0..0490bec3975495 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1492,6 +1492,11 @@ def all_special_ids(self) -> List[int]: clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): Whether or not the model should cleanup the spaces that were added when splitting the input text during the tokenization process. + split_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the special tokens should be split during the tokenization process. The default behavior is + to not split special tokens. This means that if `` is the `bos_token`, then `tokenizer.tokenize("") = + ['`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<', + 's', '>']`. This argument is only supported for `slow` tokenizers for the moment. """ @@ -1546,6 +1551,9 @@ def __init__(self, **kwargs): # By default, cleaning tokenization spaces for both fast and slow tokenizers self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True) + # By default, do not split special tokens for both fast and slow tokenizers + self.split_special_tokens = kwargs.pop("split_special_tokens", False) + self.deprecation_warnings = ( {} ) # Use to store when we have already noticed a deprecation warning (avoid overlogging). diff --git a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py index 9224fbd87ea49c..942cceaf7cd0d4 100644 --- a/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_tokenization_layoutlmv2.py @@ -384,6 +384,10 @@ def test_encode_decode_with_spaces(self): def test_right_and_left_truncation(self): pass + @unittest.skip("Not implemented") + def test_split_special_tokens(self): + pass + def test_encode_plus_with_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: diff --git a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py index 63d86f280cc007..58092834e5a160 100644 --- a/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_tokenization_layoutlmv3.py @@ -264,6 +264,10 @@ def test_encode_decode_with_spaces(self): def test_right_and_left_truncation(self): pass + @unittest.skip("Not implemented") + def test_split_special_tokens(self): + pass + def test_encode_plus_with_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: diff --git a/tests/models/layoutxlm/test_tokenization_layoutxlm.py b/tests/models/layoutxlm/test_tokenization_layoutxlm.py index bf295c9c925e0e..f7f8329706dff2 100644 --- a/tests/models/layoutxlm/test_tokenization_layoutxlm.py +++ b/tests/models/layoutxlm/test_tokenization_layoutxlm.py @@ -144,6 +144,19 @@ def test_save_sentencepiece_tokenizer(self) -> None: self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) + def test_split_special_tokens(self): + tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") + _, _, boxes = self.get_question_words_and_boxes() + special_token = "[SPECIAL_TOKEN]" + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.tokenize( + special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes + ) + self.assertTrue(len(encoded_split_special_token) > 1) + @slow def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") diff --git a/tests/models/markuplm/test_tokenization_markuplm.py b/tests/models/markuplm/test_tokenization_markuplm.py index 533a3429a8dc73..73979b255e08db 100644 --- a/tests/models/markuplm/test_tokenization_markuplm.py +++ b/tests/models/markuplm/test_tokenization_markuplm.py @@ -1344,6 +1344,19 @@ def test_special_tokens_initialization(self): self.assertTrue(special_token_id in p_output) self.assertTrue(special_token_id in cr_output) + def test_split_special_tokens(self): + # TODO this is only possible for slow currently + tokenizer = self.get_tokenizer() + special_token = "[SPECIAL_TOKEN]" + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.tokenize( + special_token, add_special_tokens=False, split_special_tokens=True + ) + self.assertTrue(len(encoded_split_special_token) > 1) + def test_training_new_tokenizer(self): # This feature only exists for fast tokenizers if not self.test_rust_tokenizer: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 3b17c6ea4f6983..aec5e493c57c00 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3909,6 +3909,7 @@ def test_save_slow_from_fast_and_reload_fast(self): # Should not raise an error self.rust_tokenizer_class.from_pretrained(tmp_dir_2) + # TODO This is ran for all models but only tests bert... def test_clean_up_tokenization_spaces(self): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") assert tokenizer.clean_up_tokenization_spaces is True @@ -3953,3 +3954,29 @@ def test_clean_up_tokenization_spaces(self): tokenizer.clean_up_tokenization_spaces = True decoded = tokenizer.decode(tokens) assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]" + + def test_split_special_tokens(self): + if not self.test_slow_tokenizer: + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + special_token = "[SPECIAL_TOKEN]" + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + if not tokenizer.is_fast: + # bloom, gptneox etc only have a fast + tokenizer.add_special_tokens({"additional_special_tokens": [special_token]}) + encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + encoded_split_special_token = tokenizer.encode( + special_token, add_special_tokens=False, split_special_tokens=True + ) + if len(encoded_split_special_token) == 1: + # if we have subword tokenization or special vocab + self.assertTrue( + encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token) + ) + else: + self.assertTrue(len(encoded_split_special_token) > 1)