Skip to content

Commit

Permalink
Make function spec explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
zero323 committed Jan 31, 2020
1 parent 243794a commit 4e95e12
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import functools
import warnings
from collections import namedtuple

if sys.version < "3":
from itertools import imap as map
Expand Down Expand Up @@ -2903,7 +2904,7 @@ def _get_lambda_parameters_legacy(f):
return spec.args


def _create_lambda(f, allowed_arities):
def _create_lambda(f, expected_nargs):
"""
Create `o.a.s.sql.expressions.LambdaFunction` corresponding
to transformation described by f
Expand All @@ -2912,17 +2913,17 @@ def _create_lambda(f, allowed_arities):
- (Column) -> Column: ...
- (Column, Column) -> Column: ...
- (Column, Column, Column) -> Column: ...
:param allowed_arities: Set[int] Expected arities
:param expected_nargs: Set[int] Expected arities
"""
if sys.version_info >= (3, 3):
parameters = _get_lambda_parameters(f)
else:
parameters = _get_lambda_parameters_legacy(f)

if len(parameters) not in allowed_arities:
if len(parameters) not in expected_nargs:
raise ValueError(
"""f arity expected to be in {} but is {}""".format(
allowed_arities, len(parameters)
expected_nargs, len(parameters)
)
)

Expand All @@ -2945,16 +2946,20 @@ def _create_lambda(f, allowed_arities):
return expressions.LambdaFunction(jexpr, jargs, False)


def _invoke_higher_order_function(name, cols, funs):
# The first element should be Python function (*Column) -> Column
# The second argument is a Set[int] used to validate arity
_LambdaSpec = namedtuple("_LambdaSpec", ["func", "expected_nargs"])


def _invoke_higher_order_function(name, cols, fun_specs):
"""
Invokes expression identified by name,
(relative to ```org.apache.spark.sql.catalyst.expressions``)
and wraps the result with Column (first Scala one, then Python).
:param name: Name of the expression
:param cols: a list of columns
:param funs: a list of tuples ((*Column) -> Column, Iterable[int])
where the second element represent allowed arities
:param fun_specs: a List[_LambdaSpec] objects
:return: a Column
"""
Expand All @@ -2963,7 +2968,7 @@ def _invoke_higher_order_function(name, cols, funs):
expr = getattr(expressions, name)

jcols = [_to_java_column(col).expr() for col in cols]
jfuns = [_create_lambda(f, a) for f, a in funs]
jfuns = [_create_lambda(spec.func, spec.expected_nargs) for spec in fun_specs]

return Column(sc._jvm.Column(expr(*jcols + jfuns)))

Expand Down Expand Up @@ -3005,7 +3010,9 @@ def transform(col, f):
|[1, -2, 3, -4]|
+--------------+
"""
return _invoke_higher_order_function("ArrayTransform", [col], [(f, {1, 2})])
return _invoke_higher_order_function(
"ArrayTransform", [col], [_LambdaSpec(func=f, expected_nargs={1, 2})]
)


@since(3.0)
Expand All @@ -3030,7 +3037,9 @@ def exists(col, f):
| true|
+------------+
"""
return _invoke_higher_order_function("ArrayExists", [col], [(f, {1})])
return _invoke_higher_order_function(
"ArrayExists", [col], [_LambdaSpec(func=f, expected_nargs={1})]
)


@since(3.0)
Expand Down Expand Up @@ -3059,7 +3068,9 @@ def forall(col, f):
| true|
+-------+
"""
return _invoke_higher_order_function("ArrayForAll", [col], [(f, {1})])
return _invoke_higher_order_function(
"ArrayForAll", [col], [_LambdaSpec(func=f, expected_nargs={1})]
)


@since(3.0)
Expand Down Expand Up @@ -3097,7 +3108,9 @@ def filter(col, f):
|[2018-09-20, 2019-07-01]|
+------------------------+
"""
return _invoke_higher_order_function("ArrayFilter", [col], [(f, {1, 2})])
return _invoke_higher_order_function(
"ArrayFilter", [col], [_LambdaSpec(func=f, expected_nargs={1, 2})]
)


@since(3.0)
Expand Down Expand Up @@ -3150,14 +3163,15 @@ def aggregate(col, zero, merge, finish=None):
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[(merge, {2}), (finish, {1})]
[
_LambdaSpec(func=merge, expected_nargs={2}),
_LambdaSpec(func=finish, expected_nargs={1}),
],
)

else:
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[(merge, {2})]
"ArrayAggregate", [col, zero], [_LambdaSpec(func=merge, expected_nargs={2})]
)


Expand Down Expand Up @@ -3193,7 +3207,9 @@ def zip_with(col1, col2, f):
|[foo_1, bar_2, 3]|
+-----------------+
"""
return _invoke_higher_order_function("ZipWith", [col1, col2], [(f, {2})])
return _invoke_higher_order_function(
"ZipWith", [col1, col2], [_LambdaSpec(func=f, expected_nargs={2})]
)


@since(3.0)
Expand All @@ -3220,7 +3236,9 @@ def transform_keys(col, f):
|[BAR -> 2.0, FOO -> -2.0]|
+-------------------------+
"""
return _invoke_higher_order_function("TransformKeys", [col], [(f, {2})])
return _invoke_higher_order_function(
"TransformKeys", [col], [_LambdaSpec(func=f, expected_nargs={2})]
)


@since(3.0)
Expand All @@ -3247,7 +3265,9 @@ def transform_values(col, f):
|[OPS -> 34.0, IT -> 20.0, SALES -> 2.0]|
+---------------------------------------+
"""
return _invoke_higher_order_function("TransformValues", [col], [(f, {2})])
return _invoke_higher_order_function(
"TransformValues", [col], [_LambdaSpec(func=f, expected_nargs={2})]
)


@since(3.0)
Expand All @@ -3273,7 +3293,9 @@ def map_filter(col, f):
|[baz -> 32.0, foo -> 42.0]|
+--------------------------+
"""
return _invoke_higher_order_function("MapFilter", [col], [(f, {2})])
return _invoke_higher_order_function(
"MapFilter", [col], [_LambdaSpec(func=f, expected_nargs={2})]
)


@since(3.0)
Expand Down Expand Up @@ -3303,7 +3325,9 @@ def map_zip_with(col1, col2, f):
|[SALES -> 16.8, IT -> 48.0]|
+---------------------------+
"""
return _invoke_higher_order_function("MapZipWith", [col1, col2], [(f, {3})])
return _invoke_higher_order_function(
"MapZipWith", [col1, col2], [_LambdaSpec(func=f, expected_nargs={3})]
)


# ---------------------------- User Defined Function ----------------------------------
Expand Down

0 comments on commit 4e95e12

Please sign in to comment.