Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 1, 2015
1 parent b917275 commit b5ff31b
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ trait CheckAnalysis {
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

case e: Expression if e.checkInputDataTypes.hasError =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: " +
e.checkInputDataTypes.errorMessage)
case e: Expression if e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.failAnalysis(
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
}

case c: Cast if !c.resolved =>
failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object HiveTypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
Expand All @@ -57,6 +57,17 @@ object HiveTypeCoercion {

case _ => None
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
private def findTightestCommonType(types: Seq[DataType]) = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}
}

/**
Expand Down Expand Up @@ -180,7 +191,7 @@ trait HiveTypeCoercion {

case (l, r) if l.dataType != r.dataType =>
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
val newLeft =
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
val newRight =
Expand Down Expand Up @@ -217,7 +228,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e

case b: BinaryExpression if b.left.dataType != b.right.dataType =>
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
val newLeft =
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
val newRight =
Expand Down Expand Up @@ -323,7 +334,6 @@ trait HiveTypeCoercion {
* e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
* sum(e1) p1 + 10 s1
* avg(e1) p1 + 4 s1 + 4
* compare max(p1, p2) max(s1, s2)
*
* Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
*
Expand Down Expand Up @@ -442,10 +452,18 @@ trait HiveTypeCoercion {
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
)

// When we compare 2 decimal types with different precisions, cast them to the smallest
// common precision.
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
val resultType = DecimalType(max(p1, p2), max(s1, s2))
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2)
if e2.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2))
case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _))
if e1.dataType == DecimalType.Unlimited =>
b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited)))

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
Expand Down Expand Up @@ -560,7 +578,7 @@ trait HiveTypeCoercion {

case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonType(a, b).getOrElse(StringType))
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))

Expand Down Expand Up @@ -590,12 +608,8 @@ trait HiveTypeCoercion {
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
})
rt match {
case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
findTightestCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
}
Expand All @@ -608,7 +622,7 @@ trait HiveTypeCoercion {
*/
object Division extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip Divisions who has not been resolved yet,
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.resolved => e

Expand All @@ -624,47 +638,36 @@ trait HiveTypeCoercion {
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] {

import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
}).map { commonType =>
val transformedBranches = cw.branches.sliding(2, 2).map {
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
val maybeCommonType = findTightestCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
case other => other
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
c match {
case _: CaseWhen => CaseWhen(castedBranches)
case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
}
}.getOrElse(cw)

case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = ckw.branches.sliding(2, 2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case s => s
}.reduce(_ ++ _)
val transformedKey = if (ckw.key.dataType != commonType) {
Cast(ckw.key, commonType)
} else {
ckw.key
}
CaseKeyWhen(transformedKey, transformedBranches)
}.getOrElse(c)

case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
case Seq(when, then) if when.dataType != commonType =>
Seq(Cast(when, commonType), then)
case other => other
}.reduce(_ ++ _)
CaseKeyWhen(Cast(c.key, commonType), castedBranches)
}.getOrElse(c)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,27 @@ package org.apache.spark.sql.catalyst.analysis

/**
* Represents the result of `Expression.checkInputDataTypes`.
* We will throw `AnalysisException` in `CheckAnalysis` if error message is not null.
* Use [[TypeCheckResult.success]] and [[TypeCheckResult.fail]] to instantiate this.
*
* We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true.
*/
class TypeCheckResult private (val errorMessage: String) extends AnyVal {
def hasError: Boolean = errorMessage != null
trait TypeCheckResult {
def isFailure: Boolean = !isSuccess
def isSuccess: Boolean
}

object TypeCheckResult {
val success: TypeCheckResult = new TypeCheckResult(null)
def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg)

/**
* Represents the successful result of `Expression.checkInputDataTypes`.
*/
object TypeCheckSuccess extends TypeCheckResult {
def isSuccess: Boolean = true
}

/**
* Represents the failing result of `Expression.checkInputDataTypes`,
* with a error message to show the reason of failure.
*/
case class TypeCheckFailure(message: String) extends TypeCheckResult {
def isSuccess: Boolean = false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ abstract class Expression extends TreeNode[Expression] {

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and `false` if it still contains any unresolved placeholders. Implementations of expressions
* should override this if the resolution of this type of expression involves more than just
* the resolution of its children.
* and input data types checking passed, and `false` if it still contains any unresolved
* placeholders or has data types mismatch.
* Implementations of expressions should override this if the resolution of this type of
* expression involves more than just the resolution of its children and type checking.
*/
lazy val resolved: Boolean = childrenResolved && !checkInputDataTypes().hasError
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

/**
* Returns the [[DataType]] of the result of evaluating this expression. It is
Expand Down Expand Up @@ -88,18 +89,19 @@ abstract class Expression extends TreeNode[Expression] {
}

/**
* Check the input data types, returns `TypeCheckResult.success` if it's valid,
* or return a `TypeCheckResult` with an error message if invalid.
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.success
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}

abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
self: Product =>

def symbol: String = sys.error(s"BinaryExpressions must either override toString or symbol")
def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")

override def foldable: Boolean = left.foldable && right.foldable

Expand Down Expand Up @@ -137,9 +139,9 @@ trait ExpectsInputTypes {

def expectedChildTypes: Seq[DataType]

override def checkInputDataTypes: TypeCheckResult = {
override def checkInputDataTypes(): TypeCheckResult = {
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
// so type mismatch error won't be reported here, but for underling `Cast`s.
TypeCheckResult.success
TypeCheckResult.TypeCheckSuccess
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ abstract class UnaryArithmetic extends UnaryExpression {
}

protected def evalInternal(evalE: Any): Any =
sys.error(s"UnaryArithmetics must either override eval or evalInternal")
sys.error(s"UnaryArithmetics must override either eval or evalInternal")
}

case class UnaryMinus(child: Expression) extends UnaryArithmetic {
Expand Down Expand Up @@ -90,7 +90,7 @@ abstract class BinaryArithmetic extends BinaryExpression {

override def checkInputDataTypes(): TypeCheckResult = {
if (left.dataType != right.dataType) {
TypeCheckResult.fail(
TypeCheckResult.TypeCheckFailure(
s"differing types in ${this.getClass.getSimpleName} " +
s"(${left.dataType} and ${right.dataType}).")
} else {
Expand All @@ -115,12 +115,15 @@ abstract class BinaryArithmetic extends BinaryExpression {
}

protected def evalInternal(evalE1: Any, evalE2: Any): Any =
sys.error(s"BinaryArithmetics must either override eval or evalInternal")
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

Expand All @@ -132,6 +135,9 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "-"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

Expand All @@ -143,6 +149,9 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "*"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

Expand All @@ -155,6 +164,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override def symbol: String = "/"
override def nullable: Boolean = true

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

Expand Down Expand Up @@ -182,6 +194,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def symbol: String = "%"
override def nullable: Boolean = true

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)

Expand Down
Loading

0 comments on commit b5ff31b

Please sign in to comment.