Skip to content

Commit

Permalink
[SPARK-3891][SQL] Add array support to percentile, percentile_approx …
Browse files Browse the repository at this point in the history
…and constant inspectors support

Supported passing array to percentile and percentile_approx UDAFs
To support percentile_approx,  constant inspectors are supported for GenericUDAF
Constant folding support added to CreateArray expression
Avoided constant udf expression re-evaluation

Author: Venkata Ramana G <ramana.gollamudihuawei.com>

Author: Venkata Ramana Gollamudi <[email protected]>

Closes apache#2802 from gvramana/percentile_array_support and squashes the following commits:

a0182e5 [Venkata Ramana Gollamudi] fixed review comment
a18f917 [Venkata Ramana Gollamudi] avoid constant udf expression re-evaluation - fixes failure due to return iterator and value type mismatch
c46db0f [Venkata Ramana Gollamudi] Removed TestHive reset
4d39105 [Venkata Ramana Gollamudi] Unified inspector creation, style check fixes
f37fd69 [Venkata Ramana Gollamudi] Fixed review comments
47f6365 [Venkata Ramana Gollamudi] fixed test
cb7c61e [Venkata Ramana Gollamudi] Supported ConstantInspector for UDAF Fixed HiveUdaf wrap object issue.
7f94aff [Venkata Ramana Gollamudi] Added foldable support to CreateArray
  • Loading branch information
gvramana authored and markhamstra committed Jan 17, 2015
1 parent efd3a6b commit cec958e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
*/
case class CreateArray(children: Seq[Expression]) extends Expression {
override type EvaluatedType = Any


override def foldable = !children.exists(!_.foldable)

lazy val childTypes = children.map(_.dataType).distinct

override lazy val resolved =
Expand Down
35 changes: 25 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
override def foldable =
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]

@transient
protected def constantReturnValue = unwrap(
returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(),
returnInspector)

@transient
protected lazy val deferedObjects =
argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
Expand All @@ -166,6 +171,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr

override def eval(input: Row): Any = {
returnInspector // Make sure initialized.
if(foldable) return constantReturnValue

var i = 0
while (i < children.length) {
val idx = i
Expand Down Expand Up @@ -193,12 +200,13 @@ private[hive] case class HiveGenericUdaf(

@transient
protected lazy val objectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
resolver.getEvaluator(parameterInfo)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
}

@transient
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
protected lazy val inspectors = children.map(toInspector)

def dataType: DataType = inspectorToDataType(objectInspector)

Expand All @@ -223,12 +231,13 @@ private[hive] case class HiveUdaf(

@transient
protected lazy val objectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
resolver.getEvaluator(parameterInfo)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
}

@transient
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
protected lazy val inspectors = children.map(toInspector)

def dataType: DataType = inspectorToDataType(objectInspector)

Expand Down Expand Up @@ -261,7 +270,7 @@ private[hive] case class HiveGenericUdtf(
protected lazy val function: GenericUDTF = funcWrapper.createFunction()

@transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
protected lazy val inputInspectors = children.map(toInspector)

@transient
protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
Expand Down Expand Up @@ -334,10 +343,13 @@ private[hive] case class HiveUdafFunction(
} else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}

private val inspectors = exprs.map(_.dataType).map(toInspector).toArray

private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray)

private val inspectors = exprs.map(toInspector).toArray

private val function = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
resolver.getEvaluator(parameterInfo)
}

private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)

Expand All @@ -350,9 +362,12 @@ private[hive] case class HiveUdafFunction(
@transient
val inputProjection = new InterpretedProjection(exprs)

@transient
protected lazy val cached = new Array[AnyRef](exprs.length)

def update(input: Row): Unit = {
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
function.iterate(buffer, inputs)
function.iterate(buffer, wrap(inputs, inspectors, cached))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,21 @@ class HiveUdfSuite extends QueryTest {
}

test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"),
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)

checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"),
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
}

test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)

checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"),
sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq)
}

test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize(
IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
Expand Down

0 comments on commit cec958e

Please sign in to comment.