Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-18632][SQL] AggregateFunction should not implement ImplicitCastInputTypes #16066

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.nio.ByteBuffer
import com.google.common.primitives.{Doubles, Ints, Longs}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{InternalRow}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
Expand Down Expand Up @@ -71,7 +71,8 @@ case class ApproximatePercentile(
percentageExpression: Expression,
accuracyExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
this(child, percentageExpression, accuracyExpression, 0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {
case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def prettyName: String = "avg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ import org.apache.spark.sql.types._
*
* @param child to compute central moments of.
*/
abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {
abstract class CentralMomentAgg(child: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {

/**
* The central moment order to be computed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.sql.types._
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
// scalastyle:on line.size.limit
case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
case class Corr(x: Expression, y: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = LongType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)

private lazy val count = AttributeReference("count", LongType, nullable = false)()

override lazy val aggBufferAttributes = count :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.sketch.CountMinSketch
Expand Down Expand Up @@ -52,7 +52,8 @@ case class CountMinSketchAgg(
confidenceExpression: Expression,
seedExpression: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {

def this(
child: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
* Compute the covariance between two expressions.
* When applied on empty data (i.e., count is zero), it returns NULL.
*/
abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate {
abstract class Covariance(x: Expression, y: Expression)
extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = Seq(x, y)
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand All @@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
case class First(child: Expression, ignoreNullsExpr: Expression)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil

override def nullable: Boolean = true
Expand All @@ -56,6 +52,20 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val first = AttributeReference("first", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ case class HyperLogLogPlusPlus(

override def dataType: DataType = LongType

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

/** Allocate enough words to store all registers. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

Expand All @@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
case class Last(child: Expression, ignoreNullsExpr: Expression)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))

private val ignoreNulls: Boolean = ignoreNullsExpr match {
case Literal(b: Boolean, BooleanType) => b
case _ =>
throw new AnalysisException("The second argument of First should be a boolean literal.")
}

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil

override def nullable: Boolean = true
Expand All @@ -56,6 +52,20 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val last = AttributeReference("last", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ case class Max(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = child.dataType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function max")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
// Return data type.
override def dataType: DataType = child.dataType

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function min")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ import org.apache.spark.util.collection.OpenHashMap
be between 0.0 and 1.0.
""")
case class Percentile(
child: Expression,
percentageExpression: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
child: Expression,
percentageExpression: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {

def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, 0, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ case class PivotFirst(

override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil

override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)

override val nullable: Boolean = false

val valueDataType = valueColumn.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._

@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.")
case class Sum(child: Expression) extends DeclarativeAggregate {
case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {

override def children: Seq[Expression] = child :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ abstract class Collect extends ImperativeAggregate {

override def dataType: DataType = ArrayType(child.dataType)

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def supportsPartial: Boolean = false

override def aggBufferAttributes: Seq[AttributeReference] = Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ case class AggregateExpression(
* Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
* aggregate functions.
*/
sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
sealed abstract class AggregateFunction extends Expression {

/** An aggregate function is not foldable. */
final override def foldable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF

abstract class RowNumberLike extends AggregateWindowFunction {
override def children: Seq[Expression] = Nil
override def inputTypes: Seq[AbstractDataType] = Nil
protected val zero = Literal(0)
protected val one = Literal(1)
protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)()
Expand Down Expand Up @@ -600,7 +599,6 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow
* This documentation has been based upon similar documentation for the Hive and Presto projects.
*/
abstract class RankLike extends AggregateWindowFunction {
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)

/** Store the values of the window 'order' expressions. */
protected val orderAttrs = children.map { expr =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ case class TypedAggregateExpression(

override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)

override def inputTypes: Seq[AbstractDataType] = Nil

private def aggregatorLiteral =
Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with NonSQLExpression with Logging {
extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
withTempView(table) {
val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)")
.mkString(", ")
val result = sql(s"SELECT $cmsSql FROM $table").head()

val cmsSql = schema.fieldNames.map { col =>
s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)"
}
val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head()
schema.indices.foreach { i =>
val binaryData = result.getAs[Array[Byte]](i)
val in = new ByteArrayInputStream(binaryData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da

import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType}
import org.apache.spark.sql.types._

class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {

Expand Down Expand Up @@ -231,7 +231,8 @@ object TypedImperativeAggregateSuite {
child: Expression,
nullable: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] {
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {


override def createAggregationBuffer(): MaxValue = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,6 @@ private[hive] case class HiveUDAFFunction(
@transient
private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe

// We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
// catalyst type checking framework.
override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)

override def nullable: Boolean = true

override def supportsPartial: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ case class TestingTypedCount(
TestingTypedCount.State(dataStream.readLong())
}

override def inputTypes: Seq[AbstractDataType] = AnyDataType :: Nil

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand Down