Skip to content

Commit

Permalink
[SPARK-23898][SQL] Simplify add & subtract code generation
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Code generation for the `Add` and `Subtract` expressions was not done using the `BinaryArithmetic.doCodeGen` method because these expressions also support `CalendarInterval`. This leads to a bit of duplication.

This PR gets rid of that duplication by adding `calendarIntervalMethod` to `BinaryArithmetic` and doing the code generation for `CalendarInterval` in `BinaryArithmetic` instead.

## How was this patch tested?
Existing tests.

Author: Herman van Hovell <[email protected]>

Closes #21005 from hvanhovell/SPARK-23898.
  • Loading branch information
hvanhovell authored and gatorsmile committed Apr 10, 2018
1 parent f94f362 commit 6498884
Showing 1 changed file with 20 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
private lazy val numeric = TypeUtils.getNumeric(dataType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
val originValue = ctx.freshName("origin")
// codegen would fail to compile if we just write (-($c))
Expand All @@ -52,7 +52,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression
${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
"""})
case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
case _: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}

protected override def nullSafeEval(input: Any): Any = {
Expand Down Expand Up @@ -104,7 +104,7 @@ case class Abs(child: Expression)
private lazy val numeric = TypeUtils.getNumeric(dataType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
case _: DecimalType =>
defineCodeGen(ctx, ev, c => s"$c.abs()")
case dt: NumericType =>
defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
Expand All @@ -117,15 +117,21 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {

override def dataType: DataType = left.dataType

override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess
override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
sys.error("BinaryArithmetics must override either decimalMethod or genCode")

/** Name of the function for this expression on a [[CalendarInterval]] type. */
def calendarIntervalMethod: String =
sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode")

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
case _: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)")
// byte and short are casted into int when add, minus, times or divide
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
Expand All @@ -152,6 +158,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def symbol: String = "+"

override def decimalMethod: String = "$plus"

override def calendarIntervalMethod: String = "add"

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
Expand All @@ -161,18 +171,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
numeric.plus(input1, input2)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

@ExpressionDescription(
Expand All @@ -188,6 +186,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti

override def symbol: String = "-"

override def decimalMethod: String = "$minus"

override def calendarIntervalMethod: String = "subtract"

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
Expand All @@ -197,18 +199,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
numeric.minus(input1, input2)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)")
case CalendarIntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

@ExpressionDescription(
Expand Down Expand Up @@ -416,7 +406,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

override def symbol: String = "pmod"

protected def checkTypesInternal(t: DataType) =
protected def checkTypesInternal(t: DataType): TypeCheckResult =
TypeUtils.checkForNumericExpr(t, "pmod")

override def inputType: AbstractDataType = NumericType
Expand Down

0 comments on commit 6498884

Please sign in to comment.