Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35362][SQL] Update null count in the column stats for UNION operator stats estimation #32494

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
attr: Attribute,
isNull: Boolean,
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) {
if (!colStatsMap.contains(attr) || colStatsMap(attr).nullCount.isEmpty) {
logDebug("[CBO] No statistics for " + attr)
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.expressions.AttributeMap
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics, Union}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -70,8 +68,20 @@ object UnionEstimation {
None
}

val unionOutput = union.output
val newMinMaxStats = computeMinMaxStats(union)
val newNullCountStats = computeNullCountStats(union)
val newAttrStats = combineStats(newMinMaxStats, newNullCountStats)

Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
}

// This method computes the min-max statistics and return the attribute stats Map.
private def computeMinMaxStats(union: Union) = {
val unionOutput = union.output
val attrToComputeMinMaxStats = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, outputIndex) => isTypeSupported(unionOutput(outputIndex).dataType) &&
// checks if all the children has min/max stats for an attribute
Expand All @@ -81,40 +91,76 @@ object UnionEstimation {
attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
}
}

val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
attrToComputeMinMaxStats.foreach {
val outputAttrStats = attrToComputeMinMaxStats.map {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
outputAttrStats += unionOutput(outputIndex) -> newStat
unionOutput(outputIndex) -> newStat
}
if (outputAttrStats.nonEmpty) {
AttributeMap(outputAttrStats.toSeq)
} else {
AttributeMap.empty[ColumnStat]
}
}

Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
/** This method computes the null count statistics and return the attribute stats Map. */
private def computeNullCountStats(union: Union) = {
val unionOutput = union.output
val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, _) => attrs.zipWithIndex.forall {
case (attr, childIndex) =>
val attrStats = union.children(childIndex).stats.attributeStats
attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
}
}
val outputAttrStats = attrToComputeNullCount.map {
case (attrs, outputIndex) =>
val firstStat = union.children.head.stats.attributeStats(attrs.head)
val firstNullCount = firstStat.nullCount.get
val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
case (totalNullCount, (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
totalNullCount + colStat.nullCount.get
}
val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
unionOutput(outputIndex) -> newStat
}
if (outputAttrStats.nonEmpty) {
AttributeMap(outputAttrStats.toSeq)
} else {
AttributeMap.empty[ColumnStat]
}
}

// Combine the two Maps by updating the min-max stats Map with null count stats.
private def combineStats(
minMaxStats: AttributeMap[ColumnStat],
nullCountStats: AttributeMap[ColumnStat]) = {
val updatedNullCountStats = nullCountStats.keys.map { key =>
if (minMaxStats.get(key).isDefined) {
val updatedColsStats = minMaxStats(key).copy(nullCount = nullCountStats(key).nullCount)
key -> updatedColsStats
} else {
key -> nullCountStats(key)
}
}
AttributeMap(minMaxStats.toSeq ++ updatedNullCountStats)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
val rowCount = Some(plan.rowCount * childrenSize)
val attributeStats = AttributeMap(
Seq(
attribute -> ColumnStat(min = Some(1), max = Some(10))))
attribute -> ColumnStat(min = Some(1), max = Some(10), nullCount = Some(0))))
checkStats(
union,
expectedStatsCboOn = Statistics(sizeInBytes = sizeInBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(1),
max = Some(4),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(4),
maxLen = Some(4)),
attrDouble -> ColumnStat(
distinctCount = Some(2),
min = Some(5.0),
max = Some(4.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s2)),
Expand All @@ -96,14 +96,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(3),
max = Some(6),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cdouble1", DoubleType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(2.0),
max = Some(7.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cshort1", ShortType)() -> ColumnStat(min = Some(s3), max = Some(s4)),
Expand Down Expand Up @@ -139,8 +139,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(
attrInt -> ColumnStat(min = Some(1), max = Some(6)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0)),
attrInt -> ColumnStat(min = Some(1), max = Some(6), nullCount = Some(2)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0), nullCount = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s4)),
attrLong -> ColumnStat(min = Some(1L), max = Some(6L)),
attrByte -> ColumnStat(min = Some(b1), max = Some(b4)),
Expand Down Expand Up @@ -188,7 +188,58 @@ class UnionEstimationSuite extends StatsEstimationTestBase {

val union = Union(Seq(child1, child2))

val expectedStats = logical.Statistics(sizeInBytes = 2 * 1024, rowCount = Some(4))
// Only null count is present in the attribute stats
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(nullCount = Some(0)))))
assert(union.stats === expectedStats)
}

test("col stats estimation when null count stats are not present for one child") {
val sz = Some(BigInt(1024))
val attrInt = AttributeReference("cint", IntegerType)()
val columnInfo = AttributeMap(
Seq(
attrInt -> ColumnStat(
distinctCount = Some(2),
min = Some(1),
max = Some(2),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4))))

// No null count
val columnInfo1 = AttributeMap(
Seq(
AttributeReference("cint1", IntegerType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(3),
max = Some(4),
avgLen = Some(8),
maxLen = Some(8))))

val child1 = StatsTestPlan(
outputList = columnInfo.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo,
size = sz)

val child2 = StatsTestPlan(
outputList = columnInfo1.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo1,
size = sz)

val union = Union(Seq(child1, child2))

// Null count should not present in the stats.
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(min = Some(1), max = Some(4), nullCount = None))))
assert(union.stats === expectedStats)
}
}