diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 25efc691c..ac31fc8a8 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,7 +1,9 @@ -use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY}; +use super::{ + super::OrderedVocabIter, Cache, Error, Pair, WithFirstLastIterator, Word, + DEFAULT_CACHE_CAPACITY, +}; use crate::tokenizer::{Model, Offsets, Result, Token}; use rand::{thread_rng, Rng}; -use serde::{Serialize, Serializer}; use serde_json::Value; use std::{ collections::HashMap, @@ -463,28 +465,6 @@ impl Model for BPE { } } -/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order -/// of token ID, smallest to largest. -struct OrderedVocabIter<'a> { - vocab_r: &'a HashMap, -} - -impl<'a> OrderedVocabIter<'a> { - fn new(vocab_r: &'a HashMap) -> Self { - Self { vocab_r } - } -} - -impl<'a> Serialize for OrderedVocabIter<'a> { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i)); - serializer.collect_map(iter) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index e1a2ad8f0..baab582b3 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -3,3 +3,28 @@ pub mod bpe; pub mod wordlevel; pub mod wordpiece; + +use serde::{Serialize, Serializer}; +use std::collections::HashMap; + +/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order +/// of token ID, smallest to largest. +struct OrderedVocabIter<'a> { + vocab_r: &'a HashMap, +} + +impl<'a> OrderedVocabIter<'a> { + fn new(vocab_r: &'a HashMap) -> Self { + Self { vocab_r } + } +} + +impl<'a> Serialize for OrderedVocabIter<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i)); + serializer.collect_map(iter) + } +} diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index b7bb12fac..c05861d71 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -1,3 +1,4 @@ +use super::OrderedVocabIter; use crate::tokenizer::{Model, Result, Token}; use serde_json::Value; use std::collections::HashMap; @@ -168,20 +169,14 @@ impl Model for WordLevel { None => "vocab.json".to_string(), }; - // Write vocab.txt + // Write vocab.json let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; - let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect(); - vocab.sort_unstable_by_key(|k| *k.1); - vocab_file.write_all( - &vocab - .into_iter() - .map(|(token, _)| format!("{}\n", token).as_bytes().to_owned()) - .flatten() - .collect::>()[..], - )?; + let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r); + let serialized = serde_json::to_string(&order_vocab_iter)?; + vocab_file.write_all(&serialized.as_bytes())?; Ok(vec![vocab_path]) } diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 8bb8c6541..3c93e62a2 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -252,8 +252,8 @@ impl Model for WordPiece { fn save(&self, folder: &Path, name: Option<&str>) -> Result> { let vocab_file_name = match name { - Some(name) => format!("{}-vocab.json", name), - None => "vocab.json".to_string(), + Some(name) => format!("{}-vocab.txt", name), + None => "vocab.txt".to_string(), }; // Write vocab.txt