Skip to content

Commit

Permalink
[SPARK-23935][SQL] Addressing review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
mn-mikke committed May 14, 2018
1 parent 56ff20a commit 1bd0d5e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
*/
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {

public static final int WORD_SIZE = 8;

//////////////////////////////////////////////////////////////////////////////
// Static methods
//////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,40 @@ class CodegenContext {
""".stripMargin
}

/**
* Generates code creating a [[UnsafeArrayData]]. The generated code executes
* a provided fallback when the size of backing array would exceed the array size limit.
* @param arrayName a name of the array to create
* @param numElements a piece of code representing the number of elements the array should contain
* @param elementSize a size of an element in bytes
* @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
* and getting the backing array as a parameter
* @param fallbackCode a piece of code executed when the array size limit is exceeded
*/
def createUnsafeArrayWithFallback(
arrayName: String,
numElements: String,
elementSize: Int,
bodyCode: String => String,
fallbackCode: String): String = {
val arraySize = freshName("size")
val arrayBytes = freshName("arrayBytes")
s"""
|final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElements,
| $elementSize);
|if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| $fallbackCode
|} else {
| final byte[] $arrayBytes = new byte[(int)$arraySize];
| UnsafeArrayData $arrayName = new UnsafeArrayData();
| Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
| $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
| ${bodyCode(arrayBytes)}
|}
""".stripMargin
}

/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,14 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
values: String,
arrayData: String,
numElements: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val data = ctx.freshName("byteArray")
val unsafeRow = ctx.freshName("unsafeRow")
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val structsOffset = ctx.freshName("structsOffset")
val calculateArraySize = "UnsafeArrayData.calculateSizeOfUnderlyingByteArray"
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val longSize = LongType.defaultSize
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + longSize * 2
val wordSize = UnsafeRow.WORD_SIZE
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
val structSizeAsLong = structSize + "L"
val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
Expand All @@ -223,27 +220,26 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
valueAssignment
}

s"""
|final long $byteArraySize = $calculateArraySize($numElements, ${longSize + structSize});
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, keys, values, arrayData, numElements)}
|} else {
| final int $structsOffset = $calculateHeader($numElements) + $numElements * $longSize;
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeArrayData $unsafeArrayData = new UnsafeArrayData();
| Platform.putLong($data, $baseOffset, $numElements);
| $unsafeArrayData.pointTo($data, $baseOffset, (int)$byteArraySize);
| UnsafeRow $unsafeRow = new UnsafeRow(2);
| for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSizeAsLong;
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($data, $baseOffset + offset, $structSize);
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
| $valueAssignmentChecked
| }
| $arrayData = $unsafeArrayData;
|}
""".stripMargin
val assignmentLoop = (byteArray: String) =>
s"""
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
|UnsafeRow $unsafeRow = new UnsafeRow(2);
|for (int z = 0; z < $numElements; z++) {
| long offset = $structsOffset + z * $structSizeAsLong;
| $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
| $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
| $valueAssignmentChecked
|}
|$arrayData = $unsafeArrayData;
""".stripMargin

ctx.createUnsafeArrayWithFallback(
unsafeArrayData,
numElements,
structSize + wordSize,
assignmentLoop,
genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
}

private def genCodeForAnyElements(
Expand All @@ -258,10 +254,10 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp

val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
} else {
getValue(values)
}
s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
} else {
getValue(values)
}

s"""
|final Object[] $data = new Object[$numElements];
Expand Down

0 comments on commit 1bd0d5e

Please sign in to comment.