From 1a0a2d8f875743f857bdd66886ed4c5e85680768 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 27 Jul 2018 00:49:43 +0800 Subject: [PATCH 1/3] Explicitly whitelist supported types in upper/lower bounds for in-memory partition pruning --- .../columnar/InMemoryTableScanExec.scala | 42 +++++++++++++------ .../columnar/PartitionBatchPruningSuite.scala | 31 +++++++++++++- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 997cf92449c68..6012aba1acbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -183,6 +183,18 @@ case class InMemoryTableScanExec( private val stats = relation.partitionStatistics private def statsFor(a: Attribute) = stats.forAttribute(a) + // Currently, only use statistics from atomic types except binary type only. + private object ExtractableLiteral { + def unapply(expr: Expression): Option[Literal] = expr match { + case lit: Literal => lit.dataType match { + case BinaryType => None + case _: AtomicType => Some(lit) + case _ => None + } + case _ => None + } + } + // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. @transient lazy val buildFilter: PartialFunction[Expression, Expression] = { @@ -194,33 +206,37 @@ case class InMemoryTableScanExec( if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => buildFilter(lhs) || buildFilter(rhs) - case EqualTo(a: AttributeReference, l: Literal) => + case EqualTo(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualTo(l: Literal, a: AttributeReference) => + case EqualTo(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(a: AttributeReference, l: Literal) => + case EqualNullSafe(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case EqualNullSafe(l: Literal, a: AttributeReference) => + case EqualNullSafe(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound - case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l - case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound + case LessThan(a: AttributeReference, ExtractableLiteral(l)) => statsFor(a).lowerBound < l + case LessThan(ExtractableLiteral(l), a: AttributeReference) => l < statsFor(a).upperBound - case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l - case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound + case LessThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + statsFor(a).lowerBound <= l + case LessThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + l <= statsFor(a).upperBound - case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound - case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l + case GreaterThan(a: AttributeReference, ExtractableLiteral(l)) => l < statsFor(a).upperBound + case GreaterThan(ExtractableLiteral(l), a: AttributeReference) => statsFor(a).lowerBound < l - case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound - case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l + case GreaterThanOrEqual(a: AttributeReference, ExtractableLiteral(l)) => + l <= statsFor(a).upperBound + case GreaterThanOrEqual(ExtractableLiteral(l), a: AttributeReference) => + statsFor(a).lowerBound <= l case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 case In(a: AttributeReference, list: Seq[Expression]) - if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => + if list.forall(ExtractableLiteral.unapply(_).isDefined) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9d862cfdecb21..5907ab3c9935f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -35,6 +36,12 @@ class PartitionBatchPruningSuite private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) private lazy val originalInMemoryPartitionPruning = spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) + private val testArrayData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key)) + } + private val testBinaryData = (1 to 100).map { key => + Tuple1(Array.fill(key)(key.toByte)) + } override protected def beforeAll(): Unit = { super.beforeAll() @@ -71,12 +78,21 @@ class PartitionBatchPruningSuite }, 5).toDF() pruningStringData.createOrReplaceTempView("pruningStringData") spark.catalog.cacheTable("pruningStringData") + + val pruningArrayData = sparkContext.makeRDD(testArrayData, 5).toDF() + pruningArrayData.createOrReplaceTempView("pruningArrayData") + spark.catalog.cacheTable("pruningArrayData") + + val pruningBinaryData = sparkContext.makeRDD(testBinaryData, 5).toDF() + pruningBinaryData.createOrReplaceTempView("pruningBinaryData") + spark.catalog.cacheTable("pruningBinaryData") } override protected def afterEach(): Unit = { try { spark.catalog.uncacheTable("pruningData") spark.catalog.uncacheTable("pruningStringData") + spark.catalog.uncacheTable("pruningArrayData") } finally { super.afterEach() } @@ -95,6 +111,16 @@ class PartitionBatchPruningSuite checkBatchPruning("SELECT key FROM pruningData WHERE 11 >= key", 1, 2)(1 to 11) checkBatchPruning("SELECT key FROM pruningData WHERE 88 < key", 1, 2)(89 to 100) checkBatchPruning("SELECT key FROM pruningData WHERE 89 <= key", 1, 2)(89 to 100) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 = array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 <= array(1)", 5, 10)(Seq(Array(1))) + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 >= array(1)", 5, 10)( + testArrayData.map(_._1)) + // Do not filter on binary type + checkBatchPruning( + query = "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", + expectedReadPartitions = 5, + expectedReadBatches = 10)(Seq(Array(1.toByte))) // IS NULL checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { @@ -131,6 +157,9 @@ class PartitionBatchPruningSuite checkBatchPruning( "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)( Seq(150)) + // Do not filter on array type + checkBatchPruning("SELECT _1 FROM pruningArrayData WHERE _1 IN (array(1), array(2, 2))", 5, 10)( + Seq(Array(1), Array(2, 2))) // With unsupported `InSet` predicate { @@ -161,7 +190,7 @@ class PartitionBatchPruningSuite query: String, expectedReadPartitions: Int, expectedReadBatches: Int)( - expectedQueryResult: => Seq[Int]): Unit = { + expectedQueryResult: => Seq[Any]): Unit = { test(query) { val df = sql(query) From fe3c0a0254d4db767750d2afa0294ecd2d8ee24f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 28 Jul 2018 16:18:32 +0800 Subject: [PATCH 2/3] style --- .../sql/execution/columnar/PartitionBatchPruningSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 5907ab3c9935f..71fca68395633 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -118,9 +118,7 @@ class PartitionBatchPruningSuite testArrayData.map(_._1)) // Do not filter on binary type checkBatchPruning( - query = "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", - expectedReadPartitions = 5, - expectedReadBatches = 10)(Seq(Array(1.toByte))) + "SELECT _1 FROM pruningBinaryData WHERE _1 == binary(chr(1))", 5, 10)(Seq(Array(1.toByte))) // IS NULL checkBatchPruning("SELECT key FROM pruningData WHERE value IS NULL", 5, 5) { From deb20eff74252a617790eb595592dd5f80eceb71 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 30 Jul 2018 09:11:01 +0800 Subject: [PATCH 3/3] Uncache table --- .../sql/execution/columnar/PartitionBatchPruningSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 71fca68395633..af493e93b5192 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -93,6 +93,7 @@ class PartitionBatchPruningSuite spark.catalog.uncacheTable("pruningData") spark.catalog.uncacheTable("pruningStringData") spark.catalog.uncacheTable("pruningArrayData") + spark.catalog.uncacheTable("pruningBinaryData") } finally { super.afterEach() }