Skip to content

Commit

Permalink
[SPARK-17077][SQL] Cardinality estimation for project operator
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Support cardinality estimation for project operator.

## How was this patch tested?

Add a test suite and a base class in the catalyst package.

Author: Zhenhua Wang <[email protected]>

Closes #16430 from wzhfy/projectEstimation.
  • Loading branch information
wzhfy authored and rxin committed Jan 9, 2017
1 parent 19d9d4c commit 3ccabdf
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])

override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2)

override def contains(k: Attribute): Boolean = get(k).isDefined

override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = baseMap.values.toMap + kv

override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -53,6 +54,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend

override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))

override lazy val statistics: Statistics =
ProjectEstimation.estimate(this).getOrElse(super.statistics)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.types.StringType


object EstimationUtils {

/** Check if each plan has rowCount in its statistics. */
def rowCountsExist(plans: LogicalPlan*): Boolean =
plans.forall(_.statistics.rowCount.isDefined)

/** Get column stats for output attributes. */
def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute])
: AttributeMap[ColumnStat] = {
AttributeMap(output.flatMap(a => inputMap.get(a).map(a -> _)))
}

def getRowSize(attributes: Seq[Attribute], attrStats: AttributeMap[ColumnStat]): Long = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
8 + attributes.map { attr =>
if (attrStats.contains(attr)) {
attr.dataType match {
case StringType =>
// UTF8String: base + offset + numBytes
attrStats(attr).avgLen + 8 + 4
case _ =>
attrStats(attr).avgLen
}
} else {
attr.dataType.defaultSize
}
}.sum
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical.statsEstimation

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics}

object ProjectEstimation {
import EstimationUtils._

def estimate(project: Project): Option[Statistics] = {
if (rowCountsExist(project.child)) {
val childStats = project.child.statistics
val inputAttrStats = childStats.attributeStats
// Match alias with its child's column stat
val aliasStats = project.expressions.collect {
case alias @ Alias(attr: Attribute, _) if inputAttrStats.contains(attr) =>
alias.toAttribute -> inputAttrStats(attr)
}
val outputAttrStats =
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
Some(childStats.copy(
sizeInBytes = childStats.rowCount.get * getRowSize(project.output, outputAttrStats),
attributeStats = outputAttrStats))
} else {
None
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.types.IntegerType


class ProjectEstimationSuite extends StatsEstimationTestBase {

test("estimate project with alias") {
val ar1 = AttributeReference("key1", IntegerType)()
val ar2 = AttributeReference("key2", IntegerType)()
val colStat1 = ColumnStat(2, Some(1), Some(2), 0, 4, 4)
val colStat2 = ColumnStat(1, Some(10), Some(10), 0, 4, 4)

val child = StatsTestPlan(
outputList = Seq(ar1, ar2),
stats = Statistics(
sizeInBytes = 2 * (4 + 4),
rowCount = Some(2),
attributeStats = AttributeMap(Seq(ar1 -> colStat1, ar2 -> colStat2))))

val project = Project(Seq(ar1, Alias(ar2, "abc")()), child)
val expectedColStats = Seq("key1" -> colStat1, "abc" -> colStat2)
val expectedAttrStats = toAttributeMap(expectedColStats, project)
// The number of rows won't change for project.
val expectedStats = Statistics(
sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats),
rowCount = Some(2),
attributeStats = expectedAttrStats)
assert(project.statistics == expectedStats)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}


class StatsEstimationTestBase extends SparkFunSuite {

/** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
: AttributeMap[ColumnStat] = {
val nameToAttr: Map[String, Attribute] = plan.output.map(a => (a.name, a)).toMap
AttributeMap(colStats.map(kv => nameToAttr(kv._1) -> kv._2))
}
}

/**
* This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
*/
protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode {
override def output: Seq[Attribute] = outputList
override lazy val statistics = stats
}

0 comments on commit 3ccabdf

Please sign in to comment.