Skip to content

Commit

Permalink
apply type check interface to CaseWhen
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 1, 2015
1 parent cffb67c commit 8883025
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,8 @@ trait HiveTypeCoercion {
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val dt: Option[DataType] = Some(NullType)
val types = es.map(_.dataType)
val rt = types.foldLeft(dt)((r, c) => r match {
val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
})
Expand Down Expand Up @@ -635,28 +634,30 @@ trait HiveTypeCoercion {
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CaseWhenCoercion extends Rule[LogicalPlan] {

import HiveTypeCoercion._

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError =>
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
val commonType = cw.valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2).getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) => findTightestCommonType(d, c)
}).map { commonType =>
val transformedBranches = cw.branches.sliding(2, 2).map {
case Seq(when, value) if value.dataType != commonType =>
Seq(when, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
cw match {
case _: CaseWhen =>
CaseWhen(transformedBranches)
case CaseKeyWhen(key, _) =>
CaseKeyWhen(key, transformedBranches)
}
}.getOrElse(cw)

case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,18 @@ trait CaseWhenLike extends Expression {
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1

override def dataType: DataType = valueTypes.head
override def checkInputDataTypes(): TypeCheckResult = {
if (valueTypes.distinct.size > 1) {
TypeCheckResult.fail(
"THEN and ELSE expressions should all be same type or coercible to a common type")
} else {
checkTypesInternal()
}
}

protected def checkTypesInternal(): TypeCheckResult

override def dataType: DataType = thenList.head.dataType

override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
Expand All @@ -347,14 +358,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {

override def children: Seq[Expression] = branches

override def checkInputDataTypes(): TypeCheckResult = {
if (!whenList.forall(_.dataType == BooleanType)) {
TypeCheckResult.fail(s"WHEN expressions should all be boolean type")
} else if (!valueTypesEqual) {
TypeCheckResult.fail(
"THEN and ELSE expressions should all be same type or coercible to a common type")
} else {
override protected def checkTypesInternal(): TypeCheckResult = {
if (whenList.forall(_.dataType == BooleanType)) {
TypeCheckResult.success
} else {
TypeCheckResult.fail(s"WHEN expressions in CaseWhen should all be boolean type")
}
}

Expand Down Expand Up @@ -399,14 +407,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW

override def children: Seq[Expression] = key +: branches

override def checkInputDataTypes(): TypeCheckResult = {
if (!valueTypesEqual) {
TypeCheckResult.fail(
"THEN and ELSE expressions should all be same type or coercible to a common type")
} else {
TypeCheckResult.success
}
}
override protected def checkTypesInternal(): TypeCheckResult = TypeCheckResult.success

/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@ class ExpressionTypeCheckingSuite extends FunSuite {
"type of predicate expression in If should be boolean")
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))

// Will write tests for CaseWhen later,
// as the error reporting of it is not handle by the new interface for now
assertError(
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
"WHEN expressions in CaseWhen should all be boolean type")

}
}

0 comments on commit 8883025

Please sign in to comment.