Skip to content

Commit

Permalink
[SPARK-27134] array_distinct function does not work correctly with co…
Browse files Browse the repository at this point in the history
…lumns containing array of array
  • Loading branch information
dilipbiswal committed Mar 12, 2019
1 parent f1e223b commit 1e76669
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3088,6 +3088,7 @@ case class ArrayDistinct(child: Expression)

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)


override def dataType: DataType = child.dataType

@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
Expand All @@ -3112,29 +3113,29 @@ case class ArrayDistinct(child: Expression)
(data: Array[AnyRef]) => new GenericArrayData(data.distinct.asInstanceOf[Array[Any]])
} else {
(data: Array[AnyRef]) => {
var foundNullElement = false
var pos = 0
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
var alreadyStoredNull = false
var found = false
for (i <- 0 until data.length) {
if (data(i) == null) {
if (!foundNullElement) {
foundNullElement = true
pos = pos + 1
if (data(i) != null) {
found = false
var j = 0;
while (!found && j < arrayBuffer.size) {
val va = arrayBuffer(j)
found = (va != null) && ordering.equiv(va, data(i))
j += 1
}
} else {
var j = 0
var done = false
while (j <= i && !done) {
if (data(j) != null && ordering.equiv(data(j), data(i))) {
done = true
}
j = j + 1
if (!found) {
arrayBuffer += data(i)
}
if (i == j - 1) {
pos = pos + 1
} else {
if (!alreadyStoredNull) {
arrayBuffer += data(i)
alreadyStoredNull = true
}
}
}
new GenericArrayData(data.slice(0, pos))
new GenericArrayData(arrayBuffer)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1393,9 +1393,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayType(ArrayType(IntegerType)))
val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1), null, null, Seq[Int](2, 1), null),
ArrayType(ArrayType(IntegerType)))
val c4 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](1, 2), Seq[Int](1, 2),
Seq[Int](3, 4), Seq[Int](4, 5)), ArrayType(ArrayType(IntegerType)))
checkEvaluation(ArrayDistinct(c0), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)))
checkEvaluation(ArrayDistinct(c1), Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c2), Seq[Seq[Int]](null, Seq[Int](2, 1)))
checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4),
Seq[Int](4, 5)))
}

test("Array Union") {
Expand Down

0 comments on commit 1e76669

Please sign in to comment.