Skip to content

Commit

Permalink
feat(python): emit suggestion for how to replace map_elements sigmoid…
Browse files Browse the repository at this point in the history
… function with expressions (#13347)
  • Loading branch information
MarcoGorelli authored Jan 2, 2024
1 parent ffa3fc6 commit 2f0a4da
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
27 changes: 24 additions & 3 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class OpNames:
"upper": "str.to_uppercase",
}

FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [
_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [
# lambda x: module.func(CONSTANT)
{
"argument_1_opname": [{"LOAD_CONST"}],
Expand Down Expand Up @@ -192,6 +192,15 @@ class OpNames:
"function_name": [{"strptime"}],
},
]
# In addition to `lambda x: func(x)`, also support cases when a unary operation
# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`.
_FUNCTION_KINDS = [
# Dict entry 1 has incompatible type "str": "object";
# expected "str": "list[AbstractSet[str]]"
{**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item]
for kind in _FUNCTION_KINDS
for unary in [[set(OpNames.UNARY)], []]
]


def _get_all_caller_variables() -> dict[str, Any]:
Expand Down Expand Up @@ -518,6 +527,9 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
if e1.startswith("pl.col("):
call = "" if op.endswith(")") else "()"
return f"{e1}.{op}{call}"
if e1[0] in OpNames.UNARY_VALUES and e1[1:].startswith("pl.col("):
call = "" if op.endswith(")") else "()"
return f"({e1}).{op}{call}"

# support use of consts as numpy/builtin params, eg:
# "np.sin(3) + np.cos(x)", or "len('const_string') + len(x)"
Expand Down Expand Up @@ -722,12 +734,13 @@ def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace function calls with a synthetic POLARS_EXPRESSION op."""
for function_kind in FUNCTION_KINDS:
for function_kind in _FUNCTION_KINDS:
opnames: list[AbstractSet[str]] = [
{"LOAD_GLOBAL", "LOAD_DEREF"},
*function_kind["module_opname"],
*function_kind["attribute_opname"],
*function_kind["argument_1_opname"],
*function_kind["argument_1_unary_opname"],
*function_kind["argument_2_opname"],
OpNames.CALL,
]
Expand Down Expand Up @@ -766,7 +779,15 @@ def _rewrite_functions(
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst3._replace(offset=inst1.offset)
updated_instructions.extend((operand, synthetic_call))
updated_instructions.extend(
(
operand,
matching_instructions[3 + attribute_count],
synthetic_call,
)
if function_kind["argument_1_unary_opname"]
else (operand, synthetic_call)
)
return len(matching_instructions)

return 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
"lambda x: (float(x) * int(x)) // 2",
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
),
(
"a",
"lambda x: 1 / (1 + np.exp(-x))",
'1 / (1 + (-pl.col("a")).exp())',
),
# ---------------------------------------------
# numpy
# ---------------------------------------------
Expand Down

0 comments on commit 2f0a4da

Please sign in to comment.