Skip to content

Commit

Permalink
[SPARK-17162] Range does not support SQL generation
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The range operator previously didn't support SQL generation, which made it not possible to use in views.

## How was this patch tested?

Unit tests.

cc hvanhovell

Author: Eric Liang <[email protected]>

Closes #14724 from ericl/spark-17162.
  • Loading branch information
ericl authored and rxin committed Aug 22, 2016
1 parent 929cb8b commit 84770b5
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
* Rule that resolves table-valued function references.
*/
object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
private lazy val defaultParallelism =
SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism

/**
* List of argument names and their types, used to declare a function.
*/
Expand Down Expand Up @@ -84,25 +81,25 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
"range" -> Map(
/* range(end) */
tvf("end" -> LongType) { case Seq(end: Long) =>
Range(0, end, 1, defaultParallelism)
Range(0, end, 1, None)
},

/* range(start, end) */
tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
Range(start, end, 1, defaultParallelism)
Range(start, end, 1, None)
},

/* range(start, end, step) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
case Seq(start: Long, end: Long, step: Long) =>
Range(start, end, step, defaultParallelism)
Range(start, end, step, None)
},

/* range(start, end, step, numPartitions) */
tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
"numPartitions" -> IntegerType) {
case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
Range(start, end, step, numPartitions)
Range(start, end, step, Some(numPartitions))
})
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,20 @@ case class Sort(

/** Factory for constructing new `Range` nodes. */
object Range {
def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = {
val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
new Range(start, end, step, numSlices, output)
}
def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
Range(start, end, step, Some(numSlices))
}
}

case class Range(
start: Long,
end: Long,
step: Long,
numSlices: Int,
numSlices: Option[Int],
output: Seq[Attribute])
extends LeafNode with MultiInstanceRelation {

Expand All @@ -449,6 +452,14 @@ case class Range(
}
}

def toSQL(): String = {
if (numSlices.isDefined) {
s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step, ${numSlices.get})"
} else {
s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step)"
}
}

override def newInstance(): Range = copy(output = output.map(_.newInstance()))

override lazy val statistics: Statistics = {
Expand All @@ -457,11 +468,7 @@ case class Range(
}

override def simpleString: String = {
if (step == 1) {
s"Range ($start, $end, splits=$numSlices)"
} else {
s"Range ($start, $end, step=$step, splits=$numSlices)"
}
s"Range ($start, $end, step=$step, splits=$numSlices)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ class SQLBuilder private (
case p: LocalRelation =>
p.toSQL(newSubqueryName())

case p: Range =>
p.toSQL()

case OneRowRelation =>
""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)

def start: Long = range.start
def step: Long = range.step
def numSlices: Int = range.numSlices
def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
def numElements: BigInt = range.numElements

override val output: Seq[Attribute] = range.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ case class CreateViewCommand(
sparkSession.sql(viewSQL).queryExecution.assertAnalyzed()
} catch {
case NonFatal(e) =>
throw new RuntimeException(
"Failed to analyze the canonicalized SQL. It is possible there is a bug in Spark.", e)
throw new RuntimeException(s"Failed to analyze the canonicalized SQL: ${viewSQL}", e)
}

val viewSchema = if (userSpecifiedColumns.isEmpty) {
Expand Down
4 changes: 4 additions & 0 deletions sql/hive/src/test/resources/sqlgen/range.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
select * from range(100)
--------------------------------------------------------------------------------
SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(0, 100, 1)) AS gen_subquery_0) AS gen_subquery_1
4 changes: 4 additions & 0 deletions sql/hive/src/test/resources/sqlgen/range_with_splits.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- This file is automatically generated by LogicalPlanToSQLSuite.
select * from range(1, 100, 20, 10)
--------------------------------------------------------------------------------
SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(1, 100, 20, 10)) AS gen_subquery_0) AS gen_subquery_1
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import java.nio.file.{Files, NoSuchFileException, Paths}
import scala.util.control.NonFatal

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
Expand Down Expand Up @@ -180,7 +183,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}

test("Test should fail if the SQL query cannot be regenerated") {
spark.range(10).createOrReplaceTempView("not_sql_gen_supported_table_so_far")
case class Unsupported() extends LeafNode with MultiInstanceRelation {
override def newInstance(): Unsupported = copy()
override def output: Seq[Attribute] = Nil
}
Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far")
sql("select * from not_sql_gen_supported_table_so_far")
val m3 = intercept[org.scalatest.exceptions.TestFailedException] {
checkSQL("select * from not_sql_gen_supported_table_so_far", "in")
Expand All @@ -196,6 +203,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
}
}

test("range") {
checkSQL("select * from range(100)", "range")
checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits")
}

test("in") {
checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in")
}
Expand Down

0 comments on commit 84770b5

Please sign in to comment.