diff --git a/regex-syntax/src/hir/interval.rs b/regex-syntax/src/hir/interval.rs index 56698c53a..5bcb44087 100644 --- a/regex-syntax/src/hir/interval.rs +++ b/regex-syntax/src/hir/interval.rs @@ -1,6 +1,8 @@ use std::char; use std::cmp; +use std::collections::hash_map::DefaultHasher; use std::fmt::Debug; +use std::hash::{Hash, Hasher}; use std::slice; use std::u8; @@ -32,9 +34,10 @@ use crate::unicode; // // Tests on this are relegated to the public API of HIR in src/hir.rs. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct IntervalSet<I> { ranges: Vec<I>, + folded: bool, } impl<I: Interval> IntervalSet<I> { @@ -44,7 +47,10 @@ impl<I: Interval> IntervalSet<I> { /// The given ranges do not need to be in any specific order, and ranges /// may overlap. pub fn new<T: IntoIterator<Item = I>>(intervals: T) -> IntervalSet<I> { - let mut set = IntervalSet { ranges: intervals.into_iter().collect() }; + let mut set = IntervalSet { + ranges: intervals.into_iter().collect(), + folded: false, + }; set.canonicalize(); set } @@ -53,8 +59,13 @@ impl<I: Interval> IntervalSet<I> { pub fn push(&mut self, interval: I) { // TODO: This could be faster. e.g., Push the interval such that // it preserves canonicalization. + + // don't collect hash if we're not going to use it + let before = if self.folded { self.get_hash() } else { 0 }; + self.ranges.push(interval); self.canonicalize(); + self.folded = self.folded && before == self.get_hash(); } /// Return an iterator over all intervals in this set. @@ -79,6 +90,9 @@ impl<I: Interval> IntervalSet<I> { /// This returns an error if the necessary case mapping data is not /// available. pub fn case_fold_simple(&mut self) -> Result<(), unicode::CaseFoldError> { + if self.folded { + return Ok(()); + } let len = self.ranges.len(); for i in 0..len { let range = self.ranges[i]; @@ -88,14 +102,28 @@ impl<I: Interval> IntervalSet<I> { } } self.canonicalize(); + self.folded = true; Ok(()) } /// Union this set with the given set, in place. pub fn union(&mut self, other: &IntervalSet<I>) { + if other.ranges.is_empty() { + return; + } + + // don't collect hash if we're not going to use it + let before_self = if self.folded { self.get_hash() } else { 0 }; + let before_other = if other.folded { other.get_hash() } else { 0 }; + // This could almost certainly be done more efficiently. self.ranges.extend(&other.ranges); self.canonicalize(); + self.folded = self.folded && other.folded || { + let current_hash = self.get_hash(); + self.folded && before_self == current_hash + || other.folded && before_other == current_hash + }; } /// Intersect this set with the given set, in place. @@ -105,9 +133,14 @@ impl<I: Interval> IntervalSet<I> { } if other.ranges.is_empty() { self.ranges.clear(); + self.folded = false; return; } + // don't collect hash if we're not going to use it + let before_self = if self.folded { self.get_hash() } else { 0 }; + let before_other = if other.folded { other.get_hash() } else { 0 }; + // There should be a way to do this in-place with constant memory, // but I couldn't figure out a simple way to do it. So just append // the intersection to the end of this range, and then drain it before @@ -134,6 +167,11 @@ impl<I: Interval> IntervalSet<I> { } } self.ranges.drain(..drain_end); + self.folded = self.folded && other.folded || { + let current_hash = self.get_hash(); + self.folded && before_self == current_hash + || other.folded && before_other == current_hash + }; } /// Subtract the given set from this set, in place. @@ -142,6 +180,10 @@ impl<I: Interval> IntervalSet<I> { return; } + // don't collect hash if we're not going to use it + let before_self = if self.folded { self.get_hash() } else { 0 }; + let before_other = if other.folded { other.get_hash() } else { 0 }; + // This algorithm is (to me) surprisingly complex. A search of the // interwebs indicate that this is a potentially interesting problem. // Folks seem to suggest interval or segment trees, but I'd like to @@ -226,6 +268,11 @@ impl<I: Interval> IntervalSet<I> { a += 1; } self.ranges.drain(..drain_end); + self.folded = self.folded && other.folded || { + let current_hash = self.get_hash(); + self.folded && before_self == current_hash + || other.folded && before_other == current_hash + }; } /// Compute the symmetric difference of the two sets, in place. @@ -276,6 +323,9 @@ impl<I: Interval> IntervalSet<I> { self.ranges.push(I::create(lower, I::Bound::max_value())); } self.ranges.drain(..drain_end); + + // we don't need to update foldedness here stays the same because, necessarily, any set of + // matching members is entirely present or entirely not present } /// Converts this set into a canonical ordering. @@ -318,6 +368,33 @@ impl<I: Interval> IntervalSet<I> { } true } + + fn get_hash(&self) -> u64 { + let mut hasher = DefaultHasher::default(); + self.hash(&mut hasher); + hasher.finish() + } +} + +impl<I> PartialEq for IntervalSet<I> +where + I: Interval, +{ + fn eq(&self, other: &Self) -> bool { + self.ranges.eq(&other.ranges) + } +} + +impl<I> Eq for IntervalSet<I> where I: Interval {} + +impl<I> Hash for IntervalSet<I> +where + I: Interval, +{ + fn hash<H: Hasher>(&self, state: &mut H) { + // don't hash the foldedness + self.ranges.hash(state) + } } /// An iterator over intervals. @@ -333,7 +410,7 @@ impl<'a, I> Iterator for IntervalSetIter<'a, I> { } pub trait Interval: - Clone + Copy + Debug + Default + Eq + PartialEq + PartialOrd + Ord + Clone + Copy + Debug + Default + Eq + PartialEq + PartialOrd + Ord + Hash { type Bound: Bound; diff --git a/regex-syntax/src/hir/mod.rs b/regex-syntax/src/hir/mod.rs index 1096e9f05..5529be2e5 100644 --- a/regex-syntax/src/hir/mod.rs +++ b/regex-syntax/src/hir/mod.rs @@ -969,7 +969,7 @@ impl<'a> Iterator for ClassUnicodeIter<'a> { /// /// The range is closed. That is, the start and end of the range are included /// in the range. -#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct ClassUnicodeRange { start: char, end: char, @@ -1028,20 +1028,33 @@ impl Interval for ClassUnicodeRange { } let start = self.start as u32; let end = (self.end as u32).saturating_add(1); - let mut next_simple_cp = None; - for cp in (start..end).filter_map(char::from_u32) { - if next_simple_cp.map_or(false, |next| cp < next) { - continue; - } - let it = match unicode::simple_fold(cp)? { - Ok(it) => it, - Err(next) => { - next_simple_cp = next; - continue; + let mut range = start..end; + let mut idx = 0; + while let Some(cp) = range.next() { + if let Some(c) = char::from_u32(cp) { + let it = match unicode::optimised_fold(idx, c)? { + Ok((it, next_idx)) => { + idx = next_idx; + it + } + Err(next) => { + if let Some((next, next_idx)) = next { + let next = next as u32; + range = next..end; + idx = next_idx; + } + continue; + } + }; + for cp_folded in it { + if let Some(last) = ranges.last_mut() { + if last.end as u32 + 1 == cp_folded as u32 { + last.end = cp_folded; + continue; + } + } + ranges.push(ClassUnicodeRange::new(cp_folded, cp_folded)); } - }; - for cp_folded in it { - ranges.push(ClassUnicodeRange::new(cp_folded, cp_folded)); } } Ok(()) @@ -1186,7 +1199,7 @@ impl<'a> Iterator for ClassBytesIter<'a> { /// /// The range is closed. That is, the start and end of the range are included /// in the range. -#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Clone, Copy, Default, Eq, PartialEq, PartialOrd, Ord, Hash)] pub struct ClassBytesRange { start: u8, end: u8, diff --git a/regex-syntax/src/unicode.rs b/regex-syntax/src/unicode.rs index 70d5954b7..ac2e56c51 100644 --- a/regex-syntax/src/unicode.rs +++ b/regex-syntax/src/unicode.rs @@ -1,5 +1,6 @@ use std::error; use std::fmt; +use std::mem::size_of; use std::result; use crate::hir; @@ -78,38 +79,82 @@ impl fmt::Display for UnicodeWordError { /// to, since there is some cost to fetching the equivalence class. /// /// This returns an error if the Unicode case folding tables are not available. +#[allow(dead_code)] pub fn simple_fold( c: char, ) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>> { + match optimised_fold(0, c) { + Ok(Ok((iter, _))) => Ok(Ok(iter)), + Ok(Err(Some((c, _)))) => Ok(Err(Some(c))), + Ok(Err(None)) => Ok(Err(None)), + Err(e) => Err(e), + } +} + +pub fn optimised_fold( + start: usize, + c: char, +) -> FoldResult< + result::Result<(impl Iterator<Item = char>, usize), Option<(char, usize)>>, +> { #[cfg(not(feature = "unicode-case"))] fn imp( + _: usize, _: char, - ) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>> - { + ) -> FoldResult< + result::Result< + (impl Iterator<Item = char>, usize), + Option<(char, usize)>, + >, + > { use std::option::IntoIter; - Err::<result::Result<IntoIter<char>, _>, _>(CaseFoldError(())) + Err::<result::Result<(IntoIter<char>, usize), _>, _>(CaseFoldError(())) } #[cfg(feature = "unicode-case")] fn imp( + start: usize, c: char, - ) -> FoldResult<result::Result<impl Iterator<Item = char>, Option<char>>> - { + ) -> FoldResult< + result::Result< + (impl Iterator<Item = char>, usize), + Option<(char, usize)>, + >, + > { use crate::unicode_tables::case_folding_simple::CASE_FOLDING_SIMPLE; + // this is the greatest number of steps before we are guaranteed to find our value + const DEPTH_MAX: usize = size_of::<usize>() * 8 + - CASE_FOLDING_SIMPLE.len().leading_zeros() as usize + + 1; + + // first, see if we can find it in less than depth; it's likely that we've recently looked + // up an adjacent value if we've provided a start + for (i, &(other, foldings)) in + CASE_FOLDING_SIMPLE[start..].iter().take(DEPTH_MAX).enumerate() + { + if other == c { + return Ok(Ok((foldings.iter().copied(), start + i + 1))); + } else if other > c { + return Ok(Err(Some((other, start + i)))); + } + } + if start + DEPTH_MAX >= CASE_FOLDING_SIMPLE.len() { + return Ok(Err(None)); + } Ok(CASE_FOLDING_SIMPLE .binary_search_by_key(&c, |&(c1, _)| c1) - .map(|i| CASE_FOLDING_SIMPLE[i].1.iter().copied()) + .map(|i| (CASE_FOLDING_SIMPLE[i].1.iter().copied(), i + 1)) .map_err(|i| { if i >= CASE_FOLDING_SIMPLE.len() { None } else { - Some(CASE_FOLDING_SIMPLE[i].0) + Some((CASE_FOLDING_SIMPLE[i].0, i)) } })) } - imp(c) + imp(start, c) } /// Returns true if and only if the given (inclusive) range contains at least