Skip to content

Commit

Permalink
Start prototyping Java Row -> UnsafeRow converters
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 22, 2015
1 parent 1ff814d commit 53ba9b7
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.string.UTF8StringMethods;
import scala.collection.Map;
import scala.collection.Seq;

Expand Down Expand Up @@ -62,12 +63,16 @@ private long getFieldOffset(int ordinal) {
return baseOffset + bitSetWidthInBytes + ordinal * 8;
}

public static int calculateBitSetWidthInBytes(int numFields) {
return ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
}

public UnsafeRow() { }

public void set(Object baseObject, long baseOffset, int numFields, StructType schema) {
assert numFields >= 0 : "numFields should >= 0";
assert schema == null || schema.fields().length == numFields;
this.bitSetWidthInBytes = ((numFields / 64) + ((numFields % 64 == 0 ? 0 : 1))) * 8;
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
Expand Down Expand Up @@ -219,9 +224,11 @@ public double getDouble(int i) {
@Override
public String getString(int i) {
assertIndexIsValid(i);
// TODO

throw new UnsupportedOperationException();
final long offsetToStringSize = getLong(i);
final long stringSizeInBytes =
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
// TODO: ugly cast; figure out whether we'll support mega long strings
return UTF8StringMethods.toJavaString(baseObject, baseOffset + offsetToStringSize + 8, (int) stringSizeInBytes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

/** Write a column into an UnsafeRow */
private abstract class UnsafeColumnWriter[T] {
/**
* Write a value into an UnsafeRow.
*
* @param value the value to write
* @param columnNumber what column to write it to
* @param row a pointer to the unsafe row
* @param baseObject
* @param baseOffset
* @param appendCursor the offset from the start of the unsafe row to the end of the row;
* used for calculating where variable-length data should be written
* @return the number of variable-length bytes written
*/
def write(
value: T,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int

/**
* Return the number of bytes that are needed to write this variable-length value.
*/
def getSize(value: T): Int
}

private object UnsafeColumnWriter {
def forType(dataType: DataType): UnsafeColumnWriter[_] = {
dataType match {
case IntegerType => IntUnsafeColumnWriter
case LongType => LongUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case _ => throw new UnsupportedOperationException()
}
}
}

private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
def getSize(value: UTF8String): Int = {
// round to nearest word
val numBytes = value.getBytes.length
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

override def write(
value: UTF8String,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
PlatformDependent.copyMemory(
value.getBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor + 8,
numBytes
)
row.setLong(columnNumber, appendCursor)
8 + ((numBytes / 8) + (if (numBytes % 8 == 0) 0 else 1)) * 8
}
}
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
def getSize(value: T): Int = 0
}

private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Int] {
override def write(
value: Int,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
row.setInt(columnNumber, value)
0
}
}
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter

private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
override def write(
value: Long,
columnNumber: Int,
row: UnsafeRow,
baseObject: Object,
baseOffset: Long,
appendCursor: Int): Int = {
row.setLong(columnNumber, value)
0
}
}
private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter


class UnsafeRowConverter(fieldTypes: Array[DataType]) {

private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
}

def getSizeRequirement(row: Row): Int = {
var fieldNumber = 0
var variableLengthFieldSize: Int = 0
while (fieldNumber < writers.length) {
if (!row.isNullAt(fieldNumber)) {
variableLengthFieldSize += writers(fieldNumber).getSize(row.get(fieldNumber))

}
fieldNumber += 1
}
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + variableLengthFieldSize
}

def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
val unsafeRow = new UnsafeRow()
unsafeRow.set(baseObject, baseOffset, writers.length, null) // TODO: schema?
var fieldNumber = 0
var appendCursor: Int =
(8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
while (fieldNumber < writers.length) {
if (row.isNullAt(fieldNumber)) {
unsafeRow.setNullAt(fieldNumber)
// TODO: type-specific null value writing?
} else {
appendCursor += writers(fieldNumber).write(
row.get(fieldNumber),
fieldNumber,
unsafeRow,
baseObject,
baseOffset,
appendCursor)
}
fieldNumber += 1
}
appendCursor
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{StringType, DataType, LongType, IntegerType}
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.scalatest.{FunSuite, Matchers}


class UnsafeRowConverterSuite extends FunSuite with Matchers {

test("basic conversion with only primitive types") {
val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setLong(1, 1)
row.setInt(2, 2)
val converter = new UnsafeRowConverter(fieldTypes)
val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (3 * 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getLong(1) should be (1)
unsafeRow.getInt(2) should be (2)
}

test("basic conversion with primitive and string types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.setString(2, "World")
val converter = new UnsafeRowConverter(fieldTypes)
val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.set(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getString(2) should be ("World")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ private ByteArrayMethods() {
// Private constructor, since this class only contains static methods.
}

public static int roundNumberOfBytesToNearestWord(int numBytes) {
int remainder = numBytes % 8;
if (remainder == 0) {
return numBytes;
} else {
return numBytes + (8 - remainder);
}
}

public static void zeroBytes(
Object baseObject,
long baseOffset,
Expand Down

0 comments on commit 53ba9b7

Please sign in to comment.