-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23931][SQL] Adds arrays_zip function to sparksql #21045
Changes from 35 commits
7bf45dd
99848fe
27b0bc2
93826b6
a7e29f6
7130fec
d552216
1fecef4
f71151a
6b4bc94
1549928
9f7bba1
3ba2b4f
3a59201
6462fa8
8b1eb7c
2bfba80
c3b062c
d9b95c4
26bbf66
d9ad04d
f29ee1c
c58d09c
38fa996
5b3066b
759a4d4
68e69db
12b3835
643cb9b
5876082
0223960
2b88387
bbc20ee
8d3a838
d8f3dea
3d68ea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,170 @@ case class MapKeys(child: Expression) | |
override def prettyName: String = "map_keys" | ||
} | ||
|
||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all | ||
N-th values of input arrays. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); | ||
[[1, 2], [2, 3], [3, 4]] | ||
> SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); | ||
[[1, 2, 3], [2, 3, 4]] | ||
""", | ||
since = "2.4.0") | ||
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) | ||
|
||
override def dataType: DataType = ArrayType(mountSchema) | ||
|
||
override def nullable: Boolean = children.exists(_.nullable) | ||
|
||
private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) | ||
|
||
private lazy val arrayElementTypes = arrayTypes.map(_.elementType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have more than one |
||
|
||
@transient private lazy val mountSchema: StructType = { | ||
val fields = children.zip(arrayElementTypes).zipWithIndex.map { | ||
case ((expr: NamedExpression, elementType), _) => | ||
StructField(expr.name, elementType, nullable = true) | ||
case ((_, elementType), idx) => | ||
StructField(idx.toString, elementType, nullable = true) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to make val fields = arrayTypes.zipWithIndex.map { case (arr, idx) =>
StructField( ... )
} ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thank you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about: val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(s"$idx", elementType, nullable = true)
} ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Way better, thanks! |
||
StructType(fields) | ||
} | ||
|
||
@transient lazy val numberOfArrays: Int = children.length | ||
|
||
@transient lazy val genericArrayData = classOf[GenericArrayData].getName | ||
|
||
def emptyInputGenCode(ev: ExprCode): ExprCode = { | ||
ev.copy(code""" | ||
|${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); | ||
|boolean ${ev.isNull} = false; | ||
""".stripMargin) | ||
} | ||
|
||
def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val genericInternalRow = classOf[GenericInternalRow].getName | ||
val arrVals = ctx.freshName("arrVals") | ||
val biggestCardinality = ctx.freshName("biggestCardinality") | ||
|
||
val currentRow = ctx.freshName("currentRow") | ||
val j = ctx.freshName("j") | ||
val i = ctx.freshName("i") | ||
val args = ctx.freshName("args") | ||
|
||
val evals = children.map(_.genCode(ctx)) | ||
val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => | ||
s""" | ||
|if ($biggestCardinality != -1) { | ||
| ${eval.code} | ||
| if (!${eval.isNull}) { | ||
| $arrVals[$index] = ${eval.value}; | ||
| $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); | ||
| } else { | ||
| $biggestCardinality = -1; | ||
| } | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
val splittedGetValuesAndCardinalities = ctx.splitExpressions( | ||
expressions = getValuesAndCardinalities, | ||
funcName = "getValuesAndCardinalities", | ||
returnType = "int", | ||
makeSplitFunction = body => | ||
s""" | ||
|$body | ||
|return $biggestCardinality; | ||
""".stripMargin, | ||
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), | ||
arguments = | ||
("ArrayData[]", arrVals) :: | ||
("int", biggestCardinality) :: Nil) | ||
|
||
val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => | ||
val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) | ||
s""" | ||
|if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { | ||
| $currentRow[$idx] = $g; | ||
|} else { | ||
| $currentRow[$idx] = null; | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
val getValueForTypeSplitted = ctx.splitExpressions( | ||
expressions = getValueForType, | ||
funcName = "extractValue", | ||
arguments = | ||
("int", i) :: | ||
("Object[]", currentRow) :: | ||
("ArrayData[]", arrVals) :: Nil) | ||
|
||
val initVariables = s""" | ||
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; | ||
|int $biggestCardinality = 0; | ||
|${CodeGenerator.javaType(dataType)} ${ev.value} = null; | ||
""".stripMargin | ||
|
||
ev.copy(code""" | ||
|$initVariables | ||
|$splittedGetValuesAndCardinalities | ||
|boolean ${ev.isNull} = $biggestCardinality == -1; | ||
|if (!${ev.isNull}) { | ||
| Object[] $args = new Object[$biggestCardinality]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We usually don't set a value if the result is null. |
||
| for (int $i = 0; $i < $biggestCardinality; $i ++) { | ||
| Object[] $currentRow = new Object[$numberOfArrays]; | ||
| $getValueForTypeSplitted | ||
| $args[$i] = new $genericInternalRow($currentRow); | ||
| } | ||
| ${ev.value} = new $genericArrayData($args); | ||
|} | ||
""".stripMargin) | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
if (numberOfArrays == 0) { | ||
emptyInputGenCode(ev) | ||
} else { | ||
nonEmptyInputGenCode(ctx, ev) | ||
} | ||
} | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) | ||
if (inputArrays.contains(null)) { | ||
null | ||
} else { | ||
val biggestCardinality = if (inputArrays.isEmpty) { | ||
0 | ||
} else { | ||
inputArrays.map(_.numElements()).max | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can compute the val biggestCardinality = if (inputArrays.isEmpty) {
0
} else {
inputArrays.map(_.numElements()).max
} |
||
|
||
val result = new Array[InternalRow](biggestCardinality) | ||
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex | ||
|
||
for (i <- 0 until biggestCardinality) { | ||
val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => | ||
if (i < arr.numElements() && !arr.isNullAt(i)) { | ||
arr.get(i, arrayElementTypes(index)) | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
result(i) = InternalRow.apply(currentLayer: _*) | ||
} | ||
new GenericArrayData(result) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
} | ||
|
||
/** | ||
* Returns an unordered array containing the values of the map. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.types._ | ||
|
||
|
@@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
Some(Literal.create(null, StringType))), null) | ||
} | ||
|
||
test("ArraysZip") { | ||
val literals = Seq( | ||
Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), | ||
Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), | ||
Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), | ||
Literal.create(Seq("a", null, "c"), ArrayType(StringType)), | ||
Literal.create(Seq(null, false, true), ArrayType(BooleanType)), | ||
Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), | ||
Literal.create(Seq(), ArrayType(NullType)), | ||
Literal.create(Seq(null), ArrayType(NullType)), | ||
Literal.create(Seq(192.toByte), ArrayType(ByteType)), | ||
Literal.create( | ||
Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), | ||
Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) | ||
) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(1))), | ||
List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why do you use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that at some point I was using |
||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(2))), | ||
List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(3))), | ||
List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(4))), | ||
List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(5))), | ||
List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(6))), | ||
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(7))), | ||
List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), literals(1), literals(2), literals(3))), | ||
List( | ||
Row(9001, null, -1, "a"), | ||
Row(9002, 1L, -3, null), | ||
Row(9003, null, 900, "c"), | ||
Row(null, 4L, null, null), | ||
Row(null, 11L, null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), | ||
List( | ||
Row(null, 1.1, null, null, 192.toByte), | ||
Row(false, null, null, null, null), | ||
Row(true, 1.3, null, null, null), | ||
Row(null, null, null, null, null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(9), literals(0))), | ||
List( | ||
Row(List(1, 2, 3), 9001), | ||
Row(null, 9002), | ||
Row(List(4, 5), 9003), | ||
Row(List(1, null, 3), null))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(7), literals(10))), | ||
List(Row(null, Array[Byte](1.toByte, 5.toByte)))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add some tests with many input arrays? eg.100 or 1000 if needed in order to force the splitExpression to actually split the generated code in order to have test coverage for that? you can check if splitExpression is splitting the code in debug mode to be sure that it happens. Thanks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked and looks like the functions are being correctly splitted. |
||
|
||
val longLiteral = | ||
Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), longLiteral)), | ||
List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ | ||
(3 to 1000).map { Row(null, _) }.toList) | ||
|
||
val manyLiterals = (0 to 1000).map { _ => | ||
Literal.create(Seq(1), ArrayType(IntegerType)) | ||
}.toSeq | ||
|
||
val numbers = List( | ||
Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), | ||
Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), | ||
Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), | ||
Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) | ||
checkEvaluation(ArraysZip(Seq(literals(0)) ++ manyLiterals), | ||
List(numbers(0), numbers(1), numbers(2), numbers(3))) | ||
|
||
checkEvaluation(ArraysZip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) | ||
checkEvaluation(ArraysZip(Seq()), List()) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a case for something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also what if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to test the case with Should I change something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant the test like CollectionExpressionsSuite.scala#L522-L524. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried, but looks like I can't create an empty genericarray. I'm pushing a code with the tests that cover these scenarios commented so maybe anyone could give me suggestions while I look for another solution. |
||
|
||
test("Array Min") { | ||
checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) | ||
checkEvaluation( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick follow up question... Under what circumstances can the output array contain
null
elements? Shouldn't the output dataType beArrayType(mountSchema, false)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the first test case I zipped
Seq(9001, 9002, 9003, null)
withSeq(null, 1L, null, 4L, 11L)
, and expected the result to beSeq(Seq(9001, null), Seq(9002, 1L), ..., Seq(null, 11L))
, for instance.I tried to define the nullability (this word exist? haha) of the output in runtime, but I thought that it was not possible since I can't eval every result before defining the schema.
What you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that fields of the nested
struct
can benull
, but can you give me an example of the input that would lead to something likeSeq(null, Seq(9002, 1L))
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm you are correct then, I don't think that such scenario could happen (correctly, at least). That means that the dataType should always reject null values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, seems like the struct which is the element of the array is not null, so the data type would be
ArrayType(mountSchema, containsNull = false)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will fix it as a part of #21352.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the struct can be null if any of the input element is null IIUC. So probably
ArrayType(mountSchema, containsNull = children.exists(_nullable))
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, the array itself will be null and
def nullable
is alreadychildren.exists(_nullable)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right, sorry!