diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index bf290de9f..ab03caa73 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -238,6 +238,28 @@ def test_train_with_special_tokens(self): "[SEP]", ] + tokenizer = Tokenizer(models.Unigram()) + trainer = trainers.UnigramTrainer( + show_progress=False, + special_tokens=["[PAD]", "[SEP]", "[CLS]"], + unk_token="[UNK]", + vocab_size=100, + ) + tokenizer.train([filename], trainer=trainer) + + assert tokenizer.get_vocab_size() == 100 + + tokenizer = Tokenizer(models.Unigram()) + trainer = trainers.UnigramTrainer( + show_progress=False, + special_tokens=["[PAD]", "[SEP]", "[CLS]", "[UNK]"], + unk_token="[UNK]", + vocab_size=100, + ) + tokenizer.train([filename], trainer=trainer) + + assert tokenizer.get_vocab_size() == 100 + def test_cannot_train_different_model(self): tokenizer = Tokenizer(models.BPE()) trainer = trainers.UnigramTrainer(show_progress=False) diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index c76448b63..dc5e536f6 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -126,19 +126,7 @@ impl UnigramTrainer { min_score_penalty += min_score_penalty_delta; } } - for (token, score) in model.iter() { - if inserted.contains::(token) { - continue; - } - inserted.insert(token.to_string()); - pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score })); - if pieces.len() == self.vocab_size as usize { - break; - } - } - pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); - // Insert the necessary tokens let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token { let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| { if t.content == *unk { @@ -154,6 +142,26 @@ impl UnigramTrainer { } else { (None, false) }; + + let vocab_size_without_special_tokens = if need_add_unk { + self.vocab_size as usize - self.special_tokens.len() - 1 + } else { + self.vocab_size as usize - self.special_tokens.len() + }; + for (token, score) in model.iter() { + if inserted.contains::(token) { + continue; + } + inserted.insert(token.to_string()); + pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score })); + + if pieces.len() == vocab_size_without_special_tokens { + break; + } + } + pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + + // Insert the necessary tokens let mut special_tokens = self .special_tokens .iter()