From b5ff31b0dde66ed24634dc8773dfafb11b95ee50 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 1 Jun 2015 13:56:05 +0800 Subject: [PATCH] address comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 +- .../catalyst/analysis/HiveTypeCoercion.scala | 91 ++++++++++--------- .../catalyst/analysis/TypeCheckResult.scala | 26 ++++-- .../sql/catalyst/expressions/Expression.scala | 22 +++-- .../sql/catalyst/expressions/arithmetic.scala | 21 ++++- .../sql/catalyst/expressions/predicates.scala | 41 ++++++--- .../spark/sql/catalyst/util/TypeUtils.scala | 12 +-- .../analysis/DecimalPrecisionSuite.scala | 6 +- .../analysis/HiveTypeCoercionSuite.scala | 15 ++- .../ExpressionTypeCheckingSuite.scala | 9 +- .../apache/spark/sql/json/InferSchema.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- 12 files changed, 153 insertions(+), 104 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5b689f22bedbb..c0695ae369421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,10 +62,12 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - case e: Expression if e.checkInputDataTypes.hasError => - e.failAnalysis( - s"cannot resolve '${e.prettyString}' due to data type mismatch: " + - e.checkInputDataTypes.errorMessage) + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } case c: Cast if !c.resolved => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ce582919e9c9..b064600e94fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -180,7 +191,7 @@ trait HiveTypeCoercion { case (l, r) if l.dataType != r.dataType => logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = @@ -217,7 +228,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => val newLeft = if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) val newRight = @@ -323,7 +334,6 @@ trait HiveTypeCoercion { * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * sum(e1) p1 + 10 s1 * avg(e1) p1 + 4 s1 + 4 - * compare max(p1, p2) max(s1, s2) * * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. * @@ -442,10 +452,18 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = DecimalType(max(p1, p2), max(s1, s2)) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) + if e2.dataType == DecimalType.Unlimited => + b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) + case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) + if e1.dataType == DecimalType.Unlimited => + b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -560,7 +578,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonType(a, b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -590,12 +608,8 @@ trait HiveTypeCoercion { // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -608,7 +622,7 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip Divisions who has not been resolved yet, + // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.resolved => e @@ -624,47 +638,36 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }).map { commonType => - val transformedBranches = cw.branches.sliding(2, 2).map { + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => Seq(when, Cast(value, commonType)) case Seq(elseVal) if elseVal.dataType != commonType => Seq(Cast(elseVal, commonType)) - case s => s + case other => other }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) } - }.getOrElse(cw) - - case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved => - val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = ckw.branches.sliding(2, 2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) - case s => s - }.reduce(_ ++ _) - val transformedKey = if (ckw.key.dataType != commonType) { - Cast(ckw.key, commonType) - } else { - ckw.key - } - CaseKeyWhen(transformedKey, transformedBranches) + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index 653015154fc16..79c3528a522d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -19,15 +19,27 @@ package org.apache.spark.sql.catalyst.analysis /** * Represents the result of `Expression.checkInputDataTypes`. - * We will throw `AnalysisException` in `CheckAnalysis` if error message is not null. - * Use [[TypeCheckResult.success]] and [[TypeCheckResult.fail]] to instantiate this. - * + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. */ -class TypeCheckResult private (val errorMessage: String) extends AnyVal { - def hasError: Boolean = errorMessage != null +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean } object TypeCheckResult { - val success: TypeCheckResult = new TypeCheckResult(null) - def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg) + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8c8a3fde9cb8e..4ed0697b6f824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -45,11 +45,12 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved && !checkInputDataTypes().hasError + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -88,18 +89,19 @@ abstract class Expression extends TreeNode[Expression] { } /** - * Check the input data types, returns `TypeCheckResult.success` if it's valid, - * or return a `TypeCheckResult` with an error message if invalid. + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` * TODO: we should remove the default implementation and implement it for all * expressions with proper error message. */ - def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String = sys.error(s"BinaryExpressions must either override toString or symbol") + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -137,9 +139,9 @@ trait ExpectsInputTypes { def expectedChildTypes: Seq[DataType] - override def checkInputDataTypes: TypeCheckResult = { + override def checkInputDataTypes(): TypeCheckResult = { // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0c2b7b4351dac..2ac53f8f6613f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -38,7 +38,7 @@ abstract class UnaryArithmetic extends UnaryExpression { } protected def evalInternal(evalE: Any): Any = - sys.error(s"UnaryArithmetics must either override eval or evalInternal") + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } case class UnaryMinus(child: Expression) extends UnaryArithmetic { @@ -90,7 +90,7 @@ abstract class BinaryArithmetic extends BinaryExpression { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in ${this.getClass.getSimpleName} " + s"(${left.dataType} and ${right.dataType}).") } else { @@ -115,12 +115,15 @@ abstract class BinaryArithmetic extends BinaryExpression { } protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryArithmetics must either override eval or evalInternal") + sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -132,6 +135,9 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -143,6 +149,9 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -155,6 +164,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def symbol: String = "/" override def nullable: Boolean = true + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) @@ -182,6 +194,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def symbol: String = "%" override def nullable: Boolean = true + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + protected def checkTypesInternal(t: DataType) = TypeUtils.checkForNumericExpr(t, "operator " + symbol) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 54a5ae9c3bb46..807021d50e8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -174,7 +174,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { override def checkInputDataTypes(): TypeCheckResult = { if (left.dataType != right.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in ${this.getClass.getSimpleName} " + s"(${left.dataType} and ${right.dataType}).") } else { @@ -199,7 +199,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryComparisons must either override eval or evalInternal") + sys.error(s"BinaryComparisons must override either eval or evalInternal") } object BinaryComparison { @@ -210,7 +210,7 @@ object BinaryComparison { case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess protected override def evalInternal(l: Any, r: Any) = { if (left.dataType != BinaryType) l == r @@ -220,9 +220,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=>" + override def nullable: Boolean = false - override protected def checkTypesInternal(t: DataType) = TypeCheckResult.success + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess override def eval(input: Row): Any = { val l = left.eval(input) @@ -289,13 +290,13 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def checkInputDataTypes(): TypeCheckResult = { if (predicate.dataType != BooleanType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"type of predicate expression in If should be boolean, not ${predicate.dataType}") } else if (trueValue.dataType != falseValue.dataType) { - TypeCheckResult.fail( + TypeCheckResult.TypeCheckFailure( s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") } else { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } } @@ -326,16 +327,16 @@ trait CaseWhenLike extends Expression { branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - // both then and else val should be considered. + // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 override def checkInputDataTypes(): TypeCheckResult = { - if (valueTypes.distinct.size > 1) { - TypeCheckResult.fail( - "THEN and ELSE expressions should all be same type or coercible to a common type") - } else { + if (valueTypesEqual) { checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") } } @@ -365,9 +366,12 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override protected def checkTypesInternal(): TypeCheckResult = { if (whenList.forall(_.dataType == BooleanType)) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"WHEN expressions in CaseWhen should all be boolean type") + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") } } @@ -412,7 +416,14 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override protected def checkTypesInternal(): TypeCheckResult = TypeCheckResult.success + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 26df4fbfcf316..0bb12d2039ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -26,25 +26,25 @@ import org.apache.spark.sql.types._ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[NumericType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts numeric types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") } } def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[IntegralType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts integral types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") } } def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { if (t.isInstanceOf[AtomicType] || t == NullType) { - TypeCheckResult.success + TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.fail(s"$caller accepts non-complex types, not $t") + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1b8d18ded2257..7bac97b7894f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -92,8 +92,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index a0798428db094..0df446636ea89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -28,11 +28,11 @@ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -140,13 +140,10 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - // Will remove exception expectation in PR#6405 - intercept[RuntimeException] { - ruleTest(cwc, - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), - CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) - ) - } + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) } test("type coercion simplification for equal to") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index 0aca2ea2111ab..dcb3635c5ccae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -103,8 +103,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(GreaterThan('intField, 'stringField)) assertSuccess(GreaterThanOrEqual('intField, 'stringField)) - assertErrorForDifferingTypes(EqualTo('intField, 'booleanField)) - assertErrorForDifferingTypes(EqualNullSafe('intField, 'booleanField)) + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 06aa19ef09bd2..565d10247f10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 95eb1174b1dd6..7e1e21f5fbb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -155,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type.