Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35362][SQL] Update null count in the column stats for UNION operator stats estimation #32494

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
attr: Attribute,
isNull: Boolean,
update: Boolean): Option[Double] = {
if (!colStatsMap.contains(attr) || !colStatsMap(attr).hasCountStats) {
if (!colStatsMap.contains(attr) || colStatsMap(attr).nullCount.isEmpty) {
logDebug("[CBO] No statistics for " + attr)
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics, Union}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -70,8 +68,27 @@ object UnionEstimation {
None
}

val unionOutput = union.output
val newMinMaxStats = computeMinMaxStats(union)
val newNullCountStats = computeNullCountStats(union)
val newAttrStats = {
val baseStats = AttributeMap(newMinMaxStats)
val overwriteStats = newNullCountStats.map { case attrStat@(attr, stat) =>
baseStats.get(attr).map { baseStat =>
attr -> baseStat.copy(nullCount = stat.nullCount)
}.getOrElse(attrStat)
}
AttributeMap(newMinMaxStats ++ overwriteStats)
}

Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
}

private def computeMinMaxStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeMinMaxStats = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, outputIndex) => isTypeSupported(unionOutput(outputIndex).dataType) &&
// checks if all the children has min/max stats for an attribute
Expand All @@ -81,40 +98,50 @@ object UnionEstimation {
attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
}
}

val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
attrToComputeMinMaxStats.foreach {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
attrToComputeMinMaxStats.map {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
outputAttrStats += unionOutput(outputIndex) -> newStat
}
AttributeMap(outputAttrStats.toSeq)
} else {
AttributeMap.empty[ColumnStat]
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
unionOutput(outputIndex) -> newStat
}
}

Some(
Statistics(
sizeInBytes = sizeInBytes,
rowCount = outputRows,
attributeStats = newAttrStats))
private def computeNullCountStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, _) => attrs.zipWithIndex.forall {
case (attr, childIndex) =>
val attrStats = union.children(childIndex).stats.attributeStats
attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
}
}
attrToComputeNullCount.map {
case (attrs, outputIndex) =>
val firstStat = union.children.head.stats.attributeStats(attrs.head)
val firstNullCount = firstStat.nullCount.get
val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
case (totalNullCount, (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
totalNullCount + colStat.nullCount.get
}
val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
unionOutput(outputIndex) -> newStat
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
val rowCount = Some(plan.rowCount * childrenSize)
val attributeStats = AttributeMap(
Seq(
attribute -> ColumnStat(min = Some(1), max = Some(10))))
attribute -> ColumnStat(min = Some(1), max = Some(10), nullCount = Some(0))))
checkStats(
union,
expectedStatsCboOn = Statistics(sizeInBytes = sizeInBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(1),
max = Some(4),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(4),
maxLen = Some(4)),
attrDouble -> ColumnStat(
distinctCount = Some(2),
min = Some(5.0),
max = Some(4.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s2)),
Expand All @@ -96,14 +96,14 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
distinctCount = Some(2),
min = Some(3),
max = Some(6),
nullCount = Some(0),
nullCount = Some(1),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cdouble1", DoubleType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(2.0),
max = Some(7.0),
nullCount = Some(0),
nullCount = Some(2),
avgLen = Some(8),
maxLen = Some(8)),
AttributeReference("cshort1", ShortType)() -> ColumnStat(min = Some(s3), max = Some(s4)),
Expand Down Expand Up @@ -139,8 +139,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(
attrInt -> ColumnStat(min = Some(1), max = Some(6)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0)),
attrInt -> ColumnStat(min = Some(1), max = Some(6), nullCount = Some(2)),
attrDouble -> ColumnStat(min = Some(2.0), max = Some(7.0), nullCount = Some(4)),
attrShort -> ColumnStat(min = Some(s1), max = Some(s4)),
attrLong -> ColumnStat(min = Some(1L), max = Some(6L)),
attrByte -> ColumnStat(min = Some(b1), max = Some(b4)),
Expand Down Expand Up @@ -188,7 +188,58 @@ class UnionEstimationSuite extends StatsEstimationTestBase {

val union = Union(Seq(child1, child2))

val expectedStats = logical.Statistics(sizeInBytes = 2 * 1024, rowCount = Some(4))
// Only null count is present in the attribute stats
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(nullCount = Some(0)))))
assert(union.stats === expectedStats)
}

test("col stats estimation when null count stats are not present for one child") {
val sz = Some(BigInt(1024))
val attrInt = AttributeReference("cint", IntegerType)()
val columnInfo = AttributeMap(
Seq(
attrInt -> ColumnStat(
distinctCount = Some(2),
min = Some(1),
max = Some(2),
nullCount = Some(2),
avgLen = Some(4),
maxLen = Some(4))))

// No null count
val columnInfo1 = AttributeMap(
Seq(
AttributeReference("cint1", IntegerType)() -> ColumnStat(
distinctCount = Some(2),
min = Some(3),
max = Some(4),
avgLen = Some(8),
maxLen = Some(8))))

val child1 = StatsTestPlan(
outputList = columnInfo.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo,
size = sz)

val child2 = StatsTestPlan(
outputList = columnInfo1.keys.toSeq,
rowCount = 2,
attributeStats = columnInfo1,
size = sz)

val union = Union(Seq(child1, child2))

// Null count should not present in the stats.
val expectedStats = logical.Statistics(
sizeInBytes = 2 * 1024,
rowCount = Some(4),
attributeStats = AttributeMap(
Seq(attrInt -> ColumnStat(min = Some(1), max = Some(4), nullCount = None))))
assert(union.stats === expectedStats)
}
}