Skip to content

Commit

Permalink
feat: Implement general array equality checks (#17043)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Jun 18, 2024
1 parent 7aa7854 commit 2d3b0c2
Show file tree
Hide file tree
Showing 21 changed files with 825 additions and 562 deletions.
16 changes: 16 additions & 0 deletions crates/polars-arrow/src/bitmap/bitmap_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,22 @@ pub fn intersects_with_mut(lhs: &MutableBitmap, rhs: &MutableBitmap) -> bool {
)
}

/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy }`.
pub fn select_constant(selector: &Bitmap, truthy: &Bitmap, falsy: bool) -> Bitmap {
let falsy_mask: u64 = if falsy {
0xFFFF_FFFF_FFFF_FFFF
} else {
0x0000_0000_0000_0000
};

binary(selector, truthy, |s, t| (s & t) | (!s & falsy_mask))
}

/// Compute `out[i] = if selector[i] { truthy[i] } else { falsy[i] }`.
pub fn select(selector: &Bitmap, truthy: &Bitmap, falsy: &Bitmap) -> Bitmap {
ternary(selector, truthy, falsy, |s, t, f| (s & t) | (!s & f))
}

impl PartialEq for Bitmap {
fn eq(&self, other: &Self) -> bool {
eq(self, other)
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-arrow/src/bitmap/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,24 @@ impl Bitmap {
pub fn num_intersections_with(&self, other: &Self) -> usize {
num_intersections_with(self, other)
}

/// Select between `truthy` and `falsy` based on `self`.
///
/// This essentially performs:
///
/// `out[i] = if self[i] { truthy[i] } else { falsy[i] }`
pub fn select(&self, truthy: &Self, falsy: &Self) -> Self {
super::bitmap_ops::select(self, truthy, falsy)
}

/// Select between `truthy` and constant `falsy` based on `self`.
///
/// This essentially performs:
///
/// `out[i] = if self[i] { truthy[i] } else { falsy }`
pub fn select_constant(&self, truthy: &Self, falsy: bool) -> Self {
super::bitmap_ops::select_constant(self, truthy, falsy)
}
}

impl<P: AsRef<[bool]>> From<P> for Bitmap {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-compute/src/arithmetic/signed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use strength_reduce::*;

use super::PrimitiveArithmeticKernelImpl;
use crate::arity::{prim_binary_values, prim_unary_values};
use crate::comparisons::TotalOrdKernel;
use crate::comparisons::TotalEqKernel;

macro_rules! impl_signed_arith_kernel {
($T:ty, $StrRed:ty) => {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-compute/src/arithmetic/unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use strength_reduce::*;

use super::PrimitiveArithmeticKernelImpl;
use crate::arity::{prim_binary_values, prim_unary_values};
use crate::comparisons::TotalOrdKernel;
use crate::comparisons::TotalEqKernel;

macro_rules! impl_unsigned_arith_kernel {
($T:ty, $StrRed:ty) => {
Expand Down
156 changes: 46 additions & 110 deletions crates/polars-compute/src/comparisons/array.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use arrow::array::{
Array, BinaryViewArray, BooleanArray, FixedSizeListArray, NullArray, PrimitiveArray,
StructArray, Utf8ViewArray,
};
use arrow::array::{Array, FixedSizeListArray};
use arrow::bitmap::utils::count_zeros;
use arrow::bitmap::Bitmap;
use arrow::datatypes::ArrowDataType;

use crate::comparisons::TotalOrdKernel;
use super::TotalEqKernel;
use crate::comparisons::dyn_array::{array_tot_eq_missing_kernel, array_tot_ne_missing_kernel};

/// Condenses a bitmap of n * width elements into one with n elements.
///
Expand All @@ -16,114 +14,68 @@ fn agg_array_bitmap<F>(bm: Bitmap, width: usize, true_zero_count: F) -> Bitmap
where
F: Fn(usize) -> bool,
{
assert!(width > 0 && bm.len() % width == 0);
let (slice, offset, _len) = bm.as_slice();

(0..bm.len() / width)
.map(|i| true_zero_count(count_zeros(slice, offset + i * width, width)))
.collect()
if bm.len() == 1 {
bm
} else {
assert!(width > 0 && bm.len() % width == 0);

let (slice, offset, _len) = bm.as_slice();
(0..bm.len() / width)
.map(|i| true_zero_count(count_zeros(slice, offset + i * width, width)))
.collect()
}
}

macro_rules! call_binary {
($T:ty, $lhs:expr, $rhs:expr, $op:path) => {{
let lhs: &$T = $lhs.as_any().downcast_ref().unwrap();
let rhs: &$T = $rhs.as_any().downcast_ref().unwrap();
$op(lhs, rhs)
}};
}
impl TotalEqKernel for FixedSizeListArray {
type Scalar = Box<dyn Array>;

macro_rules! compare {
($lhs:expr, $rhs:expr, $wrong_width:expr, $op:path) => {{
let lhs = $lhs;
let rhs = $rhs;
assert_eq!(lhs.len(), rhs.len());
let ArrowDataType::FixedSizeList(lhs_type, lhs_width) = lhs.data_type().to_logical_type()
fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
// Nested comparison always done with eq_missing, propagating doesn't
// make any sense.

assert_eq!(self.len(), other.len());
let ArrowDataType::FixedSizeList(self_type, self_width) =
self.data_type().to_logical_type()
else {
panic!("array comparison called with non-array type");
};
let ArrowDataType::FixedSizeList(rhs_type, rhs_width) = rhs.data_type().to_logical_type()
let ArrowDataType::FixedSizeList(other_type, other_width) =
other.data_type().to_logical_type()
else {
panic!("array comparison called with non-array type");
};
assert_eq!(lhs_type.data_type(), rhs_type.data_type());
assert_eq!(self_type.data_type(), other_type.data_type());

if lhs_width != rhs_width {
return Bitmap::new_with_value($wrong_width, lhs.len());
if self_width != other_width {
return Bitmap::new_with_value(false, self.len());
}

use arrow::datatypes::{PhysicalType as PH, PrimitiveType as PR};
let lv = lhs.values();
let rv = rhs.values();
match lhs_type.data_type().to_physical_type() {
PH::Boolean => call_binary!(BooleanArray, lv, rv, $op),
PH::BinaryView => call_binary!(BinaryViewArray, lv, rv, $op),
PH::Utf8View => call_binary!(Utf8ViewArray, lv, rv, $op),
PH::Primitive(PR::Int8) => call_binary!(PrimitiveArray<i8>, lv, rv, $op),
PH::Primitive(PR::Int16) => call_binary!(PrimitiveArray<i16>, lv, rv, $op),
PH::Primitive(PR::Int32) => call_binary!(PrimitiveArray<i32>, lv, rv, $op),
PH::Primitive(PR::Int64) => call_binary!(PrimitiveArray<i64>, lv, rv, $op),
PH::Primitive(PR::Int128) => call_binary!(PrimitiveArray<i128>, lv, rv, $op),
PH::Primitive(PR::UInt8) => call_binary!(PrimitiveArray<u8>, lv, rv, $op),
PH::Primitive(PR::UInt16) => call_binary!(PrimitiveArray<u16>, lv, rv, $op),
PH::Primitive(PR::UInt32) => call_binary!(PrimitiveArray<u32>, lv, rv, $op),
PH::Primitive(PR::UInt64) => call_binary!(PrimitiveArray<u64>, lv, rv, $op),
PH::Primitive(PR::UInt128) => call_binary!(PrimitiveArray<u128>, lv, rv, $op),
PH::Primitive(PR::Float16) => {
todo!("Comparison of Arrays with Primitive(Float16) are not yet supported")
},
PH::Primitive(PR::Float32) => call_binary!(PrimitiveArray<f32>, lv, rv, $op),
PH::Primitive(PR::Float64) => call_binary!(PrimitiveArray<f64>, lv, rv, $op),
PH::Primitive(PR::Int256) => {
todo!("Comparison of Arrays with Primitive(Int256) are not yet supported")
},
PH::Primitive(PR::DaysMs) => {
todo!("Comparison of Arrays with Primitive(DaysMs) are not yet supported")
},
PH::Primitive(PR::MonthDayNano) => {
todo!("Comparison of Arrays with Primitive(MonthDayNano) are not yet supported")
},
PH::FixedSizeList => call_binary!(FixedSizeListArray, lv, rv, $op),
PH::Null => call_binary!(NullArray, lv, rv, $op),
PH::Binary => todo!("Comparison of Arrays with Binary are not yet supported"),
PH::FixedSizeBinary => {
todo!("Comparison of Arrays with FixedSizeBinary are not yet supported")
},
PH::LargeBinary => todo!("Comparison of Arrays with LargeBinary are not yet supported"),
PH::Utf8 => todo!("Comparison of Arrays with Utf8 are not yet supported"),
PH::LargeUtf8 => todo!("Comparison of Arrays with LargeUtf8 are not yet supported"),
PH::List => todo!("Comparison of Arrays with List are not yet supported"),
PH::LargeList => todo!("Comparison of Arrays with LargeList are not yet supported"),
PH::Struct => call_binary!(StructArray, lv, rv, $op),
PH::Union => todo!("Comparison of Arrays with Union are not yet supported"),
PH::Map => todo!("Comparison of Arrays with Map are not yet supported"),
PH::Dictionary(_) => {
todo!("Comparison of Arrays with Dictionary are not yet supported")
},
}
}};
}

impl TotalOrdKernel for FixedSizeListArray {
type Scalar = Box<dyn Array>;
let inner = array_tot_eq_missing_kernel(self.values().as_ref(), other.values().as_ref());

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
// Nested comparison always done with eq_missing, propagating doesn't
// make any sense.
let inner = compare!(self, other, false, TotalOrdKernel::tot_eq_missing_kernel);
agg_array_bitmap(inner, self.size(), |zeroes| zeroes == 0)
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
let inner = compare!(self, other, true, TotalOrdKernel::tot_eq_missing_kernel);
agg_array_bitmap(inner, self.size(), |zeroes| zeroes > 0)
}
assert_eq!(self.len(), other.len());
let ArrowDataType::FixedSizeList(self_type, self_width) =
self.data_type().to_logical_type()
else {
panic!("array comparison called with non-array type");
};
let ArrowDataType::FixedSizeList(other_type, other_width) =
other.data_type().to_logical_type()
else {
panic!("array comparison called with non-array type");
};
assert_eq!(self_type.data_type(), other_type.data_type());

fn tot_lt_kernel(&self, _other: &Self) -> Bitmap {
unimplemented!()
}
if self_width != other_width {
return Bitmap::new_with_value(true, self.len());
}

let inner = array_tot_ne_missing_kernel(self.values().as_ref(), other.values().as_ref());

fn tot_le_kernel(&self, _other: &Self) -> Bitmap {
unimplemented!()
agg_array_bitmap(inner, self.size(), |zeroes| zeroes < self.size())
}

fn tot_eq_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
Expand All @@ -133,20 +85,4 @@ impl TotalOrdKernel for FixedSizeListArray {
fn tot_ne_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
todo!()
}

fn tot_lt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
unimplemented!()
}

fn tot_le_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
unimplemented!()
}

fn tot_gt_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
unimplemented!()
}

fn tot_ge_kernel_broadcast(&self, _other: &Self::Scalar) -> Bitmap {
unimplemented!()
}
}
114 changes: 114 additions & 0 deletions crates/polars-compute/src/comparisons/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use arrow::array::{BinaryArray, FixedSizeBinaryArray};
use arrow::bitmap::Bitmap;
use arrow::types::Offset;
use polars_utils::total_ord::{TotalEq, TotalOrd};

use super::{TotalEqKernel, TotalOrdKernel};

impl<O: Offset> TotalEqKernel for BinaryArray<O> {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_eq(&r))
.collect()
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_ne(&r))
.collect()
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_eq(&other)).collect()
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ne(&other)).collect()
}
}

impl<O: Offset> TotalOrdKernel for BinaryArray<O> {
type Scalar = [u8];

fn tot_lt_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_lt(&r))
.collect()
}

fn tot_le_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());
self.values_iter()
.zip(other.values_iter())
.map(|(l, r)| l.tot_le(&r))
.collect()
}

fn tot_lt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_lt(&other)).collect()
}

fn tot_le_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_le(&other)).collect()
}

fn tot_gt_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_gt(&other)).collect()
}

fn tot_ge_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
self.values_iter().map(|l| l.tot_ge(&other)).collect()
}
}

impl TotalEqKernel for FixedSizeBinaryArray {
type Scalar = [u8];

fn tot_eq_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());

if self.size() != other.size() {
return Bitmap::new_zeroed(self.len());
}

(0..self.len())
.map(|i| self.value(i) == other.value(i))
.collect()
}

fn tot_ne_kernel(&self, other: &Self) -> Bitmap {
assert!(self.len() == other.len());

if self.size() != other.size() {
return Bitmap::new_zeroed(self.len());
}

(0..self.len())
.map(|i| self.value(i) == other.value(i))
.collect()
}

fn tot_eq_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if self.size() != other.len() {
return Bitmap::new_zeroed(self.len());
}

(0..self.len()).map(|i| self.value(i) == other).collect()
}

fn tot_ne_kernel_broadcast(&self, other: &Self::Scalar) -> Bitmap {
if self.size() != other.len() {
return Bitmap::new_zeroed(self.len());
}

(0..self.len()).map(|i| self.value(i) != other).collect()
}
}
Loading

0 comments on commit 2d3b0c2

Please sign in to comment.