Skip to content

Commit

Permalink
[SPARK-34794][SQL] Fix lambda variable name issues in nested DataFram…
Browse files Browse the repository at this point in the history
…e functions

### What changes were proposed in this pull request?

To fix lambda variable name issues in nested DataFrame functions, this PR modifies code to use a global counter for `LambdaVariables` names created by higher order functions.

This is the rework of apache#31887. Closes apache#31887.

### Why are the changes needed?

 This moves away from the current hard-coded variable names which break on nested function calls. There is currently a bug where nested transforms in particular fail (the inner variable shadows the outer variable)

For this query:
```
val df = Seq(
    (Seq(1,2,3), Seq("a", "b", "c"))
).toDF("numbers", "letters")

df.select(
    f.flatten(
        f.transform(
            $"numbers",
            (number: Column) => { f.transform(
                $"letters",
                (letter: Column) => { f.struct(
                    number.as("number"),
                    letter.as("letter")
                ) }
            ) }
        )
    ).as("zipped")
).show(10, false)
```
This is the current (incorrect) output:
```
+------------------------------------------------------------------------+
|zipped                                                                  |
+------------------------------------------------------------------------+
|[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]|
+------------------------------------------------------------------------+
```
And this is the correct output after fix:
```
+------------------------------------------------------------------------+
|zipped                                                                  |
+------------------------------------------------------------------------+
|[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]|
+------------------------------------------------------------------------+
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added the new test in `DataFrameFunctionsSuite`.

Closes apache#32424 from maropu/pr31887.

Lead-authored-by: dsolow <[email protected]>
Co-authored-by: Takeshi Yamamuro <[email protected]>
Co-authored-by: dmsolow <[email protected]>
Signed-off-by: Takeshi Yamamuro <[email protected]>
(cherry picked from commit f550e03)
Signed-off-by: Takeshi Yamamuro <[email protected]>
(cherry picked from commit 6df4ec0)
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
3 people authored and dongjoon-hyun committed May 5, 2021
1 parent 65e0773 commit f8068fc
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.mutable

Expand Down Expand Up @@ -52,6 +52,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
override def sql: String = name
}

object UnresolvedNamedLambdaVariable {

// Counter to ensure lambda variable names are unique
private val nextVarNameId = new AtomicInteger(0)

def freshVarName(name: String): String = {
s"${name}_${nextVarNameId.getAndIncrement()}"
}
}

/**
* A named lambda variable.
*/
Expand Down
12 changes: 6 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3644,22 +3644,22 @@ object functions {
}

private def createLambda(f: Column => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
val function = f(Column(x)).expr
LambdaFunction(function, Seq(x))
}

private def createLambda(f: (Column, Column) => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val y = UnresolvedNamedLambdaVariable(Seq("y"))
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
val function = f(Column(x), Column(y)).expr
LambdaFunction(function, Seq(x, y))
}

private def createLambda(f: (Column, Column, Column) => Column) = {
val x = UnresolvedNamedLambdaVariable(Seq("x"))
val y = UnresolvedNamedLambdaVariable(Seq("y"))
val z = UnresolvedNamedLambdaVariable(Seq("z"))
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z")))
val function = f(Column(x), Column(y), Column(z)).expr
LambdaFunction(function, Seq(x, y, z))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3629,6 +3629,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
df.select(map(map_entries($"m"), lit(1))),
Row(Map(Seq(Row(1, "a")) -> 1)))
}

test("SPARK-34794: lambda variable name issues in nested functions") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters")

checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) =>
transform($"letters", (letter: Column) =>
struct(number, letter))))),
Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"))))
)
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) =>
transform($"letters", (letter: Column, j: Column) =>
struct(number + j, concat(letter, i)))))),
Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1"))))
)

val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2")

checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) =>
map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) =>
ov1 + iv1 + ov2 + iv2))),
Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10))))
)
}
}

object DataFrameFunctionsSuite {
Expand Down

0 comments on commit f8068fc

Please sign in to comment.