From 5f77a1c52e8c11257f4cf2c34e603e99c1bbc282 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 23 Aug 2022 18:20:43 +0200 Subject: [PATCH] Modify `Processor` trait to support chaining. 0 modifications yet, everything will consume the vector. Every test should be green without any modifications. --- bindings/node/native/src/processors.rs | 9 ++-- bindings/python/src/processors.rs | 9 ++-- tokenizers/src/pre_tokenizers/byte_level.rs | 18 ++----- tokenizers/src/processors/bert.rs | 42 +++++++++++---- tokenizers/src/processors/mod.rs | 17 +++--- tokenizers/src/processors/roberta.rs | 44 +++++++++------- tokenizers/src/processors/template.rs | 33 +++++++++--- tokenizers/src/tokenizer/mod.rs | 58 ++++++++++++++++----- 8 files changed, 149 insertions(+), 81 deletions(-) diff --git a/bindings/node/native/src/processors.rs b/bindings/node/native/src/processors.rs index 28154735e..013984605 100644 --- a/bindings/node/native/src/processors.rs +++ b/bindings/node/native/src/processors.rs @@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor { .added_tokens(is_pair) } - fn process( + fn process_encodings( &self, - encoding: Encoding, - pair_encoding: Option, + encodings: Vec, add_special_tokens: bool, - ) -> tk::Result { + ) -> tk::Result> { self.processor .as_ref() .ok_or("Uninitialized PostProcessor")? - .process(encoding, pair_encoding, add_special_tokens) + .process_encodings(encodings, add_special_tokens) } } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index 4e346f7d6..f1973bb64 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor { self.processor.added_tokens(is_pair) } - fn process( + fn process_encodings( &self, - encoding: Encoding, - pair_encoding: Option, + encodings: Vec, add_special_tokens: bool, - ) -> tk::Result { + ) -> tk::Result> { self.processor - .process(encoding, pair_encoding, add_special_tokens) + .process_encodings(encodings, add_special_tokens) } } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index afa3d3727..5338189fd 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -174,20 +174,13 @@ impl PostProcessor for ByteLevel { 0 } - fn process( + fn process_encodings( &self, - mut encoding: Encoding, - mut pair_encoding: Option, + mut encodings: Vec, add_special_tokens: bool, - ) -> Result { + ) -> Result> { if self.trim_offsets { - process_offsets(&mut encoding, self.add_prefix_space); - encoding - .get_overflowing_mut() - .iter_mut() - .for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); - - if let Some(encoding) = pair_encoding.as_mut() { + for encoding in encodings.iter_mut() { process_offsets(encoding, self.add_prefix_space); encoding .get_overflowing_mut() @@ -195,8 +188,7 @@ impl PostProcessor for ByteLevel { .for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); } } - - ::default_process(encoding, pair_encoding, add_special_tokens) + ::default_process(encodings, add_special_tokens) } } diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index bbfb45770..d93f03849 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::{Encoding, PostProcessor, Result}; +use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -25,6 +25,12 @@ impl BertProcessing { } } +#[derive(thiserror::Error, Debug)] +pub enum BertProcessorError { + #[error("encodings vector length must be either 1 or 2")] + InvalidEncodingsVecLength, +} + impl PostProcessor for BertProcessing { fn added_tokens(&self, is_pair: bool) -> usize { if is_pair { @@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing { } } - fn process( + fn process_encodings( &self, - mut encoding: Encoding, - pair_encoding: Option, + mut encodings: Vec, add_special_tokens: bool, - ) -> Result { + ) -> Result> { if !add_special_tokens { - return ::default_process( - encoding, - pair_encoding, - add_special_tokens, - ); + return Ok(encodings); } + let (mut encoding, pair_encoding): (Encoding, Option) = match encodings.len() { + 1 => ( + encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?, + None, + ), + 2 => { + let pair = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + let encoding = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + (encoding, Some(pair)) + } + _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), + }; + let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let tokens = [ @@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing { new_encoding.merge_with(new_pair_encoding, false); } - Ok(new_encoding) + Ok(vec![new_encoding]) } } diff --git a/tokenizers/src/processors/mod.rs b/tokenizers/src/processors/mod.rs index a74eec75d..ccacc7de5 100644 --- a/tokenizers/src/processors/mod.rs +++ b/tokenizers/src/processors/mod.rs @@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper { } } - fn process( + fn process_encodings( &self, - encoding: Encoding, - pair_encoding: Option, + encodings: Vec, add_special_tokens: bool, - ) -> Result { + ) -> Result> { match self { - Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens), - Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens), - Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens), - Self::Template(template) => { - template.process(encoding, pair_encoding, add_special_tokens) - } + Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens), + Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens), + Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens), + Self::Template(template) => template.process_encodings(encodings, add_special_tokens), } } } diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index a83843060..41be29b59 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -1,5 +1,5 @@ use crate::processors::byte_level::process_offsets; -use crate::tokenizer::{Encoding, PostProcessor, Result}; +use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::iter::FromIterator; @@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing { } } - fn process( + fn process_encodings( &self, - mut encoding: Encoding, - mut pair_encoding: Option, + mut encodings: Vec, add_special_tokens: bool, - ) -> Result { + ) -> Result> { if self.trim_offsets { - process_offsets(&mut encoding, self.add_prefix_space); - encoding - .get_overflowing_mut() - .iter_mut() - .for_each(|encoding| process_offsets(encoding, self.add_prefix_space)); - - if let Some(encoding) = pair_encoding.as_mut() { + for encoding in encodings.iter_mut() { process_offsets(encoding, self.add_prefix_space); encoding .get_overflowing_mut() @@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing { } if !add_special_tokens { - return ::default_process( - encoding, - pair_encoding, - add_special_tokens, - ); + return Ok(encodings); } + let (mut encoding, pair_encoding): (Encoding, Option) = match encodings.len() { + 1 => ( + encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?, + None, + ), + 2 => { + let pair = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + let encoding = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + (encoding, Some(pair)) + } + _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), + }; + let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let tokens = [ @@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing { new_encoding.merge_with(new_pair_encoding, false); } - Ok(new_encoding) + Ok(vec![new_encoding]) } } diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index f2463a802..860461af9 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -55,7 +55,7 @@ //! //! [`TemplateProcessing`]: struct.TemplateProcessing.html //! -use crate::{Encoding, PostProcessor, Result}; +use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result}; use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; @@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing { } } - fn process( + fn process_encodings( &self, - encoding: Encoding, - pair: Option, + mut encodings: Vec, add_special_tokens: bool, - ) -> Result { - self.apply_template( + ) -> Result> { + let (encoding, pair): (Encoding, Option) = match encodings.len() { + 1 => ( + encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?, + None, + ), + 2 => { + let pair = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + let encoding = encodings + .pop() + .ok_or(ProcessorError::InvalidEncodingsVecLength)?; + (encoding, Some(pair)) + } + _ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)), + }; + + let encoding = self.apply_template( if pair.is_some() { &self.pair.0 } else { @@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing { encoding, pair, add_special_tokens, - ) + )?; + Ok(vec![encoding]) } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 4a7061359..6ccec4d73 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -99,26 +99,50 @@ pub trait PostProcessor { encoding: Encoding, pair_encoding: Option, add_special_tokens: bool, - ) -> Result; + ) -> Result { + let encodings = if let Some(pair_encoding) = pair_encoding { + vec![encoding, pair_encoding] + } else { + vec![encoding] + }; + + let encodings = self.process_encodings(encodings, add_special_tokens)?; + + Ok(Encoding::merge(encodings, false)) + } + + /// Process any amount of encodings and returns a series of encoding (might merge them) + fn process_encodings( + &self, + encodings: Vec, + add_special_tokens: bool, + ) -> Result>; } impl dyn PostProcessor { pub fn default_process( - mut encoding: Encoding, - pair_encoding: Option, + encodings: Vec, _add_special_tokens: bool, - ) -> Result { - match pair_encoding { - None => Ok(encoding), - Some(mut pair) => { - encoding.set_sequence_id(0); - pair.set_sequence_id(1); - encoding.merge_with(pair, false); - Ok(encoding) + ) -> Result> { + match encodings.len() { + 1 => Ok(encodings), + _ => { + let mut final_encoding = Encoding::default(); + for (i, mut encoding) in encodings.into_iter().enumerate() { + encoding.set_sequence_id(i); + final_encoding.merge_with(encoding, false); + } + Ok(vec![final_encoding]) } } } } +#[derive(thiserror::Error, Debug)] +pub enum ProcessorError { + #[error("encodings vector length must be either 1 or 2")] + InvalidEncodingsVecLength, +} + /// A `Decoder` changes the raw tokens into its more readable form. pub trait Decoder { fn decode(&self, tokens: Vec) -> Result { @@ -895,7 +919,17 @@ where let final_encoding = if let Some(processor) = &self.post_processor { processor.process(encoding, pair_encoding, add_special_tokens)? } else { - ::default_process(encoding, pair_encoding, add_special_tokens)? + let encodings = if let Some(pair_encoding) = pair_encoding { + vec![encoding, pair_encoding] + } else { + vec![encoding] + }; + let mut encodings = + ::default_process(encodings, add_special_tokens)?; + if encodings.len() != 1 { + panic!("We haven't reduced the encodings like we should have"); + } + encodings.pop().unwrap() }; // 3. Then we pad if needed