From 7c146d9ce5a428600dc63e6dc2db4972ec430191 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 16 Sep 2022 11:20:59 +0200 Subject: [PATCH] Turns out we introduced a regression because bad code. (#1060) --- tokenizers/src/decoders/wordpiece.rs | 34 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 952108d65..8ecd3987c 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -49,10 +49,12 @@ impl Decoder for WordPiece { .iter_mut() .enumerate() .map(|(i, token)| { - if token.starts_with(&self.prefix) { - *token = token.replacen(&self.prefix, "", 1); - } else if i != 0 { - *token = format!(" {}", token); + if i != 0 { + if token.starts_with(&self.prefix) { + *token = token.replacen(&self.prefix, "", 1); + } else { + *token = format!(" {}", token); + } } if self.cleanup { *token = cleanup(token); @@ -62,3 +64,27 @@ impl Decoder for WordPiece { .collect::>() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wordpiece_decoder() { + let decoder = WordPiece::new("##".to_string(), false); + + assert_eq!( + decoder + .decode(vec![ + "##uelo".to_string(), + "Ara".to_string(), + "##új".to_string(), + "##o".to_string(), + "No".to_string(), + "##guera".to_string() + ]) + .unwrap(), + "##uelo Araújo Noguera" + ); + } +}