diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 6cc5bc85bcfb4..f7955bc923395 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3112,29 +3112,29 @@ case class ArrayDistinct(child: Expression) (data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]]) } else { (data: Array[AnyRef]) => { - var foundNullElement = false - var pos = 0 + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef] + var alreadyStoredNull = false for (i <- 0 until data.length) { - if (data(i) == null) { - if (!foundNullElement) { - foundNullElement = true - pos = pos + 1 - } - } else { + if (data(i) != null) { + var found = false var j = 0 - var done = false - while (j <= i && !done) { - if (data(j) != null && ordering.equiv(data(j), data(i))) { - done = true - } - j = j + 1 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + found = (va != null) && ordering.equiv(va, data(i)) + j += 1 } - if (i == j - 1) { - pos = pos + 1 + if (!found) { + arrayBuffer += data(i) + } + } else { + // De-duplicate the null values. + if (!alreadyStoredNull) { + arrayBuffer += data(i) + alreadyStoredNull = true } } } - new GenericArrayData(data.slice(0, pos)) + new GenericArrayData(arrayBuffer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index c1c045938cad8..603073b40d7aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -1364,6 +1364,8 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(DoubleType)) val a7 = Literal.create(Seq(1.123f, 0.1234f, 1.121f, 1.123f, 1.1230f, 1.121f, 0.1234f), ArrayType(FloatType)) + val a8 = + Literal.create(Seq(2, 1, 2, 3, 4, 4, 5).map(_.toString.getBytes), ArrayType(BinaryType)) checkEvaluation(new ArrayDistinct(a0), Seq(2, 1, 3, 4, 5)) checkEvaluation(new ArrayDistinct(a1), Seq.empty[Integer]) @@ -1373,6 +1375,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(new ArrayDistinct(a5), Seq(true, false)) checkEvaluation(new ArrayDistinct(a6), Seq(1.123, 0.1234, 1.121)) checkEvaluation(new ArrayDistinct(a7), Seq(1.123f, 0.1234f, 1.121f)) + checkEvaluation(new ArrayDistinct(a8), Seq(2, 1, 3, 4, 5).map(_.toString.getBytes)) // complex data types val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2), @@ -1393,9 +1396,17 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(ArrayType(IntegerType))) val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null), ArrayType(ArrayType(IntegerType))) + val c3 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](1, 2), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](4, 5)), ArrayType(ArrayType(IntegerType))) + val c4 = Literal.create(Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](1, 2), + Seq[Int](3, 4), Seq[Int](4, 5), null), ArrayType(ArrayType(IntegerType))) checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4))) checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1))) checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1))) + checkEvaluation(ArrayDistinct(c3), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4), + Seq[Int](4, 5))) + checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](3, 4), + Seq[Int](4, 5))) } test("Array Union") {