diff --git a/arrow/benches/equal.rs b/arrow/benches/equal.rs index af535506e86d..f54aff1b5cc7 100644 --- a/arrow/benches/equal.rs +++ b/arrow/benches/equal.rs @@ -20,6 +20,7 @@ #[macro_use] extern crate criterion; +use arrow::compute::eq_utf8_scalar; use criterion::Criterion; extern crate arrow; @@ -31,6 +32,10 @@ fn bench_equal>(arr_a: &A) { criterion::black_box(arr_a == arr_a); } +fn bench_equal_utf8_scalar(arr_a: &GenericStringArray, right: &str) { + criterion::black_box(eq_utf8_scalar(arr_a, right).unwrap()); +} + fn add_benchmark(c: &mut Criterion) { let arr_a = create_primitive_array::(512, 0.0); c.bench_function("equal_512", |b| b.iter(|| bench_equal(&arr_a))); @@ -41,6 +46,11 @@ fn add_benchmark(c: &mut Criterion) { let arr_a = create_string_array::(512, 0.0); c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a))); + let arr_a = create_string_array::(512, 0.0); + c.bench_function("equal_string_scalar_empty_512", |b| { + b.iter(|| bench_equal_utf8_scalar(&arr_a, "")) + }); + let arr_a_nulls = create_string_array::(512, 0.5); c.bench_function("equal_string_nulls_512", |b| { b.iter(|| bench_equal(&arr_a_nulls)) diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs index 5a79c2e82df1..d4eb5a3e1d2b 100644 --- a/arrow/src/compute/kernels/comparison.rs +++ b/arrow/src/compute/kernels/comparison.rs @@ -233,61 +233,35 @@ pub fn like_utf8( } #[inline] -fn like_scalar<'a, L: ArrayAccessor>( +fn like_scalar_op<'a, F: Fn(bool) -> bool, L: ArrayAccessor>( left: L, right: &str, + op: F, ) -> Result { - let null_bit_buffer = left.data().null_buffer().cloned(); - let bytes = bit_util::ceil(left.len(), 8); - let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); - let bool_slice = bool_buf.as_slice_mut(); - if !right.contains(is_like_pattern) { // fast path, can use equals - for i in 0..left.len() { - unsafe { - if left.value_unchecked(i) == right { - bit_util::set_bit(bool_slice, i); - } - } - } + compare_op_scalar(left, |item| op(item == right)) } else if right.ends_with('%') && !right.ends_with("\\%") && !right[..right.len() - 1].contains(is_like_pattern) { // fast path, can use starts_with let starts_with = &right[..right.len() - 1]; - for i in 0..left.len() { - unsafe { - if left.value_unchecked(i).starts_with(starts_with) { - bit_util::set_bit(bool_slice, i); - } - } - } + + compare_op_scalar(left, |item| op(item.starts_with(starts_with))) } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { // fast path, can use ends_with let ends_with = &right[1..]; - for i in 0..left.len() { - unsafe { - if left.value_unchecked(i).ends_with(ends_with) { - bit_util::set_bit(bool_slice, i); - } - } - } + compare_op_scalar(left, |item| op(item.ends_with(ends_with))) } else if right.starts_with('%') && right.ends_with('%') + && !right.ends_with("\\%") && !right[1..right.len() - 1].contains(is_like_pattern) { - // fast path, can use contains let contains = &right[1..right.len() - 1]; - for i in 0..left.len() { - unsafe { - if left.value_unchecked(i).contains(contains) { - bit_util::set_bit(bool_slice, i); - } - } - } + + compare_op_scalar(left, |item| op(item.contains(contains))) } else { let re_pattern = replace_like_wildcards(right)?; let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { @@ -297,26 +271,16 @@ fn like_scalar<'a, L: ArrayAccessor>( )) })?; - for i in 0..left.len() { - let haystack = unsafe { left.value_unchecked(i) }; - if re.is_match(haystack) { - bit_util::set_bit(bool_slice, i); - } - } - }; + compare_op_scalar(left, |item| op(re.is_match(item))) + } +} - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - left.len(), - None, - null_bit_buffer, - 0, - vec![bool_buf.into()], - vec![], - ) - }; - Ok(BooleanArray::from(data)) +#[inline] +fn like_scalar<'a, L: ArrayAccessor>( + left: L, + right: &str, +) -> Result { + like_scalar_op(left, right, |x| x) } /// Perform SQL `left LIKE right` operation on [`StringArray`] / @@ -415,86 +379,7 @@ fn nlike_scalar<'a, L: ArrayAccessor>( left: L, right: &str, ) -> Result { - let null_bit_buffer = left.data().null_buffer().cloned(); - let bytes = bit_util::ceil(left.len(), 8); - let mut bool_buf = MutableBuffer::from_len_zeroed(bytes); - let bool_slice = bool_buf.as_slice_mut(); - - if !right.contains(is_like_pattern) { - // fast path, can use equals - for i in 0..left.len() { - unsafe { - if left.value_unchecked(i) != right { - bit_util::set_bit(bool_slice, i); - } - } - } - } else if right.ends_with('%') - && !right.ends_with("\\%") - && !right[..right.len() - 1].contains(is_like_pattern) - { - // fast path, can use starts_with - let starts_with = &right[..right.len() - 1]; - for i in 0..left.len() { - unsafe { - if !(left.value_unchecked(i).starts_with(starts_with)) { - bit_util::set_bit(bool_slice, i); - } - } - } - } else if right.starts_with('%') && !right[1..].contains(is_like_pattern) { - // fast path, can use ends_with - let ends_with = &right[1..]; - - for i in 0..left.len() { - unsafe { - if !(left.value_unchecked(i).ends_with(ends_with)) { - bit_util::set_bit(bool_slice, i); - } - } - } - } else if right.starts_with('%') - && right.ends_with('%') - && !right[1..right.len() - 1].contains(is_like_pattern) - { - // fast path, can use contains - let contains = &right[1..right.len() - 1]; - for i in 0..left.len() { - unsafe { - if !(left.value_unchecked(i).contains(contains)) { - bit_util::set_bit(bool_slice, i); - } - } - } - } else { - let re_pattern = replace_like_wildcards(right)?; - let re = Regex::new(&format!("^{}$", re_pattern)).map_err(|e| { - ArrowError::ComputeError(format!( - "Unable to build regex from LIKE pattern: {}", - e - )) - })?; - - for i in 0..left.len() { - let haystack = unsafe { left.value_unchecked(i) }; - if !re.is_match(haystack) { - bit_util::set_bit(bool_slice, i); - } - } - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - left.len(), - None, - null_bit_buffer, - 0, - vec![bool_buf.into()], - vec![], - ) - }; - Ok(BooleanArray::from(data)) + like_scalar_op(left, right, |x| !x) } /// Perform SQL `left NOT LIKE right` operation on [`StringArray`] / @@ -966,11 +851,48 @@ pub fn eq_utf8( compare_op(left, right, |a, b| a == b) } +fn utf8_empty( + left: &GenericStringArray, +) -> Result { + let null_bit_buffer = left + .data() + .null_buffer() + .map(|b| b.bit_slice(left.offset(), left.len())); + + let buffer = unsafe { + MutableBuffer::from_trusted_len_iter_bool(left.value_offsets().windows(2).map( + |offset| { + if EQ { + offset[1].to_usize().unwrap() == offset[0].to_usize().unwrap() + } else { + offset[1].to_usize().unwrap() > offset[0].to_usize().unwrap() + } + }, + )) + }; + + let data = unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + left.len(), + None, + null_bit_buffer, + 0, + vec![Buffer::from(buffer)], + vec![], + ) + }; + Ok(BooleanArray::from(data)) +} + /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { + if right.is_empty() { + return utf8_empty::<_, true>(left); + } compare_op_scalar(left, |a| a == right) } @@ -1167,6 +1089,9 @@ pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { + if right.is_empty() { + return utf8_empty::<_, false>(left); + } compare_op_scalar(left, |a| a != right) } @@ -4324,13 +4249,22 @@ mod tests { #[test] fn test_utf8_eq_scalar_on_slice() { - let a = StringArray::from(vec![Some("hi"), None, Some("hello"), Some("world")]); - let a = a.slice(1, 3); + let a = StringArray::from( + vec![Some("hi"), None, Some("hello"), Some("world"), Some("")], + ); + let a = a.slice(1, 4); let a = as_string_array(&a); let a_eq = eq_utf8_scalar(a, "hello").unwrap(); assert_eq!( a_eq, - BooleanArray::from(vec![None, Some(true), Some(false)]) + BooleanArray::from(vec![None, Some(true), Some(false), Some(false)]) + ); + + let a_eq2 = eq_utf8_scalar(a, "").unwrap(); + + assert_eq!( + a_eq2, + BooleanArray::from(vec![None, Some(false), Some(false), Some(true)]) ); } @@ -4528,6 +4462,14 @@ mod tests { vec![true, false] ); + test_utf8_scalar!( + test_utf8_scalar_like_escape_contains, + vec!["ba%", "ba\\x"], + "%a\\%", + like_utf8_scalar, + vec![true, false] + ); + test_utf8!( test_utf8_scalar_ilike_regex, vec!["%%%"],