Skip to content

Commit

Permalink
[SPARK-24934][SQL] Explicitly whitelist supported types in upper/lowe…
Browse files Browse the repository at this point in the history
…r bounds for in-memory partition pruning

## What changes were proposed in this pull request?

Looks we intentionally set `null` for upper/lower bounds for complex types and don't use it. However, these look used in in-memory partition pruning, which ends up with incorrect results.

This PR proposes to explicitly whitelist the supported types.

```scala
val df = Seq(Array("a", "b"), Array("c", "d")).toDF("arrayCol")
df.cache().filter("arrayCol > array('a', 'b')").show()
```

```scala
val df = sql("select cast('a' as binary) as a")
df.cache().filter("a == cast('a' as binary)").show()
```

**Before:**

```
+--------+
|arrayCol|
+--------+
+--------+
```

```
+---+
|  a|
+---+
+---+
```

**After:**

```
+--------+
|arrayCol|
+--------+
|  [c, d]|
+--------+
```

```
+----+
|   a|
+----+
|[61]|
+----+
```

## How was this patch tested?

Unit tests were added and manually tested.

Author: hyukjinkwon <[email protected]>

Closes #21882 from HyukjinKwon/stats-filter.
  • Loading branch information
HyukjinKwon authored and cloud-fan committed Jul 30, 2018
1 parent 65a4bc1 commit bfe60fc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ case class InMemoryTableScanExec(
private val stats = relation.partitionStatistics
private def statsFor(a: Attribute) = stats.forAttribute(a)

// Currently, only use statistics from atomic types except binary type only.
private object ExtractableLiteral {
def unapply(expr: Expression): Option[Literal] = expr match {
case lit: Literal => lit.dataType match {
case BinaryType => None
case _: AtomicType => Some(lit)
case _ => None
}
case _ => None
}
}

// Returned filter predicate should return false iff it is impossible for the input expression
// to evaluate to `true' based on statistics collected about this partition batch.
@transient lazy val buildFilter: PartialFunction[Expression, Expression] = {
Expand All @@ -194,33 +206,37 @@ case class InMemoryTableScanExec(
if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
buildFilter(lhs) || buildFilter(rhs)

case EqualTo(a: AttributeReference, l: Literal) =>
case EqualTo(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
case EqualTo(l: Literal, a: AttributeReference) =>
case EqualTo(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound

case EqualNullSafe(a: AttributeReference, l: Literal) =>
case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
case EqualNullSafe(l: Literal, a: AttributeReference) =>
case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound

case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l
case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound

case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
statsFor(a).lowerBound <= l
case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
l <= statsFor(a).upperBound

case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound
case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l

case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) =>
l <= statsFor(a).upperBound
case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) =>
statsFor(a).lowerBound <= l

case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0

case In(a: AttributeReference, list: Seq[Expression])
if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty =>
if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty =>
list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
Expand All @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite
private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE)
private lazy val originalInMemoryPartitionPruning =
spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING)
private val testArrayData = (1 to 100).map { key =>
Tuple1(Array.fill(key)(key))
}
private val testBinaryData = (1 to 100).map { key =>
Tuple1(Array.fill(key)(key.toByte))
}

override protected def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -71,12 +78,22 @@ class PartitionBatchPruningSuite
}, 5).toDF()
pruningStringData.createOrReplaceTempView("pruningStringData")
spark.catalog.cacheTable("pruningStringData")

val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF()
pruningArrayData.createOrReplaceTempView("pruningArrayData")
spark.catalog.cacheTable("pruningArrayData")

val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF()
pruningBinaryData.createOrReplaceTempView("pruningBinaryData")
spark.catalog.cacheTable("pruningBinaryData")
}

override protected def afterEach(): Unit = {
try {
spark.catalog.uncacheTable("pruningData")
spark.catalog.uncacheTable("pruningStringData")
spark.catalog.uncacheTable("pruningArrayData")
spark.catalog.uncacheTable("pruningBinaryData")
} finally {
super.afterEach()
}
Expand All @@ -95,6 +112,14 @@ class PartitionBatchPruningSuite
checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100)
checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100)
// Do not filter on array type
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1)))
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1)))
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)(
testArrayData.map(_._1))
// Do not filter on binary type
checkBatchPruning(
"SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte)))

// IS NULL
checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) {
Expand Down Expand Up @@ -131,6 +156,9 @@ class PartitionBatchPruningSuite
checkBatchPruning(
"SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
Seq(150))
// Do not filter on array type
checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)(
Seq(Array(1), Array(2, 2)))

// With unsupported `InSet` predicate
{
Expand Down Expand Up @@ -161,7 +189,7 @@ class PartitionBatchPruningSuite
query: String,
expectedReadPartitions: Int,
expectedReadBatches: Int)(
expectedQueryResult: => Seq[Int]): Unit = {
expectedQueryResult: => Seq[Any]): Unit = {

test(query) {
val df = sql(query)
Expand Down

0 comments on commit bfe60fc

Please sign in to comment.