diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 0ad96fc8a..bc2cb0cab 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -725,6 +725,23 @@ class Tokenizer: """ pass + def assign_tokens(self, old_tokens, new_tokens): + """ + Add the given tokens to the vocabulary + + The given tokens are added only if they don't already exist in the vocabulary. + Each token then gets a new attributed id. + + Args: + tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + The list of tokens we want to add to the vocabulary. Each token can be either a + string or an instance of :class:`~tokenizers.AddedToken` for more customization. + + Returns: + :obj:`int`: The number of tokens that were created in the vocabulary + """ + pass + def decode(self, ids, skip_special_tokens=True): """ Decode the given list of ids back to a string diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 24a68c6bb..499cbd770 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1237,6 +1237,51 @@ impl PyTokenizer { Ok(self.tokenizer.add_tokens(&tokens)) } + /// Add the given tokens to the vocabulary + /// + /// The given tokens are added only if they don't already exist in the vocabulary. + /// Each token then gets a new attributed id. + /// + /// Args: + /// tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`): + /// The list of tokens we want to add to the vocabulary. Each token can be either a + /// string or an instance of :class:`~tokenizers.AddedToken` for more customization. + /// + /// Returns: + /// :obj:`int`: The number of tokens that were created in the vocabulary + #[pyo3(text_signature = "(self, old_tokens, new_tokens)")] + fn assign_tokens(&mut self, old_to_new_map: &Bound<'_, PyDict>) -> PyResult<()> { + use pyo3::exceptions::PyTypeError; + + let mut processed_old_tokens = HashMap::with_capacity(old_to_new_map.len()); + for (old, new) in old_to_new_map.iter() { + let old_token = if let Ok(content) = old.extract::<&str>() { + PyAddedToken::from(content.to_string(), Some(false)).get_token() + } else if let Ok(token) = old.extract::>() { + token.get_token() + } else { + return Err(PyTypeError::new_err( + "old_tokens must be a List[Union[str, AddedToken]]", + )); + }; + + let new_token = if let Ok(content) = new.extract::<&str>() { + let mut updated_token = old_token.clone(); + updated_token.content = content.to_string(); + updated_token + } else if let Ok(token) = new.extract::>() { + token.get_token() + } else { + return Err(PyTypeError::new_err( + "new_tokens must be a List[Union[str, AddedToken]]", + )); + }; + + processed_old_tokens.insert(old_token, new_token); + } + self.tokenizer.assign_tokens(&processed_old_tokens); + Ok(()) + } /// Add the given special tokens to the Tokenizer. /// /// If these tokens are already part of the vocabulary, it just let the Tokenizer know about diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 2118709a0..0fb2cbd7d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -562,6 +562,26 @@ def test_setting_to_none(self): tokenizer.pre_tokenizer = None assert tokenizer.pre_tokenizer == None + def test_re_assign_tokens_bpe(self): + tokenizer = Tokenizer.from_pretrained("gpt2") + tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"}) + assert tokenizer.decode([50256]) == "my_new_token" + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [50256] + assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} + + def test_re_assign_tokens_unigram(self): + tokenizer = Tokenizer.from_pretrained("t5-base") + tokenizer.assign_tokens({"": "my_new_token"}) + assert tokenizer.decode([32099]) == "my_new_token" + assert tokenizer.encode("my_new_token").tokens == ["my_new_token"] + assert tokenizer.encode("my_new_token").ids == [32099] + assert tokenizer.encode("").ids == [27, 91, 437, 1659, 5239, 91, 29] + assert tokenizer.encode("").tokens == ["<", "|", "end", "of", "text", "|", ">"] + assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()} + class TestTokenizerRepr: def test_repr(self): diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index dba5a0400..c604b11c6 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -3,8 +3,11 @@ use super::{ trainer::UnigramTrainer, trie::{Trie, TrieBuilder}, }; -use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::Cache; +use crate::{ + tokenizer::{Model, Result, Token}, + AddedVocabulary, +}; use std::collections::HashMap; use std::convert::TryInto; @@ -81,7 +84,7 @@ pub enum UnigramError { impl Default for Unigram { fn default() -> Self { let vocab = vec![("".to_string(), 0.0)]; - Self::from(vocab, Some(0), false).unwrap() + Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap() } } @@ -96,6 +99,7 @@ impl Unigram { vocab: Vec<(String, f64)>, unk_id: Option, byte_fallback: bool, + added_tokens: &AddedVocabulary, ) -> Result { let n = vocab.len(); let mut token_to_ids: TokenMap = HashMap::new(); @@ -114,11 +118,13 @@ impl Unigram { let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { - token_to_ids.insert(token.to_string(), id as u32); - let bytes: Vec = token.bytes().collect(); - builder.push(&bytes); - if score < &min_score { - min_score = *score; + if !added_tokens.is_special_token(token) { + token_to_ids.insert(token.to_string(), id as u32); + let bytes: Vec = token.bytes().collect(); + builder.push(&bytes); + if score < &min_score { + min_score = *score; + } } } let trie = builder.build(); @@ -480,7 +486,7 @@ mod tests { #[test] fn test_populate_nodes_unk() { let pieces = vec![("".to_string(), 0.0)]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -505,7 +511,7 @@ mod tests { ("ab".to_string(), 0.3), ("bc".to_string(), 0.4), ]; - let model = Unigram::from(pieces, Some(0), false).unwrap(); + let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -542,7 +548,8 @@ mod tests { ("abcd".to_string(), 10.0), ]; - let model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); let result = model.encode("abcd").unwrap(); assert_eq!(result, vec!["abcd"]); } @@ -564,7 +571,8 @@ mod tests { ("qr".to_string(), -0.5), ]; - let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap(); + let mut model = + Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap(); for is_optimized in &[true, false] { model.set_optimized(*is_optimized); @@ -611,7 +619,8 @@ mod tests { ("<0xC3>".to_string(), -0.01), ("<0xA9>".to_string(), -0.03), ]; - let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap(); + let unigram = + Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap(); let tokens: Vec = unigram.tokenize("é").unwrap(); assert_eq!( tokens, diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index a6e56b735..f0ff30694 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -1,3 +1,5 @@ +use crate::AddedVocabulary; + use super::model::Unigram; use serde::{ de::{Error, MapAccess, Visitor}, @@ -69,8 +71,12 @@ impl<'de> Visitor<'de> for UnigramVisitor { } } match (vocab, unk_id, byte_fallback) { - (Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback) - .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?), + (Some(vocab), unk_id, byte_fallback) => { + Ok( + Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default()) + .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?, + ) + } (None, _, _) => Err(Error::custom("Missing vocab")), } } @@ -78,12 +84,14 @@ impl<'de> Visitor<'de> for UnigramVisitor { #[cfg(test)] mod test { + use crate::AddedVocabulary; + use super::*; #[test] fn test_serialization() { let vocab = vec![("".to_string(), 0.0), ("a".to_string(), -0.5)]; - let model = Unigram::from(vocab, Some(0), false).unwrap(); + let model = Unigram::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); @@ -94,7 +102,7 @@ mod test { #[test] fn test_serialization_unk_id_not_zero() { let vocab = vec![("a".to_string(), -0.5), ("".to_string(), 0.0)]; - let model = Unigram::from(vocab, Some(1), false).unwrap(); + let model = Unigram::from(vocab, Some(1), false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); @@ -105,7 +113,7 @@ mod test { #[test] fn test_serialization_no_unk_id() { let vocab = vec![("a".to_string(), -0.5)]; - let model = Unigram::from(vocab, None, false).unwrap(); + let model = Unigram::from(vocab, None, false, &AddedVocabulary::default()).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..b3e816a59 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -2,6 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram}; use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; +use crate::AddedVocabulary; use log::debug; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; @@ -182,6 +183,7 @@ impl UnigramTrainer { special_tokens.into_iter().chain(pieces).collect(), unk_id, model.byte_fallback(), + &AddedVocabulary::default(), ) } @@ -567,7 +569,8 @@ impl UnigramTrainer { if required_chars.len() as u32 > self.vocab_size { return Err(Box::new(UnigramTrainerError::VocabularyTooSmall)); } - let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?; + let mut new_model = + Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; loop { // Sub-EM iteration. for _iter in 0..self.n_sub_iterations { @@ -576,7 +579,8 @@ impl UnigramTrainer { // Executes M step. pieces = self.run_m_step(&pieces, &expected); - new_model = Unigram::from(pieces.clone(), Some(0), false)?; + new_model = + Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; // Useful comment for checking compatibility with spm debug!( @@ -600,7 +604,7 @@ impl UnigramTrainer { // Prunes pieces. pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences); - new_model = Unigram::from(pieces.clone(), Some(0), false)?; + new_model = Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?; } self.finalize_progress(&progress, expected_updates); diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index a0c2f4542..f91a7b82f 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -4,8 +4,10 @@ use super::{ use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; use regex::Regex; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; -use std::collections::{HashMap, HashSet}; - +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; /// Represent a token added by the user on top of the existing Model vocabulary. /// AddedToken can be configured to specify the behavior they should have in various situations /// like: @@ -142,19 +144,10 @@ fn space_rightmost_at_start(sentence: &str) -> usize { pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. - added_tokens_map: HashMap, + added_tokens_map: Arc>>, /// Contains the mapping from ID to AddedToken for all the added tokens, both special /// and classic. - added_tokens_map_r: HashMap, - - /// Contains only the classic AddedToken, in the specific order the user gave them. - added_tokens: Vec, - /// Contains only the special AddedToken, in the specific order the user gave them. - special_tokens: Vec, - - /// A Set, containing all the special token for easy access while decoding. This let's - /// us remove them easily with an O(1) complexity. - special_tokens_set: HashSet, + added_tokens_map_r: Arc>>, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens split_trie: MatchingSet, @@ -176,11 +169,8 @@ impl AddedVocabulary { .build::<_, &&[u8]>([]) .expect("The normalized trie should build correctly"); Self { - added_tokens_map: HashMap::new(), - added_tokens_map_r: HashMap::new(), - added_tokens: vec![], - special_tokens: vec![], - special_tokens_set: HashSet::new(), + added_tokens_map: Arc::new(Mutex::new(HashMap::new())), + added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())), split_trie: (trie, vec![]), split_normalized_trie: (normalized_trie, vec![]), encode_special_tokens: false, @@ -189,46 +179,29 @@ impl AddedVocabulary { /// Size of the additional vocabulary #[allow(dead_code)] // Suppress the "method is never used" warning pub fn len(&self) -> usize { - self.added_tokens_map.len() + self.added_tokens_map.lock().unwrap().len() } /// Whether or not this vocabulary is empty pub fn is_empty(&self) -> bool { - self.added_tokens_map.is_empty() + self.added_tokens_map.lock().unwrap().is_empty() } /// Get the additional vocabulary - pub fn get_vocab(&self) -> &HashMap { - &self.added_tokens_map + pub fn get_vocab(&self) -> HashMap { + self.added_tokens_map.lock().unwrap().clone() } /// Get the additional vocabulary with the AddedTokens - pub fn get_added_tokens_decoder(&self) -> &HashMap { - &self.added_tokens_map_r + pub fn get_added_tokens_decoder(&self) -> HashMap { + self.added_tokens_map_r.lock().unwrap().clone() } /// Get the id matching one of our token if it exists pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option { - self.added_tokens_map - .get(token) - .copied() - .or_else(|| model.token_to_id(token)) - } - - /// Get the token matching the given id if it exists - #[deprecated( - since = "0.19.0", - note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead" - )] - pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option { - self.added_tokens_map_r - .get(&id) - .map(|t| t.content.clone()) - .or_else(|| model.id_to_token(id)) - } - - pub fn simple_id_to_token(&self, id: u32) -> Option { - self.added_tokens_map_r.get(&id).map(|t| t.content.clone()) + let added_tokens_map = self.added_tokens_map.lock().unwrap(); + let id = added_tokens_map.get(token).copied(); + id.or_else(|| model.token_to_id(token)) } // @@ -242,7 +215,14 @@ impl AddedVocabulary { /// Check if a token is a special token pub fn is_special_token(&self, token: &str) -> bool { - self.special_tokens_set.contains(token) + let hash_map = &self.added_tokens_map_r.lock().unwrap(); + let revert_hash_map = &self.added_tokens_map.lock().unwrap(); + if let Some(id) = revert_hash_map.get(token) { + if let Some(token) = hash_map.get(id) { + return token.special; + } + } + false } /// Add some special tokens to the vocabulary @@ -263,20 +243,17 @@ impl AddedVocabulary { normalizer: Option<&N>, ) -> usize { // Handle special tokens (if any) - for token in tokens { - if token.special - && !token.content.is_empty() - && !self.special_tokens_set.contains(&token.content) - { - self.special_tokens.push(token.to_owned()); - self.special_tokens_set.insert(token.content.clone()); - } - } // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. let mut ignored = 0; for token in tokens { - if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token) + if token.content.is_empty() + || self + .added_tokens_map_r + .lock() + .unwrap() + .values() + .any(|val| val == token) { ignored += 1; continue; @@ -285,33 +262,35 @@ impl AddedVocabulary { let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) { new_id } else { - self.added_tokens_map.values().cloned().max().map_or( - model.get_vocab_size() as u32, - |max| { + self.added_tokens_map + .lock() + .unwrap() + .values() + .cloned() + .max() + .map_or(model.get_vocab_size() as u32, |max| { if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 { max + 1 } else { model.get_vocab_size() as u32 } - }, - ) + }) }; // Make sure we modify the previous entry self.added_tokens_map + .lock() + .unwrap() .entry(token.content.clone()) .and_modify(|old_id| *old_id = new_id) .or_insert_with(|| new_id); // Update the current revert operation self.added_tokens_map_r + .lock() + .unwrap() .entry(new_id) .and_modify(|t| *t = token.clone()) .or_insert_with(|| token.clone()); // Make sure to remove previous entry (if the token gets a new id) - - // Finally add the token to the classic set if special - if !self.special_tokens_set.contains(&token.content) { - self.added_tokens.push(token.clone()); - } } self.refresh_added_tokens(model, normalizer); @@ -320,24 +299,53 @@ impl AddedVocabulary { tokens.len() - ignored } + /// Get the token matching the given id if it exists + pub fn simple_id_to_token(&self, id: u32) -> Option { + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); + let token = added_tokens_map_r.get(&id).map(|t| t.content.clone()); + token + } + + /// Re assigns a token's content to a new content. This helps users how want to + /// use reserved tokens (which usually are in the original vocab, and in the added vocab) + pub fn assign_tokens( + &mut self, + token_map: &HashMap, // HashMap of old token to new token + model: &impl Model, + normalizer: Option<&N>, + ) { + for (old_token, new_token) in token_map.iter() { + if let Some(id) = self.token_to_id(old_token.content.as_str(), model) { + self.added_tokens_map_r + .lock() + .unwrap() + .entry(id) + .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token + self.added_tokens_map + .lock() + .unwrap() + .remove(old_token.content.as_str()); + self.refresh_added_tokens(model, normalizer); + } else { + error!( + "Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", + old_token.content.clone() + ) + } + } + } /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary. /// /// We keep two different RegexSet, one that will take care of matching against the /// non-normalized string, and one matching against the normalized one. - fn refresh_added_tokens(&mut self, model: &impl Model, normalizer: Option<&N>) { + fn refresh_added_tokens(&mut self, _model: &impl Model, normalizer: Option<&N>) { type TupleTokenId<'a> = (&'a AddedToken, u32); - let (normalized, non_normalized): (Vec, Vec) = self - .special_tokens - .iter() - .chain(self.added_tokens.iter()) - .map(|token| { - ( - token, - self.token_to_id(&token.content, model) - .expect("Missing additional token"), - ) - }) - .partition(|(token, _)| token.normalized); + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone(); + let (normalized, non_normalized): (Vec, Vec) = + added_tokens_map_r + .iter() + .map(|(id, token)| (token, *id)) + .partition(|(token, _)| token.normalized); let (tokens, ids): (Vec<&AddedToken>, Vec) = non_normalized.into_iter().unzip(); let trie = AhoCorasickBuilder::new() @@ -381,10 +389,9 @@ impl AddedVocabulary { let mut stop = mat.end(); let aho_id = mat.pattern(); let id = split_re.1[aho_id]; - let added_token = &self.added_tokens_map_r.get(&id).unwrap(); - - if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content) - { + let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap(); + let added_token = added_tokens_map_r.get(&id).unwrap(); + if self.encode_special_tokens && added_token.special { continue; } @@ -522,6 +529,8 @@ impl Serialize for AddedVocabulary { { let mut added_tokens = self .added_tokens_map_r + .lock() + .unwrap() .iter() .map(|(id, token)| AddedTokenWithId { id: *id, @@ -715,15 +724,15 @@ mod tests { assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); assert_eq!( - *vocab.get_added_tokens_decoder(), + vocab.get_added_tokens_decoder(), HashMap::from([ (0, AddedToken::from("test", true)), (2, AddedToken::from("added_token_1", true)), (3, AddedToken::from("added_token_2", true)), ]) ); - assert!(vocab.added_tokens_map.contains_key("test")); - assert!(vocab.added_tokens_map_r.contains_key(&0)); + assert!(vocab.added_tokens_map.lock().unwrap().contains_key("test")); + assert!(vocab.added_tokens_map_r.lock().unwrap().contains_key(&0)); vocab.add_tokens( &[ diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 49bc539a2..c24654fc8 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -541,9 +541,7 @@ where model, post_processor: None, decoder: None, - added_vocabulary: AddedVocabulary::new(), - truncation: None, padding: None, } @@ -667,7 +665,7 @@ where if !added_vocab.is_empty() { final_vocab.reserve(added_vocab.len()); for (token, id) in added_vocab { - final_vocab.insert(token.clone(), *id); + final_vocab.insert(token.clone(), id); } } } @@ -960,6 +958,15 @@ where self.added_vocabulary .add_tokens(tokens, &self.model, self.normalizer.as_ref()) } + + /// Assign a new token + pub fn assign_tokens(&mut self, old_to_new_map: &HashMap) { + self.added_vocabulary.assign_tokens( + old_to_new_map, // HashMap of old token to new token + &self.model, + self.normalizer.as_ref(), + ) + } } impl TokenizerImpl