diff --git a/polars/polars-lazy/polars-plan/src/dsl/functions.rs b/polars/polars-lazy/polars-plan/src/dsl/functions.rs index e7075298f82e..5fae923aed18 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/functions.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/functions.rs @@ -899,51 +899,26 @@ pub fn sum_exprs>(exprs: E) -> Expr { /// Get the the maximum value per row pub fn max_exprs>(exprs: E) -> Expr { let exprs = exprs.as_ref().to_vec(); - max_exprs_impl(exprs) -} - -fn max_exprs_impl(mut exprs: Vec) -> Expr { - if exprs.len() == 1 { - return std::mem::take(&mut exprs[0]); + if exprs.is_empty() { + return Expr::Columns(Vec::new()); } - - let first = std::mem::take(&mut exprs[0]); - first - .map_many( - |s| { - let s = s.to_vec(); - let df = DataFrame::new_no_checks(s); - df.hmax().map(|s| s.unwrap()) - }, - &exprs[1..], - GetOutput::super_type(), - ) - .alias("max") + let func = |s1, s2| { + let df = DataFrame::new_no_checks(vec![s1, s2]); + df.hmax().map(|s| s.unwrap()) + }; + reduce_exprs(func, exprs).alias("max") } -/// Get the the minimum value per row pub fn min_exprs>(exprs: E) -> Expr { let exprs = exprs.as_ref().to_vec(); - min_exprs_impl(exprs) -} - -fn min_exprs_impl(mut exprs: Vec) -> Expr { - if exprs.len() == 1 { - return std::mem::take(&mut exprs[0]); + if exprs.is_empty() { + return Expr::Columns(Vec::new()); } - - let first = std::mem::take(&mut exprs[0]); - first - .map_many( - |s| { - let s = s.to_vec(); - let df = DataFrame::new_no_checks(s); - df.hmin().map(|s| s.unwrap()) - }, - &exprs[1..], - GetOutput::super_type(), - ) - .alias("min") + let func = |s1, s2| { + let df = DataFrame::new_no_checks(vec![s1, s2]); + df.hmin().map(|s| s.unwrap()) + }; + reduce_exprs(func, exprs).alias("min") } /// Evaluate all the expressions with a bitwise or diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 0ecb2527818c..5bc98fab0439 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1283,6 +1283,23 @@ def test_max_min_multiple_columns(fruits_cars: pl.DataFrame) -> None: assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1])) +def test_max_min_wildcard_columns(fruits_cars: pl.DataFrame) -> None: + res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.min(["*"])) + assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1])) + res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.min([pl.all()])) + assert res.to_series(0).series_equal(pl.Series("min", [1, 2, 3, 2, 1])) + + res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.max(["*"])) + assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5])) + res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select(pl.max([pl.all()])) + assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5])) + + res = fruits_cars.select([pl.col(pl.datatypes.Int64)]).select( + pl.max([pl.all(), "A", "*"]) + ) + assert res.to_series(0).series_equal(pl.Series("max", [5, 4, 3, 4, 5])) + + def test_head_tail(fruits_cars: pl.DataFrame) -> None: res_expr = fruits_cars.select([pl.head("A", 2)]) res_series = pl.head(fruits_cars["A"], 2)