From 2d3b0c2e503d9a4c2e8a5d427a6a55f9fb5faf29 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Tue, 18 Jun 2024 15:45:22 +0200 Subject: [PATCH] feat: Implement general array equality checks (#17043) --- crates/polars-arrow/src/bitmap/bitmap_ops.rs | 16 ++ crates/polars-arrow/src/bitmap/immutable.rs | 18 ++ .../polars-compute/src/arithmetic/signed.rs | 2 +- .../polars-compute/src/arithmetic/unsigned.rs | 2 +- .../polars-compute/src/comparisons/array.rs | 156 ++++--------- .../polars-compute/src/comparisons/binary.rs | 114 ++++++++++ .../polars-compute/src/comparisons/boolean.rs | 72 ++++++ .../src/comparisons/dictionary.rs | 75 +++++++ .../src/comparisons/dyn_array.rs | 85 +++++++ crates/polars-compute/src/comparisons/list.rs | 86 ++++++++ crates/polars-compute/src/comparisons/mod.rs | 62 ++++-- crates/polars-compute/src/comparisons/null.rs | 24 +- .../polars-compute/src/comparisons/scalar.rs | 202 ++--------------- crates/polars-compute/src/comparisons/simd.rs | 74 ++++--- .../polars-compute/src/comparisons/struct_.rs | 208 +++++++----------- crates/polars-compute/src/comparisons/utf8.rs | 53 +++++ crates/polars-compute/src/comparisons/view.rs | 101 +++++---- .../src/chunked_array/comparison/mod.rs | 4 +- .../src/chunked_array/comparison/scalar.rs | 2 +- py-polars/tests/unit/io/test_parquet.py | 29 +-- .../unit/testing/test_assert_series_equal.py | 2 +- 21 files changed, 825 insertions(+), 562 deletions(-) create mode 100644 crates/polars-compute/src/comparisons/binary.rs create mode 100644 crates/polars-compute/src/comparisons/boolean.rs create mode 100644 crates/polars-compute/src/comparisons/dictionary.rs create mode 100644 crates/polars-compute/src/comparisons/dyn_array.rs create mode 100644 crates/polars-compute/src/comparisons/list.rs create mode 100644 crates/polars-compute/src/comparisons/utf8.rs diff --git a/crates/polars-arrow/src/bitmap/bitmap_ops.rs b/crates/polars-arrow/src/bitmap/bitmap_ops.rs index 4392bdf25ee6..9e5ac502e6b5 100644 --- a/crates/polars-arrow/src/bitmap/bitmap_ops.rs +++ b/crates/polars-arrow/src/bitmap/bitmap_ops.rs @@ -300,6 +300,22 @@ pub fn intersects_with_mut(lhs: &MutableBitmap, rhs: &MutableBitmap) -> bool { ) } +/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy }`. +pub fn select_constant(selector: &Bitmap, truthy: &Bitmap, falsy: bool) -> Bitmap { + let falsy_mask: u64 = if falsy { + 0xFFFF_FFFF_FFFF_FFFF + } else { + 0x0000_0000_0000_0000 + }; + + binary(selector, truthy, |s, t| (s & t) | (!s & falsy_mask)) +} + +/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy[i] }`. +pub fn select(selector: &Bitmap, truthy: &Bitmap, falsy: &Bitmap) -> Bitmap { + ternary(selector, truthy, falsy, |s, t, f| (s & t) | (!s & f)) +} + impl PartialEq for Bitmap { fn eq(&self, other: &Self) -> bool { eq(self, other) diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 0b85af625c14..2d7f6a4a776a 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -485,6 +485,24 @@ impl Bitmap { pub fn num_intersections_with(&self, other: &Self) -> usize { num_intersections_with(self, other) } + + /// Select between `truthy` and `falsy` based on `self`. + /// + /// This essentially performs: + /// + /// `out[i] = if self[i] { truthy[i] } else { falsy[i] }` + pub fn select(&self, truthy: &Self, falsy: &Self) -> Self { + super::bitmap_ops::select(self, truthy, falsy) + } + + /// Select between `truthy` and constant `falsy` based on `self`. + /// + /// This essentially performs: + /// + /// `out[i] = if self[i] { truthy[i] } else { falsy }` + pub fn select_constant(&self, truthy: &Self, falsy: bool) -> Self { + super::bitmap_ops::select_constant(self, truthy, falsy) + } } impl> From

for Bitmap { diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs index a4af597d6639..a19f6b231526 100644 --- a/crates/polars-compute/src/arithmetic/signed.rs +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -5,7 +5,7 @@ use strength_reduce::*; use super::PrimitiveArithmeticKernelImpl; use crate::arity::{prim_binary_values, prim_unary_values}; -use crate::comparisons::TotalOrdKernel; +use crate::comparisons::TotalEqKernel; macro_rules! impl_signed_arith_kernel { ($T:ty, $StrRed:ty) => { diff --git a/crates/polars-compute/src/arithmetic/unsigned.rs b/crates/polars-compute/src/arithmetic/unsigned.rs index 283adc48b21c..2ae40332e820 100644 --- a/crates/polars-compute/src/arithmetic/unsigned.rs +++ b/crates/polars-compute/src/arithmetic/unsigned.rs @@ -4,7 +4,7 @@ use strength_reduce::*; use super::PrimitiveArithmeticKernelImpl; use crate::arity::{prim_binary_values, prim_unary_values}; -use crate::comparisons::TotalOrdKernel; +use crate::comparisons::TotalEqKernel; macro_rules! impl_unsigned_arith_kernel { ($T:ty, $StrRed:ty) => { diff --git a/crates/polars-compute/src/comparisons/array.rs b/crates/polars-compute/src/comparisons/array.rs index a965a3a3a222..b981a50b3547 100644 --- a/crates/polars-compute/src/comparisons/array.rs +++ b/crates/polars-compute/src/comparisons/array.rs @@ -1,12 +1,10 @@ -use arrow::array::{ - Array, BinaryViewArray, BooleanArray, FixedSizeListArray, NullArray, PrimitiveArray, - StructArray, Utf8ViewArray, -}; +use arrow::array::{Array, FixedSizeListArray}; use arrow::bitmap::utils::count_zeros; use arrow::bitmap::Bitmap; use arrow::datatypes::ArrowDataType; -use crate::comparisons::TotalOrdKernel; +use super::TotalEqKernel; +use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; /// Condenses a bitmap of n * width elements into one with n elements. /// @@ -16,114 +14,68 @@ fn agg_array_bitmap(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap where F: Fn(usize) -> bool, { - assert!(width > 0 && bm.len() % width == 0); - let (slice, offset, _len) = bm.as_slice(); - - (0..bm.len() / width) - .map(|i| true_zero_count(count_zeros(slice, offset + i * width, width))) - .collect() + if bm.len() == 1 { + bm + } else { + assert!(width > 0 && bm.len() % width == 0); + + let (slice, offset, _len) = bm.as_slice(); + (0..bm.len() / width) + .map(|i| true_zero_count(count_zeros(slice, offset + i * width, width))) + .collect() + } } -macro_rules! call_binary { - ($T:ty, $lhs:expr, $rhs:expr, $op:path) => {{ - let lhs: &$T = $lhs.as_any().downcast_ref().unwrap(); - let rhs: &$T = $rhs.as_any().downcast_ref().unwrap(); - $op(lhs, rhs) - }}; -} +impl TotalEqKernel for FixedSizeListArray { + type Scalar = Box; -macro_rules! compare { - ($lhs:expr, $rhs:expr, $wrong_width:expr, $op:path) => {{ - let lhs = $lhs; - let rhs = $rhs; - assert_eq!(lhs.len(), rhs.len()); - let ArrowDataType::FixedSizeList(lhs_type, lhs_width) = lhs.data_type().to_logical_type() + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + // Nested comparison always done with eq_missing, propagating doesn't + // make any sense. + + assert_eq!(self.len(), other.len()); + let ArrowDataType::FixedSizeList(self_type, self_width) = + self.data_type().to_logical_type() else { panic!("array comparison called with non-array type"); }; - let ArrowDataType::FixedSizeList(rhs_type, rhs_width) = rhs.data_type().to_logical_type() + let ArrowDataType::FixedSizeList(other_type, other_width) = + other.data_type().to_logical_type() else { panic!("array comparison called with non-array type"); }; - assert_eq!(lhs_type.data_type(), rhs_type.data_type()); + assert_eq!(self_type.data_type(), other_type.data_type()); - if lhs_width != rhs_width { - return Bitmap::new_with_value($wrong_width, lhs.len()); + if self_width != other_width { + return Bitmap::new_with_value(false, self.len()); } - use arrow::datatypes::{PhysicalType as PH, PrimitiveType as PR}; - let lv = lhs.values(); - let rv = rhs.values(); - match lhs_type.data_type().to_physical_type() { - PH::Boolean => call_binary!(BooleanArray, lv, rv, $op), - PH::BinaryView => call_binary!(BinaryViewArray, lv, rv, $op), - PH::Utf8View => call_binary!(Utf8ViewArray, lv, rv, $op), - PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Float16) => { - todo!("Comparison of Arrays with Primitive(Float16) are not yet supported") - }, - PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int256) => { - todo!("Comparison of Arrays with Primitive(Int256) are not yet supported") - }, - PH::Primitive(PR::DaysMs) => { - todo!("Comparison of Arrays with Primitive(DaysMs) are not yet supported") - }, - PH::Primitive(PR::MonthDayNano) => { - todo!("Comparison of Arrays with Primitive(MonthDayNano) are not yet supported") - }, - PH::FixedSizeList => call_binary!(FixedSizeListArray, lv, rv, $op), - PH::Null => call_binary!(NullArray, lv, rv, $op), - PH::Binary => todo!("Comparison of Arrays with Binary are not yet supported"), - PH::FixedSizeBinary => { - todo!("Comparison of Arrays with FixedSizeBinary are not yet supported") - }, - PH::LargeBinary => todo!("Comparison of Arrays with LargeBinary are not yet supported"), - PH::Utf8 => todo!("Comparison of Arrays with Utf8 are not yet supported"), - PH::LargeUtf8 => todo!("Comparison of Arrays with LargeUtf8 are not yet supported"), - PH::List => todo!("Comparison of Arrays with List are not yet supported"), - PH::LargeList => todo!("Comparison of Arrays with LargeList are not yet supported"), - PH::Struct => call_binary!(StructArray, lv, rv, $op), - PH::Union => todo!("Comparison of Arrays with Union are not yet supported"), - PH::Map => todo!("Comparison of Arrays with Map are not yet supported"), - PH::Dictionary(_) => { - todo!("Comparison of Arrays with Dictionary are not yet supported") - }, - } - }}; -} - -impl TotalOrdKernel for FixedSizeListArray { - type Scalar = Box; + let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref()); - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - // Nested comparison always done with eq_missing, propagating doesn't - // make any sense. - let inner = compare!(self, other, false, TotalOrdKernel::tot_eq_missing_kernel); agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0) } fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - let inner = compare!(self, other, true, TotalOrdKernel::tot_eq_missing_kernel); - agg_array_bitmap(inner, self.size(), |zeroes| zeroes > 0) - } + assert_eq!(self.len(), other.len()); + let ArrowDataType::FixedSizeList(self_type, self_width) = + self.data_type().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + let ArrowDataType::FixedSizeList(other_type, other_width) = + other.data_type().to_logical_type() + else { + panic!("array comparison called with non-array type"); + }; + assert_eq!(self_type.data_type(), other_type.data_type()); - fn tot_lt_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() - } + if self_width != other_width { + return Bitmap::new_with_value(true, self.len()); + } + + let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref()); - fn tot_le_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() + agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size()) } fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { @@ -133,20 +85,4 @@ impl TotalOrdKernel for FixedSizeListArray { fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { todo!() } - - fn tot_lt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() - } - - fn tot_le_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() - } - - fn tot_gt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() - } - - fn tot_ge_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() - } } diff --git a/crates/polars-compute/src/comparisons/binary.rs b/crates/polars-compute/src/comparisons/binary.rs new file mode 100644 index 000000000000..a9b1f17c68c5 --- /dev/null +++ b/crates/polars-compute/src/comparisons/binary.rs @@ -0,0 +1,114 @@ +use arrow::array::{BinaryArray, FixedSizeBinaryArray}; +use arrow::bitmap::Bitmap; +use arrow::types::Offset; +use polars_utils::total_ord::{TotalEq, TotalOrd}; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for BinaryArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_eq(&r)) + .collect() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_ne(&r)) + .collect() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_eq(&other)).collect() + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_ne(&other)).collect() + } +} + +impl TotalOrdKernel for BinaryArray { + type Scalar = [u8]; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_lt(&r)) + .collect() + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + self.values_iter() + .zip(other.values_iter()) + .map(|(l, r)| l.tot_le(&r)) + .collect() + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_lt(&other)).collect() + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_le(&other)).collect() + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_gt(&other)).collect() + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.values_iter().map(|l| l.tot_ge(&other)).collect() + } +} + +impl TotalEqKernel for FixedSizeBinaryArray { + type Scalar = [u8]; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + + if self.size() != other.size() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()) + .map(|i| self.value(i) == other.value(i)) + .collect() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert!(self.len() == other.len()); + + if self.size() != other.size() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()) + .map(|i| self.value(i) == other.value(i)) + .collect() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if self.size() != other.len() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()).map(|i| self.value(i) == other).collect() + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if self.size() != other.len() { + return Bitmap::new_zeroed(self.len()); + } + + (0..self.len()).map(|i| self.value(i) != other).collect() + } +} diff --git a/crates/polars-compute/src/comparisons/boolean.rs b/crates/polars-compute/src/comparisons/boolean.rs new file mode 100644 index 000000000000..39a8f9b3814a --- /dev/null +++ b/crates/polars-compute/src/comparisons/boolean.rs @@ -0,0 +1,72 @@ +use arrow::array::BooleanArray; +use arrow::bitmap::{self, Bitmap}; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for BooleanArray { + type Scalar = bool; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !(l ^ r)) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.values() ^ other.values() + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + self.values().clone() + } else { + !self.values() + } + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.tot_eq_kernel_broadcast(&!*other) + } +} + +impl TotalOrdKernel for BooleanArray { + type Scalar = bool; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !l & r) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + bitmap::binary(self.values(), other.values(), |l, r| !l | r) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + !self.values() + } else { + Bitmap::new_zeroed(self.len()) + } + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + Bitmap::new_with_value(true, self.len()) + } else { + !self.values() + } + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + Bitmap::new_zeroed(self.len()) + } else { + self.values().clone() + } + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if *other { + self.values().clone() + } else { + Bitmap::new_with_value(true, self.len()) + } + } +} diff --git a/crates/polars-compute/src/comparisons/dictionary.rs b/crates/polars-compute/src/comparisons/dictionary.rs new file mode 100644 index 000000000000..18d82c9efea1 --- /dev/null +++ b/crates/polars-compute/src/comparisons/dictionary.rs @@ -0,0 +1,75 @@ +use arrow::array::{Array, DictionaryArray, DictionaryKey}; +use arrow::bitmap::{Bitmap, MutableBitmap}; + +use super::TotalEqKernel; +use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; + +impl TotalEqKernel for DictionaryArray { + type Scalar = Box; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = MutableBitmap::with_capacity(self.len()); + + for i in 0..self.len() { + let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(true); + continue; + } + + let lkey = self.key_value(i); + let rkey = other.key_value(i); + + let mut lhs_value = self.values().clone(); + lhs_value.slice(lkey, 1); + let mut rhs_value = other.values().clone(); + rhs_value.slice(rkey, 1); + + let result = array_tot_eq_missing_kernel(lhs_value.as_ref(), rhs_value.as_ref()); + bitmap.push(result.unset_bits() == 0); + } + + bitmap.freeze() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = MutableBitmap::with_capacity(self.len()); + + for i in 0..self.len() { + let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(false); + continue; + } + + let lkey = self.key_value(i); + let rkey = other.key_value(i); + + let mut lhs_value = self.values().clone(); + lhs_value.slice(lkey, 1); + let mut rhs_value = other.values().clone(); + rhs_value.slice(rkey, 1); + + let result = array_tot_ne_missing_kernel(lhs_value.as_ref(), rhs_value.as_ref()); + bitmap.push(result.set_bits() > 0); + } + + bitmap.freeze() + } + + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> arrow::bitmap::Bitmap { + todo!() + } + + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> arrow::bitmap::Bitmap { + todo!() + } +} diff --git a/crates/polars-compute/src/comparisons/dyn_array.rs b/crates/polars-compute/src/comparisons/dyn_array.rs new file mode 100644 index 000000000000..693293f4e2c5 --- /dev/null +++ b/crates/polars-compute/src/comparisons/dyn_array.rs @@ -0,0 +1,85 @@ +use arrow::array::{ + Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, FixedSizeBinaryArray, + ListArray, NullArray, PrimitiveArray, StructArray, Utf8Array, Utf8ViewArray, +}; +use arrow::bitmap::Bitmap; +use arrow::types::{days_ms, f16, i256, months_days_ns}; + +use crate::comparisons::TotalEqKernel; + +macro_rules! call_binary { + ($T:ty, $lhs:expr, $rhs:expr, $op:path) => {{ + let lhs: &$T = $lhs.as_any().downcast_ref().unwrap(); + let rhs: &$T = $rhs.as_any().downcast_ref().unwrap(); + $op(lhs, rhs) + }}; +} + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $op:path) => {{ + let lhs = $lhs; + let rhs = $rhs; + + assert_eq!(lhs.data_type(), rhs.data_type()); + + use arrow::datatypes::{IntegerType as I, PhysicalType as PH, PrimitiveType as PR}; + match lhs.data_type().to_physical_type() { + PH::Boolean => call_binary!(BooleanArray, lhs, rhs, $op), + PH::BinaryView => call_binary!(BinaryViewArray, lhs, rhs, $op), + PH::Utf8View => call_binary!(Utf8ViewArray, lhs, rhs, $op), + PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float16) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::Int256) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::DaysMs) => call_binary!(PrimitiveArray, lhs, rhs, $op), + PH::Primitive(PR::MonthDayNano) => { + call_binary!(PrimitiveArray, lhs, rhs, $op) + }, + + #[cfg(feature = "dtype-array")] + PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray, lhs, rhs, $op), + #[cfg(not(feature = "dtype-array"))] + PH::FixedSizeList => todo!( + "Comparison of FixedSizeListArray is not supported without dtype-array feature" + ), + + PH::Null => call_binary!(NullArray, lhs, rhs, $op), + PH::FixedSizeBinary => call_binary!(FixedSizeBinaryArray, lhs, rhs, $op), + PH::Binary => call_binary!(BinaryArray, lhs, rhs, $op), + PH::LargeBinary => call_binary!(BinaryArray, lhs, rhs, $op), + PH::Utf8 => call_binary!(Utf8Array, lhs, rhs, $op), + PH::LargeUtf8 => call_binary!(Utf8Array, lhs, rhs, $op), + PH::List => call_binary!(ListArray, lhs, rhs, $op), + PH::LargeList => call_binary!(ListArray, lhs, rhs, $op), + PH::Struct => call_binary!(StructArray, lhs, rhs, $op), + PH::Union => todo!("Comparison of UnionArrays is not yet supported"), + PH::Map => todo!("Comparison of MapArrays is not yet supported"), + PH::Dictionary(I::Int8) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int16) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int32) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::Int64) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt8) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt16) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt32) => call_binary!(DictionaryArray, lhs, rhs, $op), + PH::Dictionary(I::UInt64) => call_binary!(DictionaryArray, lhs, rhs, $op), + } + }}; +} + +pub fn array_tot_eq_missing_kernel(lhs: &dyn Array, rhs: &dyn Array) -> Bitmap { + compare!(lhs, rhs, TotalEqKernel::tot_eq_missing_kernel) +} + +pub fn array_tot_ne_missing_kernel(lhs: &dyn Array, rhs: &dyn Array) -> Bitmap { + compare!(lhs, rhs, TotalEqKernel::tot_ne_missing_kernel) +} diff --git a/crates/polars-compute/src/comparisons/list.rs b/crates/polars-compute/src/comparisons/list.rs new file mode 100644 index 000000000000..a66ad4f4312a --- /dev/null +++ b/crates/polars-compute/src/comparisons/list.rs @@ -0,0 +1,86 @@ +use arrow::array::ListArray; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::types::Offset; + +use super::TotalEqKernel; +use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel}; + +impl TotalEqKernel for ListArray { + type Scalar = (); + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = MutableBitmap::with_capacity(self.len()); + + for i in 0..self.len() { + let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(true); + continue; + } + + let (lstart, lend) = self.offsets().start_end(i); + let (rstart, rend) = other.offsets().start_end(i); + + if lend - lstart != rend - rstart { + bitmap.push(false); + continue; + } + + let mut lhs_values = self.values().clone(); + lhs_values.slice(lstart, lend - lstart); + let mut rhs_values = self.values().clone(); + rhs_values.slice(rstart, rend - rstart); + + let result = array_tot_eq_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); + bitmap.push(result.unset_bits() == 0); + } + + bitmap.freeze() + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + assert_eq!(self.len(), other.len()); + + let mut bitmap = MutableBitmap::with_capacity(self.len()); + + for i in 0..self.len() { + let (lstart, lend) = self.offsets().start_end(i); + let (rstart, rend) = other.offsets().start_end(i); + + let lval = self.validity().map_or(true, |v| v.get(i).unwrap()); + let rval = other.validity().map_or(true, |v| v.get(i).unwrap()); + + if !lval || !rval { + bitmap.push(false); + continue; + } + + if lend - lstart != rend - rstart { + bitmap.push(true); + continue; + } + + let mut lhs_values = self.values().clone(); + lhs_values.slice(lstart, lend - lstart); + let mut rhs_values = self.values().clone(); + rhs_values.slice(rstart, rend - rstart); + + let result = array_tot_ne_missing_kernel(lhs_values.as_ref(), rhs_values.as_ref()); + bitmap.push(result.set_bits() > 0); + } + + bitmap.freeze() + } + + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } + + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() + } +} diff --git a/crates/polars-compute/src/comparisons/mod.rs b/crates/polars-compute/src/comparisons/mod.rs index f0721a1ac7e8..f10d36f92213 100644 --- a/crates/polars-compute/src/comparisons/mod.rs +++ b/crates/polars-compute/src/comparisons/mod.rs @@ -1,31 +1,15 @@ use arrow::array::Array; use arrow::bitmap::{self, Bitmap}; -// Low-level comparison kernel. -pub trait TotalOrdKernel: Sized + Array { +pub trait TotalEqKernel: Sized + Array { type Scalar: ?Sized; // These kernels ignore validity entirely (results for nulls are unspecified // but initialized). fn tot_eq_kernel(&self, other: &Self) -> Bitmap; fn tot_ne_kernel(&self, other: &Self) -> Bitmap; - fn tot_lt_kernel(&self, other: &Self) -> Bitmap; - fn tot_le_kernel(&self, other: &Self) -> Bitmap; - fn tot_gt_kernel(&self, other: &Self) -> Bitmap { - other.tot_lt_kernel(self) - } - fn tot_ge_kernel(&self, other: &Self) -> Bitmap { - other.tot_le_kernel(self) - } - - // These kernels ignore validity entirely (results for nulls are unspecified - // but initialized). fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; // These kernels treat null as any other value equal to itself but unequal // to anything else. @@ -50,9 +34,6 @@ pub trait TotalOrdKernel: Sized + Array { }; combined } - - // These kernels treat null as any other value equal to itself but unequal - // to anything else. other is assumed to be non-null. fn tot_eq_missing_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { let q = self.tot_eq_kernel_broadcast(other); if let Some(valid) = self.validity() { @@ -72,11 +53,52 @@ pub trait TotalOrdKernel: Sized + Array { } } +// Low-level comparison kernel. +pub trait TotalOrdKernel: Sized + Array { + type Scalar: ?Sized; + + // These kernels ignore validity entirely (results for nulls are unspecified + // but initialized). + fn tot_lt_kernel(&self, other: &Self) -> Bitmap; + fn tot_le_kernel(&self, other: &Self) -> Bitmap; + fn tot_gt_kernel(&self, other: &Self) -> Bitmap { + other.tot_lt_kernel(self) + } + fn tot_ge_kernel(&self, other: &Self) -> Bitmap { + other.tot_le_kernel(self) + } + + // These kernels ignore validity entirely (results for nulls are unspecified + // but initialized). + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap; +} + +mod binary; +mod boolean; +mod dictionary; +mod dyn_array; +mod list; mod null; mod scalar; mod struct_; +mod utf8; mod view; +#[cfg(feature = "simd")] +mod _simd_dtypes { + use arrow::types::{days_ms, f16, i256, months_days_ns}; + + use crate::NotSimdPrimitive; + + impl NotSimdPrimitive for f16 {} + impl NotSimdPrimitive for i256 {} + impl NotSimdPrimitive for days_ms {} + impl NotSimdPrimitive for months_days_ns {} +} + #[cfg(feature = "simd")] mod simd; diff --git a/crates/polars-compute/src/comparisons/null.rs b/crates/polars-compute/src/comparisons/null.rs index 7fd5ba6f2b06..9d6e9e3dcefd 100644 --- a/crates/polars-compute/src/comparisons/null.rs +++ b/crates/polars-compute/src/comparisons/null.rs @@ -1,9 +1,9 @@ use arrow::array::{Array, NullArray}; use arrow::bitmap::Bitmap; -use super::TotalOrdKernel; +use super::{TotalEqKernel, TotalOrdKernel}; -impl TotalOrdKernel for NullArray { +impl TotalEqKernel for NullArray { type Scalar = Box; fn tot_eq_kernel(&self, other: &Self) -> Bitmap { @@ -16,14 +16,6 @@ impl TotalOrdKernel for NullArray { Bitmap::new_zeroed(self.len()) } - fn tot_lt_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() - } - - fn tot_le_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() - } - fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { todo!() } @@ -31,6 +23,18 @@ impl TotalOrdKernel for NullArray { fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { todo!() } +} + +impl TotalOrdKernel for NullArray { + type Scalar = Box; + + fn tot_lt_kernel(&self, _other: &Self) -> Bitmap { + unimplemented!() + } + + fn tot_le_kernel(&self, _other: &Self) -> Bitmap { + unimplemented!() + } fn tot_lt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { unimplemented!() diff --git a/crates/polars-compute/src/comparisons/scalar.rs b/crates/polars-compute/src/comparisons/scalar.rs index f2f245ed34c5..d792503b4225 100644 --- a/crates/polars-compute/src/comparisons/scalar.rs +++ b/crates/polars-compute/src/comparisons/scalar.rs @@ -1,31 +1,13 @@ -use arrow::array::{BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; -use arrow::bitmap::{self, Bitmap}; -use polars_utils::total_ord::{TotalEq, TotalOrd}; +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use polars_utils::total_ord::TotalOrd; -use super::TotalOrdKernel; +use super::{TotalEqKernel, TotalOrdKernel}; use crate::NotSimdPrimitive; -impl TotalOrdKernel for PrimitiveArray { +impl TotalEqKernel for PrimitiveArray { type Scalar = T; - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - assert!(self.len() == other.len()); - self.values() - .iter() - .zip(other.values().iter()) - .map(|(l, r)| l.tot_lt(r)) - .collect() - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - assert!(self.len() == other.len()); - self.values() - .iter() - .zip(other.values().iter()) - .map(|(l, r)| l.tot_le(r)) - .collect() - } - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { assert!(self.len() == other.len()); self.values() @@ -51,188 +33,42 @@ impl TotalOrdKernel for PrimitiveArray { fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { self.values().iter().map(|l| l.tot_ne(other)).collect() } - - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values().iter().map(|l| l.tot_lt(other)).collect() - } - - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values().iter().map(|l| l.tot_le(other)).collect() - } - - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values().iter().map(|l| l.tot_gt(other)).collect() - } - - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values().iter().map(|l| l.tot_ge(other)).collect() - } } -impl TotalOrdKernel for BinaryArray { - type Scalar = [u8]; +impl TotalOrdKernel for PrimitiveArray { + type Scalar = T; fn tot_lt_kernel(&self, other: &Self) -> Bitmap { assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_lt(&r)) + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_lt(r)) .collect() } fn tot_le_kernel(&self, other: &Self) -> Bitmap { assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_le(&r)) - .collect() - } - - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_eq(&r)) - .collect() - } - - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - assert!(self.len() == other.len()); - self.values_iter() - .zip(other.values_iter()) - .map(|(l, r)| l.tot_ne(&r)) + self.values() + .iter() + .zip(other.values().iter()) + .map(|(l, r)| l.tot_le(r)) .collect() } - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_eq(&other)).collect() - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_ne(&other)).collect() - } - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_lt(&other)).collect() - } - - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_le(&other)).collect() - } - - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_gt(&other)).collect() - } - - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.values_iter().map(|l| l.tot_ge(&other)).collect() - } -} - -impl TotalOrdKernel for Utf8Array { - type Scalar = str; - - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - self.to_binary().tot_lt_kernel(&other.to_binary()) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - self.to_binary().tot_le_kernel(&other.to_binary()) - } - - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - self.to_binary().tot_eq_kernel(&other.to_binary()) - } - - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - self.to_binary().tot_ne_kernel(&other.to_binary()) - } - - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_eq_kernel_broadcast(other.as_bytes()) - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_ne_kernel_broadcast(other.as_bytes()) - } - - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_lt_kernel_broadcast(other.as_bytes()) - } - - fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_le_kernel_broadcast(other.as_bytes()) - } - - fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_gt_kernel_broadcast(other.as_bytes()) - } - - fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.to_binary().tot_ge_kernel_broadcast(other.as_bytes()) - } -} - -impl TotalOrdKernel for BooleanArray { - type Scalar = bool; - - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - bitmap::binary(self.values(), other.values(), |l, r| !l & r) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - bitmap::binary(self.values(), other.values(), |l, r| !l | r) - } - - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - bitmap::binary(self.values(), other.values(), |l, r| !(l ^ r)) - } - - fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - self.values() ^ other.values() - } - - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if *other { - self.values().clone() - } else { - !self.values() - } - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - self.tot_eq_kernel_broadcast(&!*other) - } - - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if *other { - !self.values() - } else { - Bitmap::new_zeroed(self.len()) - } + self.values().iter().map(|l| l.tot_lt(other)).collect() } fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if *other { - Bitmap::new_with_value(true, self.len()) - } else { - !self.values() - } + self.values().iter().map(|l| l.tot_le(other)).collect() } fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if *other { - Bitmap::new_zeroed(self.len()) - } else { - self.values().clone() - } + self.values().iter().map(|l| l.tot_gt(other)).collect() } fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if *other { - self.values().clone() - } else { - Bitmap::new_with_value(true, self.len()) - } + self.values().iter().map(|l| l.tot_ge(other)).collect() } } diff --git a/crates/polars-compute/src/comparisons/simd.rs b/crates/polars-compute/src/comparisons/simd.rs index f3aed5ed18e1..f855ed4ad1c0 100644 --- a/crates/polars-compute/src/comparisons/simd.rs +++ b/crates/polars-compute/src/comparisons/simd.rs @@ -6,7 +6,7 @@ use arrow::bitmap::Bitmap; use arrow::types::NativeType; use bytemuck::Pod; -use super::TotalOrdKernel; +use super::{TotalEqKernel, TotalOrdKernel}; fn apply_binary_kernel( lhs: &PrimitiveArray, @@ -99,7 +99,7 @@ where macro_rules! impl_int_total_ord_kernel { ($T: ty, $width: literal, $mask: ty) => { - impl TotalOrdKernel for PrimitiveArray<$T> { + impl TotalEqKernel for PrimitiveArray<$T> { type Scalar = $T; fn tot_eq_kernel(&self, other: &Self) -> Bitmap { @@ -114,18 +114,6 @@ macro_rules! impl_int_total_ord_kernel { }) } - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { - Simd::from(*l).simd_lt(Simd::from(*r)).to_bitmask() as $mask - }) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { - Simd::from(*l).simd_le(Simd::from(*r)).to_bitmask() as $mask - }) - } - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { let r = Simd::splat(*other); apply_unary_kernel::<$width, $mask, _, _>(self, |l| { @@ -139,6 +127,22 @@ macro_rules! impl_int_total_ord_kernel { Simd::from(*l).simd_ne(r).to_bitmask() as $mask }) } + } + + impl TotalOrdKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_lt(Simd::from(*r)).to_bitmask() as $mask + }) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + Simd::from(*l).simd_le(Simd::from(*r)).to_bitmask() as $mask + }) + } fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { let r = Simd::splat(*other); @@ -173,7 +177,7 @@ macro_rules! impl_int_total_ord_kernel { macro_rules! impl_float_total_ord_kernel { ($T: ty, $width: literal, $mask: ty) => { - impl TotalOrdKernel for PrimitiveArray<$T> { + impl TotalEqKernel for PrimitiveArray<$T> { type Scalar = $T; fn tot_eq_kernel(&self, other: &Self) -> Bitmap { @@ -196,24 +200,6 @@ macro_rules! impl_float_total_ord_kernel { }) } - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { - let ls = Simd::from(*l); - let rs = Simd::from(*r); - let lhs_is_nan = ls.simd_ne(ls); - (!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask - }) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { - let ls = Simd::from(*l); - let rs = Simd::from(*r); - let rhs_is_nan = rs.simd_ne(rs); - (rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask - }) - } - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { let rs = Simd::splat(*other); apply_unary_kernel::<$width, $mask, _, _>(self, |l| { @@ -233,6 +219,28 @@ macro_rules! impl_float_total_ord_kernel { (!((lhs_is_nan & rhs_is_nan) | ls.simd_eq(rs))).to_bitmask() as $mask }) } + } + + impl TotalOrdKernel for PrimitiveArray<$T> { + type Scalar = $T; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let lhs_is_nan = ls.simd_ne(ls); + (!(lhs_is_nan | ls.simd_ge(rs))).to_bitmask() as $mask + }) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + apply_binary_kernel::<$width, $mask, _, _>(self, other, |l, r| { + let ls = Simd::from(*l); + let rs = Simd::from(*r); + let rhs_is_nan = rs.simd_ne(rs); + (rhs_is_nan | ls.simd_le(rs)).to_bitmask() as $mask + }) + } fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { let rs = Simd::splat(*other); diff --git a/crates/polars-compute/src/comparisons/struct_.rs b/crates/polars-compute/src/comparisons/struct_.rs index ef3cb9ca52b6..f7c7a3a21684 100644 --- a/crates/polars-compute/src/comparisons/struct_.rs +++ b/crates/polars-compute/src/comparisons/struct_.rs @@ -1,152 +1,108 @@ -use arrow::array::{ - Array, BinaryViewArray, BooleanArray, NullArray, PrimitiveArray, StructArray, Utf8ViewArray, -}; -use arrow::bitmap::Bitmap; -use arrow::datatypes::ArrowDataType; +use arrow::array::{Array, StructArray}; +use arrow::bitmap::{Bitmap, MutableBitmap}; -use super::TotalOrdKernel; +use super::TotalEqKernel; +use crate::comparisons::dyn_array::array_tot_eq_missing_kernel; -macro_rules! call_binary { - ($T:ty, $lhs:expr, $rhs:expr, $op:path) => {{ - let lhs: &$T = $lhs.as_any().downcast_ref().unwrap(); - let rhs: &$T = $rhs.as_any().downcast_ref().unwrap(); - - $op(lhs, rhs) - }}; -} +impl TotalEqKernel for StructArray { + type Scalar = Box; -macro_rules! compare { - ($lhs:expr, $rhs:expr, $op:path, $fold:expr) => {{ - let lhs = $lhs; - let rhs = $rhs; + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + let lhs = self; + let rhs = other; assert_eq!(lhs.len(), rhs.len()); - let ArrowDataType::Struct(lhs_type) = lhs.data_type().to_logical_type() - else { - panic!("array comparison called with non-array type"); - }; - let ArrowDataType::Struct(rhs_type) = rhs.data_type().to_logical_type() - else { - panic!("array comparison called with non-array type"); - }; - assert_eq!(lhs_type.len(), rhs_type.len()); + + if lhs.fields() != rhs.fields() { + return Bitmap::new_zeroed(lhs.len()); + } + + let ln = lhs.validity(); + let rn = rhs.validity(); let lv = lhs.values(); let rv = rhs.values(); - let mut fold = None; - - for i in 0..lhs_type.len() { - assert_eq!(lhs_type[i].data_type(), rhs_type[i].data_type()); - - use arrow::datatypes::PhysicalType as PH; - use arrow::datatypes::PrimitiveType as PR; - - let lv = &lv[i]; - let rv = &rv[i]; - - let new = match lhs_type[i].data_type().to_physical_type() { - PH::Boolean => call_binary!(BooleanArray, lv, rv, $op), - PH::BinaryView => call_binary!(BinaryViewArray, lv, rv, $op), - PH::Utf8View => call_binary!(Utf8ViewArray, lv, rv, $op), - PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Float16) => todo!("Comparison of Struct with Primitive(Float16) are not yet supported"), - PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray, lv, rv, $op), - PH::Primitive(PR::Int256) => todo!("Comparison of Struct with Primitive(Int256) are not yet supported"), - PH::Primitive(PR::DaysMs) => todo!("Comparison of Struct with Primitive(DaysMs) are not yet supported"), - PH::Primitive(PR::MonthDayNano) => todo!("Comparison of Struct with Primitive(MonthDayNano) are not yet supported"), - - #[cfg(feature = "dtype-array")] - PH::FixedSizeList => call_binary!(arrow::array::FixedSizeListArray, lv, rv, $op), - #[cfg(not(feature = "dtype-array"))] - PH::FixedSizeList => todo!("Comparison of Struct with FixedSizeList are not supported without the `dtype-array` feature"), - - PH::Null => call_binary!(NullArray, lv, rv, $op), - PH::Binary => todo!("Comparison of Struct with Binary are not yet supported"), - PH::FixedSizeBinary => todo!("Comparison of Struct with FixedSizeBinary are not yet supported"), - PH::LargeBinary => todo!("Comparison of Struct with LargeBinary are not yet supported"), - PH::Utf8 => todo!("Comparison of Struct with Utf8 are not yet supported"), - PH::LargeUtf8 => todo!("Comparison of Struct with LargeUtf8 are not yet supported"), - PH::List => todo!("Comparison of Struct with List are not yet supported"), - PH::LargeList => todo!("Comparison of Struct with LargeList are not yet supported"), - PH::Struct => call_binary!(StructArray, lv, rv, $op), - PH::Union => todo!("Comparison of Struct with Union are not yet supported"), - PH::Map => todo!("Comparison of Struct with Map are not yet supported"), - PH::Dictionary(_) => todo!("Comparison of Struct with Dictionary are not yet supported"), - }; - - fold = if let Some(fold) = fold { - Some($fold(fold, new)) - } else { - Some(new) - }; - } + let mut bitmap = MutableBitmap::with_capacity(lhs.len()); - fold.unwrap() - }}; -} + for i in 0..lhs.len() { + let mut is_equal = true; -impl TotalOrdKernel for StructArray { - type Scalar = Box; + if !ln.map_or(true, |v| v.get(i).unwrap()) || !rn.map_or(true, |v| v.get(i).unwrap()) { + bitmap.push(true); + continue; + } - fn tot_eq_kernel(&self, other: &Self) -> Bitmap { - use std::ops::BitAnd; - compare!( - self, - other, - TotalOrdKernel::tot_eq_missing_kernel, - |a: Bitmap, b: Bitmap| a.bitand(&b) - ) + for j in 0..lhs.values().len() { + if lv[j].len() != rv[j].len() { + is_equal = false; + break; + } + + let result = array_tot_eq_missing_kernel(lv[j].as_ref(), rv[j].as_ref()); + if result.unset_bits() != 0 { + is_equal = false; + break; + } + } + + bitmap.push(is_equal); + } + + bitmap.freeze() } fn tot_ne_kernel(&self, other: &Self) -> Bitmap { - use std::ops::BitOr; - compare!( - self, - other, - TotalOrdKernel::tot_ne_missing_kernel, - |a: Bitmap, b: Bitmap| a.bitor(&b) - ) - } + let lhs = self; + let rhs = other; - fn tot_lt_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() - } + if lhs.fields() != rhs.fields() { + return Bitmap::new_with_value(true, lhs.len()); + } - fn tot_le_kernel(&self, _other: &Self) -> Bitmap { - unimplemented!() - } + if lhs.values().len() != rhs.values().len() { + return Bitmap::new_with_value(true, lhs.len()); + } - fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() - } + let ln = lhs.validity(); + let rn = rhs.validity(); - fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - todo!() - } + let lv = lhs.values(); + let rv = rhs.values(); - fn tot_lt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() - } + let mut bitmap = MutableBitmap::with_capacity(lhs.len()); + + for i in 0..lhs.len() { + let mut is_equal = true; - fn tot_le_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() + if !ln.map_or(true, |v| v.get(i).unwrap()) || !rn.map_or(true, |v| v.get(i).unwrap()) { + bitmap.push(false); + continue; + } + + for j in 0..lhs.values().len() { + if lv[j].len() != rv[j].len() { + is_equal = false; + break; + } + + let result = array_tot_eq_missing_kernel(lv[j].as_ref(), rv[j].as_ref()); + if result.unset_bits() != 0 { + is_equal = false; + break; + } + } + + bitmap.push(!is_equal); + } + + bitmap.freeze() } - fn tot_gt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() + fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() } - fn tot_ge_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { - unimplemented!() + fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap { + todo!() } } diff --git a/crates/polars-compute/src/comparisons/utf8.rs b/crates/polars-compute/src/comparisons/utf8.rs new file mode 100644 index 000000000000..dbb8aa2cc3fa --- /dev/null +++ b/crates/polars-compute/src/comparisons/utf8.rs @@ -0,0 +1,53 @@ +use arrow::array::Utf8Array; +use arrow::bitmap::Bitmap; +use arrow::types::Offset; + +use super::{TotalEqKernel, TotalOrdKernel}; + +impl TotalEqKernel for Utf8Array { + type Scalar = str; + + fn tot_eq_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_eq_kernel(&other.to_binary()) + } + + fn tot_ne_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_ne_kernel(&other.to_binary()) + } + + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_eq_kernel_broadcast(other.as_bytes()) + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_ne_kernel_broadcast(other.as_bytes()) + } +} + +impl TotalOrdKernel for Utf8Array { + type Scalar = str; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_lt_kernel(&other.to_binary()) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + self.to_binary().tot_le_kernel(&other.to_binary()) + } + + fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_lt_kernel_broadcast(other.as_bytes()) + } + + fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_le_kernel_broadcast(other.as_bytes()) + } + + fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_gt_kernel_broadcast(other.as_bytes()) + } + + fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + self.to_binary().tot_ge_kernel_broadcast(other.as_bytes()) + } +} diff --git a/crates/polars-compute/src/comparisons/view.rs b/crates/polars-compute/src/comparisons/view.rs index 3a822428dc6c..c39187e90c60 100644 --- a/crates/polars-compute/src/comparisons/view.rs +++ b/crates/polars-compute/src/comparisons/view.rs @@ -1,6 +1,7 @@ use arrow::array::{BinaryViewArray, Utf8ViewArray}; use arrow::bitmap::Bitmap; +use super::TotalEqKernel; use crate::comparisons::TotalOrdKernel; // If s fits in 12 bytes, returns the view encoding it would have in a @@ -43,7 +44,7 @@ fn broadcast_inequality( })) } -impl TotalOrdKernel for BinaryViewArray { +impl TotalEqKernel for BinaryViewArray { type Scalar = [u8]; fn tot_eq_kernel(&self, other: &Self) -> Bitmap { @@ -102,6 +103,46 @@ impl TotalOrdKernel for BinaryViewArray { })) } + fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() == val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; + if v_prefix_len != prefix_len { + false + } else { + self.value_unchecked(i) == other + } + })) + } + } + + fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { + if let Some(val) = small_view_encoding(other) { + Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() != val)) + } else { + let slf_views = self.views().as_slice(); + let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); + let prefix_len = ((prefix as u64) << 32) | other.len() as u64; + Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { + let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; + if v_prefix_len != prefix_len { + true + } else { + self.value_unchecked(i) != other + } + })) + } + } +} + +impl TotalOrdKernel for BinaryViewArray { + type Scalar = [u8]; + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { debug_assert!(self.len() == other.len()); @@ -146,42 +187,6 @@ impl TotalOrdKernel for BinaryViewArray { })) } - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if let Some(val) = small_view_encoding(other) { - Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() == val)) - } else { - let slf_views = self.views().as_slice(); - let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); - let prefix_len = ((prefix as u64) << 32) | other.len() as u64; - Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { - let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; - if v_prefix_len != prefix_len { - false - } else { - self.value_unchecked(i) == other - } - })) - } - } - - fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { - if let Some(val) = small_view_encoding(other) { - Bitmap::from_trusted_len_iter(self.views().iter().map(|v| v.as_u128() != val)) - } else { - let slf_views = self.views().as_slice(); - let prefix = u32::from_le_bytes(other[..4].try_into().unwrap()); - let prefix_len = ((prefix as u64) << 32) | other.len() as u64; - Bitmap::from_trusted_len_iter((0..self.len()).map(|i| unsafe { - let v_prefix_len = slf_views.get_unchecked(i).as_u128() as u64; - if v_prefix_len != prefix_len { - true - } else { - self.value_unchecked(i) != other - } - })) - } - } - fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { broadcast_inequality(self, other, |a, b| a < b, |a, b| a < b) } @@ -199,7 +204,7 @@ impl TotalOrdKernel for BinaryViewArray { } } -impl TotalOrdKernel for Utf8ViewArray { +impl TotalEqKernel for Utf8ViewArray { type Scalar = str; fn tot_eq_kernel(&self, other: &Self) -> Bitmap { @@ -210,14 +215,6 @@ impl TotalOrdKernel for Utf8ViewArray { self.to_binview().tot_ne_kernel(&other.to_binview()) } - fn tot_lt_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_lt_kernel(&other.to_binview()) - } - - fn tot_le_kernel(&self, other: &Self) -> Bitmap { - self.to_binview().tot_le_kernel(&other.to_binview()) - } - fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { self.to_binview().tot_eq_kernel_broadcast(other.as_bytes()) } @@ -225,6 +222,18 @@ impl TotalOrdKernel for Utf8ViewArray { fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { self.to_binview().tot_ne_kernel_broadcast(other.as_bytes()) } +} + +impl TotalOrdKernel for Utf8ViewArray { + type Scalar = str; + + fn tot_lt_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_lt_kernel(&other.to_binview()) + } + + fn tot_le_kernel(&self, other: &Self) -> Bitmap { + self.to_binview().tot_le_kernel(&other.to_binview()) + } fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap { self.to_binview().tot_lt_kernel_broadcast(other.as_bytes()) diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 9fae970632f5..acdc2607b87e 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -9,7 +9,7 @@ use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; use arrow::compute; use num_traits::{NumCast, ToPrimitive}; -use polars_compute::comparisons::TotalOrdKernel; +use polars_compute::comparisons::{TotalEqKernel, TotalOrdKernel}; use crate::prelude::*; use crate::series::implementations::null::NullChunked; @@ -18,7 +18,7 @@ use crate::series::IsSorted; impl ChunkCompare<&ChunkedArray> for ChunkedArray where T: PolarsNumericType, - T::Array: TotalOrdKernel, + T::Array: TotalOrdKernel + TotalEqKernel, { type Item = BooleanChunked; diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index 8f3dc22f72de..f47f23780f82 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -65,7 +65,7 @@ impl ChunkCompare for ChunkedArray where T: PolarsNumericType, Rhs: ToPrimitive, - T::Array: TotalOrdKernel, + T::Array: TotalOrdKernel + TotalEqKernel, { type Item = BooleanChunked; fn equal(&self, rhs: Rhs) -> BooleanChunked { diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 66a12202a250..634e28abb6c3 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -909,13 +909,7 @@ def test_complex_types(tmp_path: Path, series: list[Any], dtype: pl.DataType) -> xs = pl.Series(series, dtype=dtype) df = pl.DataFrame({"x": xs}) - tmp_path.mkdir(exist_ok=True) - file_path = tmp_path / "complex-types.parquet" - - df.write_parquet(file_path) - after = pl.read_parquet(file_path) - - assert str(after) == str(df) + test_round_trip(df) @pytest.mark.xfail() @@ -924,27 +918,6 @@ def test_placeholder_zero_array() -> None: pl.Series([[]], dtype=pl.Array(pl.Int8, 0)) -@pytest.mark.xfail() -def test_placeholder_no_array_equals() -> None: - # @TODO: if this does not fail anymore please just call - # `test_round_trip` instead of comparing the strings. - test_round_trip( - pl.DataFrame( - { - "x": pl.Series( - [ - [ - [1, 2], - [3, 4], - ] - ], - dtype=pl.Array(pl.List(pl.Int8), 2), - ) - } - ) - ) - - @pytest.mark.write_disk() def test_parquet_array_statistics(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 9f79991e9c6c..2e1a8ec42c40 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -24,7 +24,7 @@ def test_assert_series_equal_parametric(s: pl.Series) -> None: @given(data=st.data()) def test_assert_series_equal_parametric_array(data: st.DataObject) -> None: - inner = data.draw(dtypes(excluded_dtypes=[pl.Struct, pl.Categorical])) + inner = data.draw(dtypes(excluded_dtypes=[pl.Categorical])) shape = data.draw(st.integers(min_value=1, max_value=3)) dtype = pl.Array(inner, shape=shape) s = data.draw(series(dtype=dtype))