Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34796][SQL] Initialize counter variable for LIMIT code-gen in doProduce() #31892

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,18 @@ trait BaseLimitExec extends LimitExec with CodegenSupport {
}

protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
// The counter name is already obtained by the upstream operators via `limitNotReachedChecks`.
// Here we have to inline it to not change its name. This is fine as we won't have many limit
// operators in one query.
//
// Note: create counter variable here instead of `doConsume()` to avoid compilation error,
// because upstream operators might not call `doConsume()` here
// (e.g. `HashJoin.codegenInner()`).
ctx.addMutableState(CodeGenerator.JAVA_INT, countTerm, forceInline = true, useFreshName = false)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
Expand Down
19 changes: 19 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4097,6 +4097,25 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(df2, Seq(Row(2, 1, 1), Row(4, 2, 2)))
}
}

test("SPARK-34796: Avoid code-gen compilation error for LIMIT query") {
withTable("left_table", "empty_right_table", "output_table") {
spark.range(5).toDF("k").write.saveAsTable("left_table")
spark.range(0).toDF("k").write.saveAsTable("empty_right_table")

withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
spark.sql("CREATE TABLE output_table (k INT) USING parquet")
spark.sql(
"""
|INSERT INTO TABLE output_table
|SELECT t1.k FROM left_table t1
|JOIN empty_right_table t2
|ON t1.k = t2.k
|LIMIT 3
""".stripMargin)
}
}
}
}

case class Foo(bar: Option[String])