diff --git a/src/classicalbacktrack.rs b/src/classicalbacktrack.rs index aef5176..4ed4b1c 100644 --- a/src/classicalbacktrack.rs +++ b/src/classicalbacktrack.rs @@ -944,7 +944,11 @@ impl<'r, Input: InputIndexer> BacktrackExecutor<'r, Input> { let inp = self.input; loop { // Find the next start location, or None if none. - pos = inp.find_bytes(pos, prefix_search)?; + // Don't try this unless CODE_UNITS_ARE_BYTES - i.e. don't do byte searches + // on UTF-16 or UCS2. + if Input::CODE_UNITS_ARE_BYTES { + pos = inp.find_bytes(pos, prefix_search)?; + } if let Some(end) = self.matcher.try_at_pos(inp, 0, pos, Forward::new()) { // If we matched the empty string, we have to increment. if end != pos { diff --git a/src/cursor.rs b/src/cursor.rs index e6dc5fd..d41d4b9 100644 --- a/src/cursor.rs +++ b/src/cursor.rs @@ -62,6 +62,10 @@ pub fn next_byte( _dir: Dir, pos: &mut Input::Position, ) -> Option { + assert!( + Input::CODE_UNITS_ARE_BYTES, + "Not implemented for non-byte input" + ); let res; if Dir::FORWARD { res = input.peek_byte_right(*pos); diff --git a/src/indexing.rs b/src/indexing.rs index ebeb46c..2e204b6 100644 --- a/src/indexing.rs +++ b/src/indexing.rs @@ -53,6 +53,11 @@ where /// A type which references a position in the input string. type Position: PositionType; + /// Whether we have bytes as code units. This can optimize some operations. + /// This is true for ASCII and UTF8, but not for UCS2 or UTF16. + const CODE_UNITS_ARE_BYTES: bool; + + /// \return whether we are using unicode for case-folding. fn unicode(&self) -> bool; /// Case-fold an element. @@ -83,11 +88,11 @@ where fn next_left_pos(&self, pos: Self::Position) -> Option; /// \return the byte to the right (starting at) \p idx, or None if we are at - /// the end. + /// the end. This panics if CODE_UNITS_ARE_BYTES is false. fn peek_byte_right(&self, pos: Self::Position) -> Option; /// \return the byte to the left (ending just before) \p idx, or None if we - /// are at the start. + /// are at the start. This panics if CODE_UNITS_ARE_BYTES is false. fn peek_byte_left(&self, pos: Self::Position) -> Option; /// \return a position at the left end of this input. @@ -268,6 +273,7 @@ impl<'a> InputIndexer for Utf8Input<'a> { type Position = DefPosition<'a>; type Element = char; type CharProps = matchers::UTF8CharProperties; + const CODE_UNITS_ARE_BYTES: bool = true; #[inline(always)] fn unicode(&self) -> bool { @@ -632,6 +638,7 @@ impl<'a> InputIndexer for AsciiInput<'a> { type Position = DefPosition<'a>; type Element = u8; type CharProps = matchers::ASCIICharProperties; + const CODE_UNITS_ARE_BYTES: bool = true; #[inline(always)] fn unicode(&self) -> bool { @@ -886,6 +893,7 @@ impl<'a> InputIndexer for Utf16Input<'a> { type Position = IndexPosition<'a>; type Element = u32; type CharProps = matchers::Utf16CharProperties; + const CODE_UNITS_ARE_BYTES: bool = false; #[inline(always)] fn unicode(&self) -> bool { @@ -1009,29 +1017,13 @@ impl<'a> InputIndexer for Utf16Input<'a> { } #[inline(always)] - fn peek_byte_right(&self, mut pos: Self::Position) -> Option { - if let Some(c) = self.next_right(&mut pos) { - if cfg!(target_endian = "big") { - Some(c.to_be_bytes()[0]) - } else { - Some(c.to_le_bytes()[0]) - } - } else { - None - } + fn peek_byte_right(&self, _pos: Self::Position) -> Option { + panic!("Should never be inspecting bytes for utf16"); } #[inline(always)] - fn peek_byte_left(&self, mut pos: Self::Position) -> Option { - if let Some(c) = self.next_left(&mut pos) { - if cfg!(target_endian = "big") { - Some(c.to_be_bytes()[0]) - } else { - Some(c.to_le_bytes()[0]) - } - } else { - None - } + fn peek_byte_left(&self, _pos: Self::Position) -> Option { + panic!("Should never be inspecting bytes for utf16"); } #[inline(always)] @@ -1076,16 +1068,10 @@ impl<'a> InputIndexer for Utf16Input<'a> { #[inline(always)] fn find_bytes( &self, - pos: Self::Position, - search: &Search, + _pos: Self::Position, + _search: &Search, ) -> Option { - let idx = search.find_in( - &self.input[self.pos_to_offset(pos)..self.pos_to_offset(self.right_end())] - .iter() - .map(|c| *c as u8) - .collect::>(), - )?; - Some(pos + idx) + panic!("Should never be finding bytes for utf16"); } fn subrange_eq( @@ -1120,32 +1106,10 @@ impl<'a> InputIndexer for Utf16Input<'a> { fn match_bytes( &self, _dir: Dir, - pos: &mut Self::Position, - bytes: &Bytes, + _pos: &mut Self::Position, + _bytes: &Bytes, ) -> bool { - let len = Bytes::LENGTH; - let (start, end) = if Dir::FORWARD { - if let Some(end) = self.try_move_right(*pos, len) { - let start = *pos; - *pos = end; - (start, end) - } else { - return false; - } - } else if let Some(start) = self.try_move_left(*pos, len) { - let end = *pos; - *pos = start; - (start, end) - } else { - return false; - }; - - bytes.equals_known_len( - &self.input[self.pos_to_offset(start)..self.pos_to_offset(end)] - .iter() - .map(|c| *c as u8) - .collect::>(), - ) + panic!("Should never be matching bytes for utf16"); } } @@ -1173,6 +1137,7 @@ impl<'a> InputIndexer for Ucs2Input<'a> { type Position = IndexPosition<'a>; type Element = u32; type CharProps = matchers::Utf16CharProperties; + const CODE_UNITS_ARE_BYTES: bool = false; #[inline(always)] fn unicode(&self) -> bool { @@ -1223,29 +1188,13 @@ impl<'a> InputIndexer for Ucs2Input<'a> { } #[inline(always)] - fn peek_byte_right(&self, mut pos: Self::Position) -> Option { - if let Some(c) = self.next_right(&mut pos) { - if cfg!(target_endian = "big") { - Some(c.to_be_bytes()[0]) - } else { - Some(c.to_le_bytes()[0]) - } - } else { - None - } + fn peek_byte_right(&self, _pos: Self::Position) -> Option { + panic!("Should never be inspecting bytes for ucs2"); } #[inline(always)] - fn peek_byte_left(&self, mut pos: Self::Position) -> Option { - if let Some(c) = self.next_left(&mut pos) { - if cfg!(target_endian = "big") { - Some(c.to_be_bytes()[0]) - } else { - Some(c.to_le_bytes()[0]) - } - } else { - None - } + fn peek_byte_left(&self, _pos: Self::Position) -> Option { + panic!("Should never be inspecting bytes for ucs2"); } #[inline(always)] @@ -1290,16 +1239,10 @@ impl<'a> InputIndexer for Ucs2Input<'a> { #[inline(always)] fn find_bytes( &self, - pos: Self::Position, - search: &Search, + _pos: Self::Position, + _search: &Search, ) -> Option { - let idx = search.find_in( - &self.input[self.pos_to_offset(pos)..self.pos_to_offset(self.right_end())] - .iter() - .map(|c| *c as u8) - .collect::>(), - )?; - Some(pos + idx) + panic!("Should never be finding bytes for ucs2"); } fn subrange_eq( @@ -1334,31 +1277,9 @@ impl<'a> InputIndexer for Ucs2Input<'a> { fn match_bytes( &self, _dir: Dir, - pos: &mut Self::Position, - bytes: &Bytes, + _pos: &mut Self::Position, + _bytes: &Bytes, ) -> bool { - let len = Bytes::LENGTH; - let (start, end) = if Dir::FORWARD { - if let Some(end) = self.try_move_right(*pos, len) { - let start = *pos; - *pos = end; - (start, end) - } else { - return false; - } - } else if let Some(start) = self.try_move_left(*pos, len) { - let end = *pos; - *pos = start; - (start, end) - } else { - return false; - }; - - bytes.equals_known_len( - &self.input[self.pos_to_offset(start)..self.pos_to_offset(end)] - .iter() - .map(|c| *c as u8) - .collect::>(), - ) + panic!("Should never be matching bytes for ucs2"); } } diff --git a/src/scm.rs b/src/scm.rs index 29a5a3b..dd09096 100644 --- a/src/scm.rs +++ b/src/scm.rs @@ -115,10 +115,14 @@ impl<'a, Input: InputIndexer, Dir: Direction, Bytes: ByteSet> SingleCharMatcher< { #[inline(always)] fn matches(&self, input: &Input, dir: Dir, pos: &mut Input::Position) -> bool { - if let Some(b) = cursor::next_byte(input, dir, pos) { - self.bytes.contains(b) + if Input::CODE_UNITS_ARE_BYTES { + // Code units are bytes so we can skip decoding the full element. + cursor::next_byte(input, dir, pos).map_or(false, |b| self.bytes.contains(b)) } else { - false + // Must decode the full element. + cursor::next(input, dir, pos) + .and_then(|c| c.as_u32().try_into().ok()) + .map_or(false, |c| self.bytes.contains(c)) } } } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 947441a..f49abde 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -81,6 +81,22 @@ impl TestCompiledRegex { self.matches(input, 0).into_iter().next() } + /// Encode a string as UTF16, and match against it as UTF16. + #[cfg(feature = "utf16")] + #[track_caller] + pub fn find_utf16(&self, input: &str) -> Option { + let input = input.encode_utf16().collect::>(); + self.re.find_from_utf16(&input, 0).into_iter().next() + } + + /// Encode a string as UTF16, and match against it as UCS2. + #[cfg(feature = "utf16")] + #[track_caller] + pub fn find_ucs2(&self, input: &str) -> Option { + let input = input.encode_utf16().collect::>(); + self.re.find_from_ucs2(&input, 0).into_iter().next() + } + /// Match against a string, returning the first formatted match. #[track_caller] pub fn match1f(&self, input: &str) -> String { diff --git a/tests/tests.rs b/tests/tests.rs index 9dcceaf..c7c04df 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1820,3 +1820,54 @@ fn test_quantifiable_assertion_followed_by_tc(tc: TestConfig) { .match1f(r#"a bZ cZZ dZZZ eZZZZ"#) .test_eq("b"); } + +#[cfg(feature = "utf16")] +mod utf16_tests { + use super::*; + + #[test] + fn test_utf16_regression_100() { + test_with_configs(test_utf16_regression_100_tc) + } + + fn test_utf16_regression_100_tc(tc: TestConfig) { + // Ensure the leading bytes of UTF-16 characters don't match against brackets. + let input = "赔"; // U+8D54 + + let re = tc.compile(r"[A-Z]"); // 0x41 - 0x5A + let matched = re.find_utf16(&input); + assert!(matched.is_none()); + + let matched = re.find_ucs2(&input); + assert!(matched.is_none()); + } + + #[test] + fn test_utf16_byte_sequences() { + test_with_configs(test_utf16_byte_sequences_tc) + } + + fn test_utf16_byte_sequences_tc(tc: TestConfig) { + // Regress emits byte sequences for e.g. 'abc'. + // Ensure these are properly decoded in UTF-16/UCS2. + let re = tc.compile(r"abc"); + + let input = "abc"; + let matched = re.find_utf16(&input); + assert!(matched.is_some()); + assert_eq!(matched.unwrap().range, 0..3); + + let matched = re.find_ucs2(&input); + assert!(matched.is_some()); + assert_eq!(matched.unwrap().range, 0..3); + + let input = "xxxabczzz"; + let matched = re.find_utf16(&input); + assert!(matched.is_some()); + assert_eq!(matched.unwrap().range, 3..6); + + let matched = re.find_ucs2(&input); + assert!(matched.is_some()); + assert_eq!(matched.unwrap().range, 3..6); + } +}