diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda525294259..d382d9aace109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -764,6 +764,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5ca67e6eb767b..baeefae570997 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -195,17 +195,14 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp values: String, arrayData: String, numElements: String): String = { - val byteArraySize = ctx.freshName("byteArraySize") - val data = ctx.freshName("byteArray") val unsafeRow = ctx.freshName("unsafeRow") val unsafeArrayData = ctx.freshName("unsafeArrayData") val structsOffset = ctx.freshName("structsOffset") - val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray" val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" val baseOffset = Platform.BYTE_ARRAY_OFFSET - val longSize = LongType.defaultSize - val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2 + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 val structSizeAsLong = structSize + "L" val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) @@ -223,27 +220,26 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp valueAssignment } - s""" - |final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize}); - |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)} - |} else { - | final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize; - | final byte[] $data = new byte[(int)$byteArraySize]; - | UnsafeArrayData $unsafeArrayData = new UnsafeArrayData(); - | Platform.putLong($data, $baseOffset, $numElements); - | $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize); - | UnsafeRow $unsafeRow = new UnsafeRow(2); - | for (int z = 0; z < $numElements; z++) { - | long offset = $structsOffset + z * $structSizeAsLong; - | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); - | $unsafeRow.pointTo($data, $baseOffset + offset, $structSize); - | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); - | $valueAssignmentChecked - | } - | $arrayData = $unsafeArrayData; - |} - """.stripMargin + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) } private def genCodeForAnyElements( @@ -258,10 +254,10 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { - s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" - } else { - getValue(values) - } + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } s""" |final Object[] $data = new Object[$numElements];