Skip to content

Commit

Permalink
refactor: Divide ChunkCompare into Eq and Ineq variants
Browse files Browse the repository at this point in the history
Divide the `ChunkCompare` trait into two traits `ChunkCompareEq` and
`ChunkCompareIneq`, which allows us to statistically verify that there are no
calls to the inequality methods when these are not available (e.g. for `List`,
`Array` and `Struct`). This makes error handling a lot better as well.

For example, the following was a panic exception before.

```python
import polars as pl

a = pl.Series('a', [[1]], pl.Array(pl.Int8, 1))
b = pl.Series('b', [[1]], pl.Array(pl.Int8, 1))

c = a < b
```

Now, it returns:

```
polars.exceptions.InvalidOperationError: cannot perform '<' comparison between series 'a' of dtype: array[i8, 1] and series 'b' of dtype: array[i8, 1]
```

Fixes pola-rs#18938.
  • Loading branch information
coastalwhite committed Sep 27, 2024
1 parent d097d3c commit 7bab314
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 158 deletions.
18 changes: 15 additions & 3 deletions crates/polars-core/src/chunked_array/comparison/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ where
}
}

impl ChunkCompare<&CategoricalChunked> for CategoricalChunked {
impl ChunkCompareEq<&CategoricalChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &CategoricalChunked) -> Self::Item {
Expand Down Expand Up @@ -134,6 +134,10 @@ impl ChunkCompare<&CategoricalChunked> for CategoricalChunked {
UInt32Chunked::not_equal_missing,
)
}
}

impl ChunkCompareIneq<&CategoricalChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &CategoricalChunked) -> Self::Item {
cat_compare_helper(self, rhs, UInt32Chunked::gt, |l, r| l > r)
Expand Down Expand Up @@ -217,7 +221,7 @@ where
}
}

impl ChunkCompare<&StringChunked> for CategoricalChunked {
impl ChunkCompareEq<&StringChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &StringChunked) -> Self::Item {
Expand Down Expand Up @@ -265,6 +269,10 @@ impl ChunkCompare<&StringChunked> for CategoricalChunked {
StringChunked::not_equal_missing,
)
}
}

impl ChunkCompareIneq<&StringChunked> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &StringChunked) -> Self::Item {
cat_str_compare_helper(
Expand Down Expand Up @@ -376,7 +384,7 @@ where
}
}

impl ChunkCompare<&str> for CategoricalChunked {
impl ChunkCompareEq<&str> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn equal(&self, rhs: &str) -> Self::Item {
Expand Down Expand Up @@ -414,6 +422,10 @@ impl ChunkCompare<&str> for CategoricalChunked {
UInt32Chunked::equal_missing,
)
}
}

impl ChunkCompareIneq<&str> for CategoricalChunked {
type Item = PolarsResult<BooleanChunked>;

fn gt(&self, rhs: &str) -> Self::Item {
cat_single_str_compare_helper(
Expand Down
75 changes: 33 additions & 42 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::series::implementations::null::NullChunked;
use crate::series::IsSorted;
use crate::utils::align_chunks_binary;

impl<T> ChunkCompare<&ChunkedArray<T>> for ChunkedArray<T>
impl<T> ChunkCompareEq<&ChunkedArray<T>> for ChunkedArray<T>
where
T: PolarsNumericType,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
Expand Down Expand Up @@ -126,6 +126,14 @@ where
),
}
}
}

impl<T> ChunkCompareIneq<&ChunkedArray<T>> for ChunkedArray<T>
where
T: PolarsNumericType,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn lt(&self, rhs: &ChunkedArray<T>) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -188,7 +196,7 @@ where
}
}

impl ChunkCompare<&NullChunked> for NullChunked {
impl ChunkCompareEq<&NullChunked> for NullChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &NullChunked) -> Self::Item {
Expand All @@ -206,6 +214,10 @@ impl ChunkCompare<&NullChunked> for NullChunked {
fn not_equal_missing(&self, rhs: &NullChunked) -> Self::Item {
BooleanChunked::full(self.name().clone(), false, get_broadcast_length(self, rhs))
}
}

impl ChunkCompareIneq<&NullChunked> for NullChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &NullChunked) -> Self::Item {
BooleanChunked::full_null(self.name().clone(), get_broadcast_length(self, rhs))
Expand Down Expand Up @@ -234,7 +246,7 @@ fn get_broadcast_length(lhs: &NullChunked, rhs: &NullChunked) -> usize {
}
}

impl ChunkCompare<&BooleanChunked> for BooleanChunked {
impl ChunkCompareEq<&BooleanChunked> for BooleanChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &BooleanChunked) -> BooleanChunked {
Expand Down Expand Up @@ -348,6 +360,10 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked {
),
}
}
}

impl ChunkCompareIneq<&BooleanChunked> for BooleanChunked {
type Item = BooleanChunked;

fn lt(&self, rhs: &BooleanChunked) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -410,7 +426,7 @@ impl ChunkCompare<&BooleanChunked> for BooleanChunked {
}
}

impl ChunkCompare<&StringChunked> for StringChunked {
impl ChunkCompareEq<&StringChunked> for StringChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &StringChunked) -> BooleanChunked {
Expand All @@ -424,9 +440,14 @@ impl ChunkCompare<&StringChunked> for StringChunked {
fn not_equal(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().not_equal(&rhs.as_binary())
}

fn not_equal_missing(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().not_equal_missing(&rhs.as_binary())
}
}

impl ChunkCompareIneq<&StringChunked> for StringChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &StringChunked) -> BooleanChunked {
self.as_binary().gt(&rhs.as_binary())
Expand All @@ -445,7 +466,7 @@ impl ChunkCompare<&StringChunked> for StringChunked {
}
}

impl ChunkCompare<&BinaryChunked> for BinaryChunked {
impl ChunkCompareEq<&BinaryChunked> for BinaryChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &BinaryChunked) -> BooleanChunked {
Expand Down Expand Up @@ -551,6 +572,10 @@ impl ChunkCompare<&BinaryChunked> for BinaryChunked {
),
}
}
}

impl ChunkCompareIneq<&BinaryChunked> for BinaryChunked {
type Item = BooleanChunked;

fn lt(&self, rhs: &BinaryChunked) -> BooleanChunked {
// Broadcast.
Expand Down Expand Up @@ -644,7 +669,7 @@ where
}
}

impl ChunkCompare<&ListChunked> for ListChunked {
impl ChunkCompareEq<&ListChunked> for ListChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
let _series_equals = |lhs: Option<&Series>, rhs: Option<&Series>| match (lhs, rhs) {
Expand Down Expand Up @@ -684,23 +709,6 @@ impl ChunkCompare<&ListChunked> for ListChunked {

_list_comparison_helper(self, rhs, _series_not_equal_missing)
}

// The following are not implemented because gt, lt comparison of series don't make sense.
fn gt(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn gt_eq(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn lt(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}

fn lt_eq(&self, _rhs: &ListChunked) -> BooleanChunked {
unimplemented!()
}
}

#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -741,7 +749,7 @@ where
}

#[cfg(feature = "dtype-struct")]
impl ChunkCompare<&StructChunked> for StructChunked {
impl ChunkCompareEq<&StructChunked> for StructChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &StructChunked) -> BooleanChunked {
struct_helper(
Expand Down Expand Up @@ -785,7 +793,7 @@ impl ChunkCompare<&StructChunked> for StructChunked {
}

#[cfg(feature = "dtype-array")]
impl ChunkCompare<&ArrayChunked> for ArrayChunked {
impl ChunkCompareEq<&ArrayChunked> for ArrayChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ArrayChunked) -> BooleanChunked {
if self.width() != rhs.width() {
Expand Down Expand Up @@ -834,23 +842,6 @@ impl ChunkCompare<&ArrayChunked> for ArrayChunked {
PlSmallStr::EMPTY,
)
}

// following are not implemented because gt, lt comparison of series don't make sense
fn gt(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn gt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn lt(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}

fn lt_eq(&self, _rhs: &ArrayChunked) -> BooleanChunked {
unimplemented!()
}
}

impl Not for &BooleanChunked {
Expand Down
24 changes: 21 additions & 3 deletions crates/polars-core/src/chunked_array/comparison/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ where
ca
}

impl<T, Rhs> ChunkCompare<Rhs> for ChunkedArray<T>
impl<T, Rhs> ChunkCompareEq<Rhs> for ChunkedArray<T>
where
T: PolarsNumericType,
Rhs: ToPrimitive,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn equal(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_ge(&rhs));
Expand Down Expand Up @@ -111,6 +112,15 @@ where
})
}
}
}

impl<T, Rhs> ChunkCompareIneq<Rhs> for ChunkedArray<T>
where
T: PolarsNumericType,
Rhs: ToPrimitive,
T::Array: TotalOrdKernel<Scalar = T::Native> + TotalEqKernel<Scalar = T::Native>,
{
type Item = BooleanChunked;

fn gt(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
Expand Down Expand Up @@ -157,7 +167,7 @@ where
}
}

impl ChunkCompare<&[u8]> for BinaryChunked {
impl ChunkCompareEq<&[u8]> for BinaryChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &[u8]) -> BooleanChunked {
Expand All @@ -175,6 +185,10 @@ impl ChunkCompare<&[u8]> for BinaryChunked {
fn not_equal_missing(&self, rhs: &[u8]) -> BooleanChunked {
arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
}
}

impl ChunkCompareIneq<&[u8]> for BinaryChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &[u8]) -> BooleanChunked {
arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
Expand All @@ -193,7 +207,7 @@ impl ChunkCompare<&[u8]> for BinaryChunked {
}
}

impl ChunkCompare<&str> for StringChunked {
impl ChunkCompareEq<&str> for StringChunked {
type Item = BooleanChunked;

fn equal(&self, rhs: &str) -> BooleanChunked {
Expand All @@ -211,6 +225,10 @@ impl ChunkCompare<&str> for StringChunked {
fn not_equal_missing(&self, rhs: &str) -> BooleanChunked {
arity::unary_mut_with_options(self, |arr| arr.tot_ne_missing_kernel_broadcast(rhs).into())
}
}

impl ChunkCompareIneq<&str> for StringChunked {
type Item = BooleanChunked;

fn gt(&self, rhs: &str) -> BooleanChunked {
arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(rhs).into())
Expand Down
29 changes: 11 additions & 18 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ pub(crate) mod unique;
#[cfg(feature = "zip_with")]
pub mod zip;

use polars_utils::no_call_const;
#[cfg(feature = "serde-lazy")]
use serde::{Deserialize, Serialize};
pub use sort::options::*;
Expand Down Expand Up @@ -312,7 +311,7 @@ pub trait ChunkVar {
/// df.filter(&mask)
/// }
/// ```
pub trait ChunkCompare<Rhs> {
pub trait ChunkCompareEq<Rhs> {
type Item;

/// Check for equality.
Expand All @@ -326,30 +325,24 @@ pub trait ChunkCompare<Rhs> {

/// Check for inequality where `None == None`.
fn not_equal_missing(&self, rhs: Rhs) -> Self::Item;
}

/// Compare [`Series`] and [`ChunkedArray`]'s using inequality operators (`<`, `>=`, etc.) and get
/// a `boolean` mask that can be used to filter rows.
pub trait ChunkCompareIneq<Rhs> {
type Item;

/// Greater than comparison.
#[allow(unused_variables)]
fn gt(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn gt(&self, rhs: Rhs) -> Self::Item;

/// Greater than or equal comparison.
#[allow(unused_variables)]
fn gt_eq(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn gt_eq(&self, rhs: Rhs) -> Self::Item;

/// Less than comparison.
#[allow(unused_variables)]
fn lt(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn lt(&self, rhs: Rhs) -> Self::Item;

/// Less than or equal comparison
#[allow(unused_variables)]
fn lt_eq(&self, rhs: Rhs) -> Self::Item {
no_call_const!()
}
fn lt_eq(&self, rhs: Rhs) -> Self::Item;
}

/// Get unique values in a `ChunkedArray`
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ where
T: PolarsNumericType,
T::Native: TotalHash + TotalEq + ToTotalOrd,
<T::Native as ToTotalOrd>::TotalOrdItem: Hash + Eq + Ord,
ChunkedArray<T>: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray<T>, Item = BooleanChunked>,
ChunkedArray<T>:
IntoSeries + for<'a> ChunkCompareEq<&'a ChunkedArray<T>, Item = BooleanChunked>,
{
fn unique(&self) -> PolarsResult<Self> {
// prevent stackoverflow repeated sorted.unique call
Expand Down
Loading

0 comments on commit 7bab314

Please sign in to comment.