diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 96a11e352ec50..1504a522798b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -33,6 +33,8 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + override def contains(k: Attribute): Boolean = get(k).isDefined + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c977e788b0106..9b52a9cc817c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -53,6 +54,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) + + override lazy val statistics: Statistics = + ProjectEstimation.estimate(this).getOrElse(super.statistics) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala new file mode 100644 index 0000000000000..f099e32267461 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} +import org.apache.spark.sql.types.StringType + + +object EstimationUtils { + + /** Check if each plan has rowCount in its statistics. */ + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.statistics.rowCount.isDefined) + + /** Get column stats for output attributes. */ + def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) + : AttributeMap[ColumnStat] = { + AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _))) + } + + def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = { + // We assign a generic overhead for a Row object, the actual overhead is different for different + // Row format. + 8 + attributes.map { attr => + if (attrStats.contains(attr)) { + attr.dataType match { + case StringType => + // UTF8String: base + offset + numBytes + attrStats(attr).avgLen + 8 + 4 + case _ => + attrStats(attr).avgLen + } + } else { + attr.dataType.defaultSize + } + }.sum + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala new file mode 100644 index 0000000000000..6d63b09fd41b8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} + +object ProjectEstimation { + import EstimationUtils._ + + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.statistics + val inputAttrStats = childStats.attributeStats + // Match alias with its child's column stat + val aliasStats = project.expressions.collect { + case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) => + alias.toAttribute -> inputAttrStats(attr) + } + val outputAttrStats = + getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) + Some(childStats.copy( + sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats), + attributeStats = outputAttrStats)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala new file mode 100644 index 0000000000000..4a1bed84f84e8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.types.IntegerType + + +class ProjectEstimationSuite extends StatsEstimationTestBase { + + test("estimate project with alias") { + val ar1 = AttributeReference("key1", IntegerType)() + val ar2 = AttributeReference("key2", IntegerType)() + val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4) + val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4) + + val child = StatsTestPlan( + outputList = Seq(ar1, ar2), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2)))) + + val project = Project(Seq(ar1, Alias(ar2, "abc")()), child) + val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2) + val expectedAttrStats = toAttributeMap(expectedColStats, project) + // The number of rows won't change for project. + val expectedStats = Statistics( + sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats), + rowCount = Some(2), + attributeStats = expectedAttrStats) + assert(project.statistics == expectedStats) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala new file mode 100644 index 0000000000000..fa5b290ecb17c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} + + +class StatsEstimationTestBase extends SparkFunSuite { + + /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */ + def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan) + : AttributeMap[ColumnStat] = { + val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap + AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2)) + } +} + +/** + * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. + */ +protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode { + override def output: Seq[Attribute] = outputList + override lazy val statistics = stats +}