From fc0f0656f07870a56c3352aa1a3b19ac14be26e7 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 08:08:14 +0200
Subject: [PATCH 01/16] allow to assign a new token

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 21 ++++++++++++++++++++
 tokenizers/src/tokenizer/mod.rs              | 10 ++++++++++
 2 files changed, 31 insertions(+)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index a0c2f4542..32fc76a26 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -255,6 +255,27 @@ impl AddedVocabulary {
         self.add_tokens(tokens, model, normalizer)
     }
 
+    /// Re assigns a token's content to a new content. This helps users how want to
+    /// use reserved tokens (which usually are in the original vocab, and in the added vocab)
+    pub fn assign_token<N: Normalizer>(
+        &mut self,
+        old_token_content: &[AddedToken],
+        new_token_content: &[AddedToken],
+        model: &impl Model,
+        normalizer: Option<&N>,
+    ) {
+        for (old, new) in old_token_content.iter().zip(new_token_content.iter()) {
+            if let Some(id) = self.token_to_id(old.content.as_str(), model) {
+                self.added_tokens_map_r
+                    .entry(id)
+                    .and_modify(|t| *t = new.clone());
+                self.refresh_added_tokens(model, normalizer);
+            } else {
+                error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone())
+            }
+        }
+    }
+
     /// Add some tokens to the vocabulary
     pub fn add_tokens<N: Normalizer>(
         &mut self,
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index 49bc539a2..1d2ee3995 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -960,6 +960,16 @@ where
         self.added_vocabulary
             .add_tokens(tokens, &self.model, self.normalizer.as_ref())
     }
+
+    /// Assign a new token
+    pub fn assign_token(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) {
+        self.added_vocabulary.assign_token(
+            old_tokens,
+            new_tokens,
+            &self.model,
+            self.normalizer.as_ref(),
+        )
+    }
 }
 
 impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>

From 97e8818ecf43b418cbd97d4ac08762a154dfe2d1 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 08:52:04 +0200
Subject: [PATCH 02/16] add python bindongs as well

---
 .../python/py_src/tokenizers/__init__.pyi     | 17 ++++++
 bindings/python/src/tokenizer.rs              | 58 +++++++++++++++++++
 tokenizers/src/tokenizer/added_vocabulary.rs  |  2 +-
 tokenizers/src/tokenizer/mod.rs               |  4 +-
 4 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi
index 0ad96fc8a..a480923ef 100644
--- a/bindings/python/py_src/tokenizers/__init__.pyi
+++ b/bindings/python/py_src/tokenizers/__init__.pyi
@@ -725,6 +725,23 @@ class Tokenizer:
         """
         pass
 
+    def assing_tokens(self, old_tokens, new_tokens):
+        """
+        Add the given tokens to the vocabulary
+
+        The given tokens are added only if they don't already exist in the vocabulary.
+        Each token then gets a new attributed id.
+
+        Args:
+            tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`):
+                The list of tokens we want to add to the vocabulary. Each token can be either a
+                string or an instance of :class:`~tokenizers.AddedToken` for more customization.
+
+        Returns:
+            :obj:`int`: The number of tokens that were created in the vocabulary
+        """
+        pass
+
     def decode(self, ids, skip_special_tokens=True):
         """
         Decode the given list of ids back to a string
diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs
index 24a68c6bb..34852c129 100644
--- a/bindings/python/src/tokenizer.rs
+++ b/bindings/python/src/tokenizer.rs
@@ -1237,6 +1237,64 @@ impl PyTokenizer {
         Ok(self.tokenizer.add_tokens(&tokens))
     }
 
+    /// Add the given tokens to the vocabulary
+    ///
+    /// The given tokens are added only if they don't already exist in the vocabulary.
+    /// Each token then gets a new attributed id.
+    ///
+    /// Args:
+    ///     tokens (A :obj:`List` of :class:`~tokenizers.AddedToken` or :obj:`str`):
+    ///         The list of tokens we want to add to the vocabulary. Each token can be either a
+    ///         string or an instance of :class:`~tokenizers.AddedToken` for more customization.
+    ///
+    /// Returns:
+    ///     :obj:`int`: The number of tokens that were created in the vocabulary
+    #[pyo3(text_signature = "(self, old_tokens, new_tokens)")]
+    fn assing_tokens(
+        &mut self,
+        old_tokens: &Bound<'_, PyList>,
+        new_tokens: &Bound<'_, PyList>,
+    ) -> PyResult<()> {
+        use pyo3::exceptions::PyTypeError;
+        if old_tokens.len() != new_tokens.len() {
+            return Err(PyTypeError::new_err(
+                "old_tokens and new_tokens must have the same length",
+            ));
+        }
+
+        let mut processed_old_tokens = Vec::with_capacity(old_tokens.len());
+        let mut processed_new_tokens = Vec::with_capacity(new_tokens.len());
+        for (old, new) in old_tokens.iter().zip(new_tokens.iter()) {
+            let old_token = if let Ok(content) = old.extract::<&str>() {
+                PyAddedToken::from(content.to_string(), Some(false)).get_token()
+            } else if let Ok(token) = old.extract::<PyRefMut<PyAddedToken>>() {
+                token.get_token()
+            } else {
+                return Err(PyTypeError::new_err(
+                    "old_tokens must be a List[Union[str, AddedToken]]",
+                ));
+            };
+
+            let new_token = if let Ok(content) = new.extract::<&str>() {
+                let mut updated_token = old_token.clone();
+                updated_token.content = content.to_string();
+                updated_token
+            } else if let Ok(token) = new.extract::<PyRefMut<PyAddedToken>>() {
+                token.get_token()
+            } else {
+                return Err(PyTypeError::new_err(
+                    "new_tokens must be a List[Union[str, AddedToken]]",
+                ));
+            };
+
+            processed_old_tokens.push(old_token);
+            processed_new_tokens.push(new_token);
+        }
+
+        Ok(self
+            .tokenizer
+            .assign_tokens(&processed_old_tokens, &processed_new_tokens))
+    }
     /// Add the given special tokens to the Tokenizer.
     ///
     /// If these tokens are already part of the vocabulary, it just let the Tokenizer know about
diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 32fc76a26..379a3075b 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -257,7 +257,7 @@ impl AddedVocabulary {
 
     /// Re assigns a token's content to a new content. This helps users how want to
     /// use reserved tokens (which usually are in the original vocab, and in the added vocab)
-    pub fn assign_token<N: Normalizer>(
+    pub fn assign_tokens<N: Normalizer>(
         &mut self,
         old_token_content: &[AddedToken],
         new_token_content: &[AddedToken],
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index 1d2ee3995..c3cc6ede9 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -962,8 +962,8 @@ where
     }
 
     /// Assign a new token
-    pub fn assign_token(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) {
-        self.added_vocabulary.assign_token(
+    pub fn assign_tokens(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) {
+        self.added_vocabulary.assign_tokens(
             old_tokens,
             new_tokens,
             &self.model,

From ddab9013382767ec70975e1ef525482c734e90cb Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 09:49:42 +0200
Subject: [PATCH 03/16] current update

---
 bindings/python/src/tokenizer.rs             |   2 +-
 tokenizers/src/tokenizer/added_vocabulary.rs | 135 +++++--------------
 2 files changed, 36 insertions(+), 101 deletions(-)

diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs
index 34852c129..a2b191673 100644
--- a/bindings/python/src/tokenizer.rs
+++ b/bindings/python/src/tokenizer.rs
@@ -1250,7 +1250,7 @@ impl PyTokenizer {
     /// Returns:
     ///     :obj:`int`: The number of tokens that were created in the vocabulary
     #[pyo3(text_signature = "(self, old_tokens, new_tokens)")]
-    fn assing_tokens(
+    fn assign_tokens(
         &mut self,
         old_tokens: &Bound<'_, PyList>,
         new_tokens: &Bound<'_, PyList>,
diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 379a3075b..dd8f61952 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -4,8 +4,10 @@ use super::{
 use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
 use regex::Regex;
 use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
-use std::collections::{HashMap, HashSet};
-
+use std::{
+    collections::{HashMap, HashSet},
+    sync::{Arc, Mutex},
+};
 /// Represent a token added by the user on top of the existing Model vocabulary.
 /// AddedToken can be configured to specify the behavior they should have in various situations
 /// like:
@@ -142,19 +144,12 @@ fn space_rightmost_at_start(sentence: &str) -> usize {
 pub struct AddedVocabulary {
     /// Contains the mapping from String (token content) to ID. This map contains both special
     /// tokens and classic added tokens that were added to the this vocabulary.
-    added_tokens_map: HashMap<String, u32>,
+    added_tokens_map: Arc<Mutex<HashMap<String, u32>>>,
     /// Contains the mapping from ID to AddedToken for all the added tokens, both special
     /// and classic.
-    added_tokens_map_r: HashMap<u32, AddedToken>,
-
+    added_tokens_map_r: Arc<Mutex<HashMap<u32, AddedToken>>>,
     /// Contains only the classic AddedToken, in the specific order the user gave them.
     added_tokens: Vec<AddedToken>,
-    /// Contains only the special AddedToken, in the specific order the user gave them.
-    special_tokens: Vec<AddedToken>,
-
-    /// A Set, containing all the special token for easy access while decoding. This let's
-    /// us remove them easily with an O(1) complexity.
-    special_tokens_set: HashSet<String>,
 
     /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
     split_trie: MatchingSet,
@@ -176,11 +171,9 @@ impl AddedVocabulary {
             .build::<_, &&[u8]>([])
             .expect("The normalized trie should build correctly");
         Self {
-            added_tokens_map: HashMap::new(),
-            added_tokens_map_r: HashMap::new(),
+            added_tokens_map: Arc::new(Mutex::new(HashMap::new())),
+            added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())),
             added_tokens: vec![],
-            special_tokens: vec![],
-            special_tokens_set: HashSet::new(),
             split_trie: (trie, vec![]),
             split_normalized_trie: (normalized_trie, vec![]),
             encode_special_tokens: false,
@@ -189,30 +182,29 @@ impl AddedVocabulary {
     /// Size of the additional vocabulary
     #[allow(dead_code)] // Suppress the "method is never used" warning
     pub fn len(&self) -> usize {
-        self.added_tokens_map.len()
+        self.added_tokens_map.lock().unwrap().len()
     }
 
     /// Whether or not this vocabulary is empty
     pub fn is_empty(&self) -> bool {
-        self.added_tokens_map.is_empty()
+        self.added_tokens_map.lock().unwrap().is_empty()
     }
 
     /// Get the additional vocabulary
-    pub fn get_vocab(&self) -> &HashMap<String, u32> {
-        &self.added_tokens_map
+    pub fn get_vocab(&self) -> HashMap<String, u32> {
+        self.added_tokens_map.lock().unwrap().clone()
     }
 
     /// Get the additional vocabulary with the AddedTokens
-    pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
-        &self.added_tokens_map_r
+    pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
+        self.added_tokens_map_r.lock().unwrap().clone()
     }
 
     /// Get the id matching one of our token if it exists
     pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
-        self.added_tokens_map
-            .get(token)
-            .copied()
-            .or_else(|| model.token_to_id(token))
+        let added_tokens_map = self.added_tokens_map.lock().unwrap();
+        let id = added_tokens_map.get(token).copied();
+        id.or_else(|| model.token_to_id(token))
     }
 
     /// Get the token matching the given id if it exists
@@ -220,15 +212,6 @@ impl AddedVocabulary {
         since = "0.19.0",
         note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
     )]
-    pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
-        self.added_tokens_map_r
-            .get(&id)
-            .map(|t| t.content.clone())
-            .or_else(|| model.id_to_token(id))
-    }
-
-    pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
-        self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
     }
 
     //
@@ -253,6 +236,11 @@ impl AddedVocabulary {
         normalizer: Option<&N>,
     ) -> usize {
         self.add_tokens(tokens, model, normalizer)
+    /// Get the token matching the given id if it exists
+    pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
+        let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
+        let token = added_tokens_map_r.get(id).map(|t| t.content.clone());
+        token
     }
 
     /// Re assigns a token's content to a new content. This helps users how want to
@@ -276,69 +264,16 @@ impl AddedVocabulary {
         }
     }
 
-    /// Add some tokens to the vocabulary
-    pub fn add_tokens<N: Normalizer>(
-        &mut self,
-        tokens: &[AddedToken],
-        model: &impl Model,
-        normalizer: Option<&N>,
-    ) -> usize {
-        // Handle special tokens (if any)
-        for token in tokens {
-            if token.special
-                && !token.content.is_empty()
-                && !self.special_tokens_set.contains(&token.content)
-            {
-                self.special_tokens.push(token.to_owned());
-                self.special_tokens_set.insert(token.content.clone());
-            }
-        }
-
-        // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
-        let mut ignored = 0;
-        for token in tokens {
-            if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
-            {
-                ignored += 1;
-                continue;
-            }
-            // If a token is already part of the vocabulary, we mark it as added
-            let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
-                new_id
-            } else {
-                self.added_tokens_map.values().cloned().max().map_or(
-                    model.get_vocab_size() as u32,
-                    |max| {
-                        if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
-                            max + 1
-                        } else {
-                            model.get_vocab_size() as u32
-                        }
-                    },
-                )
-            };
-            // Make sure we modify the previous entry
-            self.added_tokens_map
-                .entry(token.content.clone())
-                .and_modify(|old_id| *old_id = new_id)
-                .or_insert_with(|| new_id);
-            // Update the current revert operation
-            self.added_tokens_map_r
-                .entry(new_id)
-                .and_modify(|t| *t = token.clone())
-                .or_insert_with(|| token.clone());
-            // Make sure to remove previous entry (if the token gets a new id)
-
-            // Finally add the token to the classic set if special
-            if !self.special_tokens_set.contains(&token.content) {
-                self.added_tokens.push(token.clone());
-            }
-        }
+    /// Add a token to the added vocabulary
+    pub fn add_token(&mut self, token: &AddedToken) {
+        let mut added_tokens_map = self.added_tokens_map.lock().unwrap();
+        let mut added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
 
-        self.refresh_added_tokens(model, normalizer);
+        let id = added_tokens_map.len() as u32;
+        added_tokens_map.insert(token.content.clone(), id);
+        added_tokens_map_r.insert(id, token.clone());
 
-        // Return the number of added tokens
-        tokens.len() - ignored
+        self.refresh_added_tokens();
     }
 
     /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
@@ -348,9 +283,8 @@ impl AddedVocabulary {
     fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
         type TupleTokenId<'a> = (&'a AddedToken, u32);
         let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
-            .special_tokens
+            .added_tokens
             .iter()
-            .chain(self.added_tokens.iter())
             .map(|token| {
                 (
                     token,
@@ -402,10 +336,9 @@ impl AddedVocabulary {
             let mut stop = mat.end();
             let aho_id = mat.pattern();
             let id = split_re.1[aho_id];
-            let added_token = &self.added_tokens_map_r.get(&id).unwrap();
+            let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap();
 
-            if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
-            {
+            if self.encode_special_tokens && added_token.special {
                 continue;
             }
 
@@ -543,6 +476,8 @@ impl Serialize for AddedVocabulary {
     {
         let mut added_tokens = self
             .added_tokens_map_r
+            .lock()
+            .unwrap()
             .iter()
             .map(|(id, token)| AddedTokenWithId {
                 id: *id,

From b359bde47a49b95b3262c1e8e9a646aec2c77824 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 09:58:06 +0200
Subject: [PATCH 04/16] nit

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 29 ++++++--------------
 1 file changed, 9 insertions(+), 20 deletions(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index dd8f61952..4d52ce77c 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -207,13 +207,6 @@ impl AddedVocabulary {
         id.or_else(|| model.token_to_id(token))
     }
 
-    /// Get the token matching the given id if it exists
-    #[deprecated(
-        since = "0.19.0",
-        note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
-    )]
-    }
-
     //
     pub fn set_encode_special_tokens(&mut self, value: bool) {
         self.encode_special_tokens = value;
@@ -225,7 +218,12 @@ impl AddedVocabulary {
 
     /// Check if a token is a special token
     pub fn is_special_token(&self, token: &str) -> bool {
-        self.special_tokens_set.contains(token)
+        self.added_tokens_map_r
+            .lock()
+            .unwrap()
+            .get(self.added_tokens_map.lock().unwrap().get(token).unwrap())
+            .unwrap()
+            .special
     }
 
     /// Add some special tokens to the vocabulary
@@ -236,6 +234,7 @@ impl AddedVocabulary {
         normalizer: Option<&N>,
     ) -> usize {
         self.add_tokens(tokens, model, normalizer)
+    }
     /// Get the token matching the given id if it exists
     pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
         let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
@@ -255,6 +254,8 @@ impl AddedVocabulary {
         for (old, new) in old_token_content.iter().zip(new_token_content.iter()) {
             if let Some(id) = self.token_to_id(old.content.as_str(), model) {
                 self.added_tokens_map_r
+                    .lock()
+                    .unwrap()
                     .entry(id)
                     .and_modify(|t| *t = new.clone());
                 self.refresh_added_tokens(model, normalizer);
@@ -264,18 +265,6 @@ impl AddedVocabulary {
         }
     }
 
-    /// Add a token to the added vocabulary
-    pub fn add_token(&mut self, token: &AddedToken) {
-        let mut added_tokens_map = self.added_tokens_map.lock().unwrap();
-        let mut added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
-
-        let id = added_tokens_map.len() as u32;
-        added_tokens_map.insert(token.content.clone(), id);
-        added_tokens_map_r.insert(id, token.clone());
-
-        self.refresh_added_tokens();
-    }
-
     /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
     ///
     /// We keep two different RegexSet, one that will take care of matching against the

From 4794ed516fb19fa36cc7533179b1e7c2e043c671 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 10:07:27 +0200
Subject: [PATCH 05/16] fix

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 69 +++++++++++++++++++-
 1 file changed, 67 insertions(+), 2 deletions(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 4d52ce77c..9274b69b6 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -235,6 +235,71 @@ impl AddedVocabulary {
     ) -> usize {
         self.add_tokens(tokens, model, normalizer)
     }
+
+    /// Add some tokens to the vocabulary
+    pub fn add_tokens<N: Normalizer>(
+        &mut self,
+        tokens: &[AddedToken],
+        model: &impl Model,
+        normalizer: Option<&N>,
+    ) -> usize {
+        // Handle special tokens (if any)
+
+        // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
+        let mut ignored = 0;
+        for token in tokens {
+            if token.content.is_empty()
+                || self
+                    .added_tokens_map_r
+                    .lock()
+                    .unwrap()
+                    .values()
+                    .any(|val| val == token)
+            {
+                ignored += 1;
+                continue;
+            }
+            // If a token is already part of the vocabulary, we mark it as added
+            let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
+                new_id
+            } else {
+                self.added_tokens_map
+                    .lock()
+                    .unwrap()
+                    .values()
+                    .cloned()
+                    .max()
+                    .map_or(model.get_vocab_size() as u32, |max| {
+                        if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
+                            max + 1
+                        } else {
+                            model.get_vocab_size() as u32
+                        }
+                    })
+            };
+            // Make sure we modify the previous entry
+            self.added_tokens_map
+                .lock()
+                .unwrap()
+                .entry(token.content.clone())
+                .and_modify(|old_id| *old_id = new_id)
+                .or_insert_with(|| new_id);
+            // Update the current revert operation
+            self.added_tokens_map_r
+                .lock()
+                .unwrap()
+                .entry(new_id)
+                .and_modify(|t| *t = token.clone())
+                .or_insert_with(|| token.clone());
+            // Make sure to remove previous entry (if the token gets a new id)
+        }
+
+        self.refresh_added_tokens(model, normalizer);
+
+        // Return the number of added tokens
+        tokens.len() - ignored
+    }
+
     /// Get the token matching the given id if it exists
     pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
         let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
@@ -325,8 +390,8 @@ impl AddedVocabulary {
             let mut stop = mat.end();
             let aho_id = mat.pattern();
             let id = split_re.1[aho_id];
-            let added_token = self.added_tokens_map_r.lock().unwrap().get(&id).unwrap();
-
+            let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
+            let added_token = added_tokens_map_r.get(&id).unwrap();
             if self.encode_special_tokens && added_token.special {
                 continue;
             }

From 4190db7dddfd16133780accb503171fb8c6e7447 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 10:12:29 +0200
Subject: [PATCH 06/16] pass compilation

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 6 +++---
 tokenizers/src/tokenizer/mod.rs              | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 9274b69b6..144f2c44e 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -5,7 +5,7 @@ use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
 use regex::Regex;
 use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
 use std::{
-    collections::{HashMap, HashSet},
+    collections::HashMap,
     sync::{Arc, Mutex},
 };
 /// Represent a token added by the user on top of the existing Model vocabulary.
@@ -301,9 +301,9 @@ impl AddedVocabulary {
     }
 
     /// Get the token matching the given id if it exists
-    pub fn simple_id_to_token(&self, id: &u32) -> Option<String> {
+    pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
         let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap();
-        let token = added_tokens_map_r.get(id).map(|t| t.content.clone());
+        let token = added_tokens_map_r.get(&id).map(|t| t.content.clone());
         token
     }
 
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index c3cc6ede9..aee256a42 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -667,7 +667,7 @@ where
             if !added_vocab.is_empty() {
                 final_vocab.reserve(added_vocab.len());
                 for (token, id) in added_vocab {
-                    final_vocab.insert(token.clone(), *id);
+                    final_vocab.insert(token.clone(), id);
                 }
             }
         }

From 2d4b3735e44ac242956adc73bd5ed0c69338f407 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 10:38:40 +0200
Subject: [PATCH 07/16] fix everything

---
 bindings/python/src/tokenizer.rs             |  1 -
 tokenizers/src/tokenizer/added_vocabulary.rs | 20 ++++++++------------
 tokenizers/src/tokenizer/mod.rs              |  2 --
 3 files changed, 8 insertions(+), 15 deletions(-)

diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs
index a2b191673..9b0b82dcf 100644
--- a/bindings/python/src/tokenizer.rs
+++ b/bindings/python/src/tokenizer.rs
@@ -1290,7 +1290,6 @@ impl PyTokenizer {
             processed_old_tokens.push(old_token);
             processed_new_tokens.push(new_token);
         }
-
         Ok(self
             .tokenizer
             .assign_tokens(&processed_old_tokens, &processed_new_tokens))
diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 144f2c44e..84f4927a7 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -322,7 +322,7 @@ impl AddedVocabulary {
                     .lock()
                     .unwrap()
                     .entry(id)
-                    .and_modify(|t| *t = new.clone());
+                    .and_modify(|t| t.content = new.content.clone());
                 self.refresh_added_tokens(model, normalizer);
             } else {
                 error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone())
@@ -336,17 +336,12 @@ impl AddedVocabulary {
     /// non-normalized string, and one matching against the normalized one.
     fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
         type TupleTokenId<'a> = (&'a AddedToken, u32);
-        let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
-            .added_tokens
-            .iter()
-            .map(|token| {
-                (
-                    token,
-                    self.token_to_id(&token.content, model)
-                        .expect("Missing additional token"),
-                )
-            })
-            .partition(|(token, _)| token.normalized);
+        let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone();
+        let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) =
+            added_tokens_map_r
+                .iter()
+                .map(|(id, token)| (token, *id))
+                .partition(|(token, _)| token.normalized);
 
         let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
         let trie = AhoCorasickBuilder::new()
@@ -363,6 +358,7 @@ impl AddedVocabulary {
                 if let Some(n) = normalizer {
                     n.normalize(&mut content).unwrap();
                 }
+                println!("{:?}", token);
                 content
             })
             .collect();
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index aee256a42..c6433dc43 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -541,9 +541,7 @@ where
             model,
             post_processor: None,
             decoder: None,
-
             added_vocabulary: AddedVocabulary::new(),
-
             truncation: None,
             padding: None,
         }

From 6d48e58219cd414a99888bfbf8e28630234654ea Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 12 Jul 2024 10:47:38 +0200
Subject: [PATCH 08/16] remove print

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 84f4927a7..d3ca1a484 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -358,7 +358,6 @@ impl AddedVocabulary {
                 if let Some(n) = normalizer {
                     n.normalize(&mut content).unwrap();
                 }
-                println!("{:?}", token);
                 content
             })
             .collect();

From b5640a65cf59cf6c4ac2458dd01fc695cb0c7504 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 4 Oct 2024 14:46:42 +0200
Subject: [PATCH 09/16] simplify the logic

---
 bindings/python/src/tokenizer.rs             | 24 +++++---------------
 tokenizers/src/tokenizer/added_vocabulary.rs | 15 ++++++------
 tokenizers/src/tokenizer/mod.rs              |  5 ++--
 3 files changed, 16 insertions(+), 28 deletions(-)

diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs
index 9b0b82dcf..499cbd770 100644
--- a/bindings/python/src/tokenizer.rs
+++ b/bindings/python/src/tokenizer.rs
@@ -1250,21 +1250,11 @@ impl PyTokenizer {
     /// Returns:
     ///     :obj:`int`: The number of tokens that were created in the vocabulary
     #[pyo3(text_signature = "(self, old_tokens, new_tokens)")]
-    fn assign_tokens(
-        &mut self,
-        old_tokens: &Bound<'_, PyList>,
-        new_tokens: &Bound<'_, PyList>,
-    ) -> PyResult<()> {
+    fn assign_tokens(&mut self, old_to_new_map: &Bound<'_, PyDict>) -> PyResult<()> {
         use pyo3::exceptions::PyTypeError;
-        if old_tokens.len() != new_tokens.len() {
-            return Err(PyTypeError::new_err(
-                "old_tokens and new_tokens must have the same length",
-            ));
-        }
 
-        let mut processed_old_tokens = Vec::with_capacity(old_tokens.len());
-        let mut processed_new_tokens = Vec::with_capacity(new_tokens.len());
-        for (old, new) in old_tokens.iter().zip(new_tokens.iter()) {
+        let mut processed_old_tokens = HashMap::with_capacity(old_to_new_map.len());
+        for (old, new) in old_to_new_map.iter() {
             let old_token = if let Ok(content) = old.extract::<&str>() {
                 PyAddedToken::from(content.to_string(), Some(false)).get_token()
             } else if let Ok(token) = old.extract::<PyRefMut<PyAddedToken>>() {
@@ -1287,12 +1277,10 @@ impl PyTokenizer {
                 ));
             };
 
-            processed_old_tokens.push(old_token);
-            processed_new_tokens.push(new_token);
+            processed_old_tokens.insert(old_token, new_token);
         }
-        Ok(self
-            .tokenizer
-            .assign_tokens(&processed_old_tokens, &processed_new_tokens))
+        self.tokenizer.assign_tokens(&processed_old_tokens);
+        Ok(())
     }
     /// Add the given special tokens to the Tokenizer.
     ///
diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index d3ca1a484..6f79ba660 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -311,25 +311,26 @@ impl AddedVocabulary {
     /// use reserved tokens (which usually are in the original vocab, and in the added vocab)
     pub fn assign_tokens<N: Normalizer>(
         &mut self,
-        old_token_content: &[AddedToken],
-        new_token_content: &[AddedToken],
+        token_map: &HashMap<AddedToken, AddedToken>, // HashMap of old token to new token
         model: &impl Model,
         normalizer: Option<&N>,
     ) {
-        for (old, new) in old_token_content.iter().zip(new_token_content.iter()) {
-            if let Some(id) = self.token_to_id(old.content.as_str(), model) {
+        for (old_token, new_token) in token_map.iter() {
+            if let Some(id) = self.token_to_id(old_token.content.as_str(), model) {
                 self.added_tokens_map_r
                     .lock()
                     .unwrap()
                     .entry(id)
-                    .and_modify(|t| t.content = new.content.clone());
+                    .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token
                 self.refresh_added_tokens(model, normalizer);
             } else {
-                error!("Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab", old.content.clone())
+                error!(
+                "Error: you tried to re-assign a token that does not exist in the added vocab. Make sure {:?} is first added to the vocab",
+                old_token.content.clone()
+            )
             }
         }
     }
-
     /// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
     ///
     /// We keep two different RegexSet, one that will take care of matching against the
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index c6433dc43..c24654fc8 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -960,10 +960,9 @@ where
     }
 
     /// Assign a new token
-    pub fn assign_tokens(&mut self, old_tokens: &[AddedToken], new_tokens: &[AddedToken]) {
+    pub fn assign_tokens(&mut self, old_to_new_map: &HashMap<AddedToken, AddedToken>) {
         self.added_vocabulary.assign_tokens(
-            old_tokens,
-            new_tokens,
+            old_to_new_map, // HashMap of old token to new token
             &self.model,
             self.normalizer.as_ref(),
         )

From ed34ffd3342dd1c6b1226948297529dc2d6d2a8c Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 4 Oct 2024 15:00:35 +0200
Subject: [PATCH 10/16] add a small test

---
 bindings/python/py_src/tokenizers/__init__.pyi   | 2 +-
 bindings/python/tests/bindings/test_tokenizer.py | 7 +++++++
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi
index a480923ef..bc2cb0cab 100644
--- a/bindings/python/py_src/tokenizers/__init__.pyi
+++ b/bindings/python/py_src/tokenizers/__init__.pyi
@@ -725,7 +725,7 @@ class Tokenizer:
         """
         pass
 
-    def assing_tokens(self, old_tokens, new_tokens):
+    def assign_tokens(self, old_tokens, new_tokens):
         """
         Add the given tokens to the vocabulary
 
diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py
index 2118709a0..370fda087 100644
--- a/bindings/python/tests/bindings/test_tokenizer.py
+++ b/bindings/python/tests/bindings/test_tokenizer.py
@@ -562,6 +562,13 @@ def test_setting_to_none(self):
         tokenizer.pre_tokenizer = None
         assert tokenizer.pre_tokenizer == None
 
+    def test_re_assign_tokens(self):
+        tokenizer = Tokenizer.from_pretrained("t5-base")
+        tokenizer.assign_tokens({"<extra_id_0>": "my_new_token"})
+        assert tokenizer.decode([32099]) == "my_new_token"
+        assert tokenizer.encode("<extra_id_0>").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", "</s>"]
+        assert "my_new_token" in tokenizer.get_vocab(True).keys()
+
 
 class TestTokenizerRepr:
     def test_repr(self):

From 545d7230f485d2f199647c627f8b98f5a728c9bc Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 4 Oct 2024 15:24:14 +0200
Subject: [PATCH 11/16] fix unwrap errors

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 6f79ba660..984201075 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -148,8 +148,6 @@ pub struct AddedVocabulary {
     /// Contains the mapping from ID to AddedToken for all the added tokens, both special
     /// and classic.
     added_tokens_map_r: Arc<Mutex<HashMap<u32, AddedToken>>>,
-    /// Contains only the classic AddedToken, in the specific order the user gave them.
-    added_tokens: Vec<AddedToken>,
 
     /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
     split_trie: MatchingSet,
@@ -173,7 +171,6 @@ impl AddedVocabulary {
         Self {
             added_tokens_map: Arc::new(Mutex::new(HashMap::new())),
             added_tokens_map_r: Arc::new(Mutex::new(HashMap::new())),
-            added_tokens: vec![],
             split_trie: (trie, vec![]),
             split_normalized_trie: (normalized_trie, vec![]),
             encode_special_tokens: false,
@@ -218,12 +215,14 @@ impl AddedVocabulary {
 
     /// Check if a token is a special token
     pub fn is_special_token(&self, token: &str) -> bool {
-        self.added_tokens_map_r
-            .lock()
-            .unwrap()
-            .get(self.added_tokens_map.lock().unwrap().get(token).unwrap())
-            .unwrap()
-            .special
+        let hash_map = &self.added_tokens_map_r.lock().unwrap();
+        let revert_hash_map = &self.added_tokens_map.lock().unwrap();
+        if let Some(id) = revert_hash_map.get(token) {
+            if let Some(token) = hash_map.get(id) {
+                return token.special;
+            }
+        }
+        false
     }
 
     /// Add some special tokens to the vocabulary
@@ -335,7 +334,7 @@ impl AddedVocabulary {
     ///
     /// We keep two different RegexSet, one that will take care of matching against the
     /// non-normalized string, and one matching against the normalized one.
-    fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
+    fn refresh_added_tokens<N: Normalizer>(&mut self, _model: &impl Model, normalizer: Option<&N>) {
         type TupleTokenId<'a> = (&'a AddedToken, u32);
         let added_tokens_map_r = self.added_tokens_map_r.lock().unwrap().clone();
         let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) =

From ee7ce80e0b65fd53c186ab4b36e1eba38383d5a4 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Fri, 4 Oct 2024 15:55:43 +0200
Subject: [PATCH 12/16] forgot to remove from added tokens map!

---
 bindings/python/tests/bindings/test_tokenizer.py | 3 +++
 tokenizers/src/tokenizer/added_vocabulary.rs     | 4 ++++
 2 files changed, 7 insertions(+)

diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py
index 370fda087..9a1fd7272 100644
--- a/bindings/python/tests/bindings/test_tokenizer.py
+++ b/bindings/python/tests/bindings/test_tokenizer.py
@@ -566,6 +566,9 @@ def test_re_assign_tokens(self):
         tokenizer = Tokenizer.from_pretrained("t5-base")
         tokenizer.assign_tokens({"<extra_id_0>": "my_new_token"})
         assert tokenizer.decode([32099]) == "my_new_token"
+        assert tokenizer.encode("my_new_token").tokens == ["my_new_token", "</s>"]
+        assert tokenizer.encode("my_new_token").ids == [32099, 1]
+        assert tokenizer.encode("<extra_id_0>").ids == [0, 1]
         assert tokenizer.encode("<extra_id_0>").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", "</s>"]
         assert "my_new_token" in tokenizer.get_vocab(True).keys()
 
diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index 984201075..e22249048 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -321,6 +321,10 @@ impl AddedVocabulary {
                     .unwrap()
                     .entry(id)
                     .and_modify(|t| *t = new_token.clone()); // Replace entire entry with new_token
+                self.added_tokens_map
+                    .lock()
+                    .unwrap()
+                    .remove(old_token.content.as_str());
                 self.refresh_added_tokens(model, normalizer);
             } else {
                 error!(

From e8933fa5b996c11d9a8b61c7549b23e639c88b8d Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Sat, 5 Oct 2024 17:16:31 +0200
Subject: [PATCH 13/16] potential initial solution for the annoying unigram
 model :)

---
 .../python/tests/bindings/test_tokenizer.py   | 22 ++++++++++++-----
 tokenizers/src/models/unigram/model.rs        | 24 ++++++++++++-------
 2 files changed, 31 insertions(+), 15 deletions(-)

diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py
index 9a1fd7272..0fb2cbd7d 100644
--- a/bindings/python/tests/bindings/test_tokenizer.py
+++ b/bindings/python/tests/bindings/test_tokenizer.py
@@ -562,15 +562,25 @@ def test_setting_to_none(self):
         tokenizer.pre_tokenizer = None
         assert tokenizer.pre_tokenizer == None
 
-    def test_re_assign_tokens(self):
+    def test_re_assign_tokens_bpe(self):
+        tokenizer = Tokenizer.from_pretrained("gpt2")
+        tokenizer.assign_tokens({"<|endoftext|>": "my_new_token"})
+        assert tokenizer.decode([50256]) == "my_new_token"
+        assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
+        assert tokenizer.encode("my_new_token").ids == [50256]
+        assert tokenizer.encode("<|endoftext|>").ids == [27, 91, 437, 1659, 5239, 91, 29]
+        assert tokenizer.encode("<|endoftext|>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
+        assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}
+
+    def test_re_assign_tokens_unigram(self):
         tokenizer = Tokenizer.from_pretrained("t5-base")
         tokenizer.assign_tokens({"<extra_id_0>": "my_new_token"})
         assert tokenizer.decode([32099]) == "my_new_token"
-        assert tokenizer.encode("my_new_token").tokens == ["my_new_token", "</s>"]
-        assert tokenizer.encode("my_new_token").ids == [32099, 1]
-        assert tokenizer.encode("<extra_id_0>").ids == [0, 1]
-        assert tokenizer.encode("<extra_id_0>").tokens == ["▁", "<", "extra", "_", "i", "d", "_", "0", ">", "</s>"]
-        assert "my_new_token" in tokenizer.get_vocab(True).keys()
+        assert tokenizer.encode("my_new_token").tokens == ["my_new_token"]
+        assert tokenizer.encode("my_new_token").ids == [32099]
+        assert tokenizer.encode("<extra_id_0>").ids == [27, 91, 437, 1659, 5239, 91, 29]
+        assert tokenizer.encode("<extra_id_0>").tokens == ["<", "|", "end", "of", "text", "|", ">"]
+        assert "my_new_token" in {k.content for k in tokenizer.get_added_tokens_decoder().values()}
 
 
 class TestTokenizerRepr:
diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs
index dba5a0400..4a5371738 100644
--- a/tokenizers/src/models/unigram/model.rs
+++ b/tokenizers/src/models/unigram/model.rs
@@ -3,8 +3,11 @@ use super::{
     trainer::UnigramTrainer,
     trie::{Trie, TrieBuilder},
 };
-use crate::tokenizer::{Model, Result, Token};
 use crate::utils::cache::Cache;
+use crate::{
+    tokenizer::{Model, Result, Token},
+    AddedVocabulary,
+};
 
 use std::collections::HashMap;
 use std::convert::TryInto;
@@ -81,7 +84,7 @@ pub enum UnigramError {
 impl Default for Unigram {
     fn default() -> Self {
         let vocab = vec![("<unk>".to_string(), 0.0)];
-        Self::from(vocab, Some(0), false).unwrap()
+        Self::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap()
     }
 }
 
@@ -96,6 +99,7 @@ impl Unigram {
         vocab: Vec<(String, f64)>,
         unk_id: Option<usize>,
         byte_fallback: bool,
+        added_tokens: &AddedVocabulary,
     ) -> Result<Self> {
         let n = vocab.len();
         let mut token_to_ids: TokenMap = HashMap::new();
@@ -114,11 +118,13 @@ impl Unigram {
 
         let mut min_score = f64::INFINITY;
         for (id, (token, score)) in vocab.iter().enumerate() {
-            token_to_ids.insert(token.to_string(), id as u32);
-            let bytes: Vec<u8> = token.bytes().collect();
-            builder.push(&bytes);
-            if score < &min_score {
-                min_score = *score;
+            if !added_tokens.is_special_token(token) {
+                token_to_ids.insert(token.to_string(), id as u32);
+                let bytes: Vec<u8> = token.bytes().collect();
+                builder.push(&bytes);
+                if score < &min_score {
+                    min_score = *score;
+                }
             }
         }
         let trie = builder.build();
@@ -480,7 +486,7 @@ mod tests {
     #[test]
     fn test_populate_nodes_unk() {
         let pieces = vec![("<unk>".to_string(), 0.0)];
-        let model = Unigram::from(pieces, Some(0), false).unwrap();
+        let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();
 
         let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
         model.populate_nodes(&mut lattice);
@@ -505,7 +511,7 @@ mod tests {
             ("ab".to_string(), 0.3),
             ("bc".to_string(), 0.4),
         ];
-        let model = Unigram::from(pieces, Some(0), false).unwrap();
+        let model = Unigram::from(pieces, Some(0), false, &AddedVocabulary::default()).unwrap();
 
         let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
         model.populate_nodes(&mut lattice);

From 0475c057dd8f5d5141b15f47ad16358ddc5a9f6b Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Sat, 5 Oct 2024 17:17:52 +0200
Subject: [PATCH 14/16] fix added vocab tests

---
 tokenizers/src/tokenizer/added_vocabulary.rs | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs
index e22249048..f91a7b82f 100644
--- a/tokenizers/src/tokenizer/added_vocabulary.rs
+++ b/tokenizers/src/tokenizer/added_vocabulary.rs
@@ -724,15 +724,15 @@ mod tests {
         assert_eq!(vocab.len(), 3); // New token was added
         assert!(vocab.is_special_token("test"));
         assert_eq!(
-            *vocab.get_added_tokens_decoder(),
+            vocab.get_added_tokens_decoder(),
             HashMap::from([
                 (0, AddedToken::from("test", true)),
                 (2, AddedToken::from("added_token_1", true)),
                 (3, AddedToken::from("added_token_2", true)),
             ])
         );
-        assert!(vocab.added_tokens_map.contains_key("test"));
-        assert!(vocab.added_tokens_map_r.contains_key(&0));
+        assert!(vocab.added_tokens_map.lock().unwrap().contains_key("test"));
+        assert!(vocab.added_tokens_map_r.lock().unwrap().contains_key(&0));
 
         vocab.add_tokens(
             &[

From 167ecdebfb217316a156cf2500e44aa354d30d53 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Sat, 5 Oct 2024 17:56:06 +0200
Subject: [PATCH 15/16] small fixed

---
 tokenizers/src/models/unigram/serialization.rs |  8 +++++---
 tokenizers/src/models/unigram/trainer.rs       | 10 +++++++---
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs
index a6e56b735..1ad95002e 100644
--- a/tokenizers/src/models/unigram/serialization.rs
+++ b/tokenizers/src/models/unigram/serialization.rs
@@ -78,12 +78,14 @@ impl<'de> Visitor<'de> for UnigramVisitor {
 
 #[cfg(test)]
 mod test {
+    use crate::AddedVocabulary;
+
     use super::*;
 
     #[test]
     fn test_serialization() {
         let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
-        let model = Unigram::from(vocab, Some(0), false).unwrap();
+        let model = Unigram::from(vocab, Some(0), false, &AddedVocabulary::default()).unwrap();
 
         let data = serde_json::to_string(&model).unwrap();
         let reconstructed = serde_json::from_str(&data).unwrap();
@@ -94,7 +96,7 @@ mod test {
     #[test]
     fn test_serialization_unk_id_not_zero() {
         let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
-        let model = Unigram::from(vocab, Some(1), false).unwrap();
+        let model = Unigram::from(vocab, Some(1), false, &AddedVocabulary::default()).unwrap();
 
         let data = serde_json::to_string(&model).unwrap();
         let reconstructed = serde_json::from_str(&data).unwrap();
@@ -105,7 +107,7 @@ mod test {
     #[test]
     fn test_serialization_no_unk_id() {
         let vocab = vec![("a".to_string(), -0.5)];
-        let model = Unigram::from(vocab, None, false).unwrap();
+        let model = Unigram::from(vocab, None, false, &AddedVocabulary::default()).unwrap();
 
         let data = serde_json::to_string(&model).unwrap();
         let reconstructed = serde_json::from_str(&data).unwrap();
diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs
index 5d178e77b..b3e816a59 100644
--- a/tokenizers/src/models/unigram/trainer.rs
+++ b/tokenizers/src/models/unigram/trainer.rs
@@ -2,6 +2,7 @@ use crate::models::unigram::{lattice::Lattice, model::Unigram};
 use crate::tokenizer::{AddedToken, Result, Trainer};
 use crate::utils::parallelism::*;
 use crate::utils::progress::{ProgressBar, ProgressStyle};
+use crate::AddedVocabulary;
 use log::debug;
 use serde::{Deserialize, Serialize};
 use std::cmp::Reverse;
@@ -182,6 +183,7 @@ impl UnigramTrainer {
             special_tokens.into_iter().chain(pieces).collect(),
             unk_id,
             model.byte_fallback(),
+            &AddedVocabulary::default(),
         )
     }
 
@@ -567,7 +569,8 @@ impl UnigramTrainer {
         if required_chars.len() as u32 > self.vocab_size {
             return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
         }
-        let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
+        let mut new_model =
+            Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
         loop {
             // Sub-EM iteration.
             for _iter in 0..self.n_sub_iterations {
@@ -576,7 +579,8 @@ impl UnigramTrainer {
 
                 // Executes M step.
                 pieces = self.run_m_step(&pieces, &expected);
-                new_model = Unigram::from(pieces.clone(), Some(0), false)?;
+                new_model =
+                    Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
 
                 // Useful comment for checking compatibility with spm
                 debug!(
@@ -600,7 +604,7 @@ impl UnigramTrainer {
 
             // Prunes pieces.
             pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
-            new_model = Unigram::from(pieces.clone(), Some(0), false)?;
+            new_model = Unigram::from(pieces.clone(), Some(0), false, &AddedVocabulary::default())?;
         }
         self.finalize_progress(&progress, expected_updates);
 

From 81d83361d0bfc466616d65f3eff91d723cc48630 Mon Sep 17 00:00:00 2001
From: Arthur Zucker <arthur.zucker@gmail.com>
Date: Sat, 5 Oct 2024 17:58:22 +0200
Subject: [PATCH 16/16] fix the unigram::from calls

---
 tokenizers/src/models/unigram/model.rs         |  9 ++++++---
 tokenizers/src/models/unigram/serialization.rs | 10 ++++++++--
 2 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs
index 4a5371738..c604b11c6 100644
--- a/tokenizers/src/models/unigram/model.rs
+++ b/tokenizers/src/models/unigram/model.rs
@@ -548,7 +548,8 @@ mod tests {
             ("abcd".to_string(), 10.0),
         ];
 
-        let model = Unigram::from(sentencepieces, Some(0), false).unwrap();
+        let model =
+            Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();
         let result = model.encode("abcd").unwrap();
         assert_eq!(result, vec!["abcd"]);
     }
@@ -570,7 +571,8 @@ mod tests {
             ("qr".to_string(), -0.5),
         ];
 
-        let mut model = Unigram::from(sentencepieces, Some(0), false).unwrap();
+        let mut model =
+            Unigram::from(sentencepieces, Some(0), false, &AddedVocabulary::default()).unwrap();
 
         for is_optimized in &[true, false] {
             model.set_optimized(*is_optimized);
@@ -617,7 +619,8 @@ mod tests {
             ("<0xC3>".to_string(), -0.01),
             ("<0xA9>".to_string(), -0.03),
         ];
-        let unigram = Unigram::from(sentencepieces, Some(0), true).unwrap();
+        let unigram =
+            Unigram::from(sentencepieces, Some(0), true, &AddedVocabulary::default()).unwrap();
         let tokens: Vec<Token> = unigram.tokenize("é").unwrap();
         assert_eq!(
             tokens,
diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs
index 1ad95002e..f0ff30694 100644
--- a/tokenizers/src/models/unigram/serialization.rs
+++ b/tokenizers/src/models/unigram/serialization.rs
@@ -1,3 +1,5 @@
+use crate::AddedVocabulary;
+
 use super::model::Unigram;
 use serde::{
     de::{Error, MapAccess, Visitor},
@@ -69,8 +71,12 @@ impl<'de> Visitor<'de> for UnigramVisitor {
             }
         }
         match (vocab, unk_id, byte_fallback) {
-            (Some(vocab), unk_id, byte_fallback) => Ok(Unigram::from(vocab, unk_id, byte_fallback)
-                .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?),
+            (Some(vocab), unk_id, byte_fallback) => {
+                Ok(
+                    Unigram::from(vocab, unk_id, byte_fallback, &AddedVocabulary::default())
+                        .map_err(|err| Error::custom(format!("Unable to load vocab {err:?}")))?,
+                )
+            }
             (None, _, _) => Err(Error::custom("Missing vocab")),
         }
     }