diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9a5977b7025e6..6ea73430367fa 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -4254,7 +4254,10 @@ def _create_lambda(f): argnames = ["x", "y", "z"] args = [ - _unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)] + _unresolved_named_lambda_variable( + expressions.UnresolvedNamedLambdaVariable.freshVarName(arg) + ) + for arg in argnames[: len(parameters)] ] result = f(*args) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f49b7b2f359e1..082d61b732429 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -493,6 +493,28 @@ def test_higher_order_function_failures(self): with self.assertRaises(ValueError): transform(col("foo"), lambda x: 1) + def test_nested_higher_order_function(self): + # SPARK-35382: lambda vars must be resolved properly in nested higher order functions + from pyspark.sql.functions import flatten, struct, transform + + df = self.spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters") + + actual = df.select(flatten( + transform( + "numbers", + lambda number: transform( + "letters", + lambda letter: struct(number.alias("n"), letter.alias("l")) + ) + ) + )).first()[0] + + expected = [(1, "a"), (1, "b"), (1, "c"), + (2, "a"), (2, "b"), (2, "c"), + (3, "a"), (3, "b"), (3, "c")] + + self.assertEquals(actual, expected) + def test_window_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.partitionBy("value").orderBy("key")