Skip to content

Commit

Permalink
add equal type constraint to EqualTo
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 1, 2015
1 parent 3affbd8 commit 6eaadff
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,12 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
s"differing types in ${this.getClass.getSimpleName} " +
s"(${left.dataType} and ${right.dataType}).")
} else {
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
checkTypesInternal(dataType)
}
}

protected def checkTypesInternal(t: DataType): TypeCheckResult

override def eval(input: Row): Any = {
val evalE1 = left.eval(input)
if (evalE1 == null) {
Expand All @@ -203,8 +205,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
override def symbol: String = "="

// EqualTo don't need 2 equal orderable types
override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success

protected override def evalInternal(l: Any, r: Any) = {
if (left.dataType != BinaryType) l == r
Expand All @@ -216,8 +217,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
override def symbol: String = "<=>"
override def nullable: Boolean = false

// EqualNullSafe don't need 2 equal orderable types
override def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success

override def eval(input: Row): Any = {
val l = left.eval(input)
Expand All @@ -235,6 +235,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
override def symbol: String = "<"

override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)

private lazy val ordering = TypeUtils.getOrdering(left.dataType)

protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2)
Expand All @@ -243,6 +246,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
override def symbol: String = "<="

override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)

private lazy val ordering = TypeUtils.getOrdering(left.dataType)

protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2)
Expand All @@ -251,6 +257,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
override def symbol: String = ">"

override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)

private lazy val ordering = TypeUtils.getOrdering(left.dataType)

protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2)
Expand All @@ -259,6 +268,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
override def symbol: String = ">="

override protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)

private lazy val ordering = TypeUtils.getOrdering(left.dataType)

protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ class ExpressionTypeCheckingSuite extends FunSuite {
}

test("check types for predicates") {
// EqualTo don't have type constraint
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))

// We will cast String to Double for binary comparison
assertSuccess(EqualTo('intField, 'stringField))
assertSuccess(EqualNullSafe('intField, 'stringField))
assertSuccess(LessThan('intField, 'stringField))
assertSuccess(LessThanOrEqual('intField, 'stringField))
assertSuccess(GreaterThan('intField, 'stringField))
assertSuccess(GreaterThanOrEqual('intField, 'stringField))

assertErrorForDifferingTypes(EqualTo('intField, 'booleanField))
assertErrorForDifferingTypes(EqualNullSafe('intField, 'booleanField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
Expand Down

0 comments on commit 6eaadff

Please sign in to comment.