-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23935][SQL] Adding map_entries function #21236
Changes from 6 commits
086e223
b9e2409
4739977
d05ad9b
6aa90ef
56ff20a
1bd0d5e
baa61e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder | |
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.Platform | ||
import org.apache.spark.unsafe.array.ByteArrayMethods | ||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String} | ||
|
||
|
@@ -118,6 +119,162 @@ case class MapValues(child: Expression) | |
override def prettyName: String = "map_values" | ||
} | ||
|
||
/** | ||
* Returns an unordered array of all entries in the given map. | ||
*/ | ||
@ExpressionDescription( | ||
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(map(1, 'a', 2, 'b')); | ||
[(1,"a"),(2,"b")] | ||
""", | ||
since = "2.4.0") | ||
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(MapType) | ||
|
||
lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] | ||
|
||
override def dataType: DataType = { | ||
ArrayType( | ||
StructType( | ||
StructField("key", childDataType.keyType, false) :: | ||
StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: | ||
Nil), | ||
false) | ||
} | ||
|
||
override protected def nullSafeEval(input: Any): Any = { | ||
val childMap = input.asInstanceOf[MapData] | ||
val keys = childMap.keyArray() | ||
val values = childMap.valueArray() | ||
val length = childMap.numElements() | ||
val resultData = new Array[AnyRef](length) | ||
var i = 0; | ||
while (i < length) { | ||
val key = keys.get(i, childDataType.keyType) | ||
val value = values.get(i, childDataType.valueType) | ||
val row = new GenericInternalRow(Array[Any](key, value)) | ||
resultData.update(i, row) | ||
i += 1 | ||
} | ||
new GenericArrayData(resultData) | ||
} | ||
|
||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
nullSafeCodeGen(ctx, ev, c => { | ||
val numElements = ctx.freshName("numElements") | ||
val keys = ctx.freshName("keys") | ||
val values = ctx.freshName("values") | ||
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) | ||
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
val code = if (isKeyPrimitive && isValuePrimitive) { | ||
genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) | ||
} else { | ||
genCodeForAnyElements(ctx, keys, values, ev.value, numElements) | ||
} | ||
s""" | ||
|final int $numElements = $c.numElements(); | ||
|final ArrayData $keys = $c.keyArray(); | ||
|final ArrayData $values = $c.valueArray(); | ||
|$code | ||
""".stripMargin | ||
}) | ||
} | ||
|
||
private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") | ||
|
||
private def getValue(varName: String) = { | ||
CodeGenerator.getValue(varName, childDataType.valueType, "z") | ||
} | ||
|
||
private def genCodeForPrimitiveElements( | ||
ctx: CodegenContext, | ||
keys: String, | ||
values: String, | ||
arrayData: String, | ||
numElements: String): String = { | ||
val byteArraySize = ctx.freshName("byteArraySize") | ||
val data = ctx.freshName("byteArray") | ||
val unsafeRow = ctx.freshName("unsafeRow") | ||
val unsafeArrayData = ctx.freshName("unsafeArrayData") | ||
val structsOffset = ctx.freshName("structsOffset") | ||
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" | ||
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" | ||
|
||
val baseOffset = Platform.BYTE_ARRAY_OFFSET | ||
val longSize = LongType.defaultSize | ||
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 | ||
val structSizeAsLong = structSize + "L" | ||
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) | ||
|
||
val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" | ||
val valueAssignmentChecked = if (childDataType.valueContainsNull) { | ||
s""" | ||
|if ($values.isNullAt(z)) { | ||
| $unsafeRow.setNullAt(1); | ||
|} else { | ||
| $valueAssignment | ||
|} | ||
""".stripMargin | ||
} else { | ||
valueAssignment | ||
} | ||
|
||
s""" | ||
|final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); | ||
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { | ||
| ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, should we use this idiom for other array functions? WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, I separated the logic that I can leverage for |
||
|} else { | ||
| final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; | ||
| final byte[] $data = new byte[(int)$byteArraySize]; | ||
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); | ||
| Platform.putLong($data, $baseOffset, $numElements); | ||
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); | ||
| UnsafeRow $unsafeRow = new UnsafeRow(2); | ||
| for (int z = 0; z < $numElements; z++) { | ||
| long offset = $structsOffset + z * $structSizeAsLong; | ||
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); | ||
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); | ||
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); | ||
| $valueAssignmentChecked | ||
| } | ||
| $arrayData = $unsafeArrayData; | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
private def genCodeForAnyElements( | ||
ctx: CodegenContext, | ||
keys: String, | ||
values: String, | ||
arrayData: String, | ||
numElements: String): String = { | ||
val genericArrayClass = classOf[GenericArrayData].getName | ||
val rowClass = classOf[GenericInternalRow].getName | ||
val data = ctx.freshName("internalRowArray") | ||
|
||
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) | ||
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { | ||
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" | ||
} else { | ||
getValue(values) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: indent |
||
|
||
s""" | ||
|final Object[] $data = new Object[$numElements]; | ||
|for (int z = 0; z < $numElements; z++) { | ||
| $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); | ||
|} | ||
|$arrayData = new $genericArrayClass($data); | ||
""".stripMargin | ||
} | ||
|
||
override def prettyName: String = "map_entries" | ||
} | ||
|
||
/** | ||
* Common base class for [[SortArray]] and [[ArraySort]]. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { | |
if (expected.isNaN) result.isNaN else expected == result | ||
case (result: Float, expected: Float) => | ||
if (expected.isNaN) result.isNaN else expected == result | ||
case (result: UnsafeRow, expected: GenericInternalRow) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mn-mikke I was just looking over compiler warnings, and noticed it claims this case is never triggered. I think it's because it would always first match the (InternalRow, InternalRow) case above. Should it go before that then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Roger that, looks like Wenchen just did so. Thanks! |
||
val structType = exprDataType.asInstanceOf[StructType] | ||
result.toSeq(structType) == expected.toSeq(structType) | ||
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) | ||
case _ => | ||
result == expected | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering it is right to use
longSize
here?I know the value is
8
and is same as the word size, but feel like the meaning is different?cc @gatorsmile @cloud-fan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ueshin Really good question. I'm eager to learn about the true purpose of the
DataType.defaultSize
function. Currently, it's used in this meaning at more places (e.g.GenArrayData.genCodeToCreateArrayData
andCodeGenerator.createUnsafeArray
.)What about using
Long.BYTES
from Java 8 instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO,
8
is the better choice since it is not related to an element size oflong
.To my best guess, it would be best to define a new constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kiszk Thanks for your suggestion, but it seems to me that
LongType.defaultSize
could be used in this case. It seems that the purpose ofdefaultSize
is not only the calculation of estimated data size in statistics.GenerateUnsafeProjection.writeArrayToBuffer
,InterpretedUnsafeProjection.getElementSize
and other parts utilizedefaultSize
in the same way.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not for the element size of arrays. I agree with @kiszk to use
8
.Maybe we need to add a constant to represent the word size in
UnsafeRow
or somewhere in the future pr.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh OK, I misunderstood the comments. Thanks guys!