Skip to content

Commit

Permalink
feat!: Native selector XOR set operation, guarantee consistent sele…
Browse files Browse the repository at this point in the history
…ctor column-order (#16833)
  • Loading branch information
alexander-beedie authored Jun 10, 2024
1 parent e74a63d commit e56d748
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 85 deletions.
19 changes: 16 additions & 3 deletions crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ impl MetaNameSpace {
}
}

pub fn _selector_and(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
s = s.bitand(s_other);
} else {
s = s.bitand(Selector::Root(Box::new(other)))
}
Ok(Expr::Selector(s))
} else {
polars_bail!(ComputeError: "expected selector, got {:?}", self.0)
}
}

pub fn _selector_sub(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
Expand All @@ -122,12 +135,12 @@ impl MetaNameSpace {
}
}

pub fn _selector_and(self, other: Expr) -> PolarsResult<Expr> {
pub fn _selector_xor(self, other: Expr) -> PolarsResult<Expr> {
if let Expr::Selector(mut s) = self.0 {
if let Expr::Selector(s_other) = other {
s = s.bitand(s_other);
s = s ^ s_other;
} else {
s = s.bitand(Selector::Root(Box::new(other)))
s = s ^ Selector::Root(Box::new(other))
}
Ok(Expr::Selector(s))
} else {
Expand Down
24 changes: 17 additions & 7 deletions crates/polars-plan/src/dsl/selector.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, BitAnd, Sub};
use std::ops::{Add, BitAnd, BitXor, Sub};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -10,6 +10,7 @@ use super::*;
pub enum Selector {
Add(Box<Selector>, Box<Selector>),
Sub(Box<Selector>, Box<Selector>),
ExclusiveOr(Box<Selector>, Box<Selector>),
InterSect(Box<Selector>, Box<Selector>),
Root(Box<Expr>),
}
Expand All @@ -29,20 +30,29 @@ impl Add for Selector {
}
}

impl Sub for Selector {
impl BitAnd for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self::Output {
Selector::Sub(Box::new(self), Box::new(rhs))
fn bitand(self, rhs: Self) -> Self::Output {
Selector::InterSect(Box::new(self), Box::new(rhs))
}
}

impl BitAnd for Selector {
impl BitXor for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn bitand(self, rhs: Self) -> Self::Output {
Selector::InterSect(Box::new(self), Box::new(rhs))
fn bitxor(self, rhs: Self) -> Self::Output {
Selector::ExclusiveOr(Box::new(self), Box::new(rhs))
}
}

impl Sub for Selector {
type Output = Selector;

#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self::Output {
Selector::Sub(Box::new(self), Box::new(rhs))
}
}
66 changes: 42 additions & 24 deletions crates/polars-plan/src/logical_plan/conversion/expr_expansion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
//! this contains code used for rewriting projections, expanding wildcards, regex selection etc.
use std::ops::BitXor;

use super::*;

pub(crate) fn prepare_projection(
Expand Down Expand Up @@ -787,11 +789,27 @@ fn replace_selector_inner(
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;
members.extend(rhs_members)
},
Selector::ExclusiveOr(lhs, rhs) => {
let mut lhs_members = Default::default();
replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

let xor_members = lhs_members.bitxor(&rhs_members);
*members = xor_members;
},
Selector::InterSect(lhs, rhs) => {
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

*members = members.intersection(&rhs_members).cloned().collect()
},
Selector::Sub(lhs, rhs) => {
// fill lhs
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

// subtract rhs
let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

Expand All @@ -801,19 +819,8 @@ fn replace_selector_inner(
new_members.insert(e);
}
}

*members = new_members;
},
Selector::InterSect(lhs, rhs) => {
// fill lhs
replace_selector_inner(*lhs, members, scratch, schema, keys)?;

// fill rhs
let mut rhs_members = Default::default();
replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?;

*members = members.intersection(&rhs_members).cloned().collect()
},
}
Ok(())
}
Expand All @@ -829,17 +836,28 @@ fn replace_selector(expr: Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<
let mut members = PlIndexSet::new();
replace_selector_inner(swapped, &mut members, &mut vec![], schema, keys)?;

Ok(Expr::Columns(
members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
unreachable!()
};
name
})
.collect(),
))
if members.len() <= 1 {
Ok(Expr::Columns(
members
.into_iter()
.map(|e| {
let Expr::Column(name) = e else {
unreachable!()
};
name
})
.collect(),
))
} else {
// Ensure that multiple columns returned from combined/nested selectors remain in schema order
let selected = schema
.iter_fields()
.map(|field| ColumnName::from(field.name().as_ref()))
.filter(|field_name| members.contains(&Expr::Column(field_name.clone())))
.collect();

Ok(Expr::Columns(selected))
}
},
e => Ok(e),
})
Expand Down
12 changes: 11 additions & 1 deletion py-polars/docs/source/reference/selectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ Importing
Set operations
--------------

Selectors support ``set`` operations such as:
Selectors support the following ``set`` operations:

- UNION: ``A | B``
- INTERSECTION: ``A & B``
- DIFFERENCE: ``A - B``
- EXCLUSIVE OR: ``A ^ B``
- COMPLEMENT: ``~A``

Note that both individual selector results and selector set operations will always return
matching columns in the same order as the underlying frame schema.

Examples
========
Expand Down Expand Up @@ -88,6 +91,13 @@ Examples
"Lmn": pl.Duration,
}
# Select the EXCLUSIVE OR of numeric columns and columns that contain an "e"
assert df.select(cs.contains("e") ^ cs.numeric()).schema == {
"abc": UInt16,
"bbb": UInt32,
"eee": Boolean,
}
# Select the COMPLEMENT of all columns of dtypes Duration and Time
assert df.select(~cs.by_dtype([pl.Duration, pl.Time])).schema == {
"abc": pl.UInt16,
Expand Down
14 changes: 9 additions & 5 deletions py-polars/polars/expr/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,20 @@ def _as_selector(self) -> Expr:
return wrap_expr(self._pyexpr._meta_as_selector())

def _selector_add(self, other: Expr) -> Expr:
"""Add selectors."""
"""Add ('+') selectors."""
return wrap_expr(self._pyexpr._meta_selector_add(other._pyexpr))

def _selector_and(self, other: Expr) -> Expr:
"""And ('&') selectors."""
return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr))

def _selector_sub(self, other: Expr) -> Expr:
"""Subtract selectors."""
"""Subtract ('-') selectors."""
return wrap_expr(self._pyexpr._meta_selector_sub(other._pyexpr))

def _selector_and(self, other: Expr) -> Expr:
"""& selectors."""
return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr))
def _selector_xor(self, other: Expr) -> Expr:
"""Xor ('^') selectors."""
return wrap_expr(self._pyexpr._meta_selector_xor(other._pyexpr))

@overload
def serialize(self, file: None = ...) -> str: ...
Expand Down
35 changes: 29 additions & 6 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,15 +308,15 @@ def __repr__(self) -> str:
elif hasattr(self, "_repr_override"):
return self._repr_override
else:
selector_name, params = self._attrs["name"], self._attrs["params"]
set_ops = {"and": "&", "or": "|", "sub": "-"}
selector_name, params = self._attrs["name"], self._attrs["params"] or {}
set_ops = {"and": "&", "or": "|", "sub": "-", "xor": "^"}
if selector_name in set_ops:
op = set_ops[selector_name]
return "({})".format(f" {op} ".join(repr(p) for p in params.values()))
else:
str_params = ", ".join(
(repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}")
for k, v in (params or {}).items()
for k, v in params.items()
).rstrip(",")
return f"cs.{selector_name}({str_params})"

Expand Down Expand Up @@ -381,6 +381,24 @@ def __or__(self, other: Any) -> SelectorType | Expr:
else:
return self.as_expr().__or__(other)

@overload # type: ignore[override]
def __xor__(self, other: SelectorType) -> SelectorType: ...

@overload
def __xor__(self, other: Any) -> Expr: ...

def __xor__(self, other: Any) -> SelectorType | Expr:
if is_column(other):
other = by_name(other.meta.output_name())
if is_selector(other):
return _selector_proxy_(
self.meta._as_selector().meta._selector_xor(other),
parameters={"self": self, "other": other},
name="xor",
)
else:
return self.as_expr().__or__(other)

def __rand__(self, other: Any) -> Expr: # type: ignore[override]
if is_column(other):
colname = other.meta.output_name()
Expand All @@ -396,6 +414,11 @@ def __ror__(self, other: Any) -> Expr: # type: ignore[override]
other = by_name(other.meta.output_name())
return self.as_expr().__ror__(other)

def __rxor__(self, other: Any) -> Expr: # type: ignore[override]
if is_column(other):
other = by_name(other.meta.output_name())
return self.as_expr().__rxor__(other)

def as_expr(self) -> Expr:
"""
Materialize the `selector` as a normal expression.
Expand Down Expand Up @@ -1149,7 +1172,7 @@ def categorical() -> SelectorType:
return _selector_proxy_(F.col(Categorical), name="categorical")


def contains(substring: str | Collection[str]) -> SelectorType:
def contains(*substring: str) -> SelectorType:
"""
Select columns whose names contain the given literal substring(s).
Expand Down Expand Up @@ -1191,7 +1214,7 @@ def contains(substring: str | Collection[str]) -> SelectorType:
Select columns that contain the substring 'ba' or the letter 'z':
>>> df.select(cs.contains(("ba", "z")))
>>> df.select(cs.contains("ba", "z"))
shape: (2, 3)
┌─────┬─────┬───────┐
│ bar ┆ baz ┆ zap │
Expand Down Expand Up @@ -1221,7 +1244,7 @@ def contains(substring: str | Collection[str]) -> SelectorType:
return _selector_proxy_(
F.col(raw_params),
name="contains",
parameters={"substring": escaped_substring},
parameters={"*substring": escaped_substring},
)


Expand Down
14 changes: 12 additions & 2 deletions py-polars/src/expr/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ impl PyExpr {
Ok(out.into())
}

fn _meta_selector_and(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
.clone()
.meta()
._selector_and(other.inner)
.map_err(PyPolarsErr::from)?;
Ok(out.into())
}

fn _meta_selector_sub(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
Expand All @@ -81,12 +91,12 @@ impl PyExpr {
Ok(out.into())
}

fn _meta_selector_and(&self, other: PyExpr) -> PyResult<PyExpr> {
fn _meta_selector_xor(&self, other: PyExpr) -> PyResult<PyExpr> {
let out = self
.inner
.clone()
.meta()
._selector_and(other.inner)
._selector_xor(other.inner)
.map_err(PyPolarsErr::from)?;
Ok(out.into())
}
Expand Down
Loading

0 comments on commit e56d748

Please sign in to comment.