From 908e96f9da7300c7cafc90b86f984a5451afd96f Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Thu, 4 May 2023 19:17:29 -0400 Subject: [PATCH 1/2] Convert text to CString --- src/whisper_ctx.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 98afa9940a6..10ce2dd33bb 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -88,12 +88,15 @@ impl WhisperContext { text: &str, max_tokens: usize, ) -> Result, WhisperError> { + // convert the text to a nul-terminated C string. Will raise an error if the text contains + // any nul bytes. + let text = CString::new(text)?; // allocate at least max_tokens to ensure the memory is valid let mut tokens: Vec = Vec::with_capacity(max_tokens); let ret = unsafe { whisper_rs_sys::whisper_tokenize( self.ctx, - text.as_ptr() as *const _, + text.as_ptr(), tokens.as_mut_ptr(), max_tokens as c_int, ) From 44e34ba3012e876030003bb860c3c4d6c327d8ba Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Thu, 4 May 2023 19:18:51 -0400 Subject: [PATCH 2/2] Add test using tiny.en model behind a feature flag --- Cargo.toml | 1 + src/whisper_ctx.rs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 35a1bcb380f..5cac3fb6c1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ hound = "3.5.0" [features] simd = [] coreml = ["whisper-rs-sys/coreml"] +test-with-tiny-model = [] [package.metadata.docs.rs] features = ["simd"] diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 10ce2dd33bb..8783737f706 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -431,3 +431,26 @@ impl Drop for WhisperContext { // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 unsafe impl Send for WhisperContext {} unsafe impl Sync for WhisperContext {} + +#[cfg(test)] +#[cfg(feature = "test-with-tiny-model")] +mod test_with_tiny_model { + use super::*; + const MODEL_PATH: &str = "./sys/whisper.cpp/models/ggml-tiny.en.bin"; + + // These tests expect that the tiny.en model has been downloaded + // using the script `sys/whisper.cpp/models/download-ggml-model.sh tiny.en` + + #[test] + fn test_tokenize_round_trip() { + let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); + let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; + let tokens = ctx.tokenize(text_in, 1024).unwrap(); + let text_out = tokens + .into_iter() + .map(|t| ctx.token_to_str(t).unwrap()) + .collect::>() + .join(""); + assert_eq!(text_in, text_out); + } +}