diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 96a1b07a41015..796f64c0eb277 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.string.UTF8StringMethods; import scala.collection.Map; import scala.collection.Seq; @@ -62,12 +63,16 @@ private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8; } + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8; + } + public UnsafeRow() { } public void set(Object baseObject, long baseOffset, int numFields, StructType schema) { assert numFields >= 0 : "numFields should >= 0"; assert schema == null || schema.fields().length == numFields; - this.bitSetWidthInBytes = ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; @@ -219,9 +224,11 @@ public double getDouble(int i) { @Override public String getString(int i) { assertIndexIsValid(i); - // TODO - - throw new UnsupportedOperationException(); + final long offsetToStringSize = getLong(i); + final long stringSizeInBytes = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); + // TODO: ugly cast; figure out whether we'll support mega long strings + return UTF8StringMethods.toJavaString(baseObject, baseOffset + offsetToStringSize + 8, (int) stringSizeInBytes); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala new file mode 100644 index 0000000000000..f4d5a5cbd8af4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** Write a column into an UnsafeRow */ +private abstract class UnsafeColumnWriter[T] { + /** + * Write a value into an UnsafeRow. + * + * @param value the value to write + * @param columnNumber what column to write it to + * @param row a pointer to the unsafe row + * @param baseObject + * @param baseOffset + * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * used for calculating where variable-length data should be written + * @return the number of variable-length bytes written + */ + def write( + value: T, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int + + /** + * Return the number of bytes that are needed to write this variable-length value. + */ + def getSize(value: T): Int +} + +private object UnsafeColumnWriter { + def forType(dataType: DataType): UnsafeColumnWriter[_] = { + dataType match { + case IntegerType => IntUnsafeColumnWriter + case LongType => LongUnsafeColumnWriter + case StringType => StringUnsafeColumnWriter + case _ => throw new UnsupportedOperationException() + } + } +} + +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] { + def getSize(value: UTF8String): Int = { + // round to nearest word + val numBytes = value.getBytes.length + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write( + value: UTF8String, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int = { + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + row.setLong(columnNumber, appendCursor) + 8 + ((numBytes / 8) + (if (numBytes % 8 == 0) 0 else 1)) * 8 + } +} +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + +private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] { + def getSize(value: T): Int = 0 +} + +private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Int] { + override def write( + value: Int, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int = { + row.setInt(columnNumber, value) + 0 + } +} +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter + +private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] { + override def write( + value: Long, + columnNumber: Int, + row: UnsafeRow, + baseObject: Object, + baseOffset: Long, + appendCursor: Int): Int = { + row.setLong(columnNumber, value) + 0 + } +} +private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter + + +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + private[this] val writers: Array[UnsafeColumnWriter[Any]] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]]) + } + + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row.get(fieldNumber)) + + } + fieldNumber += 1 + } + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + variableLengthFieldSize + } + + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + val unsafeRow = new UnsafeRow() + unsafeRow.set(baseObject, baseOffset, writers.length, null) // TODO: schema? + var fieldNumber = 0 + var appendCursor: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + // TODO: type-specific null value writing? + } else { + appendCursor += writers(fieldNumber).write( + row.get(fieldNumber), + fieldNumber, + unsafeRow, + baseObject, + baseOffset, + appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala new file mode 100644 index 0000000000000..ed1e907286f4b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType} +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.scalatest.{FunSuite, Matchers} + + +class UnsafeRowConverterSuite extends FunSuite with Matchers { + + test("basic conversion with only primitive types") { + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setLong(1, 1) + row.setInt(2, 2) + val converter = new UnsafeRowConverter(fieldTypes) + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (3 * 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + val unsafeRow = new UnsafeRow() + unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getLong(1) should be (1) + unsafeRow.getInt(2) should be (2) + } + + test("basic conversion with primitive and string types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setString(2, "World") + val converter = new UnsafeRowConverter(fieldTypes) + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + val unsafeRow = new UnsafeRow() + unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getString(2) should be ("World") + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index fda4bbb45d420..b037c46a165ad 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -31,6 +31,15 @@ private ByteArrayMethods() { // Private constructor, since this class only contains static methods. } + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes % 8; + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + public static void zeroBytes( Object baseObject, long baseOffset,