Skip to content

Commit

Permalink
[SPARK-23587][SQL] Add interpreted execution for MapObjects expression
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Add interpreted execution for `MapObjects` expression.

## How was this patch tested?

Added unit test.

Author: Liang-Chi Hsieh <[email protected]>

Closes apache#20771 from viirya/SPARK-23587.
  • Loading branch information
viirya authored and Robert Kruszewski committed Apr 4, 2018
1 parent b72b848 commit 0bcf7e4
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
Expand Down Expand Up @@ -501,12 +502,22 @@ case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression
with Unevaluable with NonSQLExpression {
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
assert(input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field.")
input.get(0, dataType)
}

override def genCode(ctx: CodegenContext): ExprCode = {
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}

// This won't be called as `genCode` is overrided, just overriding it to make
// `LambdaVariable` non-abstract.
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

/**
Expand Down Expand Up @@ -599,8 +610,92 @@ case class MapObjects private(

override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
// The data with UserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
lazy private val inputDataType = inputData.dataType match {
case u: UserDefinedType[_] => u.sqlType
case _ => inputData.dataType
}

private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
val row = new GenericInternalRow(1)
inputCollection.toIterator.map { element =>
row.update(0, element)
lambdaFunction.eval(row)
}
}

private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
_.asInstanceOf[Seq[_]]
case ObjectType(cls) if cls.isArray =>
_.asInstanceOf[Array[_]].toSeq
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
_.asInstanceOf[java.util.List[_]].asScala
case ObjectType(cls) if cls == classOf[Object] =>
(inputCollection) => {
if (inputCollection.getClass.isArray) {
inputCollection.asInstanceOf[Array[_]].toSeq
} else {
inputCollection.asInstanceOf[Seq[_]]
}
}
case ArrayType(et, _) =>
_.asInstanceOf[ArrayData].array
}

private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
executeFuncOnCollection(_).toSeq
case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala set
executeFuncOnCollection(_).toSet
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
// Java list
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
cls == classOf[java.util.AbstractSequentialList[_]]) {
// Specifying non concrete implementations of `java.util.List`
executeFuncOnCollection(_).toSeq.asJava
} else {
val constructors = cls.getConstructors()
val intParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
}
val noParamConstructor = constructors.find { constructor =>
constructor.getParameterCount == 0
}

val constructor = intParamConstructor.map { intConstructor =>
(len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
}.getOrElse {
(_: Int) => noParamConstructor.get.newInstance()
}

// Specifying concrete implementations of `java.util.List`
(inputs) => {
val results = executeFuncOnCollection(inputs)
val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
results.foreach(builder.add(_))
builder
}
}
case None =>
// array
x => new GenericArrayData(executeFuncOnCollection(x).toArray)
case Some(cls) =>
throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
"resulting collection.")
}

override def eval(input: InternalRow): Any = {
val inputCollection = inputData.eval(input)

if (inputCollection == null) {
return null
}
mapElements(convertToSeq(inputCollection))
}

override def dataType: DataType =
customCollectionCls.map(ObjectType.apply).getOrElse(
Expand Down Expand Up @@ -647,13 +742,6 @@ case class MapObjects private(
case _ => ""
}

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
}

// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkFunSuite}
Expand All @@ -25,7 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-23587: MapObjects should support interpreted execution") {
def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = {
val function = (lambda: Expression) => Add(lambda, Literal(1))
val elementType = IntegerType
val expected = Seq(2, 3, 4)

val inputObject = BoundReference(0, inputType, nullable = true)
val optClass = Option(collectionCls)
val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
val row = InternalRow.fromSeq(Seq(collection))
val result = mapObj.eval(row)

collectionCls match {
case null =>
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
case s if classOf[Seq[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[Seq[_]].toSeq == expected)
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
}
}

val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
classOf[java.util.Stack[Int]], null)

val list = new java.util.ArrayList[Int]()
list.add(1)
list.add(2)
list.add(3)
val arrayData = new GenericArrayData(Array(1, 2, 3))
val vector = new java.util.Vector[Int]()
vector.add(1)
vector.add(2)
vector.add(3)
val stack = new java.util.Stack[Int]()
stack.add(1)
stack.add(2)
stack.add(3)

Seq(
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
(Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Object])),
(Array(1, 2, 3), ObjectType(classOf[Object])),
(list, ObjectType(classOf[java.util.List[Int]])),
(vector, ObjectType(classOf[java.util.Vector[Int]])),
(stack, ObjectType(classOf[java.util.Stack[Int]])),
(arrayData, ArrayType(IntegerType))
).foreach { case (collection, inputType) =>
customCollectionClasses.foreach(testMapObjects(collection, _, inputType))

// Unsupported custom collection class
val errMsg = intercept[RuntimeException] {
testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType)
}.getMessage()
assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " +
"as resulting collection."))
}
}

test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") {
val cls = classOf[java.lang.Integer]
val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true)
Expand Down

0 comments on commit 0bcf7e4

Please sign in to comment.