Skip to content

Commit

Permalink
add error message and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 1, 2015
1 parent c71d02c commit 69ca3fe
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,13 @@ trait HiveTypeCoercion {
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
// Skip nodes who's children have not been resolved yet or input types do not match.
case e if !e.childrenResolved || e.checkInputDataTypes().hasError => e

// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
Divide(l, Cast(r, DecimalType.Unlimited))
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
Divide(Cast(l, DecimalType.Unlimited), r)

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ abstract class Expression extends TreeNode[Expression] {
}

/**
* todo
* Check the input data types, returns `TypeCheckResult.success` if it's valid,
* or return a `TypeCheckResult` with an error message if invalid.
*/
def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@ abstract class UnaryArithmetic extends UnaryExpression {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType

override def checkInputDataTypes: TypeCheckResult = {
if (TypeUtils.validForNumericExpr(child.dataType)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}

override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
Expand All @@ -52,6 +44,9 @@ abstract class UnaryArithmetic extends UnaryExpression {
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
override def toString: String = s"-$child"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "operator -")

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
Expand All @@ -62,6 +57,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
override def nullable: Boolean = true
override def toString: String = s"SQRT($child)"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sqrt")

private lazy val numeric = TypeUtils.getNumeric(child.dataType)

protected override def evalInternal(evalE: Any) = {
Expand All @@ -77,6 +75,9 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
case class Abs(child: Expression) extends UnaryArithmetic {
override def toString: String = s"Abs($child)"

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function abs")

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def evalInternal(evalE: Any) = numeric.abs(evalE)
Expand All @@ -87,10 +88,10 @@ abstract class BinaryArithmetic extends BinaryExpression {

override def dataType: DataType = left.dataType

override def checkInputDataTypes: TypeCheckResult = {
override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.fail(
s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}")
s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}")
} else {
checkTypesInternal(dataType)
}
Expand Down Expand Up @@ -123,13 +124,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
// for `Add` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -143,13 +139,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
// for `Subtract` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -163,13 +154,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
// for `Multiply` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

Expand All @@ -184,13 +170,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
// for `Divide` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
Expand Down Expand Up @@ -220,13 +201,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
// for `Remainder` in `HiveTypeCoercion`
override lazy val resolved = childrenResolved && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForNumericExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

private lazy val integral = dataType match {
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
Expand Down Expand Up @@ -254,13 +230,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "&"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val and: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -282,13 +253,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "|"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val or: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -310,13 +276,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "^"

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForBitwiseExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)

private lazy val xor: (Any, Any) => Any = dataType match {
case ByteType =>
Expand All @@ -338,13 +299,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
override def toString: String = s"~$child"

override def checkInputDataTypes: TypeCheckResult = {
if (TypeUtils.validForBitwiseExpr(dataType)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")

private lazy val not: (Any) => Any = dataType match {
case ByteType =>
Expand All @@ -363,13 +319,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForOrderingExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function maxOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand All @@ -395,13 +346,8 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
override def nullable: Boolean = left.nullable && right.nullable

protected def checkTypesInternal(t: DataType) = {
if (TypeUtils.validForOrderingExpr(t)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail("todo")
}
}
protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(t, "function minOf")

private lazy val ordering = TypeUtils.getOrdering(dataType)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)

override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"

Expand Down
Loading

0 comments on commit 69ca3fe

Please sign in to comment.