From 654d46abd2e5a988775edd0c50c395be63a4163f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 May 2015 14:13:55 +0800 Subject: [PATCH] improve tests --- .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../ExpressionTypeCheckingSuite.scala | 126 +++++++++--------- 3 files changed, 67 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7282a877f0531..72dc8cc866797 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -91,7 +91,7 @@ abstract class BinaryArithmetic extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") } else { checkTypesInternal(dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5f29036083e34..874283c0f5a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -175,7 +175,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { TypeCheckResult.fail( - s"differing types in BinaryComparison, ${left.dataType} != ${right.dataType}") + s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}") } else { TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index c241d05063efd..c9481a5f96254 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -28,89 +28,93 @@ import org.scalatest.FunSuite class ExpressionTypeCheckingSuite extends FunSuite { - val testRelation = LocalRelation('a.int, 'b.string, 'c.boolean, 'd.array(StringType)) + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) - def checkError(expr: Expression, errorMessage: String): Unit = { + def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { - checkAnalysis(expr) + assertSuccess(expr) } assert(e.getMessage.contains( s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) assert(e.getMessage.contains(errorMessage)) } - def checkAnalysis(expr: Expression): Unit = { - val analyzed = testRelation.select(expr.as("_c")).analyze + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze SimpleAnalyzer.checkAnalysis(analyzed) } test("check types for unary arithmetic") { - checkError(UnaryMinus('b), "operator - accepts numeric type") - checkAnalysis(Sqrt('b)) // We will cast String to Double for sqrt - checkError(Sqrt('c), "function sqrt accepts numeric type") - checkError(Abs('b), "function abs accepts numeric type") - checkError(BitwiseNot('b), "operator ~ accepts integral type") + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } test("check types for binary arithmetic") { // We will cast String to Double for binary arithmetic - checkAnalysis(Add('a, 'b)) - checkAnalysis(Subtract('a, 'b)) - checkAnalysis(Multiply('a, 'b)) - checkAnalysis(Divide('a, 'b)) - checkAnalysis(Remainder('a, 'b)) - // checkAnalysis(BitwiseAnd('a, 'b)) - - val msg = "differing types in BinaryArithmetic, IntegerType != BooleanType" - checkError(Add('a, 'c), msg) - checkError(Subtract('a, 'c), msg) - checkError(Multiply('a, 'c), msg) - checkError(Divide('a, 'c), msg) - checkError(Remainder('a, 'c), msg) - checkError(BitwiseAnd('a, 'c), msg) - checkError(BitwiseOr('a, 'c), msg) - checkError(BitwiseXor('a, 'c), msg) - checkError(MaxOf('a, 'c), msg) - checkError(MinOf('a, 'c), msg) - - checkError(Add('c, 'c), "operator + accepts numeric type") - checkError(Subtract('c, 'c), "operator - accepts numeric type") - checkError(Multiply('c, 'c), "operator * accepts numeric type") - checkError(Divide('c, 'c), "operator / accepts numeric type") - checkError(Remainder('c, 'c), "operator % accepts numeric type") - - checkError(BitwiseAnd('c, 'c), "operator & accepts integral type") - checkError(BitwiseOr('c, 'c), "operator | accepts integral type") - checkError(BitwiseXor('c, 'c), "operator ^ accepts integral type") - - checkError(MaxOf('d, 'd), "function maxOf accepts non-complex type") - checkError(MinOf('d, 'd), "function minOf accepts non-complex type") + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" + assertError(Add('intField, 'booleanField), msg("Add")) + assertError(Subtract('intField, 'booleanField), msg("Subtract")) + assertError(Multiply('intField, 'booleanField), msg("Multiply")) + assertError(Divide('intField, 'booleanField), msg("Divide")) + assertError(Remainder('intField, 'booleanField), msg("Remainder")) + assertError(BitwiseAnd('intField, 'booleanField), msg("BitwiseAnd")) + assertError(BitwiseOr('intField, 'booleanField), msg("BitwiseOr")) + assertError(BitwiseXor('intField, 'booleanField), msg("BitwiseXor")) + assertError(MaxOf('intField, 'booleanField), msg("MaxOf")) + assertError(MinOf('intField, 'booleanField), msg("MinOf")) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") } test("check types for predicates") { // EqualTo don't have type constraint - checkAnalysis(EqualTo('a, 'c)) - checkAnalysis(EqualNullSafe('a, 'c)) + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) // We will cast String to Double for binary comparison - checkAnalysis(LessThan('a, 'b)) - checkAnalysis(LessThanOrEqual('a, 'b)) - checkAnalysis(GreaterThan('a, 'b)) - checkAnalysis(GreaterThanOrEqual('a, 'b)) - - val msg = "differing types in BinaryComparison, IntegerType != BooleanType" - checkError(LessThan('a, 'c), msg) - checkError(LessThanOrEqual('a, 'c), msg) - checkError(GreaterThan('a, 'c), msg) - checkError(GreaterThanOrEqual('a, 'c), msg) - - checkError(LessThan('d, 'd), "operator < accepts non-complex type") - checkError(LessThanOrEqual('d, 'd), "operator <= accepts non-complex type") - checkError(GreaterThan('d, 'd), "operator > accepts non-complex type") - checkError(GreaterThanOrEqual('d, 'd), "operator >= accepts non-complex type") - - checkError(If('a, 'a, 'a), "type of predicate expression in If should be boolean") - checkError(If('c, 'a, 'b), "differing types in If, IntegerType != StringType") + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType" + assertError(LessThan('intField, 'booleanField), msg("LessThan")) + assertError(LessThanOrEqual('intField, 'booleanField), msg("LessThanOrEqual")) + assertError(GreaterThan('intField, 'booleanField), msg("GreaterThan")) + assertError(GreaterThanOrEqual('intField, 'booleanField), msg("GreaterThanOrEqual")) + + assertError(LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError(LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError(GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError(GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") + assertError(If('booleanField, 'intField, 'stringField), "differing types in If, IntegerType != StringType") // Will write tests for CaseWhen later, // as the error reporting of it is not handle by the new interface for now