Skip to content

Commit

Permalink
Drop arity checks
Browse files Browse the repository at this point in the history
  • Loading branch information
zero323 committed Feb 4, 2020
1 parent 948deac commit 05741c8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 27 deletions.
37 changes: 14 additions & 23 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,7 +2903,7 @@ def _get_lambda_parameters_legacy(f):
return spec.args


def _create_lambda(f, allowed_arities):
def _create_lambda(f):
"""
Create `o.a.s.sql.expressions.LambdaFunction` corresponding
to transformation described by f
Expand All @@ -2912,20 +2912,12 @@ def _create_lambda(f, allowed_arities):
- (Column) -> Column: ...
- (Column, Column) -> Column: ...
- (Column, Column, Column) -> Column: ...
:param allowed_arities: Set[int] Expected arities
"""
if sys.version_info >= (3, 3):
parameters = _get_lambda_parameters(f)
else:
parameters = _get_lambda_parameters_legacy(f)

if len(parameters) not in allowed_arities:
raise ValueError(
"""f arity expected to be in {} but is {}""".format(
allowed_arities, len(parameters)
)
)

sc = SparkContext._active_spark_context
expressions = sc._jvm.org.apache.spark.sql.catalyst.expressions

Expand Down Expand Up @@ -2953,8 +2945,7 @@ def _invoke_higher_order_function(name, cols, funs):
:param name: Name of the expression
:param cols: a list of columns
:param funs: a list of tuples ((*Column) -> Column, Iterable[int])
where the second element represent allowed arities
:param funs: a list of((*Column) -> Column functions.
:return: a Column
"""
Expand All @@ -2963,7 +2954,7 @@ def _invoke_higher_order_function(name, cols, funs):
expr = getattr(expressions, name)

jcols = [_to_java_column(col).expr() for col in cols]
jfuns = [_create_lambda(f, a) for f, a in funs]
jfuns = [_create_lambda(f) for f in funs]

return Column(sc._jvm.Column(expr(*jcols + jfuns)))

Expand Down Expand Up @@ -3005,7 +2996,7 @@ def transform(col, f):
|[1, -2, 3, -4]|
+--------------+
"""
return _invoke_higher_order_function("ArrayTransform", [col], [(f, {1, 2})])
return _invoke_higher_order_function("ArrayTransform", [col], [f])


@since(3.1)
Expand All @@ -3030,7 +3021,7 @@ def exists(col, f):
| true|
+------------+
"""
return _invoke_higher_order_function("ArrayExists", [col], [(f, {1})])
return _invoke_higher_order_function("ArrayExists", [col], [f])


@since(3.1)
Expand Down Expand Up @@ -3059,7 +3050,7 @@ def forall(col, f):
| true|
+-------+
"""
return _invoke_higher_order_function("ArrayForAll", [col], [(f, {1})])
return _invoke_higher_order_function("ArrayForAll", [col], [f])


@since(3.1)
Expand Down Expand Up @@ -3097,7 +3088,7 @@ def filter(col, f):
|[2018-09-20, 2019-07-01]|
+------------------------+
"""
return _invoke_higher_order_function("ArrayFilter", [col], [(f, {1, 2})])
return _invoke_higher_order_function("ArrayFilter", [col], [f])


@since(3.1)
Expand Down Expand Up @@ -3150,14 +3141,14 @@ def aggregate(col, zero, merge, finish=None):
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[(merge, {2}), (finish, {1})]
[merge, finish]
)

else:
return _invoke_higher_order_function(
"ArrayAggregate",
[col, zero],
[(merge, {2})]
[merge]
)


Expand Down Expand Up @@ -3193,7 +3184,7 @@ def zip_with(col1, col2, f):
|[foo_1, bar_2, 3]|
+-----------------+
"""
return _invoke_higher_order_function("ZipWith", [col1, col2], [(f, {2})])
return _invoke_higher_order_function("ZipWith", [col1, col2], [f])


@since(3.1)
Expand All @@ -3220,7 +3211,7 @@ def transform_keys(col, f):
|[BAR -> 2.0, FOO -> -2.0]|
+-------------------------+
"""
return _invoke_higher_order_function("TransformKeys", [col], [(f, {2})])
return _invoke_higher_order_function("TransformKeys", [col], [f])


@since(3.1)
Expand All @@ -3247,7 +3238,7 @@ def transform_values(col, f):
|[OPS -> 34.0, IT -> 20.0, SALES -> 2.0]|
+---------------------------------------+
"""
return _invoke_higher_order_function("TransformValues", [col], [(f, {2})])
return _invoke_higher_order_function("TransformValues", [col], [f])


@since(3.1)
Expand All @@ -3273,7 +3264,7 @@ def map_filter(col, f):
|[baz -> 32.0, foo -> 42.0]|
+--------------------------+
"""
return _invoke_higher_order_function("MapFilter", [col], [(f, {2})])
return _invoke_higher_order_function("MapFilter", [col], [f])


@since(3.1)
Expand Down Expand Up @@ -3303,7 +3294,7 @@ def map_zip_with(col1, col2, f):
|[SALES -> 16.8, IT -> 48.0]|
+---------------------------+
"""
return _invoke_higher_order_function("MapZipWith", [col1, col2], [(f, {3})])
return _invoke_higher_order_function("MapZipWith", [col1, col2], [f])


# ---------------------------- User Defined Function ----------------------------------
Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,6 @@ def test_higher_order_function_failures(self):
with self.assertRaises(ValueError):
transform(col("foo"), lambda x: 1)

# Should fail if arity doesn't match expectations
with self.assertRaises(ValueError):
exists('numbers', lambda x, y, z: x < 0)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 05741c8

Please sign in to comment.