diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 0e525b1e22eb9..42aac06a43c7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -140,7 +140,7 @@ trait CodegenSupport extends SparkPlan { * Note that `outputVars` and `row` can't both be null. */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { - val inputVars = + val inputVarsCandidate = if (outputVars != null) { assert(outputVars.length == output.length) // outputVars will be used to generate the code for UnsafeRow, so we should copy them @@ -154,6 +154,11 @@ trait CodegenSupport extends SparkPlan { } } + val inputVars = inputVarsCandidate match { + case stream: Stream[ExprCode] => stream.force + case other => other + } + val rowVar = prepareRowVar(ctx, row, outputVars) // Set up the `currentVars` in the codegen context, as we generate the code of `inputVars` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..161a6bfcaefe0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -319,4 +319,24 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { assert(df.limit(1).collect() === Array(Row("bat", 8.0))) } } + + test("SPARK-25767: Lazy evaluated stream of expressions handled correctly") { + val a = Seq(1).toDF("key") + val b = Seq((1, "a")).toDF("key", "value") + val c = Seq(1).toDF("key") + + val ab = a.join(b, Stream("key"), "left") + val abc = ab.join(c, Seq("key"), "left") + + checkAnswer(abc, Row(1, "a")) + } + + test("SPARK-26680: Stream in groupBy does not cause StackOverflowError") { + val groupByCols = Stream(col("key")) + val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value") + .groupBy(groupByCols: _*) + .max("value") + + checkAnswer(df, Seq(Row(1, 3), Row(2, 3))) + } }