Skip to content

Commit

Permalink
make GROUPING__ID compatible with Hive
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 9, 2016
1 parent bcb8d9e commit 736e8d2
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case _ => sys.error("Expect GROUPING SETS clause")
}

val mask = (1 << keys.length) - 1
(keys, bitmasks.map(x => (~x) & mask))
(keys, bitmasks)
}

protected def nodeToPlan(node: ASTNode): LogicalPlan = node match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ class Analyzer(
case e if isPartOfAggregation(e) => e
case e: GroupingID =>
if (e.groupByExprs == x.groupByExprs) {
gid
// the bitmask is following Hive, which is wrong, we need to reverse it here
// TODO: don't not follow Hive
BitwiseReverse(BitwiseNot(gid), e.groupByExprs.length)
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
Expand All @@ -259,8 +261,7 @@ class Analyzer(
case Grouping(col: Expression) =>
val idx = x.groupByExprs.indexOf(col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
Cast(BitwiseAnd(ShiftRight(BitwiseNot(gid), Literal(idx)), Literal(1)), ByteType)
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${x.groupByExprs.mkString(",")}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,47 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp

protected override def nullSafeEval(input: Any): Any = not(input)
}

/**
* A function that reverse the lowest N bits of a integer.
*
* Note: this is only used for grouping_id()
*/
case class BitwiseReverse(child: Expression, width: Int)
extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)

override def dataType: DataType = IntegerType

override def toString: String = s"^$child"

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c => {
val v = ctx.freshName("v")
val i = ctx.freshName("i")
s"""
| int $v = $c;
| ${ev.value} = 0;
| for (int $i = 0; $i < $width; $i ++) {
| ${ev.value} <<= 1;
| ${ev.value} |= $v & 1;
| $v >>>= 1;
| }
""".stripMargin
})
}

protected override def nullSafeEval(input: Any): Any = {
var v = input.asInstanceOf[Int]
var r = 0
var i = 0
while (i < width) {
r <<= 1
r |= v & 1
v >>>= 1
i += 1
}
r
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ private[sql] object Expand {

var bit = exprs.length - 1
while (bit >= 0) {
if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
if (((bitmask >> bit) & 1) == 0) set += exprs(bit)
bit -= 1
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,17 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(BitwiseXor, dt, dt)
}
}

test("BitwiseReverse") {
def check(input1: Any, width: Int, expected: Any): Unit = {
val expr = BitwiseReverse(Literal(input1), width)
checkEvaluation(expr, expected)
}

check(1, 1, 1)
check(0, 1, 0)
check(1, 2, 2)
check(3, 4, 12)
check(9, 4, 9)
}
}
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ object functions extends LegacyFunctions {
* @group agg_funcs
* @since 2.0.0
*/
def grouping_id(colName:String, colNames: String*): Column = {
grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
def grouping_id(colName: String, colNames: String*): Column = {
grouping_id((Seq(colName) ++ colNames).map(n => Column(n)): _*)
}

/**
Expand Down

0 comments on commit 736e8d2

Please sign in to comment.