Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
shahidki31 committed May 19, 2021
1 parent 3afaf32 commit 06fbbec
Showing 1 changed file with 31 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.AttributeMap
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Statistics, Union}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -71,16 +71,13 @@ object UnionEstimation {
val newMinMaxStats = computeMinMaxStats(union)
val newNullCountStats = computeNullCountStats(union)
val newAttrStats = {
val updatedNullCountStats = newNullCountStats.keys.map { key =>
if (newMinMaxStats.get(key).isDefined) {
val updatedColsStats = newMinMaxStats(key)
.copy(nullCount = newNullCountStats(key).nullCount)
key -> updatedColsStats
} else {
key -> newNullCountStats(key)
}
val baseStats = AttributeMap(newMinMaxStats)
val overwriteStats = newNullCountStats.map { case attrStat@(attr, stat) =>
baseStats.get(attr).map { baseStat =>
attr -> baseStat.copy(nullCount = stat.nullCount)
}.getOrElse(attrStat)
}
AttributeMap(newMinMaxStats.toSeq ++ updatedNullCountStats)
AttributeMap(newMinMaxStats ++ overwriteStats)
}

Some(
Expand All @@ -90,7 +87,7 @@ object UnionEstimation {
attributeStats = newAttrStats))
}

private def computeMinMaxStats(union: Union) = {
private def computeMinMaxStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeMinMaxStats = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, outputIndex) => isTypeSupported(unionOutput(outputIndex).dataType) &&
Expand All @@ -101,36 +98,31 @@ object UnionEstimation {
attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
}
}
val outputAttrStats = attrToComputeMinMaxStats.map {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
unionOutput(outputIndex) -> newStat
}
if (outputAttrStats.nonEmpty) {
AttributeMap(outputAttrStats)
} else {
AttributeMap.empty[ColumnStat]
attrToComputeMinMaxStats.map {
case (attrs, outputIndex) =>
val dataType = unionOutput(outputIndex).dataType
val statComparator = createStatComparator(dataType)
val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) {
case ((minVal, maxVal), (attr, childIndex)) =>
val colStat = union.children(childIndex).stats.attributeStats(attr)
val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) {
colStat.min
} else {
minVal
}
val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) {
colStat.max
} else {
maxVal
}
(min, max)
}
val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
unionOutput(outputIndex) -> newStat
}
}

private def computeNullCountStats(union: Union) = {
private def computeNullCountStats(union: Union): Seq[(Attribute, ColumnStat)] = {
val unionOutput = union.output
val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter {
case (attrs, _) => attrs.zipWithIndex.forall {
Expand All @@ -139,7 +131,7 @@ object UnionEstimation {
attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
}
}
val outputAttrStats = attrToComputeNullCount.map {
attrToComputeNullCount.map {
case (attrs, outputIndex) =>
val firstStat = union.children.head.stats.attributeStats(attrs.head)
val firstNullCount = firstStat.nullCount.get
Expand All @@ -151,10 +143,5 @@ object UnionEstimation {
val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
unionOutput(outputIndex) -> newStat
}
if (outputAttrStats.nonEmpty) {
AttributeMap(outputAttrStats)
} else {
AttributeMap.empty[ColumnStat]
}
}
}

0 comments on commit 06fbbec

Please sign in to comment.