diff --git a/bindings/node/native/src/tasks/tokenizer.rs b/bindings/node/native/src/tasks/tokenizer.rs index 2a7e3e0ad..495ae53a7 100644 --- a/bindings/node/native/src/tasks/tokenizer.rs +++ b/bindings/node/native/src/tasks/tokenizer.rs @@ -106,14 +106,17 @@ impl Task for DecodeTask { .tokenizer .read() .unwrap() - .decode(ids.to_vec(), *skip_special_tokens) + .decode(ids.as_slice(), *skip_special_tokens) .map_err(|e| format!("{}", e)) .map(DecodeOutput::Single), DecodeTask::Batch(worker, ids, skip_special_tokens) => worker .tokenizer .read() .unwrap() - .decode_batch(ids.to_vec(), *skip_special_tokens) + .decode_batch( + &ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(), + *skip_special_tokens, + ) .map_err(|e| format!("{}", e)) .map(DecodeOutput::Batch), } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 95a954a27..1fe296ed0 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1009,7 +1009,7 @@ impl PyTokenizer { #[pyo3(signature = (ids, skip_special_tokens = true))] #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")] fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> { - ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() + ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into() } /// Decode a batch of ids back to their corresponding string @@ -1032,7 +1032,8 @@ impl PyTokenizer { skip_special_tokens: bool, ) -> PyResult<Vec<String>> { py.allow_threads(|| { - ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into() + let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>(); + ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into() }) } diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index 6bf523ef8..54b82357f 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> { println!("Offsets:\t{:?}", encoded.get_offsets()); println!( "Decoded:\t{}", - tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap() + tokenizer.decode(encoded.get_ids(), true).unwrap() ); println!("Tokenized in {:?}", elapsed); } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a88306f3a..01ec187cd 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -795,12 +795,12 @@ where } /// Decode the given ids, back to a String - pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> { + pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> { let tokens = ids - .into_iter() + .iter() .filter_map(|id| { self.added_vocabulary - .id_to_token(id, &self.model) + .id_to_token(*id, &self.model) .filter(|token| { !skip_special_tokens || !self.added_vocabulary.is_special_token(token) }) @@ -1008,7 +1008,7 @@ where /// Decode all sentences in parallel pub fn decode_batch( &self, - sentences: Vec<Vec<u32>>, + sentences: &[&[u32]], skip_special_tokens: bool, ) -> Result<Vec<String>> where diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 605f8a4bd..7cf04debe 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -54,7 +54,7 @@ fn load_tokenizer() { assert_eq!(encodings.get_ids(), ids); assert_eq!(encodings.get_tokens(), tokens); - let decoded = tokenizer.decode(ids, false).unwrap(); + let decoded = tokenizer.decode(&ids, false).unwrap(); assert_eq!(decoded, example); } @@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> { // [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2] let decoded = tokenizer.decode( - vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], + &[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], true, )?; println!("{}", decoded); @@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { println!("{:?}", output.get_tokens()); // ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] - let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; + let decoded = bert_tokenizer.decode(output.get_ids(), true)?; println!("{}", decoded); // "welcome to the tok ##eni ##zer ##s library ." // END bert_test_decoding @@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; bert_tokenizer.with_decoder(WordPieceDecoder::default()); - let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; + let decoded = bert_tokenizer.decode(output.get_ids(), true)?; // "welcome to the tokenizers library." // END bert_proper_decoding assert_eq!(decoded, "welcome to the tokenizers library.");