diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8cbbaa43a76c2..d6661f2d0f4d2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -21,6 +21,7 @@ import sys import functools import warnings +from collections import namedtuple if sys.version < "3": from itertools import imap as map @@ -2903,7 +2904,7 @@ def _get_lambda_parameters_legacy(f): return spec.args -def _create_lambda(f, allowed_arities): +def _create_lambda(f, expected_nargs): """ Create `o.a.s.sql.expressions.LambdaFunction` corresponding to transformation described by f @@ -2912,17 +2913,17 @@ def _create_lambda(f, allowed_arities): - (Column) -> Column: ... - (Column, Column) -> Column: ... - (Column, Column, Column) -> Column: ... - :param allowed_arities: Set[int] Expected arities + :param expected_nargs: Set[int] Expected arities """ if sys.version_info >= (3, 3): parameters = _get_lambda_parameters(f) else: parameters = _get_lambda_parameters_legacy(f) - if len(parameters) not in allowed_arities: + if len(parameters) not in expected_nargs: raise ValueError( """f arity expected to be in {} but is {}""".format( - allowed_arities, len(parameters) + expected_nargs, len(parameters) ) ) @@ -2945,7 +2946,12 @@ def _create_lambda(f, allowed_arities): return expressions.LambdaFunction(jexpr, jargs, False) -def _invoke_higher_order_function(name, cols, funs): +# The first element should be Python function (*Column) -> Column +# The second argument is a Set[int] used to validate arity +_LambdaSpec = namedtuple("_LambdaSpec", ["func", "expected_nargs"]) + + +def _invoke_higher_order_function(name, cols, fun_specs): """ Invokes expression identified by name, (relative to ```org.apache.spark.sql.catalyst.expressions``) @@ -2953,8 +2959,7 @@ def _invoke_higher_order_function(name, cols, funs): :param name: Name of the expression :param cols: a list of columns - :param funs: a list of tuples ((*Column) -> Column, Iterable[int]) - where the second element represent allowed arities + :param fun_specs: a List[_LambdaSpec] objects :return: a Column """ @@ -2963,7 +2968,7 @@ def _invoke_higher_order_function(name, cols, funs): expr = getattr(expressions, name) jcols = [_to_java_column(col).expr() for col in cols] - jfuns = [_create_lambda(f, a) for f, a in funs] + jfuns = [_create_lambda(spec.func, spec.expected_nargs) for spec in fun_specs] return Column(sc._jvm.Column(expr(*jcols + jfuns))) @@ -3005,7 +3010,9 @@ def transform(col, f): |[1, -2, 3, -4]| +--------------+ """ - return _invoke_higher_order_function("ArrayTransform", [col], [(f, {1, 2})]) + return _invoke_higher_order_function( + "ArrayTransform", [col], [_LambdaSpec(func=f, expected_nargs={1, 2})] + ) @since(3.0) @@ -3030,7 +3037,9 @@ def exists(col, f): | true| +------------+ """ - return _invoke_higher_order_function("ArrayExists", [col], [(f, {1})]) + return _invoke_higher_order_function( + "ArrayExists", [col], [_LambdaSpec(func=f, expected_nargs={1})] + ) @since(3.0) @@ -3059,7 +3068,9 @@ def forall(col, f): | true| +-------+ """ - return _invoke_higher_order_function("ArrayForAll", [col], [(f, {1})]) + return _invoke_higher_order_function( + "ArrayForAll", [col], [_LambdaSpec(func=f, expected_nargs={1})] + ) @since(3.0) @@ -3097,7 +3108,9 @@ def filter(col, f): |[2018-09-20, 2019-07-01]| +------------------------+ """ - return _invoke_higher_order_function("ArrayFilter", [col], [(f, {1, 2})]) + return _invoke_higher_order_function( + "ArrayFilter", [col], [_LambdaSpec(func=f, expected_nargs={1, 2})] + ) @since(3.0) @@ -3150,14 +3163,15 @@ def aggregate(col, zero, merge, finish=None): return _invoke_higher_order_function( "ArrayAggregate", [col, zero], - [(merge, {2}), (finish, {1})] + [ + _LambdaSpec(func=merge, expected_nargs={2}), + _LambdaSpec(func=finish, expected_nargs={1}), + ], ) else: return _invoke_higher_order_function( - "ArrayAggregate", - [col, zero], - [(merge, {2})] + "ArrayAggregate", [col, zero], [_LambdaSpec(func=merge, expected_nargs={2})] ) @@ -3193,7 +3207,9 @@ def zip_with(col1, col2, f): |[foo_1, bar_2, 3]| +-----------------+ """ - return _invoke_higher_order_function("ZipWith", [col1, col2], [(f, {2})]) + return _invoke_higher_order_function( + "ZipWith", [col1, col2], [_LambdaSpec(func=f, expected_nargs={2})] + ) @since(3.0) @@ -3220,7 +3236,9 @@ def transform_keys(col, f): |[BAR -> 2.0, FOO -> -2.0]| +-------------------------+ """ - return _invoke_higher_order_function("TransformKeys", [col], [(f, {2})]) + return _invoke_higher_order_function( + "TransformKeys", [col], [_LambdaSpec(func=f, expected_nargs={2})] + ) @since(3.0) @@ -3247,7 +3265,9 @@ def transform_values(col, f): |[OPS -> 34.0, IT -> 20.0, SALES -> 2.0]| +---------------------------------------+ """ - return _invoke_higher_order_function("TransformValues", [col], [(f, {2})]) + return _invoke_higher_order_function( + "TransformValues", [col], [_LambdaSpec(func=f, expected_nargs={2})] + ) @since(3.0) @@ -3273,7 +3293,9 @@ def map_filter(col, f): |[baz -> 32.0, foo -> 42.0]| +--------------------------+ """ - return _invoke_higher_order_function("MapFilter", [col], [(f, {2})]) + return _invoke_higher_order_function( + "MapFilter", [col], [_LambdaSpec(func=f, expected_nargs={2})] + ) @since(3.0) @@ -3303,7 +3325,9 @@ def map_zip_with(col1, col2, f): |[SALES -> 16.8, IT -> 48.0]| +---------------------------+ """ - return _invoke_higher_order_function("MapZipWith", [col1, col2], [(f, {3})]) + return _invoke_higher_order_function( + "MapZipWith", [col1, col2], [_LambdaSpec(func=f, expected_nargs={3})] + ) # ---------------------------- User Defined Function ----------------------------------