Skip to content

Commit

Permalink
[SPARK-7886] Use FunctionRegistry for built-in expressions in HiveCon…
Browse files Browse the repository at this point in the history
…text.

This builds on #6710 and also uses FunctionRegistry for function lookup in HiveContext.

Author: Reynold Xin <[email protected]>

Closes #6712 from rxin/udf-registry-hive and squashes the following commits:

f4c2df0 [Reynold Xin] Fixed style violation.
0bd4127 [Reynold Xin] Fixed Python UDFs.
f9a0378 [Reynold Xin] Disable one more test.
5609494 [Reynold Xin] Disable some failing tests.
4efea20 [Reynold Xin] Don't check children resolved for UDF resolution.
2ebe549 [Reynold Xin] Removed more hardcoded functions.
aadce78 [Reynold Xin] [SPARK-7886] Use FunctionRegistry for built-in expressions in HiveContext.
  • Loading branch information
rxin committed Jun 10, 2015
1 parent 778f3ca commit 57c60c5
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val FULL = Keyword("FULL")
protected val GROUP = Keyword("GROUP")
protected val HAVING = Keyword("HAVING")
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
protected val INSERT = Keyword("INSERT")
Expand Down Expand Up @@ -277,6 +276,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
| APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) if u.childrenResolved =>
case u @ UnresolvedFunction(name, children) =>
withPosition(u) {
registry.lookupFunction(name, children)
}
Expand Down Expand Up @@ -494,20 +494,21 @@ class Analyzer(
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
if aggregate.resolved && containsAggregate(havingCondition) =>

val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs

Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}
}

protected def containsAggregate(condition: Expression): Boolean =
protected def containsAggregate(condition: Expression): Boolean = {
condition
.collect { case ae: AggregateExpression => ae }
.nonEmpty
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ trait FunctionRegistry {
def lookupFunction(name: String, children: Seq[Expression]): Expression
}

trait OverrideFunctionRegistry extends FunctionRegistry {
class OverrideFunctionRegistry(underlying: FunctionRegistry) extends FunctionRegistry {

private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)

override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
functionBuilders.put(name, builder)
}

abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children))
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders.get(name).map(_(children)).getOrElse(underlying.lookupFunction(name, children))
}
}

Expand Down Expand Up @@ -133,6 +133,12 @@ object FunctionRegistry {
expression[Sum]("sum")
)

val builtin: FunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) }
fr
}

/** See usage above. */
private def expression[T <: Expression](name: String)
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import org.apache.spark.sql.types._


/**
* For Catalyst to work correctly, concrete implementations of [[Expression]]s must be case classes
* whose constructor arguments are all Expressions types. In addition, if we want to support more
* than one constructor, define those constructors explicitly as apply methods in the companion
* object.
* If an expression wants to be exposed in the function registry (so users can call it with
* "name(arguments...)", the concrete implementation must be a case class whose constructor
* arguments are all Expressions types. In addition, if it needs to support more than one
* constructor, define those constructors explicitly as apply methods in the companion object.
*
* See [[Substring]] for an example.
*/
Expand Down
7 changes: 2 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,8 @@ class SQLContext(@transient val sparkContext: SparkContext)

// TODO how to handle the temp function per user session?
@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = {
val fr = new SimpleFunctionRegistry
FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) }
fr
}
protected[sql] lazy val functionRegistry: FunctionRegistry =
new OverrideFunctionRegistry(FunctionRegistry.builtin)

@transient
protected[sql] lazy val analyzer: Analyzer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ private[spark] case class PythonUDF(
def nullable: Boolean = true

override def eval(input: Row): Any = {
sys.error("PythonUDFs can not be directly evaluated.")
throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
}
}

Expand All @@ -71,43 +71,49 @@ private[spark] case class PythonUDF(
private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Skip EvaluatePython nodes.
case p: EvaluatePython => p
case plan: EvaluatePython => plan

case l: LogicalPlan =>
case plan: LogicalPlan =>
// Extract any PythonUDFs from the current operator.
val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf})
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
// If there aren't any, we are done.
l
plan
} else {
// Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
// If there is more than one, we will add another evaluation operator in a subsequent pass.
val udf = udfs.head

var evaluation: EvaluatePython = null

// Rewrite the child that has the input required for the UDF
val newChildren = l.children.map { child =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
// Other cases are disallowed as they are ambiguous or would require a cartisian product.
if (udf.references.subsetOf(child.outputSet)) {
evaluation = EvaluatePython(udf, child)
evaluation
} else if (udf.references.intersect(child.outputSet).nonEmpty) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
child
}
udfs.find(_.resolved) match {
case Some(udf) =>
var evaluation: EvaluatePython = null

// Rewrite the child that has the input required for the UDF
val newChildren = plan.children.map { child =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
if (udf.references.subsetOf(child.outputSet)) {
evaluation = EvaluatePython(udf, child)
evaluation
} else if (udf.references.intersect(child.outputSet).nonEmpty) {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
} else {
child
}
}

assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")

// Trim away the new UDF value if it was only used for filtering or something.
logical.Project(
plan.output,
plan.transformExpressions {
case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
}.withNewChildren(newChildren))

case None =>
// If there is no Python UDF that is resolved, skip this round.
plan
}

assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")

// Trim away the new UDF value if it was only used for filtering or something.
logical.Project(
l.output,
l.transformExpressions {
case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
}.withNewChildren(newChildren))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,19 +817,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf2",
"udf5",
"udf6",
"udf7",
// "udf7", turn this on after we figure out null vs nan vs infinity
"udf8",
"udf9",
"udf_10_trims",
"udf_E",
"udf_PI",
"udf_abs",
"udf_acos",
// "udf_acos", turn this on after we figure out null vs nan vs infinity
"udf_add",
"udf_array",
"udf_array_contains",
"udf_ascii",
"udf_asin",
// "udf_asin", turn this on after we figure out null vs nan vs infinity
"udf_atan",
"udf_avg",
"udf_bigint",
Expand Down Expand Up @@ -917,7 +917,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_repeat",
"udf_rlike",
"udf_round",
"udf_round_3",
// "udf_round_3", TODO: FIX THIS failed due to cast exception
"udf_rpad",
"udf_rtrim",
"udf_second",
Expand All @@ -931,7 +931,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_stddev_pop",
"udf_stddev_samp",
"udf_string",
"udf_struct",
// "udf_struct", TODO: FIX THIS and enable it.
"udf_substring",
"udf_subtract",
"udf_sum",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry: FunctionRegistry =
new HiveFunctionRegistry with OverrideFunctionRegistry
new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin))

/* An analyzer that uses the Hive metastore. */
@transient
Expand Down
30 changes: 0 additions & 30 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1307,16 +1307,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
HiveParser.DecimalLiteral)

/* Case insensitive matches */
val ARRAY = "(?i)ARRAY".r
val COALESCE = "(?i)COALESCE".r
val COUNT = "(?i)COUNT".r
val AVG = "(?i)AVG".r
val SUM = "(?i)SUM".r
val MAX = "(?i)MAX".r
val MIN = "(?i)MIN".r
val UPPER = "(?i)UPPER".r
val LOWER = "(?i)LOWER".r
val RAND = "(?i)RAND".r
val AND = "(?i)AND".r
val OR = "(?i)OR".r
val NOT = "(?i)NOT".r
Expand All @@ -1330,8 +1323,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val BETWEEN = "(?i)BETWEEN".r
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
val SUBSTR = "(?i)SUBSTR(?:ING)?".r
val SQRT = "(?i)SQRT".r

protected def nodeToExpr(node: Node): Expression = node match {
/* Attribute References */
Expand All @@ -1353,18 +1344,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
UnresolvedStar(Some(name))

/* Aggregate Functions */
case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg))

/* System functions about string operations */
case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg))

/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
Expand Down Expand Up @@ -1414,7 +1396,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
case Token("TOK_FUNCTION", Token(SQRT(), Nil) :: arg :: Nil) => Sqrt(nodeToExpr(arg))

/* Comparisons */
case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
Expand Down Expand Up @@ -1469,17 +1450,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("[", child :: ordinal :: Nil) =>
UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))

/* Other functions */
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
CreateArray(children.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr))

/* Window Functions */
case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) =>
val function = UnresolvedWindowFunction(name, args.map(nodeToExpr))
Expand Down
51 changes: 29 additions & 22 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.util.Try

import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector}
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
Expand All @@ -33,6 +34,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
Expand All @@ -41,35 +43,40 @@ import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._


private[hive] abstract class HiveFunctionRegistry
private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
extends analysis.FunctionRegistry with HiveInspectors {

def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo =
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
throw new AnalysisException(s"undefined function $name"))

val functionClassName = functionInfo.getFunctionClass.getName

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
Try(underlying.lookupFunction(name, children)).getOrElse {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
// not always serializable.
val functionInfo: FunctionInfo =
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
throw new AnalysisException(s"undefined function $name"))

val functionClassName = functionInfo.getFunctionClass.getName

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
}

override def registerFunction(name: String, builder: FunctionBuilder): Unit =
throw new UnsupportedOperationException
}

private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
Expand Down

0 comments on commit 57c60c5

Please sign in to comment.