Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes decode and decode_batch work on borrowed content. #1251

Merged
merged 11 commits into from
May 17, 2023
7 changes: 5 additions & 2 deletions bindings/node/native/src/tasks/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,17 @@ impl Task for DecodeTask {
.tokenizer
.read()
.unwrap()
.decode(ids.to_vec(), *skip_special_tokens)
.decode(ids.as_slice(), *skip_special_tokens)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Single),
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
.tokenizer
.read()
.unwrap()
.decode_batch(ids.to_vec(), *skip_special_tokens)
.decode_batch(
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
*skip_special_tokens,
)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Batch),
}
Expand Down
5 changes: 3 additions & 2 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ impl PyTokenizer {
#[pyo3(signature = (ids, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
}

/// Decode a batch of ids back to their corresponding string
Expand All @@ -1032,7 +1032,8 @@ impl PyTokenizer {
skip_special_tokens: bool,
) -> PyResult<Vec<String>> {
py.allow_threads(|| {
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into()
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
})
}

Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
println!("Offsets:\t{:?}", encoded.get_offsets());
println!(
"Decoded:\t{}",
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
tokenizer.decode(encoded.get_ids(), true).unwrap()
);
println!("Tokenized in {:?}", elapsed);
}
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,12 +795,12 @@ where
}

/// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids
.into_iter()
.iter()
.filter_map(|id| {
self.added_vocabulary
.id_to_token(id, &self.model)
.id_to_token(*id, &self.model)
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
Expand Down Expand Up @@ -1008,7 +1008,7 @@ where
/// Decode all sentences in parallel
pub fn decode_batch(
&self,
sentences: Vec<Vec<u32>>,
sentences: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>>
where
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn load_tokenizer() {
assert_eq!(encodings.get_ids(), ids);
assert_eq!(encodings.get_tokens(), tokens);

let decoded = tokenizer.decode(ids, false).unwrap();
let decoded = tokenizer.decode(&ids, false).unwrap();
assert_eq!(decoded, example);
}

Expand Down Expand Up @@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]

let decoded = tokenizer.decode(
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
&[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
true,
)?;
println!("{}", decoded);
Expand Down Expand Up @@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
println!("{:?}", output.get_tokens());
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]

let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
println!("{}", decoded);
// "welcome to the tok ##eni ##zer ##s library ."
// END bert_test_decoding
Expand All @@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;

bert_tokenizer.with_decoder(WordPieceDecoder::default());
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
// "welcome to the tokenizers library."
// END bert_proper_decoding
assert_eq!(decoded, "welcome to the tokenizers library.");
Expand Down