Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
guowei2 committed Dec 23, 2014
1 parent b4985a2 commit 3214e0a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,21 @@ trait HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

// Cast is no need for logical operator
case LessThan(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
// Cast is not needed for binary comparison
case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThan(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
assert(analyzer(plan).schema.fields(0).dataType === expectedType)
}

private def checkComparison(expression: Expression, expectedType: DataType): Unit = {
val plan = Project(Seq(Alias(expression, "c")()), relation)
val comparison = analyzer(plan).expressions(0).children(0).asInstanceOf[BinaryComparison]
assert(comparison.left.dataType === expectedType)
assert(comparison.right.dataType === expectedType)
}

test("basic operations") {
checkType(Add(d1, d2), DecimalType(6, 2))
checkType(Subtract(d1, d2), DecimalType(6, 2))
Expand All @@ -65,6 +72,14 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
}

test("Comparison operations") {
checkComparison(LessThan(i, d1), DecimalType.Unlimited)
checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited)
checkComparison(GreaterThan(d2, u), DecimalType.Unlimited)
checkComparison(GreaterThanOrEqual(d1, f), DoubleType)
checkComparison(GreaterThan(d2, d2), DecimalType(5, 2))
}

test("bringing in primitive types") {
checkType(Add(d1, i), DecimalType(12, 1))
checkType(Add(d1, f), DoubleType)
Expand Down

0 comments on commit 3214e0a

Please sign in to comment.