Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12706] [SQL] grouping() and grouping_id() #10677

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,8 @@ def groupBy(self, *cols):
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> sorted(df.groupBy(df.name).avg().collect())
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
>>> sorted(df.groupBy(['name', df.age]).count().collect())
[Row(name=u'Alice', age=2, count=1), Row(name=u'Bob', age=5, count=1)]
"""
jgd = self._jdf.groupBy(self._jcols(*cols))
from pyspark.sql.group import GroupedData
Expand All @@ -900,15 +900,15 @@ def rollup(self, *cols):
Create a multi-dimensional rollup for the current :class:`DataFrame` using
the specified columns, so we can run aggregation on them.

>>> df.rollup('name', df.age).count().show()
>>> df.rollup("name", df.age).count().orderBy("name", "age").show()
+-----+----+-----+
| name| age|count|
+-----+----+-----+
|Alice| 2| 1|
| Bob| 5| 1|
| Bob|null| 1|
| null|null| 2|
|Alice|null| 1|
|Alice| 2| 1|
| Bob|null| 1|
| Bob| 5| 1|
+-----+----+-----+
"""
jgd = self._jdf.rollup(self._jcols(*cols))
Expand All @@ -921,17 +921,17 @@ def cube(self, *cols):
Create a multi-dimensional cube for the current :class:`DataFrame` using
the specified columns, so we can run aggregation on them.

>>> df.cube('name', df.age).count().show()
>>> df.cube("name", df.age).count().orderBy("name", "age").show()
+-----+----+-----+
| name| age|count|
+-----+----+-----+
| null|null| 2|
| null| 2| 1|
|Alice| 2| 1|
| Bob| 5| 1|
| null| 5| 1|
| Bob|null| 1|
| null|null| 2|
|Alice|null| 1|
|Alice| 2| 1|
| Bob|null| 1|
| Bob| 5| 1|
+-----+----+-----+
"""
jgd = self._jdf.cube(self._jcols(*cols))
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,50 @@ def first(col, ignorenulls=False):
return Column(jc)


@since(2.0)
def grouping(col):
"""
Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
or not, returns 1 for aggregated or 0 for not aggregated in the result set.

>>> df.cube("name").agg(grouping("name"), sum("age")).orderBy("name").show()
+-----+--------------+--------+
| name|grouping(name)|sum(age)|
+-----+--------------+--------+
| null| 1| 7|
|Alice| 0| 2|
| Bob| 0| 5|
+-----+--------------+--------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.grouping(_to_java_column(col))
return Column(jc)


@since(2.0)
def grouping_id(*cols):
"""
Aggregate function: returns the level of grouping, equals to

(grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)

Note: the list of columns should match with grouping columns exactly, or empty (means all the
grouping columns).

>>> df.cube("name").agg(grouping_id(), sum("age")).orderBy("name").show()
+-----+------------+--------+
| name|groupingid()|sum(age)|
+-----+------------+--------+
| null| 1| 7|
|Alice| 0| 2|
| Bob| 0| 5|
+-----+------------+--------+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column))
return Column(jc)


@since(1.6)
def input_file_name():
"""Creates a string column for the file name of the current Spark task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
*
* The bitmask denotes the grouping expressions validity for a grouping set,
* the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
* e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of
* GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively.
* e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of
* GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively.
*/
protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
val (keyASTs, setASTs) = children.partition {
Expand All @@ -198,12 +198,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val keys = keyASTs.map(nodeToExpr)
val keyMap = keyASTs.zipWithIndex.toMap

val mask = (1 << keys.length) - 1
val bitmasks: Seq[Int] = setASTs.map {
case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
columns.foldLeft(0)((bitmap, col) => {
val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2)
bitmap | 1 << keyIndex.getOrElse(
columns.foldLeft(mask)((bitmap, col) => {
val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse(
throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list"))
// 0 means that the column at the given index is a grouping column, 1 means it is not,
// so we unset the bit in bitmap.
bitmap & ~(1 << (keys.length - 1 - keyIndex))
})
case _ => sys.error("Expect GROUPING SETS clause")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,39 @@ class Analyzer(
}
}.toMap

val aggregations: Seq[NamedExpression] = x.aggregations.map {
// If an expression is an aggregate (contains a AggregateExpression) then we dont change
// it so that the aggregation is computed on the unmodified value of its argument
// expressions.
case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
// If not then its a grouping expression and we need to use the modified (with nulls from
// Expand) value of the expression.
case expr => expr.transformDown {
val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
val aggsBuffer = ArrayBuffer[Expression]()
// Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to take a long hard look to understand this. You are comparing references not structure. Maybe a small piece of documentation?

}
expr.transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
case e: AggregateExpression =>
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
case e: GroupingID =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a dumb question. What happens if we use these functions without grouping sets? Do we get a nice analysis exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, it will fail to resolve, agreed that should be have a better error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had capture this in CheckAnasys.

if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
gid
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${x.groupByExprs.mkString(",")})")
}
case Grouping(col: Expression) =>
val idx = x.groupByExprs.indexOf(col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${x.groupByExprs.mkString(",")}")
}
case e =>
groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e)
}.asInstanceOf[NamedExpression]
Expand Down Expand Up @@ -814,8 +839,11 @@ class Analyzer(
}
}

private def isAggregateExpression(e: Expression): Boolean = {
e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
}
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
condition.find(isAggregateExpression).isDefined
}
}

Expand Down Expand Up @@ -997,7 +1025,7 @@ class Analyzer(
_.transform {
// Extracts children expressions of a WindowFunction (input parameters of
// a WindowFunction).
case wf : WindowFunction =>
case wf: WindowFunction =>
val newChildren = wf.children.map(extractExpr)
wf.withNewChildren(newChildren)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ trait CheckAnalysis {
failAnalysis(
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")

case g: Grouping =>
failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup")
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")

case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ object FunctionRegistry {
// grouping sets
expression[Cube]("cube"),
expression[Rollup]("rollup"),
expression[Grouping]("grouping"),
expression[GroupingID]("grouping_id"),

// window functions
expression[Lead]("lead"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,3 @@ abstract class DeclarativeAggregate
def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,26 @@ trait GroupingSet extends Expression with CodegenFallback {
case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}

case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}

/**
* Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
* GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
*/
case class Grouping(child: Expression) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = child :: Nil
override def dataType: DataType = ByteType
override def nullable: Boolean = false
}

/**
* GroupingID is a function that computes the level of grouping.
*
* If groupByExprs is empty, it means all grouping expressions in GroupingSets.
*/
case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = groupByExprs
override def dataType: DataType = IntegerType
override def nullable: Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ private[sql] object Expand {

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
bit -= 1
}

Expand Down
46 changes: 46 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,52 @@ object functions extends LegacyFunctions {
*/
def first(columnName: String): Column = first(Column(columnName))


/**
* Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
* or not, returns 1 for aggregated or 0 for not aggregated in the result set.
*
* @group agg_funcs
* @since 2.0.0
*/
def grouping(e: Column): Column = Column(Grouping(e.expr))

/**
* Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
* or not, returns 1 for aggregated or 0 for not aggregated in the result set.
*
* @group agg_funcs
* @since 2.0.0
*/
def grouping(columnName: String): Column = grouping(Column(columnName))

/**
* Aggregate function: returns the level of grouping, equals to
*
* (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
*
* Note: the list of columns should match with grouping columns exactly, or empty (means all the
* grouping columns).
*
* @group agg_funcs
* @since 2.0.0
*/
def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr)))

/**
* Aggregate function: returns the level of grouping, equals to
*
* (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
*
* Note: the list of columns should match with grouping columns exactly.
*
* @group agg_funcs
* @since 2.0.0
*/
def grouping_id(colName: String, colNames: String*): Column = {
grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
}

/**
* Aggregate function: returns the kurtosis of the values in a group.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.DecimalType
Expand Down Expand Up @@ -98,6 +99,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(cube0.where("date IS NULL").count > 0)
}

test("grouping and grouping_id") {
checkAnswer(
courseSales.cube("course", "year")
.agg(grouping("course"), grouping("year"), grouping_id("course", "year")),
Row("Java", 2012, 0, 0, 0) ::
Row("Java", 2013, 0, 0, 0) ::
Row("Java", null, 0, 1, 1) ::
Row("dotNET", 2012, 0, 0, 0) ::
Row("dotNET", 2013, 0, 0, 0) ::
Row("dotNET", null, 0, 1, 1) ::
Row(null, 2012, 1, 0, 2) ::
Row(null, 2013, 1, 0, 2) ::
Row(null, null, 1, 1, 3) :: Nil
)

intercept[AnalysisException] {
courseSales.groupBy().agg(grouping("course")).explain()
}
intercept[AnalysisException] {
courseSales.groupBy().agg(grouping_id("course")).explain()
}
}

test("grouping/grouping_id inside window function") {

val w = Window.orderBy(sum("earnings"))
checkAnswer(
courseSales.cube("course", "year")
.agg(sum("earnings"),
grouping_id("course", "year"),
rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))),
Row("Java", 2012, 20000.0, 0, 2) ::
Row("Java", 2013, 30000.0, 0, 3) ::
Row("Java", null, 50000.0, 1, 1) ::
Row("dotNET", 2012, 15000.0, 0, 1) ::
Row("dotNET", 2013, 48000.0, 0, 4) ::
Row("dotNET", null, 63000.0, 1, 2) ::
Row(null, 2012, 35000.0, 2, 1) ::
Row(null, 2013, 78000.0, 2, 2) ::
Row(null, null, 113000.0, 3, 1) :: Nil
)
}

test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
Expand Down
Loading