diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0f3c024b6220e..263c9ba60d145 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -392,31 +392,34 @@ case class FilterEstimation(plan: Filter) extends Logging { val dataType = attr.dataType var newNdv = ndv + if (ndv.toDouble == 0 || colStat.min.isEmpty || colStat.max.isEmpty) { + return Some(0.0) + } + // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => - if (colStat.min.isDefined && colStat.max.isDefined) { - val statsInterval = - ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] - val validQuerySet = hSet.filter { v => - v != null && statsInterval.contains(Literal(v, dataType)) - } + val statsInterval = + ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] + val validQuerySet = hSet.filter { v => + v != null && statsInterval.contains(Literal(v, dataType)) + } - if (validQuerySet.isEmpty) { - return Some(0.0) - } + if (validQuerySet.isEmpty) { + return Some(0.0) + } - val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType)) - val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType)) - // newNdv should not be greater than the old ndv. For example, column has only 2 values - // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. - newNdv = ndv.min(BigInt(validQuerySet.size)) - if (update) { - val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), - max = Some(newMax), nullCount = Some(0)) - colStatsMap.update(attr, newStats) - } + val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType)) + // newNdv should not be greater than the old ndv. For example, column has only 2 values + // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. + newNdv = ndv.min(BigInt(validQuerySet.size)) + if (update) { + val newStats = colStat.copy(distinctCount = Some(newNdv), min = Some(newMin), + max = Some(newMax), nullCount = Some(0)) + colStatsMap.update(attr, newStats) } + // We assume the whole set since there is no min/max information for String/Binary type case StringType | BinaryType => newNdv = ndv.min(BigInt(hSet.size)) @@ -428,11 +431,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - if (ndv.toDouble != 0) { - Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0)) - } else { - Some(0.0) - } + Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index b1ca37195cc3a..16cb5d032cf57 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -360,9 +360,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("evaluateInSet with all zeros") { validateEstimatedStats( Filter(InSet(attrString, Set(3, 4, 5)), - StatsTestPlan(Seq(attrString), 10, + StatsTestPlan(Seq(attrString), 0, AttributeMap(Seq(attrString -> - ColumnStat(distinctCount = Some(0), min = Some(0), max = Some(0), + ColumnStat(distinctCount = Some(0), min = None, max = None, nullCount = Some(0), avgLen = Some(0), maxLen = Some(0)))))), Seq(attrString -> ColumnStat(distinctCount = Some(0))), expectedRowCount = 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 14a565863d66c..877746beb79f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -382,4 +382,34 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } } + + test("Simple queries must be working, if CBO is turned on") { + withSQLConf(("spark.sql.cbo.enabled", "true")) { + withTable("TBL1", "TBL") { + import org.apache.spark.sql.functions._ + val df = spark.range(1000L).select('id, + 'id * 2 as "FLD1", + 'id * 12 as "FLD2", + lit("aaa") + 'id as "fld3") + df.write + .mode(SaveMode.Overwrite) + .bucketBy(10, "id", "FLD1", "FLD2") + .sortBy("id", "FLD1", "FLD2") + .saveAsTable("TBL") + spark.sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") + spark.sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") + val df2 = spark.sql( + """ + SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 + FROM tbl t1 + JOIN tbl t2 on t1.id=t2.id + WHERE t1.fld3 IN (-123.23,321.23) + """.stripMargin) + df2.createTempView("TBL2") + spark.sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ").explain() + } + } + + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CBOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CBOSuite.scala deleted file mode 100644 index bdc1610cb7c74..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CBOSuite.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{QueryTest, SaveMode} -import org.apache.spark.sql.test.SharedSparkSession - -class CBOSuite extends QueryTest with SharedSparkSession { - - import testImplicits._ - - test("Simple queries must be working, if CBO is turned on") { - withSQLConf(("spark.sql.cbo.enabled", "true")) { - withTable("TBL1", "TBL") { - import org.apache.spark.sql.functions._ - val df = spark.range(1000L).select('id, - 'id * 2 as "FLD1", - 'id * 12 as "FLD2", - lit("aaa") + 'id as "fld3") - df.write - .mode(SaveMode.Overwrite) - .bucketBy(10, "id", "FLD1", "FLD2") - .sortBy("id", "FLD1", "FLD2") - .saveAsTable("TBL") - spark.sql("ANALYZE TABLE TBL COMPUTE STATISTICS ") - spark.sql("ANALYZE TABLE TBL COMPUTE STATISTICS FOR COLUMNS ID, FLD1, FLD2, FLD3") - val df2 = spark.sql( - """ - SELECT t1.id, t1.fld1, t1.fld2, t1.fld3 - FROM tbl t1 - JOIN tbl t2 on t1.id=t2.id - WHERE t1.fld3 IN (-123.23,321.23) - """.stripMargin) - df2.createTempView("TBL2") - val df3 = spark.sql("SELECT * FROM tbl2 WHERE fld3 IN ('qqq', 'qwe') ") - assertResult(0, "") { - df3.count() - } - } - } - - } - -}