From 05741c8422364a77389d13ba0a85b7ad92392be6 Mon Sep 17 00:00:00 2001 From: zero323 Date: Tue, 4 Feb 2020 02:18:24 +0100 Subject: [PATCH] Drop arity checks --- python/pyspark/sql/functions.py | 37 ++++++++-------------- python/pyspark/sql/tests/test_functions.py | 4 --- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8351b383d42e7..b519e070791bf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2903,7 +2903,7 @@ def _get_lambda_parameters_legacy(f): return spec.args -def _create_lambda(f, allowed_arities): +def _create_lambda(f): """ Create `o.a.s.sql.expressions.LambdaFunction` corresponding to transformation described by f @@ -2912,20 +2912,12 @@ def _create_lambda(f, allowed_arities): - (Column) -> Column: ... - (Column, Column) -> Column: ... - (Column, Column, Column) -> Column: ... - :param allowed_arities: Set[int] Expected arities """ if sys.version_info >= (3, 3): parameters = _get_lambda_parameters(f) else: parameters = _get_lambda_parameters_legacy(f) - if len(parameters) not in allowed_arities: - raise ValueError( - """f arity expected to be in {} but is {}""".format( - allowed_arities, len(parameters) - ) - ) - sc = SparkContext._active_spark_context expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions @@ -2953,8 +2945,7 @@ def _invoke_higher_order_function(name, cols, funs): :param name: Name of the expression :param cols: a list of columns - :param funs: a list of tuples ((*Column) -> Column, Iterable[int]) - where the second element represent allowed arities + :param funs: a list of((*Column) -> Column functions. :return: a Column """ @@ -2963,7 +2954,7 @@ def _invoke_higher_order_function(name, cols, funs): expr = getattr(expressions, name) jcols = [_to_java_column(col).expr() for col in cols] - jfuns = [_create_lambda(f, a) for f, a in funs] + jfuns = [_create_lambda(f) for f in funs] return Column(sc._jvm.Column(expr(*jcols + jfuns))) @@ -3005,7 +2996,7 @@ def transform(col, f): |[1, -2, 3, -4]| +--------------+ """ - return _invoke_higher_order_function("ArrayTransform", [col], [(f, {1, 2})]) + return _invoke_higher_order_function("ArrayTransform", [col], [f]) @since(3.1) @@ -3030,7 +3021,7 @@ def exists(col, f): | true| +------------+ """ - return _invoke_higher_order_function("ArrayExists", [col], [(f, {1})]) + return _invoke_higher_order_function("ArrayExists", [col], [f]) @since(3.1) @@ -3059,7 +3050,7 @@ def forall(col, f): | true| +-------+ """ - return _invoke_higher_order_function("ArrayForAll", [col], [(f, {1})]) + return _invoke_higher_order_function("ArrayForAll", [col], [f]) @since(3.1) @@ -3097,7 +3088,7 @@ def filter(col, f): |[2018-09-20, 2019-07-01]| +------------------------+ """ - return _invoke_higher_order_function("ArrayFilter", [col], [(f, {1, 2})]) + return _invoke_higher_order_function("ArrayFilter", [col], [f]) @since(3.1) @@ -3150,14 +3141,14 @@ def aggregate(col, zero, merge, finish=None): return _invoke_higher_order_function( "ArrayAggregate", [col, zero], - [(merge, {2}), (finish, {1})] + [merge, finish] ) else: return _invoke_higher_order_function( "ArrayAggregate", [col, zero], - [(merge, {2})] + [merge] ) @@ -3193,7 +3184,7 @@ def zip_with(col1, col2, f): |[foo_1, bar_2, 3]| +-----------------+ """ - return _invoke_higher_order_function("ZipWith", [col1, col2], [(f, {2})]) + return _invoke_higher_order_function("ZipWith", [col1, col2], [f]) @since(3.1) @@ -3220,7 +3211,7 @@ def transform_keys(col, f): |[BAR -> 2.0, FOO -> -2.0]| +-------------------------+ """ - return _invoke_higher_order_function("TransformKeys", [col], [(f, {2})]) + return _invoke_higher_order_function("TransformKeys", [col], [f]) @since(3.1) @@ -3247,7 +3238,7 @@ def transform_values(col, f): |[OPS -> 34.0, IT -> 20.0, SALES -> 2.0]| +---------------------------------------+ """ - return _invoke_higher_order_function("TransformValues", [col], [(f, {2})]) + return _invoke_higher_order_function("TransformValues", [col], [f]) @since(3.1) @@ -3273,7 +3264,7 @@ def map_filter(col, f): |[baz -> 32.0, foo -> 42.0]| +--------------------------+ """ - return _invoke_higher_order_function("MapFilter", [col], [(f, {2})]) + return _invoke_higher_order_function("MapFilter", [col], [f]) @since(3.1) @@ -3303,7 +3294,7 @@ def map_zip_with(col1, col2, f): |[SALES -> 16.8, IT -> 48.0]| +---------------------------+ """ - return _invoke_higher_order_function("MapZipWith", [col1, col2], [(f, {3})]) + return _invoke_higher_order_function("MapZipWith", [col1, col2], [f]) # ---------------------------- User Defined Function ---------------------------------- diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 110f713aff5f8..ddb8283cafa85 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -360,10 +360,6 @@ def test_higher_order_function_failures(self): with self.assertRaises(ValueError): transform(col("foo"), lambda x: 1) - # Should fail if arity doesn't match expectations - with self.assertRaises(ValueError): - exists('numbers', lambda x, y, z: x < 0) - if __name__ == "__main__": import unittest