diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 4b80a7f75..7c21c5b56 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -836,6 +836,18 @@ class Tokenizer: Returns: A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch + """ + pass + @property + def encode_special_tokens(self): + """ + Modifies the tokenizer in order to use or not the special tokens + during encoding. + + Args: + value (:obj:`bool`): + Whether to use the special tokens or not + """ pass @staticmethod diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 77e071314..4e792ef54 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1109,6 +1109,25 @@ impl PyTokenizer { self.tokenizer.id_to_token(id) } + /// Modifies the tokenizer in order to use or not the special tokens + /// during encoding. + /// + /// Args: + /// value (:obj:`bool`): + /// Whether to use the special tokens or not + /// + #[setter] + fn set_encode_special_tokens(&mut self, value: bool) { + self.tokenizer.set_encode_special_tokens(value); + } + /// Get the value of the `encode_special_tokens` attribute + /// + /// Returns: + /// :obj:`bool`: the tokenizer's encode_special_tokens attribute + #[getter] + fn get_encode_special_tokens(&self) -> bool { + self.tokenizer.get_encode_special_tokens() + } /// Add the given tokens to the vocabulary /// /// The given tokens are added only if they don't already exist in the vocabulary. diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index a1e41c283..2eb5ce59c 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -457,3 +457,34 @@ def test_unigram_byte_fallback(self): output = tokenizer.encode("A sentence 🤗") assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9] assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"] + + def test_encode_special_tokens(self): + tokenizer = Tokenizer.from_pretrained("t5-base") + tokenizer.add_tokens([""]) + tokenizer.add_special_tokens([""]) + output = tokenizer.encode("Hey there dearfriend!", add_special_tokens=False) + assert output.tokens == ["▁Hey", "▁there", "", "▁dear", "", "▁friend", "!"] + + tokenizer.encode_special_tokens = True + assert tokenizer.encode_special_tokens == True + + output = tokenizer.encode("Hey there dearfriend!", add_special_tokens=False) + assert output.tokens == [ + "▁Hey", + "▁there", + "<", + "end", + "_", + "of", + "_", + "text", + ">", + "▁dear", + "", + "▁friend", + "!", + ] + + tokenizer.add_tokens(["of_text>"]) + output = tokenizer.encode("Hey there dearfriend!", add_special_tokens=False) + assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "", "▁friend", "!"] diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 6870b7ec7..487fb4479 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -160,6 +160,9 @@ pub(super) struct AddedVocabulary { split_trie: MatchingSet, /// A RegexSet containing all the normalized patterns used to split on AddedTokens split_normalized_trie: MatchingSet, + + /// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them + encode_special_tokens: bool, } impl AddedVocabulary { @@ -180,6 +183,7 @@ impl AddedVocabulary { special_tokens_set: HashSet::new(), split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), + encode_special_tokens: false, } } /// Size of the additional vocabulary @@ -214,6 +218,15 @@ impl AddedVocabulary { .or_else(|| model.id_to_token(id)) } + // + pub fn set_encode_special_tokens(&mut self, value: bool) { + self.encode_special_tokens = value; + } + + pub fn get_encode_special_tokens(&self) -> bool { + self.encode_special_tokens + } + /// Check if a token is a special token pub fn is_special_token(&self, token: &str) -> bool { self.special_tokens_set.contains(token) @@ -356,6 +369,12 @@ impl AddedVocabulary { let aho_id = mat.pattern(); let id = split_re.1[aho_id]; let added_token = &self.added_tokens_map_r.get(&id).unwrap(); + + if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) + { + continue; + } + if added_token.single_word { let start_space = start == 0 || !ends_with_word(&sentence[..start]); let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]); @@ -436,6 +455,18 @@ impl AddedVocabulary { .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie))) .expect("AddedVocabulary bad split"); + // normalized = False + // "I read a book Hey" -> "I read a book", " ", "Hey" + + // normalized = True -> "▁" + // "I read a bookHey" -> "I read a bookHey" + + // Day normalized = True -> "Day" + // "I read a book monday" -> "I read a book monday" + + // [DAY] normalized = False -> "Day" + // "I read a [DAY] monday" -> "I read a " "[DAY]", "book monday" + // 320055 // 2. Then extract the normalized tokens from the normalized pieces of the string pretokenized .split(|_, mut sequence| { @@ -444,6 +475,14 @@ impl AddedVocabulary { }) .expect("AddedVocabulary bad split"); + // ["I read a book", " ", "Hey"] -> ["▁I read a book", "▁ ", "▁Hey"] + // ["▁I read a book", "▁ ", "▁Hey"] -> [.., "▁ ", "", "▁Hey"] + + // normalized = True -> "▁" + // "I read a bookHey" -> ["▁I read a book", "<","/","s",">", "Hey"] + + // "I read a " "[DAY]", "book monday" -> "i read a " "[day]", "book monday" + pretokenized } } @@ -880,4 +919,66 @@ mod tests { ] ); } + + #[test] + fn test_encode_special_tokens() { + let model = ModelMock::new(&[]); + let mut vocab = AddedVocabulary::new(); + let normalizer = Lowercase; + + vocab.add_tokens( + &[ + AddedToken::from("", true) + .lstrip(true) + .rstrip(true) + .single_word(true), + AddedToken::from("ask>", false), + AddedToken::from("", true), + ], + &model, + Some(&normalizer), + ); + vocab.set_encode_special_tokens(true); + + let result = vocab.extract_and_normalize( + Some(&normalizer), + "Hi there\t\t\u{2000} ", + ); + + assert_eq!( + simplify_output(&result), + vec![ + ("hi ", Some(vec![1])), + (" there\t", Some(vec![1])), + ("\t", Some(vec![1])), + ("\u{2000} ", Some(vec![1])), + ("", None) + ] + ); + + vocab.set_encode_special_tokens(false); + + let result = vocab.extract_and_normalize( + Some(&normalizer), + "Hi there\t\t\u{2000} ", + ); + assert_eq!( + simplify_output(&result), + vec![ + ("hi", None), + (" ", Some(vec![0])), + ("there", None), + ("\t\t", Some(vec![0])), + ("\u{2000} ", Some(vec![0])), + ("", Some(vec![2])), + (" ", Some(vec![0])), + ("", Some(vec![2])), + ("", Some(vec![2])) + ] + ); + } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 2d7e10f73..ae6a64362 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -685,6 +685,16 @@ where self.added_vocabulary.id_to_token(id, &self.model) } + /// set the added bocab's splitting scheme + pub fn set_encode_special_tokens(&mut self, value: bool) { + self.added_vocabulary.set_encode_special_tokens(value); + } + + /// Get added token value + pub fn get_encode_special_tokens(&self) -> bool { + self.added_vocabulary.get_encode_special_tokens() + } + /// Encode a single sequence fn encode_single_sequence( &self,