-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23931][SQL] Adds arrays_zip function to sparksql #21045
Closed
Closed
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
7bf45dd
Adds zip function to sparksql
DylanGuedes 99848fe
Changes zip construction
DylanGuedes 27b0bc2
Changes tests and uses builtin namespace in pyspark
DylanGuedes 93826b6
fixes examples string and uses struct instead of arrays
DylanGuedes a7e29f6
working pyspark zip_lists
DylanGuedes 7130fec
Fixes java version when arrays have different lengths
DylanGuedes d552216
remove unused variables
DylanGuedes 1fecef4
rename zip_lists to zip
DylanGuedes f71151a
adds expression tests and uses strip margin syntax
DylanGuedes 6b4bc94
Adds variable number of inputs to zip function
DylanGuedes 1549928
uses foldleft instead of while for iterating
DylanGuedes 9f7bba1
rewritten some notation
DylanGuedes 3ba2b4f
fix dogencode generation
DylanGuedes 3a59201
Adds new tests, uses lazy val and split calls
DylanGuedes 6462fa8
uses splitFunction
DylanGuedes 8b1eb7c
move arraytypes to private member
DylanGuedes 2bfba80
adds binary and array of array tests
DylanGuedes c3b062c
uses stored array types names
DylanGuedes d9b95c4
split input function using ctxsplitexpression
DylanGuedes 26bbf66
uses splitexpression for inputs
DylanGuedes d9ad04d
Refactor cases, add new tests with empty seq, check size of array
DylanGuedes f29ee1c
Check empty seq as input
DylanGuedes c58d09c
Uses switch instead of if
DylanGuedes 38fa996
refactor switch and else methods
DylanGuedes 5b3066b
uses if instead of switch
DylanGuedes 759a4d4
Not using storedarrtype anymore
DylanGuedes 68e69db
split between empty and nonempty codegen
DylanGuedes 12b3835
remove ternary if
DylanGuedes 643cb9b
Fixes null values evaluation and adds back tests
DylanGuedes 5876082
move to else
DylanGuedes 0223960
remove unused lines
DylanGuedes 2b88387
use zip alias
DylanGuedes bbc20ee
using same docs for all apis
DylanGuedes 8d3a838
adds transient to method
DylanGuedes d8f3dea
rename zip function to arrays_zip
DylanGuedes 3d68ea9
adds pretty_name for arrays_zip
DylanGuedes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,172 @@ case class MapKeys(child: Expression) | |
override def prettyName: String = "map_keys" | ||
} | ||
|
||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all | ||
N-th values of input arrays. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); | ||
[[1, 2], [2, 3], [3, 4]] | ||
> SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); | ||
[[1, 2, 3], [2, 3, 4]] | ||
""", | ||
since = "2.4.0") | ||
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) | ||
|
||
override def dataType: DataType = ArrayType(mountSchema) | ||
|
||
override def nullable: Boolean = children.exists(_.nullable) | ||
|
||
private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) | ||
|
||
private lazy val arrayElementTypes = arrayTypes.map(_.elementType) | ||
|
||
@transient private lazy val mountSchema: StructType = { | ||
val fields = children.zip(arrayElementTypes).zipWithIndex.map { | ||
case ((expr: NamedExpression, elementType), _) => | ||
StructField(expr.name, elementType, nullable = true) | ||
case ((_, elementType), idx) => | ||
StructField(idx.toString, elementType, nullable = true) | ||
} | ||
StructType(fields) | ||
} | ||
|
||
@transient lazy val numberOfArrays: Int = children.length | ||
|
||
@transient lazy val genericArrayData = classOf[GenericArrayData].getName | ||
|
||
def emptyInputGenCode(ev: ExprCode): ExprCode = { | ||
ev.copy(code""" | ||
|${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); | ||
|boolean ${ev.isNull} = false; | ||
""".stripMargin) | ||
} | ||
|
||
def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val genericInternalRow = classOf[GenericInternalRow].getName | ||
val arrVals = ctx.freshName("arrVals") | ||
val biggestCardinality = ctx.freshName("biggestCardinality") | ||
|
||
val currentRow = ctx.freshName("currentRow") | ||
val j = ctx.freshName("j") | ||
val i = ctx.freshName("i") | ||
val args = ctx.freshName("args") | ||
|
||
val evals = children.map(_.genCode(ctx)) | ||
val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => | ||
s""" | ||
|if ($biggestCardinality != -1) { | ||
| ${eval.code} | ||
| if (!${eval.isNull}) { | ||
| $arrVals[$index] = ${eval.value}; | ||
| $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements()); | ||
| } else { | ||
| $biggestCardinality = -1; | ||
| } | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
val splittedGetValuesAndCardinalities = ctx.splitExpressions( | ||
expressions = getValuesAndCardinalities, | ||
funcName = "getValuesAndCardinalities", | ||
returnType = "int", | ||
makeSplitFunction = body => | ||
s""" | ||
|$body | ||
|return $biggestCardinality; | ||
""".stripMargin, | ||
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), | ||
arguments = | ||
("ArrayData[]", arrVals) :: | ||
("int", biggestCardinality) :: Nil) | ||
|
||
val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => | ||
val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) | ||
s""" | ||
|if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) { | ||
| $currentRow[$idx] = $g; | ||
|} else { | ||
| $currentRow[$idx] = null; | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
val getValueForTypeSplitted = ctx.splitExpressions( | ||
expressions = getValueForType, | ||
funcName = "extractValue", | ||
arguments = | ||
("int", i) :: | ||
("Object[]", currentRow) :: | ||
("ArrayData[]", arrVals) :: Nil) | ||
|
||
val initVariables = s""" | ||
|ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; | ||
|int $biggestCardinality = 0; | ||
|${CodeGenerator.javaType(dataType)} ${ev.value} = null; | ||
""".stripMargin | ||
|
||
ev.copy(code""" | ||
|$initVariables | ||
|$splittedGetValuesAndCardinalities | ||
|boolean ${ev.isNull} = $biggestCardinality == -1; | ||
|if (!${ev.isNull}) { | ||
| Object[] $args = new Object[$biggestCardinality]; | ||
| for (int $i = 0; $i < $biggestCardinality; $i ++) { | ||
| Object[] $currentRow = new Object[$numberOfArrays]; | ||
| $getValueForTypeSplitted | ||
| $args[$i] = new $genericInternalRow($currentRow); | ||
| } | ||
| ${ev.value} = new $genericArrayData($args); | ||
|} | ||
""".stripMargin) | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
if (numberOfArrays == 0) { | ||
emptyInputGenCode(ev) | ||
} else { | ||
nonEmptyInputGenCode(ctx, ev) | ||
} | ||
} | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) | ||
if (inputArrays.contains(null)) { | ||
null | ||
} else { | ||
val biggestCardinality = if (inputArrays.isEmpty) { | ||
0 | ||
} else { | ||
inputArrays.map(_.numElements()).max | ||
} | ||
|
||
val result = new Array[InternalRow](biggestCardinality) | ||
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex | ||
|
||
for (i <- 0 until biggestCardinality) { | ||
val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => | ||
if (i < arr.numElements() && !arr.isNullAt(i)) { | ||
arr.get(i, arrayElementTypes(index)) | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
result(i) = InternalRow.apply(currentLayer: _*) | ||
} | ||
new GenericArrayData(result) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
|
||
override def prettyName: String = "arrays_zip" | ||
} | ||
|
||
/** | ||
* Returns an unordered array containing the values of the map. | ||
*/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick follow up question... Under what circumstances can the output array contain
null
elements? Shouldn't the output dataType beArrayType(mountSchema, false)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the first test case I zipped
Seq(9001, 9002, 9003, null)
withSeq(null, 1L, null, 4L, 11L)
, and expected the result to beSeq(Seq(9001, null), Seq(9002, 1L), ..., Seq(null, 11L))
, for instance.I tried to define the nullability (this word exist? haha) of the output in runtime, but I thought that it was not possible since I can't eval every result before defining the schema.
What you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that fields of the nested
struct
can benull
, but can you give me an example of the input that would lead to something likeSeq(null, Seq(9002, 1L))
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm you are correct then, I don't think that such scenario could happen (correctly, at least). That means that the dataType should always reject null values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, seems like the struct which is the element of the array is not null, so the data type would be
ArrayType(mountSchema, containsNull = false)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I will fix it as a part of #21352.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the struct can be null if any of the input element is null IIUC. So probably
ArrayType(mountSchema, containsNull = children.exists(_nullable))
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, the array itself will be null and
def nullable
is alreadychildren.exists(_nullable)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right, sorry!