Skip to content

Commit

Permalink
Modify Processor trait to support chaining. (#1047)
Browse files Browse the repository at this point in the history
0 modifications yet, everything will consume the vector.
Every test should be green without any modifications.
  • Loading branch information
Narsil authored Aug 24, 2022
1 parent 06c71d0 commit 1196e68
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 81 deletions.
9 changes: 4 additions & 5 deletions bindings/node/native/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ impl tk::PostProcessor for Processor {
.added_tokens(is_pair)
}

fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
) -> tk::Result<Vec<Encoding>> {
self.processor
.as_ref()
.ok_or("Uninitialized PostProcessor")?
.process(encoding, pair_encoding, add_special_tokens)
.process_encodings(encodings, add_special_tokens)
}
}

Expand Down
9 changes: 4 additions & 5 deletions bindings/python/src/processors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,13 @@ impl PostProcessor for PyPostProcessor {
self.processor.added_tokens(is_pair)
}

fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
) -> tk::Result<Vec<Encoding>> {
self.processor
.process(encoding, pair_encoding, add_special_tokens)
.process_encodings(encodings, add_special_tokens)
}
}

Expand Down
18 changes: 5 additions & 13 deletions tokenizers/src/pre_tokenizers/byte_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,29 +174,21 @@ impl PostProcessor for ByteLevel {
0
}

fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));

if let Some(encoding) = pair_encoding.as_mut() {
for encoding in encodings.iter_mut() {
process_offsets(encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
}
}

<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
<dyn PostProcessor>::default_process(encodings, add_special_tokens)
}
}

Expand Down
42 changes: 31 additions & 11 deletions tokenizers/src/processors/bert.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
Expand All @@ -25,6 +25,12 @@ impl BertProcessing {
}
}

#[derive(thiserror::Error, Debug)]
pub enum BertProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}

impl PostProcessor for BertProcessing {
fn added_tokens(&self, is_pair: bool) -> usize {
if is_pair {
Expand All @@ -34,20 +40,34 @@ impl PostProcessor for BertProcessing {
}
}

fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if !add_special_tokens {
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
return Ok(encodings);
}

let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};

let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
Expand Down Expand Up @@ -166,7 +186,7 @@ impl PostProcessor for BertProcessing {
new_encoding.merge_with(new_pair_encoding, false);
}

Ok(new_encoding)
Ok(vec![new_encoding])
}
}

Expand Down
17 changes: 7 additions & 10 deletions tokenizers/src/processors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,16 @@ impl PostProcessor for PostProcessorWrapper {
}
}

fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
match self {
Self::Bert(bert) => bert.process(encoding, pair_encoding, add_special_tokens),
Self::ByteLevel(bl) => bl.process(encoding, pair_encoding, add_special_tokens),
Self::Roberta(roberta) => roberta.process(encoding, pair_encoding, add_special_tokens),
Self::Template(template) => {
template.process(encoding, pair_encoding, add_special_tokens)
}
Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
}
}
}
Expand Down
44 changes: 26 additions & 18 deletions tokenizers/src/processors/roberta.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::tokenizer::{Encoding, PostProcessor, ProcessorError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
Expand Down Expand Up @@ -55,20 +55,13 @@ impl PostProcessor for RobertaProcessing {
}
}

fn process(
fn process_encodings(
&self,
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
) -> Result<Vec<Encoding>> {
if self.trim_offsets {
process_offsets(&mut encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
.iter_mut()
.for_each(|encoding| process_offsets(encoding, self.add_prefix_space));

if let Some(encoding) = pair_encoding.as_mut() {
for encoding in encodings.iter_mut() {
process_offsets(encoding, self.add_prefix_space);
encoding
.get_overflowing_mut()
Expand All @@ -78,13 +71,28 @@ impl PostProcessor for RobertaProcessing {
}

if !add_special_tokens {
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
return Ok(encodings);
}

let (mut encoding, pair_encoding): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};

let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
Expand Down Expand Up @@ -213,7 +221,7 @@ impl PostProcessor for RobertaProcessing {
new_encoding.merge_with(new_pair_encoding, false);
}

Ok(new_encoding)
Ok(vec![new_encoding])
}
}

Expand Down
33 changes: 26 additions & 7 deletions tokenizers/src/processors/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
//!
//! [`TemplateProcessing`]: struct.TemplateProcessing.html
//!
use crate::{Encoding, PostProcessor, Result};
use crate::{tokenizer::ProcessorError, Encoding, PostProcessor, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -630,13 +630,31 @@ impl PostProcessor for TemplateProcessing {
}
}

fn process(
fn process_encodings(
&self,
encoding: Encoding,
pair: Option<Encoding>,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding> {
self.apply_template(
) -> Result<Vec<Encoding>> {
let (encoding, pair): (Encoding, Option<Encoding>) = match encodings.len() {
1 => (
encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?,
None,
),
2 => {
let pair = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
let encoding = encodings
.pop()
.ok_or(ProcessorError::InvalidEncodingsVecLength)?;
(encoding, Some(pair))
}
_ => return Err(Box::new(ProcessorError::InvalidEncodingsVecLength)),
};

let encoding = self.apply_template(
if pair.is_some() {
&self.pair.0
} else {
Expand All @@ -645,7 +663,8 @@ impl PostProcessor for TemplateProcessing {
encoding,
pair,
add_special_tokens,
)
)?;
Ok(vec![encoding])
}
}

Expand Down
58 changes: 46 additions & 12 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,26 +99,50 @@ pub trait PostProcessor {
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> Result<Encoding>;
) -> Result<Encoding> {
let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};

let encodings = self.process_encodings(encodings, add_special_tokens)?;

Ok(Encoding::merge(encodings, false))
}

/// Process any amount of encodings and returns a series of encoding (might merge them)
fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>>;
}
impl dyn PostProcessor {
pub fn default_process(
mut encoding: Encoding,
pair_encoding: Option<Encoding>,
encodings: Vec<Encoding>,
_add_special_tokens: bool,
) -> Result<Encoding> {
match pair_encoding {
None => Ok(encoding),
Some(mut pair) => {
encoding.set_sequence_id(0);
pair.set_sequence_id(1);
encoding.merge_with(pair, false);
Ok(encoding)
) -> Result<Vec<Encoding>> {
match encodings.len() {
1 => Ok(encodings),
_ => {
let mut final_encoding = Encoding::default();
for (i, mut encoding) in encodings.into_iter().enumerate() {
encoding.set_sequence_id(i);
final_encoding.merge_with(encoding, false);
}
Ok(vec![final_encoding])
}
}
}
}

#[derive(thiserror::Error, Debug)]
pub enum ProcessorError {
#[error("encodings vector length must be either 1 or 2")]
InvalidEncodingsVecLength,
}

/// A `Decoder` changes the raw tokens into its more readable form.
pub trait Decoder {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Expand Down Expand Up @@ -895,7 +919,17 @@ where
let final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding, add_special_tokens)?
} else {
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
let encodings = if let Some(pair_encoding) = pair_encoding {
vec![encoding, pair_encoding]
} else {
vec![encoding]
};
let mut encodings =
<dyn PostProcessor>::default_process(encodings, add_special_tokens)?;
if encodings.len() != 1 {
panic!("We haven't reduced the encodings like we should have");
}
encodings.pop().unwrap()
};

// 3. Then we pad if needed
Expand Down

0 comments on commit 1196e68

Please sign in to comment.