From 9774381196e211c2ef1779c26ef2d882475db5c1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 18:41:54 +0200 Subject: [PATCH 01/10] Using serde (serde_pyo3) to get __str__ and __repr__ easily. --- bindings/python/Cargo.toml | 1 + bindings/python/src/decoders.rs | 2 +- bindings/python/src/models.rs | 2 +- bindings/python/src/normalizers.rs | 2 +- bindings/python/src/pre_tokenizers.rs | 2 +- bindings/python/src/processors.rs | 2 +- bindings/python/src/tokenizer.rs | 21 ++++++++++++++++++++- bindings/python/src/trainers.rs | 2 +- 8 files changed, 27 insertions(+), 7 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index b494e4085..8a81ac3d2 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" itertools = "0.12" +serde_pyo3 = { git = "https://github.com/Narsil/serde_pyo3" } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index ed21f3469..4a4af94dd 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -29,8 +29,8 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyDecoder { - #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index bffa1bc21..2bfaafd34 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult}; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyModel { - #[serde(flatten)] pub model: Arc>, } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 864947e39..724e79b85 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -44,8 +44,8 @@ impl PyNormalizedStringMut<'_> { /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyNormalizer { - #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a2bd9b39c..a9060ec3b 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,8 +35,8 @@ use super::utils::*; subclass )] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyPreTokenizer { - #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index c46d8ea49..aceb1d446 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -28,8 +28,8 @@ use tokenizers as tk; subclass )] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyPostProcessor { - #[serde(flatten)] pub processor: Arc, } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 1c6bc9cc1..5bc57f777 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; @@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl PyResult { + serde_pyo3::to_string(self).map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + /// Return the number of special tokens that would be added for single/pair sentences. /// :param is_pair: Boolean indicating if the input would be a single sentence or a pair /// :return: @@ -1434,4 +1441,16 @@ mod test { Tokenizer::from_file(&tmp).unwrap(); } + + #[test] + fn serde_pyo3() { + let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); + tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ]))); + + let output = serde_pyo3::to_string(&tokenizer).unwrap(); + assert_eq!(output, ""); + } } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 716e4cfeb..cbce2aef9 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -16,8 +16,8 @@ use tokenizers as tk; /// Trainer will return an instance of this class when instantiated. #[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyTrainer { - #[serde(flatten)] pub trainer: Arc>, } From 32de311c51bc44a280252d9d5fda05a7dda821ea Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 19:51:44 +0200 Subject: [PATCH 02/10] Putting it within tokenizers, it needs to be too specific. --- bindings/python/Cargo.toml | 1 - bindings/python/src/tokenizer.rs | 5 +- bindings/python/src/utils/mod.rs | 1 + bindings/python/src/utils/serde_pyo3.rs | 642 ++++++++++++++++++ .../python/tests/bindings/test_tokenizer.py | 23 +- 5 files changed, 668 insertions(+), 4 deletions(-) create mode 100644 bindings/python/src/utils/serde_pyo3.rs diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 8a81ac3d2..b494e4085 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,7 +18,6 @@ pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" itertools = "0.12" -serde_pyo3 = { git = "https://github.com/Narsil/serde_pyo3" } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 5bc57f777..22eb3def4 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -642,7 +642,8 @@ impl PyTokenizer { #[pyo3(signature = ())] fn repr(&self) -> PyResult { - serde_pyo3::to_string(self).map_err(|e| exceptions::PyException::new_err(e.to_string())) + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) } /// Return the number of special tokens that would be added for single/pair sentences. @@ -1450,7 +1451,7 @@ mod test { Arc::new(RwLock::new(Lowercase.into())), ]))); - let output = serde_pyo3::to_string(&tokenizer).unwrap(); + let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); assert_eq!(output, ""); } } diff --git a/bindings/python/src/utils/mod.rs b/bindings/python/src/utils/mod.rs index 1e409a504..43352a7fa 100644 --- a/bindings/python/src/utils/mod.rs +++ b/bindings/python/src/utils/mod.rs @@ -5,6 +5,7 @@ mod iterators; mod normalization; mod pretokenization; mod regex; +pub mod serde_pyo3; pub use iterators::*; pub use normalization::*; diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs new file mode 100644 index 000000000..bf010f21f --- /dev/null +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -0,0 +1,642 @@ +use serde::de::value::Error; +use serde::{ser, Serialize}; +type Result = ::std::result::Result; + +pub struct Serializer { + // This string starts empty and JSON is appended as values are serialized. + output: String, +} + +// By convention, the public API of a Serde serializer is one or more `to_abc` +// functions such as `to_string`, `to_bytes`, or `to_writer` depending on what +// Rust types the serializer is able to produce as output. +// +// This basic serializer supports only `to_string`. +pub fn to_string(value: &T) -> Result +where + T: Serialize, +{ + let mut serializer = Serializer { + output: String::new(), + }; + value.serialize(&mut serializer)?; + Ok(serializer.output) +} + +impl<'a> ser::Serializer for &'a mut Serializer { + // The output type produced by this `Serializer` during successful + // serialization. Most serializers that produce text or binary output should + // set `Ok = ()` and serialize into an `io::Write` or buffer contained + // within the `Serializer` instance, as happens here. Serializers that build + // in-memory data structures may be simplified by using `Ok` to propagate + // the data structure around. + type Ok = (); + + // The error type when some error occurs during serialization. + type Error = Error; + + // Associated types for keeping track of additional state while serializing + // compound data structures like sequences and maps. In this case no + // additional state is required beyond what is already stored in the + // Serializer struct. + type SerializeSeq = Self; + type SerializeTuple = Self; + type SerializeTupleStruct = Self; + type SerializeTupleVariant = Self; + type SerializeMap = Self; + type SerializeStruct = Self; + type SerializeStructVariant = Self; + + // Here we go with the simple methods. The following 12 methods receive one + // of the primitive types of the data model and map it to JSON by appending + // into the output string. + fn serialize_bool(self, v: bool) -> Result<()> { + self.output += if v { "True" } else { "False" }; + Ok(()) + } + + // JSON does not distinguish between different sizes of integers, so all + // signed integers will be serialized the same and all unsigned integers + // will be serialized the same. Other formats, especially compact binary + // formats, may need independent logic for the different sizes. + fn serialize_i8(self, v: i8) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i16(self, v: i16) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + fn serialize_i32(self, v: i32) -> Result<()> { + self.serialize_i64(i64::from(v)) + } + + // Not particularly efficient but this is example code anyway. A more + // performant approach would be to use the `itoa` crate. + fn serialize_i64(self, v: i64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + fn serialize_u8(self, v: u8) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u16(self, v: u16) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u32(self, v: u32) -> Result<()> { + self.serialize_u64(u64::from(v)) + } + + fn serialize_u64(self, v: u64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + fn serialize_f32(self, v: f32) -> Result<()> { + self.serialize_f64(f64::from(v)) + } + + fn serialize_f64(self, v: f64) -> Result<()> { + self.output += &v.to_string(); + Ok(()) + } + + // Serialize a char as a single-character string. Other formats may + // represent this differently. + fn serialize_char(self, v: char) -> Result<()> { + self.serialize_str(&v.to_string()) + } + + // This only works for strings that don't require escape sequences but you + // get the idea. For example it would emit invalid JSON if the input string + // contains a '"' character. + fn serialize_str(self, v: &str) -> Result<()> { + self.output += "\""; + self.output += v; + self.output += "\""; + Ok(()) + } + + // Serialize a byte array as an array of bytes. Could also use a base64 + // string here. Binary formats will typically represent byte arrays more + // compactly. + fn serialize_bytes(self, v: &[u8]) -> Result<()> { + use serde::ser::SerializeSeq; + let mut seq = self.serialize_seq(Some(v.len()))?; + for byte in v { + seq.serialize_element(byte)?; + } + seq.end() + } + + // An absent optional is represented as the JSON `null`. + fn serialize_none(self) -> Result<()> { + self.serialize_unit() + } + + // A present optional is represented as just the contained value. Note that + // this is a lossy representation. For example the values `Some(())` and + // `None` both serialize as just `null`. Unfortunately this is typically + // what people expect when working with JSON. Other formats are encouraged + // to behave more intelligently if possible. + fn serialize_some(self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + // In Serde, unit means an anonymous value containing no data. Map this to + // JSON as `null`. + fn serialize_unit(self) -> Result<()> { + self.output += "None"; + Ok(()) + } + + // Unit struct means a named value containing no data. Again, since there is + // no data, map this to JSON as `null`. There is no need to serialize the + // name in most formats. + fn serialize_unit_struct(self, _name: &'static str) -> Result<()> { + self.serialize_unit() + } + + // When serializing a unit variant (or any other kind of variant), formats + // can choose whether to keep track of it by index or by name. Binary + // formats typically use the index of the variant and human-readable formats + // typically use the name. + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result<()> { + // self.serialize_str(variant) + self.output += variant; + Ok(()) + } + + // As is done here, serializers are encouraged to treat newtype structs as + // insignificant wrappers around the data they contain. + fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + // Note that newtype variant (and all of the other variant serialization + // methods) refer exclusively to the "externally tagged" enum + // representation. + // + // Serialize this to JSON in externally tagged form as `{ NAME: VALUE }`. + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + value: &T, + ) -> Result<()> + where + T: ?Sized + Serialize, + { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + value.serialize(&mut *self)?; + self.output += ")"; + Ok(()) + } + + // Now we get to the serialization of compound types. + // + // The start of the sequence, each value, and the end are three separate + // method calls. This one is responsible only for serializing the start, + // which in JSON is `[`. + // + // The length of the sequence may or may not be known ahead of time. This + // doesn't make a difference in JSON because the length is not represented + // explicitly in the serialized form. Some serializers may only be able to + // support sequences for which the length is known up front. + fn serialize_seq(self, _len: Option) -> Result { + self.output += "["; + Ok(self) + } + + // Tuples look just like sequences in JSON. Some formats may be able to + // represent tuples more efficiently by omitting the length, since tuple + // means that the corresponding `Deserialize implementation will know the + // length without needing to look at the serialized data. + fn serialize_tuple(self, _len: usize) -> Result { + self.output += "("; + Ok(self) + } + + // Tuple structs look just like sequences in JSON. + fn serialize_tuple_struct( + self, + _name: &'static str, + len: usize, + ) -> Result { + self.serialize_tuple(len) + } + + // Tuple variants are represented in JSON as `{ NAME: [DATA...] }`. Again + // this method is only responsible for the externally tagged representation. + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + Ok(self) + } + + // Maps are represented in JSON as `{ K: V, K: V, ... }`. + fn serialize_map(self, _len: Option) -> Result { + println!("Serialize map"); + self.output += "{"; + Ok(self) + } + + // Structs look just like maps in JSON. In particular, JSON requires that we + // serialize the field names of the struct. Other formats may be able to + // omit the field names when serializing structs because the corresponding + // Deserialize implementation is required to know what the keys are without + // looking at the serialized data. + fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + // self.serialize_map(Some(len)) + // name.serialize(&mut *self)?; + if name.ends_with("Helper") { + self.output += &name[..name.len() - "Helper".len()]; + } else { + self.output += name; + } + self.output += "("; + Ok(self) + } + + // Struct variants are represented in JSON as `{ NAME: { K: V, ... } }`. + // This is the externally tagged representation. + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + _len: usize, + ) -> Result { + // variant.serialize(&mut *self)?; + self.output += variant; + self.output += "("; + Ok(self) + } +} + +// The following 7 impls deal with the serialization of compound types like +// sequences and maps. Serialization of such types is begun by a Serializer +// method and followed by zero or more calls to serialize individual elements of +// the compound type and one call to end the compound type. +// +// This impl is SerializeSeq so these methods are called after `serialize_seq` +// is called on the Serializer. +impl<'a> ser::SerializeSeq for &'a mut Serializer { + // Must match the `Ok` type of the serializer. + type Ok = (); + // Must match the `Error` type of the serializer. + type Error = Error; + + // Serialize a single element of the sequence. + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('[') { + self.output += ", "; + } + value.serialize(&mut **self) + } + + // Close the sequence. + fn end(self) -> Result<()> { + self.output += "]"; + Ok(()) + } +} + +// Same thing but for tuples. +impl<'a> ser::SerializeTuple for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output += ")"; + Ok(()) + } +} + +// Same thing but for tuple structs. +impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output += ")"; + Ok(()) + } +} + +// Tuple variants are a little different. Refer back to the +// `serialize_tuple_variant` method above: +// +// self.output += "{"; +// variant.serialize(&mut *self)?; +// self.output += ":["; +// +// So the `end` method in this impl is responsible for closing both the `]` and +// the `}`. +impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output += ")"; + Ok(()) + } +} + +// Some `Serialize` types are not able to hold a key and value in memory at the +// same time so `SerializeMap` implementations are required to support +// `serialize_key` and `serialize_value` individually. +// +// There is a third optional method on the `SerializeMap` trait. The +// `serialize_entry` method allows serializers to optimize for the case where +// key and value are both available simultaneously. In JSON it doesn't make a +// difference so the default behavior for `serialize_entry` is fine. +impl<'a> ser::SerializeMap for &'a mut Serializer { + type Ok = (); + type Error = Error; + + // The Serde data model allows map keys to be any serializable type. JSON + // only allows string keys so the implementation below will produce invalid + // JSON if the key serializes as something other than a string. + // + // A real JSON serializer would need to validate that map keys are strings. + // This can be done by using a different Serializer to serialize the key + // (instead of `&mut **self`) and having that other serializer only + // implement `serialize_str` and return an error on any other data type. + fn serialize_key(&mut self, key: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('{') { + self.output += ", "; + } + key.serialize(&mut **self) + } + + // It doesn't make a difference whether the colon is printed at the end of + // `serialize_key` or at the beginning of `serialize_value`. In this case + // the code is a bit simpler having it here. + fn serialize_value(&mut self, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + self.output += ":"; + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output += "}"; + Ok(()) + } +} + +// Structs are like maps in which the keys are constrained to be compile-time +// constant strings. +impl<'a> ser::SerializeStruct for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + // key.serialize(&mut **self)?; + if key != "type" { + self.output += key; + self.output += "="; + value.serialize(&mut **self) + } else { + Ok(()) + } + } + + fn end(self) -> Result<()> { + self.output += ")"; + Ok(()) + } +} + +// Similar to `SerializeTupleVariant`, here the `end` method is responsible for +// closing both of the curly braces opened by `serialize_struct_variant`. +impl<'a> ser::SerializeStructVariant for &'a mut Serializer { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, key: &'static str, value: &T) -> Result<()> + where + T: ?Sized + Serialize, + { + if !self.output.ends_with('(') { + self.output += ", "; + } + // key.serialize(&mut **self)?; + self.output += key; + self.output += "="; + value.serialize(&mut **self) + } + + fn end(self) -> Result<()> { + self.output += ")"; + Ok(()) + } +} + +//////////////////////////////////////////////////////////////////////////////// + +#[test] +fn test_basic() { + assert_eq!(to_string(&true).unwrap(), "True"); + assert_eq!(to_string(&Some(1)).unwrap(), "1"); + assert_eq!(to_string(&None::).unwrap(), "None"); +} + +#[test] +fn test_struct() { + #[derive(Serialize)] + struct Test { + int: u32, + seq: Vec<&'static str>, + } + + let test = Test { + int: 1, + seq: vec!["a", "b"], + }; + let expected = r#"Test(int=1, seq=["a", "b"])"#; + assert_eq!(to_string(&test).unwrap(), expected); +} + +#[test] +fn test_enum() { + #[derive(Serialize)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let u = E::Unit; + let expected = r#"Unit"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let n = E::Newtype(1); + let expected = r#"Newtype(1)"#; + assert_eq!(to_string(&n).unwrap(), expected); + + let t = E::Tuple(1, 2); + let expected = r#"Tuple(1, 2)"#; + assert_eq!(to_string(&t).unwrap(), expected); + + let s = E::Struct { a: 1 }; + let expected = r#"Struct(a=1)"#; + assert_eq!(to_string(&s).unwrap(), expected); +} + +#[test] +fn test_enum_untagged() { + #[derive(Serialize)] + #[serde(untagged)] + enum E { + Unit, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let u = E::Unit; + let expected = r#"None"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let n = E::Newtype(1); + let expected = r#"1"#; + assert_eq!(to_string(&n).unwrap(), expected); + + let t = E::Tuple(1, 2); + let expected = r#"(1, 2)"#; + assert_eq!(to_string(&t).unwrap(), expected); + + let s = E::Struct { a: 1 }; + let expected = r#"E(a=1)"#; + assert_eq!(to_string(&s).unwrap(), expected); +} + +#[test] +fn test_struct_tagged() { + #[derive(Serialize)] + #[serde(untagged)] + enum E { + A(A), + } + + #[derive(Serialize)] + #[serde(tag = "type")] + struct A { + a: bool, + b: usize, + } + + let u = A { a: true, b: 1 }; + let expected = r#"A(type="A", a=True, b=1)"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = E::A(A { a: true, b: 1 }); + let expected = r#"A(type="A", a=True, b=1)"#; + assert_eq!(to_string(&u).unwrap(), expected); +} + +#[test] +fn test_flatten() { + #[derive(Serialize)] + struct A { + a: bool, + b: usize, + } + + #[derive(Serialize)] + struct B { + c: A, + d: usize, + } + + #[derive(Serialize)] + struct C { + #[serde(flatten)] + c: A, + d: usize, + } + + let u = B { + c: A { a: true, b: 1 }, + d: 2, + }; + let expected = r#"B(c=A(a=True, b=1), d=2)"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = C { + c: A { a: true, b: 1 }, + d: 2, + }; + let expected = r#"C(a=True, b=1, d=2)"#; + assert_eq!(to_string(&u).unwrap(), expected); +} diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 39f110d07..b5539618d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -7,7 +7,8 @@ from tokenizers.implementations import BertWordPieceTokenizer from tokenizers.models import BPE, Model, Unigram from tokenizers.pre_tokenizers import ByteLevel -from tokenizers.processors import RobertaProcessing +from tokenizers.processors import RobertaProcessing, TemplateProcessing +from tokenizers.normalizers import Strip, Lowercase, Sequence from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files @@ -549,3 +550,23 @@ def test_decode_special(self): output = tokenizer.decode([0, 1, 2, 3], skip_special_tokens=True) assert output == "name is john" assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True) + +class TestTokenizerRepr: + def test_repr(self): + tokenizer = Tokenizer(BPE()) + out = tokenizer.repr() + print(out) + assert out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + + def test_repr_complete(self): + tokenizer = Tokenizer(BPE()) + tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) + tokenizer.post_processor = TemplateProcessing( + single=["[CLS]", "$0", "[SEP]"], + pair=["[CLS]:0", "$A", "[SEP]:0", "$B:1", "[SEP]:1"], + special_tokens=[("[CLS]", 1), ("[SEP]", 0)], + ) + tokenizer.normalizer = Sequence([Lowercase(), Strip()]) + out = tokenizer.repr() + print(out) + assert out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' From bcbea438c218dc9bfb81c7a507d798a9e041b4a2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 20:02:32 +0200 Subject: [PATCH 03/10] Clippy is our friend. --- bindings/python/src/utils/serde_pyo3.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs index bf010f21f..0dc336f2d 100644 --- a/bindings/python/src/utils/serde_pyo3.rs +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -273,10 +273,10 @@ impl<'a> ser::Serializer for &'a mut Serializer { fn serialize_struct(self, name: &'static str, _len: usize) -> Result { // self.serialize_map(Some(len)) // name.serialize(&mut *self)?; - if name.ends_with("Helper") { - self.output += &name[..name.len() - "Helper".len()]; + if let Some(stripped) = name.strip_suffix("Helper") { + self.output += stripped; } else { - self.output += name; + self.output += name } self.output += "("; Ok(self) From d8ee213c885b27d79dbe883f6a22a84f1022c06a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 20:08:34 +0200 Subject: [PATCH 04/10] Ruff. --- bindings/python/tests/bindings/test_tokenizer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index b5539618d..9c9c4f12d 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -551,12 +551,16 @@ def test_decode_special(self): assert output == "name is john" assert tokenizer.get_added_tokens_decoder()[0] == AddedToken("my", special=True) + class TestTokenizerRepr: def test_repr(self): tokenizer = Tokenizer(BPE()) out = tokenizer.repr() print(out) - assert out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + assert ( + out + == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + ) def test_repr_complete(self): tokenizer = Tokenizer(BPE()) @@ -569,4 +573,7 @@ def test_repr_complete(self): tokenizer.normalizer = Sequence([Lowercase(), Strip()]) out = tokenizer.repr() print(out) - assert out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + assert ( + out + == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' + ) From 18fad02066f728682daba0c4af8957a8295d3989 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 20:18:35 +0200 Subject: [PATCH 05/10] Update the tests. --- bindings/python/src/tokenizer.rs | 2 +- bindings/python/src/utils/serde_pyo3.rs | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 22eb3def4..6faeb7ad7 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1452,6 +1452,6 @@ mod test { ]))); let output = crate::utils::serde_pyo3::to_string(&tokenizer).unwrap(); - assert_eq!(output, ""); + assert_eq!(output, "Tokenizer(version=\"1.0\", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[NFKC(), Lowercase()]), pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))"); } } diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs index 0dc336f2d..3969277d5 100644 --- a/bindings/python/src/utils/serde_pyo3.rs +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -597,11 +597,13 @@ fn test_struct_tagged() { } let u = A { a: true, b: 1 }; - let expected = r#"A(type="A", a=True, b=1)"#; + // let expected = r#"A(type="A", a=True, b=1)"#; + // No we skip all `type` manually inserted variants. + let expected = r#"A(a=True, b=1)"#; assert_eq!(to_string(&u).unwrap(), expected); let u = E::A(A { a: true, b: 1 }); - let expected = r#"A(type="A", a=True, b=1)"#; + let expected = r#"A(a=True, b=1)"#; assert_eq!(to_string(&u).unwrap(), expected); } @@ -626,6 +628,12 @@ fn test_flatten() { d: usize, } + #[derive(Serialize)] + #[serde(transparent)] + struct D { + e: A, + } + let u = B { c: A { a: true, b: 1 }, d: 2, @@ -637,6 +645,15 @@ fn test_flatten() { c: A { a: true, b: 1 }, d: 2, }; - let expected = r#"C(a=True, b=1, d=2)"#; + // XXX This is unfortunate but true, flatten forces the serialization + // to use the serialize_map without any means for the Serializer to know about this + // flattening attempt + let expected = r#"{"a":True, "b":1, "d":2}"#; + assert_eq!(to_string(&u).unwrap(), expected); + + let u = D { + e: A { a: true, b: 1 }, + }; + let expected = r#"A(a=True, b=1)"#; assert_eq!(to_string(&u).unwrap(), expected); } From 1c195313e8f2d5591e026b70dde89eadf5915c74 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Wed, 7 Aug 2024 04:16:41 -0400 Subject: [PATCH 06/10] Pretty sure this is wrong (#1589) --- bindings/python/src/utils/serde_pyo3.rs | 152 +++++++++++++++++++++++- 1 file changed, 151 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs index 3969277d5..c5fc6453c 100644 --- a/bindings/python/src/utils/serde_pyo3.rs +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -2,9 +2,12 @@ use serde::de::value::Error; use serde::{ser, Serialize}; type Result = ::std::result::Result; +const MAX_DEPTH: usize = 5; + pub struct Serializer { // This string starts empty and JSON is appended as values are serialized. output: String, + level: usize, } // By convention, the public API of a Serde serializer is one or more `to_abc` @@ -18,6 +21,7 @@ where { let mut serializer = Serializer { output: String::new(), + level: 0, }; value.serialize(&mut serializer)?; Ok(serializer.output) @@ -51,6 +55,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // of the primitive types of the data model and map it to JSON by appending // into the output string. fn serialize_bool(self, v: bool) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += if v { "True" } else { "False" }; Ok(()) } @@ -74,6 +83,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Not particularly efficient but this is example code anyway. A more // performant approach would be to use the `itoa` crate. fn serialize_i64(self, v: i64) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += &v.to_string(); Ok(()) } @@ -91,6 +105,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_u64(self, v: u64) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += &v.to_string(); Ok(()) } @@ -100,6 +119,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_f64(self, v: f64) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += &v.to_string(); Ok(()) } @@ -114,6 +138,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // get the idea. For example it would emit invalid JSON if the input string // contains a '"' character. fn serialize_str(self, v: &str) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += "\""; self.output += v; self.output += "\""; @@ -152,6 +181,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // In Serde, unit means an anonymous value containing no data. Map this to // JSON as `null`. fn serialize_unit(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += "None"; Ok(()) } @@ -173,6 +207,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { _variant_index: u32, variant: &'static str, ) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } // self.serialize_str(variant) self.output += variant; Ok(()) @@ -202,6 +241,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; @@ -221,6 +265,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // explicitly in the serialized form. Some serializers may only be able to // support sequences for which the length is known up front. fn serialize_seq(self, _len: Option) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } self.output += "["; Ok(self) } @@ -230,6 +279,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // means that the corresponding `Deserialize implementation will know the // length without needing to look at the serialized data. fn serialize_tuple(self, _len: usize) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } self.output += "("; Ok(self) } @@ -252,6 +306,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; @@ -260,6 +319,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Maps are represented in JSON as `{ K: V, K: V, ... }`. fn serialize_map(self, _len: Option) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } println!("Serialize map"); self.output += "{"; Ok(self) @@ -271,6 +335,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Deserialize implementation is required to know what the keys are without // looking at the serialized data. fn serialize_struct(self, name: &'static str, _len: usize) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } // self.serialize_map(Some(len)) // name.serialize(&mut *self)?; if let Some(stripped) = name.strip_suffix("Helper") { @@ -291,6 +360,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(self); + } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; @@ -316,6 +390,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('[') { self.output += ", "; } @@ -324,6 +403,11 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer { // Close the sequence. fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += "]"; Ok(()) } @@ -338,6 +422,11 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('(') { self.output += ", "; } @@ -345,6 +434,11 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer { } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ")"; Ok(()) } @@ -359,6 +453,11 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('(') { self.output += ", "; } @@ -366,6 +465,11 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ")"; Ok(()) } @@ -388,6 +492,11 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('(') { self.output += ", "; } @@ -395,6 +504,11 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ")"; Ok(()) } @@ -424,6 +538,11 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('{') { self.output += ", "; } @@ -437,11 +556,21 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ":"; value.serialize(&mut **self) } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += "}"; Ok(()) } @@ -457,6 +586,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('(') { self.output += ", "; } @@ -471,6 +605,11 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer { } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ")"; Ok(()) } @@ -486,6 +625,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer { where T: ?Sized + Serialize, { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } if !self.output.ends_with('(') { self.output += ", "; } @@ -496,6 +640,11 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer { } fn end(self) -> Result<()> { + self.level += 1; + if self.level > MAX_DEPTH { + self.output += "..."; + return Ok(()); + } self.output += ")"; Ok(()) } @@ -525,7 +674,7 @@ fn test_struct() { let expected = r#"Test(int=1, seq=["a", "b"])"#; assert_eq!(to_string(&test).unwrap(), expected); } - +/* #[test] fn test_enum() { #[derive(Serialize)] @@ -657,3 +806,4 @@ fn test_flatten() { let expected = r#"A(a=True, b=1)"#; assert_eq!(to_string(&u).unwrap(), expected); } +*/ \ No newline at end of file From 5d15c0cb07fb7a0febdb09865b325862ad5c2041 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Aug 2024 11:45:43 +0200 Subject: [PATCH 07/10] Adding support for ellipsis. --- bindings/python/src/decoders.rs | 10 + bindings/python/src/models.rs | 10 + bindings/python/src/normalizers.rs | 10 + bindings/python/src/pre_tokenizers.rs | 10 + bindings/python/src/processors.rs | 10 + bindings/python/src/tokenizer.rs | 8 +- bindings/python/src/trainers.rs | 10 + bindings/python/src/utils/serde_pyo3.rs | 296 +++++++++----------- bindings/python/tests/test_serialization.py | 41 +++ 9 files changed, 237 insertions(+), 168 deletions(-) diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 4a4af94dd..1a03a7721 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -114,6 +114,16 @@ impl PyDecoder { fn decode(&self, tokens: Vec) -> PyResult { ToPyResult(self.decoder.decode(tokens)).into() } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 2bfaafd34..424be9f57 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -220,6 +220,16 @@ impl PyModel { fn get_trainer(&self, py: Python<'_>) -> PyResult { PyTrainer::from(self.model.read().unwrap().get_trainer()).get_as_subtype(py) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } /// An implementation of the BPE (Byte-Pair Encoding) algorithm diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 724e79b85..51c1e8bfe 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -169,6 +169,16 @@ impl PyNormalizer { ToPyResult(self.normalizer.normalize(&mut normalized)).into_py()?; Ok(normalized.get().to_owned()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a9060ec3b..4b97319d3 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -181,6 +181,16 @@ impl PyPreTokenizer { .map(|(s, o, _)| (s.to_owned(), o)) .collect()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } macro_rules! getter { diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index aceb1d446..1d8e8dfac 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -139,6 +139,16 @@ impl PyPostProcessor { .into_py()?; Ok(final_encoding.into()) } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } /// This post-processor takes care of adding the special tokens needed by diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 6faeb7ad7..f41bf335f 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -640,8 +640,12 @@ impl PyTokenizer { ToPyResult(self.tokenizer.save(path, pretty)).into() } - #[pyo3(signature = ())] - fn repr(&self) -> PyResult { + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { crate::utils::serde_pyo3::to_string(self) .map_err(|e| exceptions::PyException::new_err(e.to_string())) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index cbce2aef9..c71442298 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -69,6 +69,16 @@ impl PyTrainer { Err(e) => Err(e), } } + + fn __repr__(&self) -> PyResult { + crate::utils::serde_pyo3::repr(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + + fn __str__(&self) -> PyResult { + crate::utils::serde_pyo3::to_string(self) + .map_err(|e| exceptions::PyException::new_err(e.to_string())) + } } impl Trainer for PyTrainer { diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs index c5fc6453c..b7a189bd7 100644 --- a/bindings/python/src/utils/serde_pyo3.rs +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -2,12 +2,17 @@ use serde::de::value::Error; use serde::{ser, Serialize}; type Result = ::std::result::Result; -const MAX_DEPTH: usize = 5; - pub struct Serializer { // This string starts empty and JSON is appended as values are serialized. output: String, + /// Each levels remembers its own number of elements + num_elements: Vec, + max_elements: usize, level: usize, + max_depth: usize, + /// Maximum string representation + /// Useful to ellipsis precompiled_charmap + max_string: usize, } // By convention, the public API of a Serde serializer is one or more `to_abc` @@ -19,9 +24,34 @@ pub fn to_string(value: &T) -> Result where T: Serialize, { + let max_depth = 20; + let max_elements = 6; + let max_string = 100; + let mut serializer = Serializer { + output: String::new(), + level: 0, + max_depth, + max_elements, + num_elements: vec![0; max_depth], + max_string + }; + value.serialize(&mut serializer)?; + Ok(serializer.output) +} + +pub fn repr(value: &T) -> Result +where + T: Serialize, +{ + let max_depth = 200; + let max_string = usize::MAX; let mut serializer = Serializer { output: String::new(), level: 0, + max_depth, + max_elements: 100, + num_elements: vec![0; max_depth], + max_string }; value.serialize(&mut serializer)?; Ok(serializer.output) @@ -55,11 +85,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { // of the primitive types of the data model and map it to JSON by appending // into the output string. fn serialize_bool(self, v: bool) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } self.output += if v { "True" } else { "False" }; Ok(()) } @@ -83,11 +108,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Not particularly efficient but this is example code anyway. A more // performant approach would be to use the `itoa` crate. fn serialize_i64(self, v: i64) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } self.output += &v.to_string(); Ok(()) } @@ -105,11 +125,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_u64(self, v: u64) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } self.output += &v.to_string(); Ok(()) } @@ -119,11 +134,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { } fn serialize_f64(self, v: f64) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } self.output += &v.to_string(); Ok(()) } @@ -138,14 +148,14 @@ impl<'a> ser::Serializer for &'a mut Serializer { // get the idea. For example it would emit invalid JSON if the input string // contains a '"' character. fn serialize_str(self, v: &str) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { + self.output += "\""; + if v.len() > self.max_string{ + self.output += &v[..self.max_string]; self.output += "..."; - return Ok(()); + }else{ + self.output += v; } self.output += "\""; - self.output += v; - self.output += "\""; Ok(()) } @@ -181,11 +191,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { // In Serde, unit means an anonymous value containing no data. Map this to // JSON as `null`. fn serialize_unit(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } self.output += "None"; Ok(()) } @@ -207,11 +212,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { _variant_index: u32, variant: &'static str, ) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } // self.serialize_str(variant) self.output += variant; Ok(()) @@ -241,11 +241,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; @@ -265,12 +260,9 @@ impl<'a> ser::Serializer for &'a mut Serializer { // explicitly in the serialized form. Some serializers may only be able to // support sequences for which the length is known up front. fn serialize_seq(self, _len: Option) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } self.output += "["; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } @@ -279,12 +271,9 @@ impl<'a> ser::Serializer for &'a mut Serializer { // means that the corresponding `Deserialize implementation will know the // length without needing to look at the serialized data. fn serialize_tuple(self, _len: usize) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } self.output += "("; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } @@ -306,26 +295,19 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } // Maps are represented in JSON as `{ K: V, K: V, ... }`. fn serialize_map(self, _len: Option) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } - println!("Serialize map"); self.output += "{"; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } @@ -335,11 +317,6 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Deserialize implementation is required to know what the keys are without // looking at the serialized data. fn serialize_struct(self, name: &'static str, _len: usize) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } // self.serialize_map(Some(len)) // name.serialize(&mut *self)?; if let Some(stripped) = name.strip_suffix("Helper") { @@ -348,6 +325,8 @@ impl<'a> ser::Serializer for &'a mut Serializer { self.output += name } self.output += "("; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } @@ -360,14 +339,11 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(self); - } // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; + self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.num_elements[self.level] = 0; Ok(self) } } @@ -390,24 +366,25 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } - if !self.output.ends_with('[') { - self.output += ", "; + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + if !self.output.ends_with('[') { + self.output += ", "; + } + value.serialize(&mut **self) + }else{ + if num_elements == self.max_elements{ + self.output += ", ..."; + } + Ok(()) } - value.serialize(&mut **self) } // Close the sequence. fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += "]"; Ok(()) } @@ -422,23 +399,24 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } - if !self.output.ends_with('(') { - self.output += ", "; + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + }else{ + if num_elements == self.max_elements{ + self.output += ", ..."; + } + Ok(()) } - value.serialize(&mut **self) } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += ")"; Ok(()) } @@ -453,23 +431,24 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } - if !self.output.ends_with('(') { - self.output += ", "; + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + }else{ + if num_elements == self.max_elements{ + self.output += ", ..."; + } + Ok(()) } - value.serialize(&mut **self) } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += ")"; Ok(()) } @@ -492,23 +471,24 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } - if !self.output.ends_with('(') { - self.output += ", "; + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + if !self.output.ends_with('(') { + self.output += ", "; + } + value.serialize(&mut **self) + }else{ + if num_elements == self.max_elements{ + self.output += ", ..."; + } + Ok(()) } - value.serialize(&mut **self) } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += ")"; Ok(()) } @@ -538,15 +518,19 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } - if !self.output.ends_with('{') { - self.output += ", "; + self.num_elements[self.level] += 1; + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + if !self.output.ends_with('{') { + self.output += ", "; + } + key.serialize(&mut **self) + }else{ + if num_elements == self.max_elements{ + self.output += ", ..."; + } + Ok(()) } - key.serialize(&mut **self) } // It doesn't make a difference whether the colon is printed at the end of @@ -556,21 +540,18 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); + let num_elements = self.num_elements[self.level]; + if num_elements < self.max_elements{ + self.output += ":"; + value.serialize(&mut **self) + }else{ + Ok(()) } - self.output += ":"; - value.serialize(&mut **self) } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += "}"; Ok(()) } @@ -586,11 +567,6 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } if !self.output.ends_with('(') { self.output += ", "; } @@ -605,11 +581,8 @@ impl<'a> ser::SerializeStruct for &'a mut Serializer { } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += ")"; Ok(()) } @@ -625,11 +598,6 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer { where T: ?Sized + Serialize, { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } if !self.output.ends_with('(') { self.output += ", "; } @@ -640,11 +608,8 @@ impl<'a> ser::SerializeStructVariant for &'a mut Serializer { } fn end(self) -> Result<()> { - self.level += 1; - if self.level > MAX_DEPTH { - self.output += "..."; - return Ok(()); - } + self.num_elements[self.level] = 0; + self.level = self.level.saturating_sub(1); self.output += ")"; Ok(()) } @@ -674,7 +639,7 @@ fn test_struct() { let expected = r#"Test(int=1, seq=["a", "b"])"#; assert_eq!(to_string(&test).unwrap(), expected); } -/* + #[test] fn test_enum() { #[derive(Serialize)] @@ -806,4 +771,3 @@ fn test_flatten() { let expected = r#"A(a=True, b=1)"#; assert_eq!(to_string(&u).unwrap(), expected); } -*/ \ No newline at end of file diff --git a/bindings/python/tests/test_serialization.py b/bindings/python/tests/test_serialization.py index 4434e6304..d39282d67 100644 --- a/bindings/python/tests/test_serialization.py +++ b/bindings/python/tests/test_serialization.py @@ -5,6 +5,7 @@ import tqdm from huggingface_hub import hf_hub_download from tokenizers import Tokenizer +from tokenizers.models import BPE, Unigram from .utils import albert_base, data_dir @@ -16,6 +17,46 @@ def test_full_serialization_albert(self, albert_base): # file exceeds the buffer capacity Tokenizer.from_file(albert_base) + def test_str_big(self, albert_base): + tokenizer = Tokenizer.from_file(albert_base) + assert str(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":1, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":2, "content":"[CLS]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":3, "content":"[SEP]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":4, "content":"[MASK]", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=Sequence(normalizers=[Replace(pattern=String("``"), content="\""), Replace(pattern=String("''"), content="\""), NFKD(), StripAccents(), Lowercase(), ...]), pre_tokenizer=Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)]), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[2], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[3], tokens=["[SEP]"])}), decoder=Metaspace(replacement="▁", prepend_scheme=always, split=True), model=Unigram(unk_id=1, vocab=[("", 0), ("", 0), ("[CLS]", 0), ("[SEP]", 0), ("[MASK]", 0), ...], byte_fallback=False))""" + + def test_repr_str(self): + tokenizer = Tokenizer(BPE()) + tokenizer.add_tokens(["my"]) + assert repr(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, "normalized":True, "special":False}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + assert str(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + + def test_repr_str_ellipsis(self): + model = BPE() + assert repr(model) == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + assert str(model) == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + + vocab = [ + ("A", 0.0), + ("B", -0.01), + ("C", -0.02), + ("D", -0.03), + ("E", -0.04), + ] + # No ellispsis yet + model = Unigram(vocab, 0, byte_fallback=False) + assert repr(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + assert str(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + + # Ellispis for longer than 5 elements only on `str`. + vocab = [ + ("A", 0.0), + ("B", -0.01), + ("C", -0.02), + ("D", -0.03), + ("E", -0.04), + ("F", -0.04), + ] + model = Unigram(vocab, 0, byte_fallback=False) + assert repr(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ("F", -0.04)], byte_fallback=False)""" + assert str(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ...], byte_fallback=False)""" + def check(tokenizer_file) -> bool: with open(tokenizer_file, "r") as f: From c4901e857d3a64e7f46129fbe45a8bd00c9f3cb9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Aug 2024 11:47:07 +0200 Subject: [PATCH 08/10] Fmt. --- bindings/python/src/utils/serde_pyo3.rs | 54 ++++++++++++------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/bindings/python/src/utils/serde_pyo3.rs b/bindings/python/src/utils/serde_pyo3.rs index b7a189bd7..471993614 100644 --- a/bindings/python/src/utils/serde_pyo3.rs +++ b/bindings/python/src/utils/serde_pyo3.rs @@ -33,7 +33,7 @@ where max_depth, max_elements, num_elements: vec![0; max_depth], - max_string + max_string, }; value.serialize(&mut serializer)?; Ok(serializer.output) @@ -51,7 +51,7 @@ where max_depth, max_elements: 100, num_elements: vec![0; max_depth], - max_string + max_string, }; value.serialize(&mut serializer)?; Ok(serializer.output) @@ -149,10 +149,10 @@ impl<'a> ser::Serializer for &'a mut Serializer { // contains a '"' character. fn serialize_str(self, v: &str) -> Result<()> { self.output += "\""; - if v.len() > self.max_string{ + if v.len() > self.max_string { self.output += &v[..self.max_string]; self.output += "..."; - }else{ + } else { self.output += v; } self.output += "\""; @@ -261,7 +261,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { // support sequences for which the length is known up front. fn serialize_seq(self, _len: Option) -> Result { self.output += "["; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -272,7 +272,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { // length without needing to look at the serialized data. fn serialize_tuple(self, _len: usize) -> Result { self.output += "("; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -298,7 +298,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -306,7 +306,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { // Maps are represented in JSON as `{ K: V, K: V, ... }`. fn serialize_map(self, _len: Option) -> Result { self.output += "{"; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -325,7 +325,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { self.output += name } self.output += "("; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -342,7 +342,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { // variant.serialize(&mut *self)?; self.output += variant; self.output += "("; - self.level = std::cmp::min(self.max_depth- 1, self.level + 1); + self.level = std::cmp::min(self.max_depth - 1, self.level + 1); self.num_elements[self.level] = 0; Ok(self) } @@ -368,13 +368,13 @@ impl<'a> ser::SerializeSeq for &'a mut Serializer { { self.num_elements[self.level] += 1; let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { if !self.output.ends_with('[') { self.output += ", "; } value.serialize(&mut **self) - }else{ - if num_elements == self.max_elements{ + } else { + if num_elements == self.max_elements { self.output += ", ..."; } Ok(()) @@ -401,13 +401,13 @@ impl<'a> ser::SerializeTuple for &'a mut Serializer { { self.num_elements[self.level] += 1; let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { if !self.output.ends_with('(') { self.output += ", "; } value.serialize(&mut **self) - }else{ - if num_elements == self.max_elements{ + } else { + if num_elements == self.max_elements { self.output += ", ..."; } Ok(()) @@ -433,13 +433,13 @@ impl<'a> ser::SerializeTupleStruct for &'a mut Serializer { { self.num_elements[self.level] += 1; let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { if !self.output.ends_with('(') { self.output += ", "; } value.serialize(&mut **self) - }else{ - if num_elements == self.max_elements{ + } else { + if num_elements == self.max_elements { self.output += ", ..."; } Ok(()) @@ -473,13 +473,13 @@ impl<'a> ser::SerializeTupleVariant for &'a mut Serializer { { self.num_elements[self.level] += 1; let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { if !self.output.ends_with('(') { self.output += ", "; } value.serialize(&mut **self) - }else{ - if num_elements == self.max_elements{ + } else { + if num_elements == self.max_elements { self.output += ", ..."; } Ok(()) @@ -520,13 +520,13 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { { self.num_elements[self.level] += 1; let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { if !self.output.ends_with('{') { self.output += ", "; } key.serialize(&mut **self) - }else{ - if num_elements == self.max_elements{ + } else { + if num_elements == self.max_elements { self.output += ", ..."; } Ok(()) @@ -541,10 +541,10 @@ impl<'a> ser::SerializeMap for &'a mut Serializer { T: ?Sized + Serialize, { let num_elements = self.num_elements[self.level]; - if num_elements < self.max_elements{ + if num_elements < self.max_elements { self.output += ":"; value.serialize(&mut **self) - }else{ + } else { Ok(()) } } From 183702c1bf06dc777f24ba6ced334ced93680f7a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Aug 2024 11:52:07 +0200 Subject: [PATCH 09/10] Ruff. --- bindings/python/tests/test_serialization.py | 45 ++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/bindings/python/tests/test_serialization.py b/bindings/python/tests/test_serialization.py index d39282d67..9da2c3e27 100644 --- a/bindings/python/tests/test_serialization.py +++ b/bindings/python/tests/test_serialization.py @@ -19,18 +19,33 @@ def test_full_serialization_albert(self, albert_base): def test_str_big(self, albert_base): tokenizer = Tokenizer.from_file(albert_base) - assert str(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":1, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":2, "content":"[CLS]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":3, "content":"[SEP]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":4, "content":"[MASK]", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=Sequence(normalizers=[Replace(pattern=String("``"), content="\""), Replace(pattern=String("''"), content="\""), NFKD(), StripAccents(), Lowercase(), ...]), pre_tokenizer=Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)]), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[2], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[3], tokens=["[SEP]"])}), decoder=Metaspace(replacement="▁", prepend_scheme=always, split=True), model=Unigram(unk_id=1, vocab=[("", 0), ("", 0), ("[CLS]", 0), ("[SEP]", 0), ("[MASK]", 0), ...], byte_fallback=False))""" + assert ( + str(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":1, "content":"", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":2, "content":"[CLS]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":3, "content":"[SEP]", "single_word":False, "lstrip":False, "rstrip":False, ...}, {"id":4, "content":"[MASK]", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=Sequence(normalizers=[Replace(pattern=String("``"), content="\""), Replace(pattern=String("''"), content="\""), NFKD(), StripAccents(), Lowercase(), ...]), pre_tokenizer=Sequence(pretokenizers=[WhitespaceSplit(), Metaspace(replacement="▁", prepend_scheme=always, split=True)]), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[2], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[3], tokens=["[SEP]"])}), decoder=Metaspace(replacement="▁", prepend_scheme=always, split=True), model=Unigram(unk_id=1, vocab=[("", 0), ("", 0), ("[CLS]", 0), ("[SEP]", 0), ("[MASK]", 0), ...], byte_fallback=False))""" + ) def test_repr_str(self): tokenizer = Tokenizer(BPE()) tokenizer.add_tokens(["my"]) - assert repr(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, "normalized":True, "special":False}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" - assert str(tokenizer) == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + assert ( + repr(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, "normalized":True, "special":False}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + ) + assert ( + str(tokenizer) + == """Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"my", "single_word":False, "lstrip":False, "rstrip":False, ...}], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))""" + ) def test_repr_str_ellipsis(self): model = BPE() - assert repr(model) == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" - assert str(model) == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + assert ( + repr(model) + == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + ) + assert ( + str(model) + == """BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[])""" + ) vocab = [ ("A", 0.0), @@ -41,8 +56,14 @@ def test_repr_str_ellipsis(self): ] # No ellispsis yet model = Unigram(vocab, 0, byte_fallback=False) - assert repr(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" - assert str(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + assert ( + repr(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + ) + assert ( + str(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04)], byte_fallback=False)""" + ) # Ellispis for longer than 5 elements only on `str`. vocab = [ @@ -54,8 +75,14 @@ def test_repr_str_ellipsis(self): ("F", -0.04), ] model = Unigram(vocab, 0, byte_fallback=False) - assert repr(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ("F", -0.04)], byte_fallback=False)""" - assert str(model) == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ...], byte_fallback=False)""" + assert ( + repr(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ("F", -0.04)], byte_fallback=False)""" + ) + assert ( + str(model) + == """Unigram(unk_id=0, vocab=[("A", 0), ("B", -0.01), ("C", -0.02), ("D", -0.03), ("E", -0.04), ...], byte_fallback=False)""" + ) def check(tokenizer_file) -> bool: From 61cfeb7fceb9d80bbedc0a61b034865d80577826 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 7 Aug 2024 12:01:05 +0200 Subject: [PATCH 10/10] Fixing tokenizer. --- bindings/python/tests/bindings/test_tokenizer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 9c9c4f12d..3851f0764 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -555,8 +555,7 @@ def test_decode_special(self): class TestTokenizerRepr: def test_repr(self): tokenizer = Tokenizer(BPE()) - out = tokenizer.repr() - print(out) + out = repr(tokenizer) assert ( out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=None, pre_tokenizer=None, post_processor=None, decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))' @@ -571,8 +570,7 @@ def test_repr_complete(self): special_tokens=[("[CLS]", 1), ("[SEP]", 0)], ) tokenizer.normalizer = Sequence([Lowercase(), Strip()]) - out = tokenizer.repr() - print(out) + out = repr(tokenizer) assert ( out == 'Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[], normalizer=Sequence(normalizers=[Lowercase(), Strip(strip_left=True, strip_right=True)]), pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0)], pair=[SpecialToken(id="[CLS]", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="[SEP]", type_id=0), Sequence(id=B, type_id=1), SpecialToken(id="[SEP]", type_id=1)], special_tokens={"[CLS]":SpecialToken(id="[CLS]", ids=[1], tokens=["[CLS]"]), "[SEP]":SpecialToken(id="[SEP]", ids=[0], tokens=["[SEP]"])}), decoder=None, model=BPE(dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=False, byte_fallback=False, ignore_merges=False, vocab={}, merges=[]))'