Skip to content

Commit

Permalink
generalize UnresolvedGetField to support all map, struct, and array
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed May 8, 2015
1 parent cd1d411 commit c9d85f5
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected lazy val primary: PackratParser[Expression] =
( literal
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
{ case base ~ ordinal => UnresolvedGetField(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedGetField(base, Literal(fieldName)) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ class Analyzer(
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
GetField(child, fieldName, resolver)
case UnresolvedGetField(child, fieldExpr) if child.resolved =>
GetField(child, fieldExpr, resolver)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
case class UnresolvedGetField(child: Expression, fieldExpr: Expression) extends UnaryExpression {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
Expand All @@ -193,5 +193,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends Unar
override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString: String = s"$child.$fieldName"
override def toString: String = s"$child.getField($fieldExpr)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ package object dsl {
def isNull: Predicate = IsNull(expr)
def isNotNull: Predicate = IsNotNull(expr)

def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
def getItem(ordinal: Expression): UnresolvedGetField = UnresolvedGetField(expr, ordinal)
def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, Literal(fieldName))

def cast(to: DataType): Expression = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection.Map

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._

object GetField {
/**
* Returns the resolved `GetField`. It will return one kind of concrete `GetField`,
* depend on the type of `child` and `fieldExpr`.
*/
def apply(
child: Expression,
fieldExpr: Expression,
resolver: Resolver): GetField = {

(child.dataType, fieldExpr) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
SimpleStructGetField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
ArrayStructGetField(child, fields(ordinal), ordinal, containsNull)
case (_: ArrayType, _) if fieldExpr.dataType.isInstanceOf[IntegralType] =>
ArrayOrdinalGetField(child, fieldExpr)
case (_: MapType, _) =>
MapOrdinalGetField(child, fieldExpr)
case (otherType, _) =>
throw new AnalysisException(
s"GetField is not valid on child of type $otherType with fieldExpr of type ${fieldExpr.dataType}")
}
}

def unapply(g: GetField): Option[(Expression, Expression)] = {
g match {
case _: StructGetField => Some((g.child, null))
case o: OrdinalGetField => Some((o.child, o.ordinal))
case _ => None
}
}

/**
* find the ordinal of StructField, report error if no desired field or over one
* desired fields are found.
*/
private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
}

trait GetField extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
}

abstract class StructGetField extends GetField {
self: Product =>

def field: StructField

override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"
}

abstract class OrdinalGetField extends GetField {
self: Product =>

def ordinal: Expression

/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def foldable: Boolean = child.foldable && ordinal.foldable
override def toString: String = s"$child[$ordinal]"
override def children: Seq[Expression] = child :: ordinal :: Nil

override def eval(input: Row): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
val o = ordinal.eval(input)
if (o == null) {
null
} else {
evalNotNull(value, o)
}
}
}

protected def evalNotNull(value: Any, ordinal: Any): Any
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class SimpleStructGetField(child: Expression, field: StructField, ordinal: Int)
extends StructGetField {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
}

/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class ArrayStructGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends StructGetField {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
}
}
}

/**
* Returns the field at `ordinal` in the Array `child`
*/
case class ArrayOrdinalGetField(child: Expression, ordinal: Expression)
extends OrdinalGetField {

override def dataType = child.dataType.asInstanceOf[ArrayType].elementType

override lazy val resolved = childrenResolved &&
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]

protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
val index = ordinal.asInstanceOf[Int]
if (index >= baseValue.size || index < 0) {
null
} else {
baseValue(index)
}
}
}

/**
* Returns the value of key `ordinal` in Map `child`
*/
case class MapOrdinalGetField(child: Expression, ordinal: Expression)
extends OrdinalGetField {

override def dataType = child.dataType.asInstanceOf[MapType].valueType

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]

protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,139 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.Map

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.types._

/**
* Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`.
*/
case class GetItem(child: Expression, ordinal: Expression) extends Expression {
type EvaluatedType = Any

val children: Seq[Expression] = child :: ordinal :: Nil
/** `Null` is returned for invalid ordinals. */
override def nullable: Boolean = true
override def foldable: Boolean = child.foldable && ordinal.foldable

override def dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case MapType(_, vt, _) => vt
}
override lazy val resolved =
childrenResolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

override def toString: String = s"$child[$ordinal]"

override def eval(input: Row): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
val key = ordinal.eval(input)
if (key == null) {
null
} else {
if (child.dataType.isInstanceOf[ArrayType]) {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
val baseValue = value.asInstanceOf[Seq[_]]
val o = key.asInstanceOf[Int]
if (o >= baseValue.size || o < 0) {
null
} else {
baseValue(o)
}
} else {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(key).orNull
}
}
}
}
}


trait GetField extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"

def field: StructField
}

object GetField {
/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*/
def apply(
expr: Expression,
fieldName: String,
resolver: Resolver): GetField = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType =>
throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {

override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
}

/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends GetField {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
}
}
}

/**
* Returns an Array containing the evaluation of all children expressions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal.create(null, e.dataType)
case e @ GetField(_, Literal(null, _)) => Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Loading

0 comments on commit c9d85f5

Please sign in to comment.