diff --git a/Cargo.lock b/Cargo.lock index 7e8ac5c..b7e70c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,6 +66,7 @@ dependencies = [ "arraystring", "bitflags", "regex", + "widestring", ] [[package]] @@ -75,6 +76,7 @@ dependencies = [ "diplomat", "diplomat-runtime", "ib-pinyin", + "widestring", ] [[package]] @@ -200,3 +202,9 @@ name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" diff --git a/Cargo.toml b/Cargo.toml index be11c31..344d97d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,9 @@ keywords.workspace = true arraystring = "0.3.0" bitflags = "2.4.1" regex = "1.10.2" +widestring = { version = "1.0.2", optional = true } [features] inmut-data = [] minimal = ["inmut-data"] +encoding = ["dep:widestring"] diff --git a/README.md b/README.md index b7321b4..1b61dd4 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,14 @@ assert!(matcher.is_match("拼音搜索Everything")); #include #include +// UTF-8 bool is_match = ib_pinyin_is_match_u8c(u8"pysousuoeve", u8"拼音搜索Everything", PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII); + +// UTF-16 +bool is_match = ib_pinyin_is_match_u16c(u"pysousuoeve", u"拼音搜索Everything", PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII); + +// UTF-32 +bool is_match = ib_pinyin_is_match_u32c(U"pysousuoeve", U"拼音搜索Everything", PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII); ``` ### C++ diff --git a/bindings/c/Cargo.toml b/bindings/c/Cargo.toml index 693ad42..a336a10 100644 --- a/bindings/c/Cargo.toml +++ b/bindings/c/Cargo.toml @@ -16,4 +16,5 @@ crate-type = ["staticlib", "cdylib"] [dependencies] diplomat = "0.7.0" diplomat-runtime = "0.7.0" -ib-pinyin = { path = "../..", features = ["minimal"] } +ib-pinyin = { path = "../..", features = ["minimal", "encoding"] } +widestring = "1.0.2" diff --git a/bindings/c/examples/cmake/main.c b/bindings/c/examples/cmake/main.c index d6ae40d..ab6b666 100644 --- a/bindings/c/examples/cmake/main.c +++ b/bindings/c/examples/cmake/main.c @@ -3,9 +3,8 @@ #include #include -int main() -{ - const char *pattern = u8"pysousuoeve"; +void test_u8() { + const char *pattern = u8"pysousuoeve"; const char *haystack = u8"拼音搜索Everything"; // 0x3 const PinyinNotation notations = PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII; @@ -13,6 +12,47 @@ int main() printf("%d\n", ib_pinyin_is_match_u8(pattern, strlen(pattern), haystack, strlen(haystack), notations)); printf("%d\n", ib_pinyin_is_match_u8c(pattern, haystack, notations)); +} + +void test_u16() { + const char16_t *pattern = u"pysousuoeve"; + const char16_t *haystack = u"拼音搜索Everything"; + // 0x3 + const PinyinNotation notations = PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII; + + printf("%d\n", ib_pinyin_is_match_u16( + pattern, + sizeof(u"pysousuoeve") / sizeof(char16_t) - 1, + haystack, + sizeof(u"拼音搜索Everything") / sizeof(char16_t) - 1, + notations + )); + + printf("%d\n", ib_pinyin_is_match_u16c(pattern, haystack, notations)); +} + +void test_u32() { + const char32_t *pattern = U"pysousuoeve"; + const char32_t *haystack = U"拼音搜索Everything"; + // 0x3 + const PinyinNotation notations = PINYIN_NOTATION_ASCII_FIRST_LETTER | PINYIN_NOTATION_ASCII; + + printf("%d\n", ib_pinyin_is_match_u32( + pattern, + sizeof(U"pysousuoeve") / sizeof(char32_t) - 1, + haystack, + sizeof(U"拼音搜索Everything") / sizeof(char32_t) - 1, + notations + )); + + printf("%d\n", ib_pinyin_is_match_u32c(pattern, haystack, notations)); +} + +int main() +{ + test_u8(); + test_u16(); + test_u32(); return 0; } diff --git a/bindings/c/include/ib_pinyin/ib_pinyin.h b/bindings/c/include/ib_pinyin/ib_pinyin.h index 2bb294c..9f19e60 100644 --- a/bindings/c/include/ib_pinyin/ib_pinyin.h +++ b/bindings/c/include/ib_pinyin/ib_pinyin.h @@ -22,6 +22,14 @@ extern "C" { bool ib_pinyin_is_match_u8(const char* pattern_data, size_t pattern_len, const char* haystack_data, size_t haystack_len, uint32_t pinyin_notations); bool ib_pinyin_is_match_u8c(const uint8_t* pattern, const uint8_t* haystack, uint32_t pinyin_notations); + +bool ib_pinyin_is_match_u16(const uint16_t* pattern, size_t pattern_len, const uint16_t* haystack, size_t haystack_len, uint32_t pinyin_notations); + +bool ib_pinyin_is_match_u16c(const uint16_t* pattern, const uint16_t* haystack, uint32_t pinyin_notations); + +bool ib_pinyin_is_match_u32(const uint32_t* pattern, size_t pattern_len, const uint32_t* haystack, size_t haystack_len, uint32_t pinyin_notations); + +bool ib_pinyin_is_match_u32c(const uint32_t* pattern, const uint32_t* haystack, uint32_t pinyin_notations); void ib_pinyin_destroy(ib_pinyin* self); #ifdef __cplusplus diff --git a/bindings/c/src/lib.rs b/bindings/c/src/lib.rs index 35c21ce..be43729 100644 --- a/bindings/c/src/lib.rs +++ b/bindings/c/src/lib.rs @@ -5,6 +5,7 @@ mod ffi { use std::ffi::CStr; use ::ib_pinyin::{minimal, pinyin::PinyinNotation}; + use widestring::{U16CStr, U16Str, U32CStr, U32Str}; /// https://github.com/rust-diplomat/diplomat/issues/392 #[allow(non_camel_case_types)] @@ -20,6 +21,7 @@ mod ffi { ) } + /// TODO: Lossy decoding? pub fn is_match_u8c(pattern: &u8, haystack: &u8, pinyin_notations: u32) -> bool { (|| -> Result { Ok(Self::is_match_u8( @@ -30,5 +32,49 @@ mod ffi { })() .unwrap_or(false) } + + pub fn is_match_u16( + pattern: &u16, + pattern_len: usize, + haystack: &u16, + haystack_len: usize, + pinyin_notations: u32, + ) -> bool { + minimal::is_pinyin_match_u16( + unsafe { U16Str::from_ptr(pattern as *const u16, pattern_len) }, + unsafe { U16Str::from_ptr(haystack as *const u16, haystack_len) }, + PinyinNotation::from_bits_truncate(pinyin_notations), + ) + } + + pub fn is_match_u16c(pattern: &u16, haystack: &u16, pinyin_notations: u32) -> bool { + minimal::is_pinyin_match_u16( + unsafe { U16CStr::from_ptr_str(pattern as *const u16) }.as_ustr(), + unsafe { U16CStr::from_ptr_str(haystack as *const u16) }.as_ustr(), + PinyinNotation::from_bits_truncate(pinyin_notations), + ) + } + + pub fn is_match_u32( + pattern: &u32, + pattern_len: usize, + haystack: &u32, + haystack_len: usize, + pinyin_notations: u32, + ) -> bool { + minimal::is_pinyin_match_u32( + unsafe { U32Str::from_ptr(pattern as *const u32, pattern_len) }, + unsafe { U32Str::from_ptr(haystack as *const u32, haystack_len) }, + PinyinNotation::from_bits_truncate(pinyin_notations), + ) + } + + pub fn is_match_u32c(pattern: &u32, haystack: &u32, pinyin_notations: u32) -> bool { + minimal::is_pinyin_match_u32( + unsafe { U32CStr::from_ptr_str(pattern as *const u32) }.as_ustr(), + unsafe { U32CStr::from_ptr_str(haystack as *const u32) }.as_ustr(), + PinyinNotation::from_bits_truncate(pinyin_notations), + ) + } } } diff --git a/src/lib.rs b/src/lib.rs index 3946562..dca3438 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(return_position_impl_trait_in_trait)] + pub mod matcher; #[cfg(feature = "minimal")] pub mod minimal; diff --git a/src/matcher/encoding.rs b/src/matcher/encoding.rs new file mode 100644 index 0000000..92b2662 --- /dev/null +++ b/src/matcher/encoding.rs @@ -0,0 +1,90 @@ +/// TODO: Extended ASCII code pages +/// TODO: Index/SliceIndex +pub trait EncodedStr { + const ELEMENT_LEN_BYTE: usize; + + fn is_ascii(&self) -> bool; + fn as_bytes(&self) -> &[u8]; + + fn char_index_strs(&self) -> impl Iterator; + fn char_len_next_strs(&self) -> impl Iterator; +} + +impl EncodedStr for str { + const ELEMENT_LEN_BYTE: usize = core::mem::size_of::(); + + fn is_ascii(&self) -> bool { + self.is_ascii() + } + + fn as_bytes(&self) -> &[u8] { + self.as_bytes() + } + + fn char_index_strs(&self) -> impl Iterator { + self.char_indices().map(|(i, c)| (i, c, &self[i..])) + } + + fn char_len_next_strs(&self) -> impl Iterator { + self.char_indices().map(|(i, c)| { + let len = c.len_utf8(); + (c, len, &self[i + len..]) + }) + } +} + +#[cfg(feature = "encoding")] +impl EncodedStr for widestring::U16Str { + const ELEMENT_LEN_BYTE: usize = core::mem::size_of::(); + + fn is_ascii(&self) -> bool { + self.as_bytes().is_ascii() + } + + fn as_bytes(&self) -> &[u8] { + unsafe { + core::slice::from_raw_parts( + self.as_ptr() as *const u8, + self.len() * core::mem::size_of::(), + ) + } + } + + fn char_index_strs(&self) -> impl Iterator { + self.char_indices_lossy().map(|(i, c)| (i, c, &self[i..])) + } + + fn char_len_next_strs(&self) -> impl Iterator { + self.char_indices_lossy().map(|(i, c)| { + let len = c.len_utf16(); + (c, len, &self[i + len..]) + }) + } +} + +#[cfg(feature = "encoding")] +impl EncodedStr for widestring::U32Str { + const ELEMENT_LEN_BYTE: usize = core::mem::size_of::(); + + fn is_ascii(&self) -> bool { + self.as_bytes().is_ascii() + } + + fn as_bytes(&self) -> &[u8] { + unsafe { + core::slice::from_raw_parts( + self.as_ptr() as *const u8, + self.len() * core::mem::size_of::(), + ) + } + } + + fn char_index_strs(&self) -> impl Iterator { + self.char_indices_lossy().map(|(i, c)| (i, c, &self[i..])) + } + + fn char_len_next_strs(&self) -> impl Iterator { + self.char_indices_lossy() + .map(|(i, c)| (c, 1, &self[i + 1..])) + } +} diff --git a/src/matcher/mod.rs b/src/matcher/mod.rs index 0f6fef1..320e9a7 100644 --- a/src/matcher/mod.rs +++ b/src/matcher/mod.rs @@ -1,28 +1,45 @@ -use std::{borrow::Cow, ops::Range}; +use std::{borrow::Cow, marker::PhantomData, ops::Range}; use crate::pinyin::{PinyinData, PinyinNotation}; mod unicode; use unicode::{CharToMonoLowercase, StrToMonoLowercase}; -pub struct PinyinMatcherBuilder<'a> { +pub mod encoding; +use encoding::EncodedStr; + +mod regex_utils; + +pub struct PinyinMatcherBuilder<'a, HaystackStr = str> +where + HaystackStr: EncodedStr + ?Sized, +{ pattern: String, + pattern_bytes: Vec, case_insensitive: bool, is_pattern_partial: bool, pinyin_data: Option<&'a PinyinData>, pinyin_notations: PinyinNotation, pinyin_case_insensitive: bool, + + _haystack_str: PhantomData, } -impl<'a> PinyinMatcherBuilder<'a> { - fn new(pattern: &str) -> Self { +impl<'a, HaystackStr> PinyinMatcherBuilder<'a, HaystackStr> +where + HaystackStr: EncodedStr + ?Sized, +{ + fn new(pattern: &HaystackStr) -> Self { Self { - pattern: pattern.to_owned(), + pattern: pattern.char_index_strs().map(|(_, c, _)| c).collect(), + pattern_bytes: pattern.as_bytes().to_owned(), case_insensitive: true, is_pattern_partial: false, pinyin_data: None, pinyin_notations: PinyinNotation::Ascii | PinyinNotation::AsciiFirstLetter, pinyin_case_insensitive: false, + + _haystack_str: PhantomData, } } @@ -74,7 +91,7 @@ impl<'a> PinyinMatcherBuilder<'a> { PinyinNotation::DiletterZrm, ]; - pub fn build(self) -> PinyinMatcher<'a> { + pub fn build(self) -> PinyinMatcher<'a, HaystackStr> { let pattern_string = self.pattern.clone(); let pattern_s: &str = pattern_string.as_str(); let pattern_s: &'static str = unsafe { std::mem::transmute(pattern_s) }; @@ -85,7 +102,7 @@ impl<'a> PinyinMatcherBuilder<'a> { // TODO: If pattern does not contain any pinyin letter, then pinyin_data is not needed. PinyinMatcher { - regex: regex::RegexBuilder::new(®ex::escape(&self.pattern)) + regex: regex::bytes::RegexBuilder::new(®ex_utils::escape_bytes(&self.pattern_bytes)) .case_insensitive(self.case_insensitive) .build() .unwrap(), @@ -133,6 +150,8 @@ impl<'a> PinyinMatcherBuilder<'a> { notations.into_boxed_slice() }, pinyin_case_insensitive: self.pinyin_case_insensitive, + + _haystack_str: PhantomData, } } } @@ -141,12 +160,14 @@ impl<'a> PinyinMatcherBuilder<'a> { /// TODO: No-pinyin pattern optimization /// TODO: Match Ascii only after AsciiFirstLetter; get_pinyins_and_for_each /// TODO: Anchors, `*_at` -/// TODO: UTF-16 and UCS-4 /// TODO: Unicode normalization /// TODO: No-hanzi haystack optimization (0.2/0.9%) -pub struct PinyinMatcher<'a> { +pub struct PinyinMatcher<'a, HaystackStr = str> +where + HaystackStr: EncodedStr + ?Sized, +{ /// For ASCII-only haystack optimization. - regex: regex::Regex, + regex: regex::bytes::Regex, pattern: Box<[PatternChar<'a>]>, _pattern_string: String, @@ -158,6 +179,8 @@ pub struct PinyinMatcher<'a> { pinyin_data: Cow<'a, PinyinData>, pinyin_notations: Box<[PinyinNotation]>, pinyin_case_insensitive: bool, + + _haystack_str: PhantomData, } struct PatternChar<'a> { @@ -214,16 +237,19 @@ impl SubMatch { } } -impl<'a> PinyinMatcher<'a> { - pub fn builder(pattern: &str) -> PinyinMatcherBuilder<'a> { +impl<'a, HaystackStr> PinyinMatcher<'a, HaystackStr> +where + HaystackStr: EncodedStr + ?Sized, +{ + pub fn builder(pattern: &HaystackStr) -> PinyinMatcherBuilder<'a, HaystackStr> { PinyinMatcherBuilder::new(pattern) } - pub fn find(&self, haystack: &str) -> Option { + pub fn find(&self, haystack: &HaystackStr) -> Option { self.find_with_is_ascii(haystack, haystack.is_ascii()) } - fn find_with_is_ascii(&self, haystack: &str, is_ascii: bool) -> Option { + fn find_with_is_ascii(&self, haystack: &HaystackStr, is_ascii: bool) -> Option { if self.pattern.is_empty() { return Some(Match { start: 0, @@ -233,15 +259,15 @@ impl<'a> PinyinMatcher<'a> { } if is_ascii { - return self.regex.find(haystack).map(|m| Match { - start: m.start(), - end: m.end(), + return self.regex.find(haystack.as_bytes()).map(|m| Match { + start: m.start() / HaystackStr::ELEMENT_LEN_BYTE, + end: m.end() / HaystackStr::ELEMENT_LEN_BYTE, is_pattern_partial: false, }); } - for (i, _c) in haystack.char_indices() { - if let Some(submatch) = self.sub_test(&self.pattern, &haystack[i..], 0) { + for (i, _c, str) in haystack.char_index_strs() { + if let Some(submatch) = self.sub_test(&self.pattern, str, 0) { return Some(Match { start: i, end: i + submatch.len, @@ -253,9 +279,9 @@ impl<'a> PinyinMatcher<'a> { None } - pub fn is_match(&self, haystack: &str) -> bool { + pub fn is_match(&self, haystack: &HaystackStr) -> bool { if haystack.is_ascii() { - return self.regex.is_match(haystack); + return self.regex.is_match(haystack.as_bytes()); } self.find_with_is_ascii(haystack, false).is_some() @@ -264,7 +290,7 @@ impl<'a> PinyinMatcher<'a> { /// ## Returns /// - `Match.start()` is guaranteed to be 0. /// - If there are multiple possible matches, the longer ones are preferred. But the result is not guaranteed to be the longest one. - pub fn test(&self, haystack: &str) -> Option { + pub fn test(&self, haystack: &HaystackStr) -> Option { if self.pattern.is_empty() { return Some(Match { start: 0, @@ -275,11 +301,11 @@ impl<'a> PinyinMatcher<'a> { if haystack.is_ascii() { // TODO: Use regex-automata's anchored searches? - return match self.regex.find(haystack) { + return match self.regex.find(haystack.as_bytes()) { Some(m) => match m.start() { 0 => Some(Match { start: 0, - end: m.end(), + end: m.end() / HaystackStr::ELEMENT_LEN_BYTE, is_pattern_partial: false, }), _ => None, @@ -303,18 +329,18 @@ impl<'a> PinyinMatcher<'a> { fn sub_test( &self, pattern: &[PatternChar], - haystack: &str, + haystack: &HaystackStr, matched_len: usize, ) -> Option { debug_assert!(!pattern.is_empty()); - let (haystack_c, haystack_next) = { - let mut chars = haystack.chars(); - match chars.next() { - Some(c) => (c, chars.as_str()), + let (haystack_c, haystack_c_len, haystack_next) = { + match haystack.char_len_next_strs().next() { + Some(v) => v, None => return None, } }; + let matched_len = matched_len + haystack_c_len; let (pattern_c, pattern_next) = pattern.split_first().unwrap(); @@ -323,7 +349,6 @@ impl<'a> PinyinMatcher<'a> { false => haystack_c == pattern_c.c, } { // If haystack_c == pattern_c, then it is impossible that pattern_c is a pinyin letter and haystack_c is a hanzi. - let matched_len = matched_len + haystack_c.len_utf8(); return if pattern_next.is_empty() { Some(SubMatch::new(matched_len, false)) } else { @@ -334,7 +359,8 @@ impl<'a> PinyinMatcher<'a> { for pinyin in self.pinyin_data.get_pinyins(haystack_c) { for ¬ation in self.pinyin_notations.iter() { let pinyin = pinyin.notation(notation).unwrap(); - if let Some(submatch) = self.sub_test_pinyin(pattern, haystack, matched_len, pinyin) + if let Some(submatch) = + self.sub_test_pinyin(pattern, haystack_next, matched_len, pinyin) { return Some(submatch); } @@ -351,8 +377,8 @@ impl<'a> PinyinMatcher<'a> { fn sub_test_pinyin( &self, pattern: &[PatternChar], - haystack: &str, - matched_len: usize, + haystack_next: &HaystackStr, + matched_len_next: usize, pinyin: &str, ) -> Option { debug_assert!(!pattern.is_empty()); @@ -363,22 +389,19 @@ impl<'a> PinyinMatcher<'a> { false => pattern[0].s, }; - let haystack_c_len = haystack.chars().next().unwrap().len_utf8(); - let matched_len = matched_len + haystack_c_len; - if pattern_s.len() < pinyin.len() { if self.is_pattern_partial && pinyin.starts_with(pattern_s) { - return Some(SubMatch::new(matched_len, true)); + return Some(SubMatch::new(matched_len_next, true)); } } else if pattern_s.starts_with(pinyin) { if pattern_s.len() == pinyin.len() { - return Some(SubMatch::new(matched_len, false)); + return Some(SubMatch::new(matched_len_next, false)); } if let Some(submatch) = self.sub_test( &pattern[pinyin.chars().count()..], - &haystack[haystack_c_len..], - matched_len, + haystack_next, + matched_len_next, ) { return Some(submatch); } @@ -400,7 +423,7 @@ mod test { fn ordered_pinyin_notations() { assert_eq!( PinyinNotation::all().iter().count(), - PinyinMatcherBuilder::ORDERED_PINYIN_NOTATIONS.len() + PinyinMatcherBuilder::::ORDERED_PINYIN_NOTATIONS.len() ) } @@ -439,6 +462,44 @@ mod test { assert_match(matcher.test("柯尔"), Some((0, 6))); } + #[cfg(feature = "encoding")] + #[test] + fn test_u16() { + use widestring::u16str; + + let matcher = PinyinMatcher::builder(u16str!("xing")) + .pinyin_notations(PinyinNotation::Ascii) + .build(); + assert_match(matcher.test(u16str!("")), None); + assert_match(matcher.test(u16str!("xing")), Some((0, 4))); + assert_match(matcher.test(u16str!("XiNG")), Some((0, 4))); + assert_match(matcher.test(u16str!("行")), Some((0, 1))); + + let matcher = PinyinMatcher::builder(u16str!("ke")) + .pinyin_notations(PinyinNotation::Ascii) + .build(); + assert_match(matcher.test(u16str!("ke")), Some((0, 2))); + assert_match(matcher.test(u16str!("科")), Some((0, 1))); + assert_match(matcher.test(u16str!("k鹅")), Some((0, 2))); + assert_match(matcher.test(u16str!("凯尔")), None); + + let matcher = PinyinMatcher::builder(u16str!("")) + .pinyin_notations(PinyinNotation::Ascii) + .build(); + assert_match(matcher.test(u16str!("")), Some((0, 0))); + assert_match(matcher.test(u16str!("abc")), Some((0, 0))); + + let matcher = PinyinMatcher::builder(u16str!("ke")) + .pinyin_notations(PinyinNotation::Ascii | PinyinNotation::AsciiFirstLetter) + .build(); + assert_match(matcher.test(u16str!("ke")), Some((0, 2))); + assert_match(matcher.test(u16str!("科")), Some((0, 1))); + assert_match(matcher.test(u16str!("k鹅")), Some((0, 2))); + assert_match(matcher.test(u16str!("凯尔")), Some((0, 2))); + // AsciiFirstLetter is preferred + assert_match(matcher.test(u16str!("柯尔")), Some((0, 2))); + } + #[test] fn test_case_insensitive() { let matcher = PinyinMatcher::builder("xing") diff --git a/src/matcher/regex_utils.rs b/src/matcher/regex_utils.rs new file mode 100644 index 0000000..b24264c --- /dev/null +++ b/src/matcher/regex_utils.rs @@ -0,0 +1,10 @@ +use core::fmt::Write; + +/// https://github.com/rust-lang/regex/issues/451 +pub fn escape_bytes(bytes: &[u8]) -> String { + let mut pattern = String::with_capacity(bytes.len() * 4); + for byte in bytes { + write!(pattern, "\\x{:02X}", byte).unwrap(); + } + return pattern; +} diff --git a/src/minimal.rs b/src/minimal.rs index 2f8fd07..be910cf 100644 --- a/src/minimal.rs +++ b/src/minimal.rs @@ -43,6 +43,78 @@ pub fn is_pinyin_match(pattern: &str, haystack: &str, pinyin_notations: PinyinNo cache.matcher.is_match(haystack) } +#[cfg(feature = "encoding")] +pub fn is_pinyin_match_u16( + pattern: &widestring::U16Str, + haystack: &widestring::U16Str, + pinyin_notations: PinyinNotation, +) -> bool { + struct MatcherCache { + pattern: widestring::U16String, + pinyin_notations: PinyinNotation, + matcher: PinyinMatcher<'static, widestring::U16Str>, + } + + static MATCHER_CACHE: OnceLock> = OnceLock::new(); + let init = || MatcherCache { + pattern: pattern.to_owned(), + pinyin_notations, + matcher: PinyinMatcher::builder(pattern) + .pinyin_data(pinyin_data()) + .pinyin_notations(pinyin_notations) + .build(), + }; + let lock = MATCHER_CACHE.get_or_init(|| RwLock::new(init())); + let cache = { + let guard = lock.read().unwrap(); + if guard.pattern == pattern && guard.pinyin_notations == pinyin_notations { + guard + } else { + drop(guard); + *lock.write().unwrap() = init(); + lock.read().unwrap() + } + }; + + cache.matcher.is_match(haystack) +} + +#[cfg(feature = "encoding")] +pub fn is_pinyin_match_u32( + pattern: &widestring::U32Str, + haystack: &widestring::U32Str, + pinyin_notations: PinyinNotation, +) -> bool { + struct MatcherCache { + pattern: widestring::U32String, + pinyin_notations: PinyinNotation, + matcher: PinyinMatcher<'static, widestring::U32Str>, + } + + static MATCHER_CACHE: OnceLock> = OnceLock::new(); + let init = || MatcherCache { + pattern: pattern.to_owned(), + pinyin_notations, + matcher: PinyinMatcher::builder(pattern) + .pinyin_data(pinyin_data()) + .pinyin_notations(pinyin_notations) + .build(), + }; + let lock = MATCHER_CACHE.get_or_init(|| RwLock::new(init())); + let cache = { + let guard = lock.read().unwrap(); + if guard.pattern == pattern && guard.pinyin_notations == pinyin_notations { + guard + } else { + drop(guard); + *lock.write().unwrap() = init(); + lock.read().unwrap() + } + }; + + cache.matcher.is_match(haystack) +} + #[cfg(test)] mod tests { use super::*;