Skip to content

Commit

Permalink
Correct certain UTF-16 / UCS2 errors
Browse files Browse the repository at this point in the history
When emitting a bracket like [a-z], we emit a byte bitmap. However these
bytes may be found in certain UTF-16 sequences, causing us to
incorrectly match when performing UTF-16 or UCS2. Fix this and in
general prohibit byte-level operations on these indexers.

Fixes #100
  • Loading branch information
ridiculousfish committed Jan 12, 2025
1 parent ec659e3 commit ff467d2
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 114 deletions.
6 changes: 5 additions & 1 deletion src/classicalbacktrack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,11 @@ impl<'r, Input: InputIndexer> BacktrackExecutor<'r, Input> {
let inp = self.input;
loop {
// Find the next start location, or None if none.
pos = inp.find_bytes(pos, prefix_search)?;
// Don't try this unless CODE_UNITS_ARE_BYTES - i.e. don't do byte searches
// on UTF-16 or UCS2.
if Input::CODE_UNITS_ARE_BYTES {
pos = inp.find_bytes(pos, prefix_search)?;
}
if let Some(end) = self.matcher.try_at_pos(inp, 0, pos, Forward::new()) {
// If we matched the empty string, we have to increment.
if end != pos {
Expand Down
4 changes: 4 additions & 0 deletions src/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ pub fn next_byte<Input: InputIndexer, Dir: Direction>(
_dir: Dir,
pos: &mut Input::Position,
) -> Option<u8> {
assert!(
Input::CODE_UNITS_ARE_BYTES,
"Not implemented for non-byte input"
);
let res;
if Dir::FORWARD {
res = input.peek_byte_right(*pos);
Expand Down
141 changes: 31 additions & 110 deletions src/indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ where
/// A type which references a position in the input string.
type Position: PositionType;

/// Whether we have bytes as code units. This can optimize some operations.
/// This is true for ASCII and UTF8, but not for UCS2 or UTF16.
const CODE_UNITS_ARE_BYTES: bool;

/// \return whether we are using unicode for case-folding.
fn unicode(&self) -> bool;

/// Case-fold an element.
Expand Down Expand Up @@ -83,11 +88,11 @@ where
fn next_left_pos(&self, pos: Self::Position) -> Option<Self::Position>;

/// \return the byte to the right (starting at) \p idx, or None if we are at
/// the end.
/// the end. This panics if CODE_UNITS_ARE_BYTES is false.
fn peek_byte_right(&self, pos: Self::Position) -> Option<u8>;

/// \return the byte to the left (ending just before) \p idx, or None if we
/// are at the start.
/// are at the start. This panics if CODE_UNITS_ARE_BYTES is false.
fn peek_byte_left(&self, pos: Self::Position) -> Option<u8>;

/// \return a position at the left end of this input.
Expand Down Expand Up @@ -268,6 +273,7 @@ impl<'a> InputIndexer for Utf8Input<'a> {
type Position = DefPosition<'a>;
type Element = char;
type CharProps = matchers::UTF8CharProperties;
const CODE_UNITS_ARE_BYTES: bool = true;

#[inline(always)]
fn unicode(&self) -> bool {
Expand Down Expand Up @@ -632,6 +638,7 @@ impl<'a> InputIndexer for AsciiInput<'a> {
type Position = DefPosition<'a>;
type Element = u8;
type CharProps = matchers::ASCIICharProperties;
const CODE_UNITS_ARE_BYTES: bool = true;

#[inline(always)]
fn unicode(&self) -> bool {
Expand Down Expand Up @@ -886,6 +893,7 @@ impl<'a> InputIndexer for Utf16Input<'a> {
type Position = IndexPosition<'a>;
type Element = u32;
type CharProps = matchers::Utf16CharProperties;
const CODE_UNITS_ARE_BYTES: bool = false;

#[inline(always)]
fn unicode(&self) -> bool {
Expand Down Expand Up @@ -1009,29 +1017,13 @@ impl<'a> InputIndexer for Utf16Input<'a> {
}

#[inline(always)]
fn peek_byte_right(&self, mut pos: Self::Position) -> Option<u8> {
if let Some(c) = self.next_right(&mut pos) {
if cfg!(target_endian = "big") {
Some(c.to_be_bytes()[0])
} else {
Some(c.to_le_bytes()[0])
}
} else {
None
}
fn peek_byte_right(&self, _pos: Self::Position) -> Option<u8> {
panic!("Should never be inspecting bytes for utf16");
}

#[inline(always)]
fn peek_byte_left(&self, mut pos: Self::Position) -> Option<u8> {
if let Some(c) = self.next_left(&mut pos) {
if cfg!(target_endian = "big") {
Some(c.to_be_bytes()[0])
} else {
Some(c.to_le_bytes()[0])
}
} else {
None
}
fn peek_byte_left(&self, _pos: Self::Position) -> Option<u8> {
panic!("Should never be inspecting bytes for utf16");
}

#[inline(always)]
Expand Down Expand Up @@ -1076,16 +1068,10 @@ impl<'a> InputIndexer for Utf16Input<'a> {
#[inline(always)]
fn find_bytes<Search: bytesearch::ByteSearcher>(
&self,
pos: Self::Position,
search: &Search,
_pos: Self::Position,
_search: &Search,
) -> Option<Self::Position> {
let idx = search.find_in(
&self.input[self.pos_to_offset(pos)..self.pos_to_offset(self.right_end())]
.iter()
.map(|c| *c as u8)
.collect::<Vec<_>>(),
)?;
Some(pos + idx)
panic!("Should never be finding bytes for utf16");
}

fn subrange_eq<Dir: Direction>(
Expand Down Expand Up @@ -1120,32 +1106,10 @@ impl<'a> InputIndexer for Utf16Input<'a> {
fn match_bytes<Dir: Direction, Bytes: ByteSeq>(
&self,
_dir: Dir,
pos: &mut Self::Position,
bytes: &Bytes,
_pos: &mut Self::Position,
_bytes: &Bytes,
) -> bool {
let len = Bytes::LENGTH;
let (start, end) = if Dir::FORWARD {
if let Some(end) = self.try_move_right(*pos, len) {
let start = *pos;
*pos = end;
(start, end)
} else {
return false;
}
} else if let Some(start) = self.try_move_left(*pos, len) {
let end = *pos;
*pos = start;
(start, end)
} else {
return false;
};

bytes.equals_known_len(
&self.input[self.pos_to_offset(start)..self.pos_to_offset(end)]
.iter()
.map(|c| *c as u8)
.collect::<Vec<_>>(),
)
panic!("Should never be matching bytes for utf16");
}
}

Expand Down Expand Up @@ -1173,6 +1137,7 @@ impl<'a> InputIndexer for Ucs2Input<'a> {
type Position = IndexPosition<'a>;
type Element = u32;
type CharProps = matchers::Utf16CharProperties;
const CODE_UNITS_ARE_BYTES: bool = false;

#[inline(always)]
fn unicode(&self) -> bool {
Expand Down Expand Up @@ -1223,29 +1188,13 @@ impl<'a> InputIndexer for Ucs2Input<'a> {
}

#[inline(always)]
fn peek_byte_right(&self, mut pos: Self::Position) -> Option<u8> {
if let Some(c) = self.next_right(&mut pos) {
if cfg!(target_endian = "big") {
Some(c.to_be_bytes()[0])
} else {
Some(c.to_le_bytes()[0])
}
} else {
None
}
fn peek_byte_right(&self, _pos: Self::Position) -> Option<u8> {
panic!("Should never be inspecting bytes for ucs2");
}

#[inline(always)]
fn peek_byte_left(&self, mut pos: Self::Position) -> Option<u8> {
if let Some(c) = self.next_left(&mut pos) {
if cfg!(target_endian = "big") {
Some(c.to_be_bytes()[0])
} else {
Some(c.to_le_bytes()[0])
}
} else {
None
}
fn peek_byte_left(&self, _pos: Self::Position) -> Option<u8> {
panic!("Should never be inspecting bytes for ucs2");
}

#[inline(always)]
Expand Down Expand Up @@ -1290,16 +1239,10 @@ impl<'a> InputIndexer for Ucs2Input<'a> {
#[inline(always)]
fn find_bytes<Search: bytesearch::ByteSearcher>(
&self,
pos: Self::Position,
search: &Search,
_pos: Self::Position,
_search: &Search,
) -> Option<Self::Position> {
let idx = search.find_in(
&self.input[self.pos_to_offset(pos)..self.pos_to_offset(self.right_end())]
.iter()
.map(|c| *c as u8)
.collect::<Vec<_>>(),
)?;
Some(pos + idx)
panic!("Should never be finding bytes for ucs2");
}

fn subrange_eq<Dir: Direction>(
Expand Down Expand Up @@ -1334,31 +1277,9 @@ impl<'a> InputIndexer for Ucs2Input<'a> {
fn match_bytes<Dir: Direction, Bytes: ByteSeq>(
&self,
_dir: Dir,
pos: &mut Self::Position,
bytes: &Bytes,
_pos: &mut Self::Position,
_bytes: &Bytes,
) -> bool {
let len = Bytes::LENGTH;
let (start, end) = if Dir::FORWARD {
if let Some(end) = self.try_move_right(*pos, len) {
let start = *pos;
*pos = end;
(start, end)
} else {
return false;
}
} else if let Some(start) = self.try_move_left(*pos, len) {
let end = *pos;
*pos = start;
(start, end)
} else {
return false;
};

bytes.equals_known_len(
&self.input[self.pos_to_offset(start)..self.pos_to_offset(end)]
.iter()
.map(|c| *c as u8)
.collect::<Vec<_>>(),
)
panic!("Should never be matching bytes for ucs2");
}
}
10 changes: 7 additions & 3 deletions src/scm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,14 @@ impl<'a, Input: InputIndexer, Dir: Direction, Bytes: ByteSet> SingleCharMatcher<
{
#[inline(always)]
fn matches(&self, input: &Input, dir: Dir, pos: &mut Input::Position) -> bool {
if let Some(b) = cursor::next_byte(input, dir, pos) {
self.bytes.contains(b)
if Input::CODE_UNITS_ARE_BYTES {
// Code units are bytes so we can skip decoding the full element.
cursor::next_byte(input, dir, pos).map_or(false, |b| self.bytes.contains(b))
} else {
false
// Must decode the full element.
cursor::next(input, dir, pos)
.and_then(|c| c.as_u32().try_into().ok())
.map_or(false, |c| self.bytes.contains(c))
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ impl TestCompiledRegex {
self.matches(input, 0).into_iter().next()
}

/// Encode a string as UTF16, and match against it as UTF16.
#[cfg(feature = "utf16")]
#[track_caller]
pub fn find_utf16(&self, input: &str) -> Option<regress::Match> {
let input = input.encode_utf16().collect::<Vec<_>>();
self.re.find_from_utf16(&input, 0).into_iter().next()
}

/// Encode a string as UTF16, and match against it as UCS2.
#[cfg(feature = "utf16")]
#[track_caller]
pub fn find_ucs2(&self, input: &str) -> Option<regress::Match> {
let input = input.encode_utf16().collect::<Vec<_>>();
self.re.find_from_ucs2(&input, 0).into_iter().next()
}

/// Match against a string, returning the first formatted match.
#[track_caller]
pub fn match1f(&self, input: &str) -> String {
Expand Down
51 changes: 51 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1820,3 +1820,54 @@ fn test_quantifiable_assertion_followed_by_tc(tc: TestConfig) {
.match1f(r#"a bZ cZZ dZZZ eZZZZ"#)
.test_eq("b");
}

#[cfg(feature = "utf16")]
mod utf16_tests {
use super::*;

#[test]
fn test_utf16_regression_100() {
test_with_configs(test_utf16_regression_100_tc)
}

fn test_utf16_regression_100_tc(tc: TestConfig) {
// Ensure the leading bytes of UTF-16 characters don't match against brackets.
let input = "赔"; // U+8D54

let re = tc.compile(r"[A-Z]"); // 0x41 - 0x5A
let matched = re.find_utf16(&input);
assert!(matched.is_none());

let matched = re.find_ucs2(&input);
assert!(matched.is_none());
}

#[test]
fn test_utf16_byte_sequences() {
test_with_configs(test_utf16_byte_sequences_tc)
}

fn test_utf16_byte_sequences_tc(tc: TestConfig) {
// Regress emits byte sequences for e.g. 'abc'.
// Ensure these are properly decoded in UTF-16/UCS2.
let re = tc.compile(r"abc");

let input = "abc";
let matched = re.find_utf16(&input);
assert!(matched.is_some());
assert_eq!(matched.unwrap().range, 0..3);

let matched = re.find_ucs2(&input);
assert!(matched.is_some());
assert_eq!(matched.unwrap().range, 0..3);

let input = "xxxabczzz";
let matched = re.find_utf16(&input);
assert!(matched.is_some());
assert_eq!(matched.unwrap().range, 3..6);

let matched = re.find_ucs2(&input);
assert!(matched.is_some());
assert_eq!(matched.unwrap().range, 3..6);
}
}

0 comments on commit ff467d2

Please sign in to comment.