Skip to content

Commit

Permalink
Implement wasm32 port
Browse files Browse the repository at this point in the history
  • Loading branch information
marmeladema committed Sep 8, 2021
1 parent 69ddae2 commit f0d333a
Show file tree
Hide file tree
Showing 5 changed files with 640 additions and 333 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@ repository = "https://github.com/cloudflare/sliceslice-rs"
license = "MIT"
keywords = ["search", "text", "string", "single", "simd"]

[lib]
crate-type = ["cdylib", "rlib"]

[dependencies]
memchr = "2.3"
seq-macro = "0.2"

[dev-dependencies]
cfg-if = "1"

[profile.release]
debug = true
314 changes: 314 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod x86;

/// Substring search implementations using wasm32 architecture features.
#[cfg(target_arch = "wasm32")]
pub mod wasm32;

mod bits;
mod memcmp;

Expand Down Expand Up @@ -106,6 +110,24 @@ impl Needle for Vec<u8> {
}
}

trait NeedleWithSize: Needle {
#[inline]
fn size(&self) -> usize {
if let Some(size) = Self::SIZE {
size
} else {
self.as_bytes().len()
}
}

#[inline]
fn is_empty(&self) -> bool {
self.size() == 0
}
}

impl<N: Needle + ?Sized> NeedleWithSize for N {}

/// Single-byte searcher using `memchr` for faster matching.
pub struct MemchrSearcher(u8);

Expand All @@ -131,6 +153,139 @@ impl MemchrSearcher {
}
}

/// Represents a generic SIMD register type.
trait Vector: Copy {
const LANES: usize;

unsafe fn set1_epi8(a: i8) -> Self;

unsafe fn loadu_si(a: *const u8) -> Self;

unsafe fn cmpeq_epi8(a: Self, b: Self) -> Self;

unsafe fn and_si(a: Self, b: Self) -> Self;

unsafe fn movemask_epi8(a: Self) -> i32;
}

/// Hash of the first and "last" bytes in the needle for use with the SIMD
/// algorithm implemented by `Searcher::vector_search_in`. As explained, any
/// byte can be chosen to represent the "last" byte of the hash to prevent
/// worst-case attacks.
struct VectorHash<V: Vector> {
first: V,
last: V,
}

impl<T: Vector, V: Vector + From<T>> From<&VectorHash<T>> for VectorHash<V> {
#[inline]
fn from(hash: &VectorHash<T>) -> Self {
Self {
first: V::from(hash.first),
last: V::from(hash.last),
}
}
}

impl<V: Vector> VectorHash<V> {
#[inline]
unsafe fn new(first: u8, last: u8) -> Self {
Self {
first: Vector::set1_epi8(first as i8),
last: Vector::set1_epi8(last as i8),
}
}
}

trait Searcher<N: NeedleWithSize + ?Sized> {
fn needle(&self) -> &N;

fn position(&self) -> usize;

#[inline]
unsafe fn vector_search_in_chunk<V: Vector>(
&self,
haystack: &[u8],
hash: &VectorHash<V>,
start: *const u8,
mask: i32,
) -> bool {
let first = Vector::loadu_si(start);
let last = Vector::loadu_si(start.add(self.position()));

let eq_first = Vector::cmpeq_epi8(hash.first, first);
let eq_last = Vector::cmpeq_epi8(hash.last, last);

let eq = Vector::and_si(eq_first, eq_last);
let mut eq = (Vector::movemask_epi8(eq) & mask) as u32;

let start = start as usize - haystack.as_ptr() as usize;
let chunk = haystack.as_ptr().add(start + 1);
let needle = self.needle().as_bytes().as_ptr().add(1);

while eq != 0 {
let chunk = chunk.add(eq.trailing_zeros() as usize);
let equal = match N::SIZE {
Some(0) => unreachable!(),
Some(1) => memcmp::specialized::<0>(chunk, needle),
Some(2) => memcmp::specialized::<1>(chunk, needle),
Some(3) => memcmp::specialized::<2>(chunk, needle),
Some(4) => memcmp::specialized::<3>(chunk, needle),
Some(5) => memcmp::specialized::<4>(chunk, needle),
Some(6) => memcmp::specialized::<5>(chunk, needle),
Some(7) => memcmp::specialized::<6>(chunk, needle),
Some(8) => memcmp::specialized::<7>(chunk, needle),
Some(9) => memcmp::specialized::<8>(chunk, needle),
Some(10) => memcmp::specialized::<9>(chunk, needle),
Some(11) => memcmp::specialized::<10>(chunk, needle),
Some(12) => memcmp::specialized::<11>(chunk, needle),
Some(13) => memcmp::specialized::<12>(chunk, needle),
Some(14) => memcmp::specialized::<13>(chunk, needle),
Some(15) => memcmp::specialized::<14>(chunk, needle),
Some(16) => memcmp::specialized::<15>(chunk, needle),
_ => memcmp::generic(chunk, needle, self.needle().size() - 1),
};
if equal {
return true;
}

eq = bits::clear_leftmost_set(eq);
}

false
}

#[inline]
unsafe fn vector_search_in<V: Vector>(
&self,
haystack: &[u8],
end: usize,
hash: &VectorHash<V>,
) -> bool {
debug_assert!(haystack.len() >= self.needle().size());

let mut chunks = haystack[..end].chunks_exact(V::LANES);
// while let Some(chunk) = chunks.next() {
for chunk in &mut chunks {
if self.vector_search_in_chunk(haystack, hash, chunk.as_ptr(), -1) {
return true;
}
}

let remainder = chunks.remainder().len();
if remainder > 0 {
let start = haystack.as_ptr().add(end - V::LANES);
let mask = -1 << (V::LANES - remainder);

if self.vector_search_in_chunk(haystack, hash, start, mask) {
return true;
}
}

false
}
}

#[cfg(test)]
mod tests {
use super::{MemchrSearcher, Needle};
Expand Down Expand Up @@ -200,4 +355,163 @@ mod tests {

assert_eq!(<&[u8] as Needle>::SIZE, None);
}

fn search(haystack: &[u8], needle: &[u8]) -> bool {
let result = haystack
.windows(needle.len())
.any(|window| window == needle);

for position in 0..needle.len() {
cfg_if::cfg_if! {
if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
use crate::x86::{Avx2Searcher, DynamicAvx2Searcher};

let searcher = unsafe { Avx2Searcher::with_position(needle, position) };
assert_eq!(unsafe { searcher.search_in(haystack) }, result);

let searcher = unsafe { DynamicAvx2Searcher::with_position(needle, position) };
assert_eq!(unsafe { searcher.search_in(haystack) }, result);
} else if #[cfg(target_arch = "wasm32")] {
use crate::wasm32::Wasm32Searcher;

let searcher = unsafe { Wasm32Searcher::with_position(needle, position) };
assert_eq!(unsafe { searcher.search_in(haystack) }, result);
} else {
compile_error!("Unsupported architecture");
}
}
}

result
}

#[test]
fn search_same() {
assert!(search(b"x", b"x"));

assert!(search(b"xy", b"xy"));

assert!(search(b"foo", b"foo"));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus"
));
}

#[test]
fn search_different() {
assert!(!search(b"x", b"y"));

assert!(!search(b"xy", b"xz"));

assert!(!search(b"bar", b"foo"));

assert!(!search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"foo"
));

assert!(!search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"foo"
));

assert!(!search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"foo bar baz qux quux quuz corge grault garply waldo fred plugh xyzzy thud"
));
}

#[test]
fn search_prefix() {
assert!(search(b"xy", b"x"));

assert!(search(b"foobar", b"foo"));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"Lorem"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"Lorem"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit"
));
}

#[test]
fn search_suffix() {
assert!(search(b"xy", b"y"));

assert!(search(b"foobar", b"bar"));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"elit"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"purus"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"Aliquam iaculis fringilla mi, nec aliquet purus"
));
}

#[test]
fn search_multiple() {
assert!(search(b"xx", b"x"));

assert!(search(b"xyxy", b"xy"));

assert!(search(b"foobarfoo", b"foo"));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"it"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"conse"
));
}

#[test]
fn search_middle() {
assert!(search(b"xyz", b"y"));

assert!(search(b"wxyz", b"xy"));

assert!(search(b"foobarfoo", b"bar"));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit",
b"consectetur"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"orci"
));

assert!(search(
b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Maecenas commodo posuere orci a consectetur. Ut mattis turpis ut auctor consequat. Aliquam iaculis fringilla mi, nec aliquet purus",
b"Maecenas commodo posuere orci a consectetur"
));
}
}
Loading

0 comments on commit f0d333a

Please sign in to comment.