Skip to content

Commit

Permalink
feat!: Do not propagate nulls in clip bounds (#14413)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jun 4, 2024
1 parent 62a1577 commit 3c9f984
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 49 deletions.
90 changes: 58 additions & 32 deletions crates/polars-ops/src/series/ops/clip.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use num_traits::{clamp, clamp_max, clamp_min};
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise};
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
Expand All @@ -25,7 +24,7 @@ pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult<Series> {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();
let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();
let out = clip_helper(ca, min, max).into_series();
let out = clip_helper_both_bounds(ca, min, max).into_series();
if original_type.is_logical() {
out.cast(original_type)
} else {
Expand Down Expand Up @@ -54,7 +53,7 @@ pub fn clip_max(s: &Series, max: &Series) -> PolarsResult<Series> {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();
let out = clip_min_max_helper(ca, max, clamp_max).into_series();
let out = clip_helper_single_bound(ca, max, num_traits::clamp_max).into_series();
if original_type.is_logical() {
out.cast(original_type)
} else {
Expand Down Expand Up @@ -83,7 +82,7 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult<Series> {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();
let out = clip_min_max_helper(ca, min, clamp_min).into_series();
let out = clip_helper_single_bound(ca, min, num_traits::clamp_min).into_series();
if original_type.is_logical() {
out.cast(original_type)
} else {
Expand All @@ -95,7 +94,7 @@ pub fn clip_min(s: &Series, min: &Series) -> PolarsResult<Series> {
}
}

fn clip_helper<T>(
fn clip_helper_both_bounds<T>(
ca: &ChunkedArray<T>,
min: &ChunkedArray<T>,
max: &ChunkedArray<T>,
Expand All @@ -106,35 +105,24 @@ where
{
match (min.len(), max.len()) {
(1, 1) => match (min.get(0), max.get(0)) {
(Some(min), Some(max)) => {
ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max)))
},
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
(Some(min), Some(max)) => clip_unary(ca, |v| num_traits::clamp(v, min, max)),
(Some(min), None) => clip_unary(ca, |v| num_traits::clamp_min(v, min)),
(None, Some(max)) => clip_unary(ca, |v| num_traits::clamp_max(v, max)),
(None, None) => ca.clone(),
},
(1, _) => match min.get(0) {
Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) {
(Some(s), Some(max)) => Some(clamp(s, min, max)),
_ => None,
}),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
Some(min) => clip_binary(ca, max, |v, b| num_traits::clamp(v, min, b)),
None => clip_binary(ca, max, num_traits::clamp_max),
},
(_, 1) => match max.get(0) {
Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) {
(Some(s), Some(min)) => Some(clamp(s, min, max)),
_ => None,
}),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
Some(max) => clip_binary(ca, min, |v, b| num_traits::clamp(v, b, max)),
None => clip_binary(ca, min, num_traits::clamp_min),
},
_ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| {
match (opt_s, opt_min, opt_max) {
(Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)),
_ => None,
}
}),
_ => clip_ternary(ca, min, max),
}
}

fn clip_min_max_helper<T, F>(
fn clip_helper_single_bound<T, F>(
ca: &ChunkedArray<T>,
bound: &ChunkedArray<T>,
op: F,
Expand All @@ -146,12 +134,50 @@ where
{
match bound.len() {
1 => match bound.get(0) {
Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
Some(bound) => clip_unary(ca, |v| op(v, bound)),
None => ca.clone(),
},
_ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) {
(Some(s), Some(bound)) => Some(op(s, bound)),
_ => None,
}),
_ => clip_binary(ca, bound, op),
}
}

fn clip_unary<T, F>(ca: &ChunkedArray<T>, op: F) -> ChunkedArray<T>
where
T: PolarsNumericType,
F: Fn(T::Native) -> T::Native + Copy,
{
ca.apply_generic(|v| v.map(op))
}

fn clip_binary<T, F>(ca: &ChunkedArray<T>, bound: &ChunkedArray<T>, op: F) -> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: PartialOrd,
F: Fn(T::Native, T::Native) -> T::Native,
{
binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) {
(Some(s), Some(bound)) => Some(op(s, bound)),
(Some(s), None) => Some(s),
(None, _) => None,
})
}

fn clip_ternary<T>(
ca: &ChunkedArray<T>,
min: &ChunkedArray<T>,
max: &ChunkedArray<T>,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: PartialOrd,
{
ternary_elementwise(ca, min, max, |opt_v, opt_min, opt_max| {
match (opt_v, opt_min, opt_max) {
(Some(v), Some(min), Some(max)) => Some(num_traits::clamp(v, min, max)),
(Some(v), Some(min), None) => Some(num_traits::clamp_min(v, min)),
(Some(v), None, Some(max)) => Some(num_traits::clamp_max(v, max)),
(Some(v), None, None) => Some(v),
(None, _, _) => None,
}
})
}
34 changes: 17 additions & 17 deletions py-polars/tests/unit/operations/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ def clip_exprs() -> list[pl.Expr]:
def test_clip_int(clip_exprs: list[pl.Expr]) -> None:
lf = pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5],
"min": [0, -1, 4, None, 4],
"max": [2, 1, 8, 5, None],
"a": [1, 2, 3, 4, 5, None],
"min": [0, -1, 4, None, 4, -10],
"max": [2, 1, 8, 5, None, 10],
}
)
result = lf.select(clip_exprs)
expected = pl.LazyFrame(
{
"clip": [1, 1, 4, None, None],
"clip_min": [1, 2, 4, None, 5],
"clip_max": [1, 1, 3, 4, None],
"clip": [1, 1, 4, 4, 5, None],
"clip_min": [1, 2, 4, 4, 5, None],
"clip_max": [1, 1, 3, 4, 5, None],
}
)
assert_frame_equal(result, expected)
Expand All @@ -39,17 +39,17 @@ def test_clip_int(clip_exprs: list[pl.Expr]) -> None:
def test_clip_float(clip_exprs: list[pl.Expr]) -> None:
lf = pl.LazyFrame(
{
"a": [1.0, 2.0, 3.0, 4.0, 5.0],
"min": [0.0, -1.0, 4.0, None, 4.0],
"max": [2.0, 1.0, 8.0, 5.0, None],
"a": [1.0, 2.0, 3.0, 4.0, 5.0, None],
"min": [0.0, -1.0, 4.0, None, 4.0, None],
"max": [2.0, 1.0, 8.0, 5.0, None, None],
}
)
result = lf.select(clip_exprs)
expected = pl.LazyFrame(
{
"clip": [1.0, 1.0, 4.0, None, None],
"clip_min": [1.0, 2.0, 4.0, None, 5.0],
"clip_max": [1.0, 1.0, 3.0, 4.0, None],
"clip": [1.0, 1.0, 4.0, 4.0, 5.0, None],
"clip_min": [1.0, 2.0, 4.0, 4.0, 5.0, None],
"clip_max": [1.0, 1.0, 3.0, 4.0, 5.0, None],
}
)
assert_frame_equal(result, expected)
Expand Down Expand Up @@ -92,15 +92,15 @@ def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None:
datetime(1996, 6, 5),
datetime(2023, 9, 20, 18, 30, 6),
None,
None,
None,
datetime(1993, 3, 13),
datetime(2000, 1, 10),
],
"clip_min": [
datetime(1995, 6, 5, 10, 30),
datetime(1996, 6, 5),
datetime(2023, 10, 20, 18, 30, 6),
None,
None,
datetime(2023, 9, 24),
datetime(2000, 1, 10),
],
"clip_max": [
Expand All @@ -109,7 +109,7 @@ def test_clip_datetime(clip_exprs: list[pl.Expr]) -> None:
datetime(2023, 9, 20, 18, 30, 6),
None,
datetime(1993, 3, 13),
None,
datetime(2000, 1, 10),
],
}
)
Expand All @@ -127,7 +127,7 @@ def test_clip_non_numeric_dtype_fails() -> None:
def test_clip_string_input() -> None:
df = pl.DataFrame({"a": [0, 1, 2], "min": [1, None, 1]})
result = df.select(pl.col("a").clip("min"))
expected = pl.DataFrame({"a": [1, None, 2]})
expected = pl.DataFrame({"a": [1, 1, 2]})
assert_frame_equal(result, expected)


Expand Down

0 comments on commit 3c9f984

Please sign in to comment.