Skip to content

Commit

Permalink
ZeroTrie: Refactor helpers to not require as many Options (#4382)
Browse files Browse the repository at this point in the history
  • Loading branch information
sffc authored Dec 1, 2023
1 parent db3884c commit 75848da
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 117 deletions.
16 changes: 7 additions & 9 deletions experimental/zerotrie/src/byte_phf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,12 @@ where
if n == 0 {
return None;
}
let (qq, eks) = debug_split_at(buffer, n)?;
let (qq, eks) = buffer.debug_split_at(n);
debug_assert_eq!(qq.len(), eks.len());
let q = debug_get(qq, f1(key, *p, n))?;
let l2 = f2(key, q, n);
let ek = debug_get(eks, l2)?;
if ek == key {
let q = debug_unwrap!(qq.get(f1(key, *p, n)), return None);
let l2 = f2(key, *q, n);
let ek = debug_unwrap!(eks.get(l2), return None);
if *ek == key {
Some(l2)
} else {
None
Expand All @@ -232,9 +232,7 @@ where
/// Get an iterator over the keys in the order in which they are stored in the map.
pub fn keys(&self) -> &[u8] {
let n = self.num_items();
debug_split_at(self.0.as_ref(), 1 + n)
.map(|s| s.1)
.unwrap_or(&[])
self.0.as_ref().debug_split_at(1 + n).1
}
/// Diagnostic function that returns `p` and the maximum value of `q`
#[cfg(test)]
Expand All @@ -244,7 +242,7 @@ where
if n == 0 {
return None;
}
let (qq, _) = debug_split_at(buffer, n)?;
let (qq, _) = buffer.debug_split_at(n);
Some((*p, *qq.iter().max().unwrap()))
}
/// Returns the map as bytes. The map can be recovered with [`Self::from_store`]
Expand Down
113 changes: 73 additions & 40 deletions experimental/zerotrie/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,88 @@
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).

use core::ops::Range;
pub(crate) trait MaybeSplitAt<T> {
/// Like slice::split_at but returns an Option instead of panicking
/// if the index is out of range.
fn maybe_split_at(&self, mid: usize) -> Option<(&Self, &Self)>;
/// Like slice::split_at but debug-panics and returns an empty second slice
/// if the index is out of range.
fn debug_split_at(&self, mid: usize) -> (&Self, &Self);
}

/// Like slice::split_at but returns an Option instead of panicking.
///
/// Debug-panics if `mid` is out of range.
#[inline]
pub(crate) fn debug_split_at(slice: &[u8], mid: usize) -> Option<(&[u8], &[u8])> {
if mid > slice.len() {
debug_assert!(false, "debug_split_at: index expected to be in range");
None
} else {
// Note: We're trusting the compiler to inline this and remove the assertion
// hiding on the top of slice::split_at: `assert(mid <= self.len())`
Some(slice.split_at(mid))
impl<T> MaybeSplitAt<T> for [T] {
#[inline]
fn maybe_split_at(&self, mid: usize) -> Option<(&Self, &Self)> {
if mid > self.len() {
None
} else {
// Note: We're trusting the compiler to inline this and remove the assertion
// hiding on the top of slice::split_at: `assert(mid <= self.len())`
Some(self.split_at(mid))
}
}
#[inline]
fn debug_split_at(&self, mid: usize) -> (&Self, &Self) {
if mid > self.len() {
debug_assert!(false, "debug_split_at: index expected to be in range");
(self, &[])
} else {
// Note: We're trusting the compiler to inline this and remove the assertion
// hiding on the top of slice::split_at: `assert(mid <= self.len())`
self.split_at(mid)
}
}
}

/// Like slice::split_at but returns an Option instead of panicking.
#[inline]
pub(crate) fn maybe_split_at(slice: &[u8], mid: usize) -> Option<(&[u8], &[u8])> {
if mid > slice.len() {
None
} else {
// Note: We're trusting the compiler to inline this and remove the assertion
// hiding on the top of slice::split_at: `assert(mid <= self.len())`
Some(slice.split_at(mid))
}
pub(crate) trait DebugUnwrapOr<T> {
/// Unwraps the option or panics in debug mode, returning the `gigo_value`
fn debug_unwrap_or(self, gigo_value: T) -> T;
}

/// Gets the item at the specified index, panicking in debug mode if it is not there.
#[inline]
pub(crate) fn debug_get(slice: &[u8], index: usize) -> Option<u8> {
match slice.get(index) {
Some(x) => Some(*x),
None => {
debug_assert!(false, "debug_get: index expected to be in range");
None
impl<T> DebugUnwrapOr<T> for Option<T> {
#[inline]
fn debug_unwrap_or(self, gigo_value: T) -> T {
match self {
Some(x) => x,
None => {
debug_assert!(false, "debug_unwrap_or called on a None value");
gigo_value
}
}
}
}

/// Gets the range between the specified indices, panicking in debug mode if not in bounds.
#[inline]
pub(crate) fn debug_get_range(slice: &[u8], range: Range<usize>) -> Option<&[u8]> {
match slice.get(range) {
Some(x) => Some(x),
None => {
debug_assert!(false, "debug_get_range: indices expected to be in range");
None
macro_rules! debug_unwrap {
($expr:expr, return $retval:expr, $($arg:tt)+) => {
match $expr {
Some(x) => x,
None => {
debug_assert!(false, $($arg)*);
return $retval;
}
}
}
};
($expr:expr, return $retval:expr) => {
debug_unwrap!($expr, return $retval, "invalid trie")
};
($expr:expr, break, $($arg:tt)+) => {
match $expr {
Some(x) => x,
None => {
debug_assert!(false, $($arg)*);
break;
}
}
};
($expr:expr, break) => {
debug_unwrap!($expr, break, "invalid trie")
};
($expr:expr, $($arg:tt)+) => {
debug_unwrap!($expr, return (), $($arg)*)
};
($expr:expr) => {
debug_unwrap!($expr, return ())
};
}

pub(crate) use debug_unwrap;
1 change: 1 addition & 0 deletions experimental/zerotrie/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ extern crate alloc;
mod builder;
mod byte_phf;
mod error;
#[macro_use]
mod helpers;
mod reader;
#[cfg(feature = "serde")]
Expand Down
74 changes: 37 additions & 37 deletions experimental/zerotrie/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,17 @@ use alloc::string::String;
/// - `n` = the number of items in the offset table
/// - `w` = the width of the offset table items minus one
#[inline]
fn get_branch(mut trie: &[u8], i: usize, n: usize, mut w: usize) -> Option<&[u8]> {
fn get_branch(mut trie: &[u8], i: usize, n: usize, mut w: usize) -> &[u8] {
let mut p = 0usize;
let mut q = 0usize;
loop {
let indices;
(indices, trie) = debug_split_at(trie, n - 1)?;
(indices, trie) = trie.debug_split_at(n - 1);
p = (p << 8)
+ if i == 0 {
0
} else {
debug_get(indices, i - 1)? as usize
*indices.get(i - 1).debug_unwrap_or(&0) as usize
};
q = match indices.get(i) {
Some(x) => (q << 8) + *x as usize,
Expand All @@ -240,24 +240,24 @@ fn get_branch(mut trie: &[u8], i: usize, n: usize, mut w: usize) -> Option<&[u8]
}
w -= 1;
}
debug_get_range(trie, p..q)
trie.get(p..q).debug_unwrap_or(&[])
}

/// Version of [`get_branch()`] specialized for the case `w == 0` for performance
#[inline]
fn get_branch_w0(mut trie: &[u8], i: usize, n: usize) -> Option<&[u8]> {
fn get_branch_w0(mut trie: &[u8], i: usize, n: usize) -> &[u8] {
let indices;
(indices, trie) = debug_split_at(trie, n - 1)?;
(indices, trie) = trie.debug_split_at(n - 1);
let p = if i == 0 {
0
} else {
debug_get(indices, i - 1)? as usize
*indices.get(i - 1).debug_unwrap_or(&0) as usize
};
let q = match indices.get(i) {
Some(x) => *x as usize,
None => trie.len(),
};
debug_get_range(trie, p..q)
trie.get(p..q).debug_unwrap_or(&[])
}

/// The node type. See the module-level docs for more explanation of the four node types.
Expand Down Expand Up @@ -311,8 +311,8 @@ pub fn get_bsearch_only(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let byte_type = byte_type(*b);
(x, trie) = match byte_type {
NodeType::Ascii => (0, trie),
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie)?,
NodeType::Branch => read_varint_meta2(*b, trie)?,
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie),
NodeType::Branch => read_varint_meta2(*b, trie),
};
if let Some((c, temp)) = ascii.split_first() {
if matches!(byte_type, NodeType::Ascii) {
Expand All @@ -331,8 +331,8 @@ pub fn get_bsearch_only(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
}
if matches!(byte_type, NodeType::Span) {
let (trie_span, ascii_span);
(trie_span, trie) = debug_split_at(trie, x)?;
(ascii_span, ascii) = maybe_split_at(ascii, x)?;
(trie_span, trie) = trie.debug_split_at(x);
(ascii_span, ascii) = ascii.maybe_split_at(x)?;
if trie_span == ascii_span {
// Matched a byte span
continue;
Expand All @@ -348,13 +348,13 @@ pub fn get_bsearch_only(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let w = w & 0x3;
let x = if x == 0 { 256 } else { x };
// Always use binary search
(search, trie) = debug_split_at(trie, x)?;
(search, trie) = trie.debug_split_at(x);
i = search.binary_search(c).ok()?;
trie = if w == 0 {
get_branch_w0(trie, i, x)
} else {
get_branch(trie, i, x, w)
}?;
};
ascii = temp;
continue;
} else {
Expand All @@ -375,8 +375,8 @@ pub fn get_phf_limited(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let byte_type = byte_type(*b);
(x, trie) = match byte_type {
NodeType::Ascii => (0, trie),
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie)?,
NodeType::Branch => read_varint_meta2(*b, trie)?,
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie),
NodeType::Branch => read_varint_meta2(*b, trie),
};
if let Some((c, temp)) = ascii.split_first() {
if matches!(byte_type, NodeType::Ascii) {
Expand All @@ -395,8 +395,8 @@ pub fn get_phf_limited(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
}
if matches!(byte_type, NodeType::Span) {
let (trie_span, ascii_span);
(trie_span, trie) = debug_split_at(trie, x)?;
(ascii_span, ascii) = maybe_split_at(ascii, x)?;
(trie_span, trie) = trie.debug_split_at(x);
(ascii_span, ascii) = ascii.maybe_split_at(x)?;
if trie_span == ascii_span {
// Matched a byte span
continue;
Expand All @@ -413,18 +413,18 @@ pub fn get_phf_limited(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let x = if x == 0 { 256 } else { x };
if x < 16 {
// binary search
(search, trie) = debug_split_at(trie, x)?;
(search, trie) = trie.debug_split_at(x);
i = search.binary_search(c).ok()?;
} else {
// phf
(search, trie) = debug_split_at(trie, x * 2 + 1)?;
(search, trie) = trie.debug_split_at(x * 2 + 1);
i = PerfectByteHashMap::from_store(search).get(*c)?;
}
trie = if w == 0 {
get_branch_w0(trie, i, x)
} else {
get_branch(trie, i, x, w)
}?;
};
ascii = temp;
continue;
} else {
Expand All @@ -445,8 +445,8 @@ pub fn get_phf_extended(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let byte_type = byte_type(*b);
(x, trie) = match byte_type {
NodeType::Ascii => (0, trie),
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie)?,
NodeType::Branch => read_varint_meta2(*b, trie)?,
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie),
NodeType::Branch => read_varint_meta2(*b, trie),
};
if let Some((c, temp)) = ascii.split_first() {
if matches!(byte_type, NodeType::Ascii) {
Expand All @@ -465,8 +465,8 @@ pub fn get_phf_extended(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
}
if matches!(byte_type, NodeType::Span) {
let (trie_span, ascii_span);
(trie_span, trie) = debug_split_at(trie, x)?;
(ascii_span, ascii) = maybe_split_at(ascii, x)?;
(trie_span, trie) = trie.debug_split_at(x);
(ascii_span, ascii) = ascii.maybe_split_at(x)?;
if trie_span == ascii_span {
// Matched a byte span
continue;
Expand All @@ -480,18 +480,18 @@ pub fn get_phf_extended(mut trie: &[u8], mut ascii: &[u8]) -> Option<usize> {
let x = if x == 0 { 256 } else { x };
if x < 16 {
// binary search
(search, trie) = debug_split_at(trie, x)?;
(search, trie) = trie.debug_split_at(x);
i = search.binary_search(c).ok()?;
} else {
// phf
(search, trie) = debug_split_at(trie, x * 2 + 1)?;
(search, trie) = trie.debug_split_at(x * 2 + 1);
i = PerfectByteHashMap::from_store(search).get(*c)?;
}
trie = if w == 0 {
get_branch_w0(trie, i, x)
} else {
get_branch(trie, i, x, w)
}?;
};
ascii = temp;
continue;
} else {
Expand Down Expand Up @@ -554,11 +554,11 @@ impl<'a> Iterator for ZeroTrieIterator<'a> {
}
(x, trie) = match byte_type {
NodeType::Ascii => (0, trie),
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie)?,
NodeType::Branch => read_varint_meta2(*b, trie)?,
NodeType::Span | NodeType::Value => read_varint_meta3(*b, trie),
NodeType::Branch => read_varint_meta2(*b, trie),
};
if matches!(byte_type, NodeType::Span) {
(span, trie) = debug_split_at(trie, x)?;
(span, trie) = trie.debug_split_at(x);
string.extend(span);
continue;
}
Expand All @@ -578,19 +578,19 @@ impl<'a> Iterator for ZeroTrieIterator<'a> {
}
let byte = if x < 16 || !self.use_phf {
// binary search
(search, trie) = debug_split_at(trie, x)?;
debug_get(search, branch_idx)?
(search, trie) = trie.debug_split_at(x);
debug_unwrap!(search.get(branch_idx), return None)
} else {
// phf
(search, trie) = debug_split_at(trie, x * 2 + 1)?;
debug_get(search, branch_idx + x + 1)?
(search, trie) = trie.debug_split_at(x * 2 + 1);
debug_unwrap!(search.get(branch_idx + x + 1), return None)
};
string.push(byte);
string.push(*byte);
trie = if w == 0 {
get_branch_w0(trie, branch_idx, x)
} else {
get_branch(trie, branch_idx, x, w)
}?;
};
branch_idx = 0;
}
}
Expand Down
Loading

0 comments on commit 75848da

Please sign in to comment.