From 23061811f9c755e9270090d94a54d2c2a8d28eef Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 31 Jul 2018 20:06:33 +0100 Subject: [PATCH 1/4] follow the approach using ArrayExcept --- .../expressions/collectionOperations.scala | 367 +++++++++--------- .../CollectionExpressionsSuite.scala | 21 +- 2 files changed, 203 insertions(+), 185 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e385c2d9782e8..149ab2cb5809c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3767,158 +3767,108 @@ object ArraySetLike { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { - var hsInt: OpenHashSet[Int] = _ - var hsLong: OpenHashSet[Long] = _ - - def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getInt(idx) - if (!hsInt.contains(elem)) { - if (resultArray != null) { - resultArray.setInt(pos, elem) - } - hsInt.add(elem) - true - } else { - false - } - } - - def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getLong(idx) - if (!hsLong.contains(elem)) { - if (resultArray != null) { - resultArray.setLong(pos, elem) - } - hsLong.add(elem) - true - } else { - false - } - } + with ComplexTypeMergingExpression { - def evalIntLongPrimitiveType( - array1: ArrayData, - array2: ArrayData, - resultArray: ArrayData, - isLongType: Boolean): Int = { - // store elements into resultArray - var nullElementSize = 0 - var pos = 0 - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - val size = if (!isLongType) hsInt.size else hsLong.size - if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(size) - } - if (array.isNullAt(i)) { - if (nullElementSize == 0) { - if (resultArray != null) { - resultArray.setNullAt(pos) + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } } - pos += 1 - nullElementSize = 1 + i += 1 } - } else { - val assigned = if (!isLongType) { - assignInt(array, i, resultArray, pos) + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } } else { - assignLong(array, i, resultArray, pos) + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } } - if (assigned) { - pos += 1 + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem } - } - i += 1 - } + })) + new GenericArrayData(arrayBuffer) } - pos } override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) - hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, false) - resultArray - case LongType => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) - hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, true) - resultArray - case _ => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new OpenHashSet[Any] - var foundNullElement = false - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - if (array.isNullAt(i)) { - if (!foundNullElement) { - arrayBuffer += null - foundNullElement = true - } - } else { - val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) - } - arrayBuffer += elem - hs.add(elem) - } - } - i += 1 - } - } - new GenericArrayData(arrayBuffer) - } - } else { - ArrayUnion.unionOrdering(array1, array2, elementType, ordering) - } + evalUnion(array1, array2) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val pos = ctx.freshName("pos") val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = - if (elementTypeSupportEquals) { + if (elementTypeSupportEquals) { + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = elementType match { - case ByteType | ShortType | IntegerType | LongType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - if (elementType == LongType) "Long" else "Int", - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - if (elementType == LongType) "(long)" else "(int)", + case ByteType | ShortType | IntegerType => + ("$mcI$sp", "Int", "int", s"(int) $value", + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", + CodeGenerator.javaType(elementType), ptName, s""" |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} |${ev.value} = $unsafeArray; @@ -3926,71 +3876,130 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike case _ => val genericArrayData = classOf[GenericArrayData].getName val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", - s"get($i, $et)", s"update($pos, $value)", "Object", "", + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", s"${ev.value} = new $genericArrayData(new Object[$size]);") } - } else { - ("", "", "", "", "", "", "") - } - nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true + nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val array = ctx.freshName("array") + val arrays = ctx.freshName("arrays") + val arrayDataIdx = ctx.freshName("arrayDataIdx") val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" val hs = ctx.freshName("hs") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") - val array = ctx.freshName("array") - val arrayDataIdx = ctx.freshName("arrayDataIdx") + val genericArrayData = classOf[GenericArrayData].getName + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" + val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { + s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" + } else { + s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" + } + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val arrayBody = + s""" + |$javaTypeName $value = $array.$getter; + |$hsJavaTypeName $hsValue = $genHsValue; + |if (!$hs.contains($hsValue)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hs.add$postFix($hsValue); + | $builder.$$plus$$eq($value); + |} + """.stripMargin + + val nonNullArrayDataBuild = { + val build = if (postFix != "") { + val defaultSize = elementType.defaultSize + s""" + |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { + | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); + |} else { + | ${ev.value} = new $genericArrayData($builder.result()); + |} + """.stripMargin + } else { + s"${ev.value} = new $genericArrayData($builder.result());" + } + s""" + |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful try create array with " + $size + + | " bytes of data due to exceeding the limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + + | " $prettyName failed."); + |} + |$build + """.stripMargin + } + + def buildResultArrayData(nonNullArrayDataBuild: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($nullElementIndex < 0) { + | // result has no null element + | $nonNullArrayDataBuild + |} else { + | // result has null element + | $arrayDataBuilder + | $javaTypeName[] $array = $builder.result(); + | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { + | if ($pos == $nullElementIndex) { + | ${ev.value}.setNullAt($pos); + | } else { + | $javaTypeName $value = $array[$i++]; + | ${ev.value}.$setter; + | } + | } + |} + """.stripMargin + } else { + nonNullArrayDataBuild + } + s""" |$openHashSet $hs = new $openHashSet$postFix($classTag); |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + |int $size = 0; + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { | $arrayData $array = $arrays[$arrayDataIdx]; | for (int $i = 0; $i < $array.numElements(); $i++) { - | if ($array.isNullAt($i)) { - | $foundNullElement = true; - | } else { - | $hs.add$postFix($array.$getter); - | } - | } - |} - |int $size = $hs.size() + ($foundNullElement ? 1 : 0); - |$arrayBuilder - |$hs = new $openHashSet$postFix($classTag); - |$foundNullElement = false; - |int $pos = 0; - |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { - | $arrayData $array = $arrays[$arrayDataIdx]; - | for (int $i = 0; $i < $array.numElements(); $i++) { - | if ($array.isNullAt($i)) { - | if (!$foundNullElement) { - | ${ev.value}.setNullAt($pos++); - | $foundNullElement = true; - | } - | } else { - | $javaTypeName $value = $array.$getter; - | if (!$hs.contains($castOp $value)) { - | $hs.add$postFix($value); - | ${ev.value}.$setter; - | $pos++; - | } - | } + | ${withArrayNullAssignment(arrayBody)} | } |} + |${buildResultArrayData(nonNullArrayDataBuild)} """.stripMargin - } else { - val arrayUnion = classOf[ArrayUnion].getName - val et = ctx.addReferenceObj("elementTypeUnion", elementType) - val order = ctx.addReferenceObj("orderingUnion", ordering) - val method = "unionOrdering" - s"${ev.value} = $arrayUnion$$.MODULE$$.$method($array1, $array2, $et, $order);" - } - }) + }) + } else { + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val expr = ctx.addReferenceObj("arrayUnionExpr", this) + s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + }) + } } override def prettyName: String = "array_union" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 4daa113869b5d..c6b3f9502f2bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1362,10 +1362,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val a02 = Literal.create(Seq(1, 2, null, 4, 5), ArrayType(IntegerType, containsNull = true)) val a03 = Literal.create(Seq(-5, 4, -3, 2, 4), ArrayType(IntegerType, containsNull = false)) val a04 = Literal.create(Seq.empty[Int], ArrayType(IntegerType, containsNull = false)) - val a05 = Literal.create(Seq[Byte](1, 2, 3), ArrayType(ByteType, containsNull = false)) - val a06 = Literal.create(Seq[Byte](4, 2), ArrayType(ByteType, containsNull = false)) - val a07 = Literal.create(Seq[Short](1, 2, 3), ArrayType(ShortType, containsNull = false)) - val a08 = Literal.create(Seq[Short](4, 2), ArrayType(ShortType, containsNull = false)) + val abl0 = Literal.create(Seq[Boolean](true, true), ArrayType(BooleanType, false)) + val abl1 = Literal.create(Seq[Boolean](false, false), ArrayType(BooleanType, false)) + val ab0 = Literal.create(Seq[Byte](1, 2, 3, 2), ArrayType(ByteType, false)) + val ab1 = Literal.create(Seq[Byte](4, 2, 4), ArrayType(ByteType, false)) + val as0 = Literal.create(Seq[Short](1, 2, 3, 2), ArrayType(ShortType, false)) + val as1 = Literal.create(Seq[Short](4, 2, 4), ArrayType(ShortType, false)) + val af0 = Literal.create(Seq[Float](1.1F, 2.2F, 3.3F, 2.2F), ArrayType(FloatType, false)) + val af1 = Literal.create(Seq[Float](4.4F, 2.2F, 4.4F), ArrayType(FloatType, false)) + val ad0 = Literal.create(Seq[Double](1.1, 2.2, 3.3, 2.2), ArrayType(DoubleType, false)) + val ad1 = Literal.create(Seq[Double](4.4, 2.2, 4.4), ArrayType(DoubleType, false)) val a10 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, containsNull = false)) val a11 = Literal.create(Seq(4L, 2L), ArrayType(LongType, containsNull = false)) @@ -1384,8 +1390,11 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayUnion(a02, a03), Seq(1, 2, null, 4, 5, -5, -3)) checkEvaluation(ArrayUnion(a03, a02), Seq(-5, 4, -3, 2, 1, null, 5)) checkEvaluation(ArrayUnion(a02, a04), Seq(1, 2, null, 4, 5)) - checkEvaluation(ArrayUnion(a05, a06), Seq[Byte](1, 2, 3, 4)) - checkEvaluation(ArrayUnion(a07, a08), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(abl0, abl1), Seq[Boolean](true, false)) + checkEvaluation(ArrayUnion(ab0, ab1), Seq[Byte](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(as0, as1), Seq[Short](1, 2, 3, 4)) + checkEvaluation(ArrayUnion(af0, af1), Seq[Float](1.1F, 2.2F, 3.3F, 4.4F)) + checkEvaluation(ArrayUnion(ad0, ad1), Seq[Double](1.1, 2.2, 3.3, 4.4)) checkEvaluation(ArrayUnion(a10, a11), Seq(1L, 2L, 3L, 4L)) checkEvaluation(ArrayUnion(a12, a13), Seq(1L, 2L, null, 4L, 5L, -5L, -3L, -1L)) From d5e9158249e28f184ffd861258ac2dad07abdd0e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 4 Aug 2018 08:28:25 +0100 Subject: [PATCH 2/4] update --- .../expressions/collectionOperations.scala | 125 ++++-------------- 1 file changed, 24 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 149ab2cb5809c..87a78fbd95b70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3841,45 +3841,11 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") - val hsValue = ctx.freshName("hsValue") val size = ctx.freshName("size") - if (elementTypeSupportEquals) { - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - val (postFix, openHashElementType, hsJavaTypeName, genHsValue, - getter, setter, javaTypeName, primitiveTypeName, arrayDataBuilder) = - elementType match { - case ByteType | ShortType | IntegerType => - ("$mcI$sp", "Int", "int", s"(int) $value", - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case LongType | FloatType | DoubleType => - val signature = elementType match { - case LongType => "$mcJ$sp" - case FloatType => "$mcF$sp" - case DoubleType => "$mcD$sp" - } - (signature, CodeGenerator.boxedType(elementType), - CodeGenerator.javaType(elementType), value, - s"get$ptName($i)", s"set$ptName($pos, $value)", - CodeGenerator.javaType(elementType), ptName, - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", "Object", value, - s"get($i, $et)", s"update($pos, $value)", "Object", "Ref", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") @@ -3889,16 +3855,11 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val arrays = ctx.freshName("arrays") val arrayDataIdx = ctx.freshName("arrayDataIdx") val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") - val genericArrayData = classOf[GenericArrayData].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") val arrayBuilder = "scala.collection.mutable.ArrayBuilder" - val arrayBuilderClass = s"$arrayBuilder$$of$primitiveTypeName" - val arrayBuilderClassTag = if (primitiveTypeName != "Ref") { - s"scala.reflect.ClassTag$$.MODULE$$.$primitiveTypeName()" - } else { - s"scala.reflect.ClassTag$$.MODULE$$.AnyRef()" - } + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArrayNullAssignment(body: String) = if (dataType.asInstanceOf[ArrayType].containsNull) { @@ -3908,6 +3869,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike | $nullElementIndex = $size; | $foundNullElement = true; | $size++; + | $builder.$$plus$$eq($nullValueHolder); | } |} else { | $body @@ -3916,71 +3878,32 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } else { body } - val arrayBody = + + val processArray = withArrayNullAssignment( s""" - |$javaTypeName $value = $array.$getter; - |$hsJavaTypeName $hsValue = $genHsValue; - |if (!$hs.contains($hsValue)) { + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | break; | } - | $hs.add$postFix($hsValue); + | $hashSet.add$hsPostFix($hsValueCast$value); | $builder.$$plus$$eq($value); |} - """.stripMargin + """.stripMargin) - val nonNullArrayDataBuild = { - val build = if (postFix != "") { - val defaultSize = elementType.defaultSize - s""" - |if (!UnsafeArrayData.shouldUseGenericArrayData($defaultSize, $size)) { - | ${ev.value} = UnsafeArrayData.fromPrimitiveArray($builder.result()); - |} else { - | ${ev.value} = new $genericArrayData($builder.result()); - |} - """.stripMargin - } else { - s"${ev.value} = new $genericArrayData($builder.result());" - } + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { s""" - |if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful try create array with " + $size + - | " bytes of data due to exceeding the limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for GenericArrayData." + - | " $prettyName failed."); - |} - |$build + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; """.stripMargin + } else { + "" } - def buildResultArrayData(nonNullArrayDataBuild: String) = - if (dataType.asInstanceOf[ArrayType].containsNull) { - s""" - |if ($nullElementIndex < 0) { - | // result has no null element - | $nonNullArrayDataBuild - |} else { - | // result has null element - | $arrayDataBuilder - | $javaTypeName[] $array = $builder.result(); - | for (int $i = 0, $pos = 0; $pos < $size; $pos++) { - | if ($pos == $nullElementIndex) { - | ${ev.value}.setNullAt($pos); - | } else { - | $javaTypeName $value = $array[$i++]; - | ${ev.value}.$setter; - | } - | } - |} - """.stripMargin - } else { - nonNullArrayDataBuild - } - s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $foundNullElement = false; - |int $nullElementIndex = -1; + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables |int $size = 0; |$arrayBuilderClass $builder = | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); @@ -3988,10 +3911,10 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { | $arrayData $array = $arrays[$arrayDataIdx]; | for (int $i = 0; $i < $array.numElements(); $i++) { - | ${withArrayNullAssignment(arrayBody)} + | $processArray | } |} - |${buildResultArrayData(nonNullArrayDataBuild)} + |${buildResultArray(builder, ev.value, size, nullElementIndex)} """.stripMargin }) } else { From 93ec9ecf58db4c401eb695afe69cec0e0aea5ade Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 16:34:47 +0100 Subject: [PATCH 3/4] rebase and update --- .../expressions/collectionOperations.scala | 6 ++--- .../spark/sql/DataFrameFunctionsSuite.scala | 24 +++++++++---------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 87a78fbd95b70..66d3a8d18e6ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3857,9 +3857,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike val openHashSet = classOf[OpenHashSet[_]].getName val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" val hashSet = ctx.freshName("hashSet") - val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" - val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" def withArrayNullAssignment(body: String) = if (dataType.asInstanceOf[ArrayType].containsNull) { @@ -3905,8 +3904,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); |$declareNullTrackVariables |int $size = 0; - |$arrayBuilderClass $builder = - | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); + |$arrayBuilderClass $builder = new $arrayBuilderClass(); |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { | $arrayData $array = $arrays[$arrayDataIdx]; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 3c5831f33b23c..c04780db4e525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1148,28 +1148,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df5.selectExpr("array_union(a, b)"), ans5) val df6 = Seq((null, Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df6.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df6.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) val df7 = Seq((null, null)).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df7.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df7.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) val df8 = Seq((Array(Array(1)), Array("a"))).toDF("a", "b") - intercept[AnalysisException] { + assert(intercept[AnalysisException] { df8.select(array_union($"a", $"b")) - } - intercept[AnalysisException] { + }.getMessage.contains("data type mismatch")) + assert(intercept[AnalysisException] { df8.selectExpr("array_union(a, b)") - } + }.getMessage.contains("data type mismatch")) } test("concat function - arrays") { From 26c1f8a586cc8bd6f01d5519c7c9fda47453fb08 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 6 Aug 2018 19:11:43 +0100 Subject: [PATCH 4/4] address review comment --- .../catalyst/expressions/collectionOperations.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 66d3a8d18e6ce..fbb182631eefa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3839,7 +3839,6 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") @@ -3905,9 +3904,9 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike |$declareNullTrackVariables |int $size = 0; |$arrayBuilderClass $builder = new $arrayBuilderClass(); - |$arrayData[] $arrays = new $arrayData[]{$array1, $array2}; + |ArrayData[] $arrays = new ArrayData[]{$array1, $array2}; |for (int $arrayDataIdx = 0; $arrayDataIdx < 2; $arrayDataIdx++) { - | $arrayData $array = $arrays[$arrayDataIdx]; + | ArrayData $array = $arrays[$arrayDataIdx]; | for (int $i = 0; $i < $array.numElements(); $i++) { | $processArray | } @@ -3918,7 +3917,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike } else { nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayUnionExpr", this) - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" }) } } @@ -4084,7 +4083,6 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") @@ -4198,7 +4196,7 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetL } else { nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayIntersectExpr", this) - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" }) } } @@ -4317,7 +4315,6 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") val value = ctx.freshName("value") val size = ctx.freshName("size") @@ -4420,7 +4417,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike } else { nullSafeCodeGen(ctx, ev, (array1, array2) => { val expr = ctx.addReferenceObj("arrayExceptExpr", this) - s"${ev.value} = ($arrayData)$expr.nullSafeEval($array1, $array2);" + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array1, $array2);" }) } }