Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24934][SQL] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning #21882

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ case class InMemoryTableScanExec(
private val stats = relation.partitionStatistics
private def statsFor(a: Attribute) = stats.forAttribute(a)

// Currently, only use statistics from atomic types except binary type only.
private object ExtractableLiteral {
def unapply(expr: Expression): Option[Literal] = expr match {
case lit: Literal => lit.dataType match {
case BinaryType => None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add test for binary type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add late tonight or tomorrow

case _: AtomicType => Some(lit)
case _ => None
}
case _ => None
}
}

// Returned filter predicate should return false iff it is impossible for the input expression
// to evaluate to `true' based on statistics collected about this partition batch.
@transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
Expand All @@ -194,33 +206,37 @@ case class InMemoryTableScanExec(
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
buildFilter(lhs) || buildFilter(rhs)

case EqualTo(a: AttributeReference, l: Literal) =>
case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
case EqualTo(l: Literal, a: AttributeReference) =>
case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound

case EqualNullSafe(a: AttributeReference, l: Literal) =>
case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
case EqualNullSafe(l: Literal, a: AttributeReference) =>
case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound

case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound

case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l
case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
l <= statsFor(a).upperBound

case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l

case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
l <= statsFor(a).upperBound
case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l

case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0

case In(a: AttributeReference, list: Seq[Expression])
if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
Expand All @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite
private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
private lazy val originalInMemoryPartitionPruning =
spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
private val testArrayData = (1 to 100).map { key =>
Tuple1(Array.fill(key)(key))
}
private val testBinaryData = (1 to 100).map { key =>
Tuple1(Array.fill(key)(key.toByte))
}

override protected def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -71,12 +78,22 @@ class PartitionBatchPruningSuite
}, 5).toDF()
pruningStringData.createOrReplaceTempView("pruningStringData")
spark.catalog.cacheTable("pruningStringData")

val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF()
pruningArrayData.createOrReplaceTempView("pruningArrayData")
spark.catalog.cacheTable("pruningArrayData")

val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> spark.sparkContext.makeRDD((1 to 100).map { key => Tuple1(Array.fill(key)(key.toByte)) }, 5).toDF().printSchema()
root
 |-- _1: binary (nullable = true)

pruningBinaryData.createOrReplaceTempView("pruningBinaryData")
spark.catalog.cacheTable("pruningBinaryData")
}

override protected def afterEach(): Unit = {
try {
spark.catalog.uncacheTable("pruningData")
spark.catalog.uncacheTable("pruningStringData")
spark.catalog.uncacheTable("pruningArrayData")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uncache the pruningBinaryData too

spark.catalog.uncacheTable("pruningBinaryData")
} finally {
super.afterEach()
}
Expand All @@ -95,6 +112,14 @@ class PartitionBatchPruningSuite
checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
// Do not filter on array type
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1)))
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1)))
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)(
testArrayData.map(_._1))
// Do not filter on binary type
checkBatchPruning(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this change, Expected Array(Array(1)), but got Array() Wrong query result

"SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte)))

// IS NULL
checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
Expand Down Expand Up @@ -131,6 +156,9 @@ class PartitionBatchPruningSuite
checkBatchPruning(
"SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
Seq(150))
// Do not filter on array type
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)(
Seq(Array(1), Array(2, 2)))

// With unsupported `InSet` predicate
{
Expand Down Expand Up @@ -161,7 +189,7 @@ class PartitionBatchPruningSuite
query: String,
expectedReadPartitions: Int,
expectedReadBatches: Int)(
expectedQueryResult: => Seq[Int]): Unit = {
expectedQueryResult: => Seq[Any]): Unit = {

test(query) {
val df = sql(query)
Expand Down