From 02a4b8755023ba63315af0e9791933e8fcabbfd5 Mon Sep 17 00:00:00 2001 From: shahid Date: Tue, 11 May 2021 03:41:06 +0530 Subject: [PATCH] Update null count in the column stats for UNION stats estimation --- .../logical/statsEstimation/FilterEstimation.scala | 2 +- .../logical/statsEstimation/UnionEstimation.scala | 14 ++++++++++++-- .../statsEstimation/UnionEstimationSuite.scala | 12 ++++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 2c5beef43f52a..f7453e250b048 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging { attr: Attribute, isNull: Boolean, update: Boolean): Option[Double] = { - if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) { + if (!colStatsMap.contains(attr) || colStatsMap(attr).nullCount.isEmpty) { logDebug("[CBO] No statistics for " + attr) return None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala index c89ee1e80d926..3ce66d4c7d9ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala @@ -86,11 +86,20 @@ object UnionEstimation { val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() attrToComputeMinMaxStats.foreach { case (attrs, outputIndex) => + var nullCount: Option[BigInt] = None val dataType = unionOutput(outputIndex).dataType val statComparator = createStatComparator(dataType) val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) { case ((minVal, maxVal), (attr, childIndex)) => val colStat = union.children(childIndex).stats.attributeStats(attr) + // Update null count + nullCount = if (nullCount.isDefined && colStat.nullCount.isDefined) { + Some(nullCount.get + colStat.nullCount.get) + } else if (colStat.nullCount.isDefined) { + colStat.nullCount + } else { + nullCount + } val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { colStat.min } else { @@ -103,10 +112,11 @@ object UnionEstimation { } (min, max) } - val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2) + val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2, + nullCount = nullCount) outputAttrStats += unionOutput(outputIndex) -> newStat } - AttributeMap(outputAttrStats.toSeq) + AttributeMap(outputAttrStats) } else { AttributeMap.empty[ColumnStat] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala index 06bd38fd8ad78..cdf7e4b5a49cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala @@ -68,14 +68,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase { distinctCount = Some(2), min = Some(1), max = Some(4), - nullCount = Some(0), + nullCount = Some(1), avgLen = Some(4), maxLen = Some(4)), attrDouble -> ColumnStat( distinctCount = Some(2), min = Some(5.0), max = Some(4.0), - nullCount = Some(0), + nullCount = Some(2), avgLen = Some(4), maxLen = Some(4)), attrShort -> ColumnStat(min = Some(s1), max = Some(s2)), @@ -96,14 +96,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase { distinctCount = Some(2), min = Some(3), max = Some(6), - nullCount = Some(0), + nullCount = Some(1), avgLen = Some(8), maxLen = Some(8)), AttributeReference("cdouble1", DoubleType)() -> ColumnStat( distinctCount = Some(2), min = Some(2.0), max = Some(7.0), - nullCount = Some(0), + nullCount = Some(2), avgLen = Some(8), maxLen = Some(8)), AttributeReference("cshort1", ShortType)() -> ColumnStat(min = Some(s3), max = Some(s4)), @@ -139,8 +139,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase { rowCount = Some(4), attributeStats = AttributeMap( Seq( - attrInt -> ColumnStat(min = Some(1), max = Some(6)), - attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0)), + attrInt -> ColumnStat(min = Some(1), max = Some(6), nullCount = Some(2)), + attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0), nullCount = Some(4)), attrShort -> ColumnStat(min = Some(s1), max = Some(s4)), attrLong -> ColumnStat(min = Some(1L), max = Some(6L)), attrByte -> ColumnStat(min = Some(b1), max = Some(b4)),