Skip to content

Commit

Permalink
Add SplitDelimiterBehavior to Punctuation constructor
Browse files Browse the repository at this point in the history
Resolves: #642
  • Loading branch information
vladdy committed Mar 17, 2021
1 parent f5e9bb8 commit 6162e2c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
26 changes: 21 additions & 5 deletions tokenizers/src/pre_tokenizers/punctuation.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
use serde::{Deserialize, Serialize};

use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use unicode_categories::UnicodeCategories;

fn is_punc(x: char) -> bool {
char::is_ascii_punctuation(&x) || x.is_punctuation()
}

#[derive(Copy, Clone, Debug)]
pub struct Punctuation;
impl_serde_unit_struct!(PunctuationVisitor, Punctuation);
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(tag = "type")]
pub struct Punctuation {
behavior: SplitDelimiterBehavior,
}

impl Punctuation {
pub fn new(behavior: SplitDelimiterBehavior) -> Self {
Self { behavior }
}
}

impl Default for Punctuation {
fn default() -> Self {
Self::new(SplitDelimiterBehavior::Isolated)
}
}

impl PreTokenizer for Punctuation {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, s| s.split(is_punc, SplitDelimiterBehavior::Isolated))
pretokenized.split(|_, s| s.split(is_punc, self.behavior))
}
}

Expand All @@ -22,7 +38,7 @@ mod tests {

#[test]
fn punctuation_basic() {
let pretok = Punctuation;
let pretok = Punctuation::default();
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/pre_tokenizers/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mod tests {
fn sequence_basic() {
let pretokenizers = vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit),
PreTokenizerWrapper::Punctuation(Punctuation),
PreTokenizerWrapper::Punctuation(Punctuation::default()),
];
let pretok = Sequence::new(pretokenizers);
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
Expand Down

0 comments on commit 6162e2c

Please sign in to comment.