Skip to content

Commit

Permalink
feat: Expose returns_scalar to map_elements (#17613)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 13, 2024
1 parent f304a0c commit 6816707
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
6 changes: 5 additions & 1 deletion crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,19 @@ impl AExpr {
options,
..
} => {
*nested = nested.saturating_sub(options.returns_scalar as _);
let tmp = function.get_output();
let output_type = tmp.as_ref().unwrap_or(output_type);
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
output_type.get_field(schema, Context::Default, &fields)
},
Function {
function, input, ..
function,
input,
options,
} => {
*nested = nested.saturating_sub(options.returns_scalar as _);
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
function.get_field(schema, Context::Default, &fields)
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Object,
String,
Time,
Unknown,
)
from polars.datatypes.convert import is_polars_dtype

Expand Down Expand Up @@ -93,6 +94,9 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData
return Null()
elif input is list or input is tuple:
return List
# this is required as pass through. Don't remove
elif input == Unknown:
return Unknown

elif hasattr(input, "__origin__") and hasattr(input, "__args__"):
return _parse_generic_into_dtype(input)
Expand Down
22 changes: 19 additions & 3 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4413,6 +4413,7 @@ def map_elements(
skip_nulls: bool = True,
pass_name: bool = False,
strategy: MapElementsStrategy = "thread_local",
returns_scalar: bool = False,
) -> Expr:
"""
Map a custom/user-defined function (UDF) to each element of a column.
Expand Down Expand Up @@ -4459,6 +4460,10 @@ def map_elements(
Don't map the function over values that contain nulls (this is faster).
pass_name
Pass the Series name to the custom function (this is more expensive).
returns_scalar
If the function passed does a reduction
(e.g. sum, min, etc), Polars must be informed of this otherwise
the schema might be incorrect.
strategy : {'thread_local', 'threading'}
The threading strategy to use.
Expand Down Expand Up @@ -4637,14 +4642,22 @@ def wrap_f(x: Series) -> Series: # pragma: no cover
)

if strategy == "thread_local":
return self.map_batches(wrap_f, agg_list=True, return_dtype=return_dtype)
return self.map_batches(
wrap_f,
agg_list=True,
return_dtype=return_dtype,
returns_scalar=returns_scalar,
)
elif strategy == "threading":

def wrap_threading(x: Series) -> Series:
def get_lazy_promise(df: DataFrame) -> LazyFrame:
return df.lazy().select(
F.col("x").map_batches(
wrap_f, agg_list=True, return_dtype=return_dtype
wrap_f,
agg_list=True,
return_dtype=return_dtype,
returns_scalar=returns_scalar,
)
)

Expand Down Expand Up @@ -4678,7 +4691,10 @@ def get_lazy_promise(df: DataFrame) -> LazyFrame:
return F.concat(out, rechunk=False)

return self.map_batches(
wrap_threading, agg_list=True, return_dtype=return_dtype
wrap_threading,
agg_list=True,
return_dtype=return_dtype,
returns_scalar=returns_scalar,
)
else:
msg = f"strategy {strategy!r} is not supported"
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,25 @@ def test_schema_picklable() -> None:
s2 = pickle.loads(pickled)

assert s == s2


def test_schema_in_map_elements_returns_scalar() -> None:
schema = pl.Schema([("portfolio", pl.String()), ("irr", pl.Float64())])

ldf = pl.LazyFrame(
{
"portfolio": ["A", "A", "B", "B"],
"amounts": [100.0, -110.0] * 2,
}
)

q = ldf.group_by("portfolio").agg(
pl.col("amounts")
.map_elements(
lambda x: float(x.sum()), return_dtype=pl.Float64, returns_scalar=True
)
.alias("irr")
)

assert (q.collect_schema()) == schema
assert q.collect().schema == schema

0 comments on commit 6816707

Please sign in to comment.