Skip to content

Commit

Permalink
Implement Arrow column writers
Browse files Browse the repository at this point in the history
Move column writers to Arrow.scala

Add support for more types; Switch to arrow NullableVector

closes apache#16
  • Loading branch information
icexelloss authored and BryanCutler committed Feb 23, 2017
1 parent 5837b38 commit bdba357
Showing 1 changed file with 180 additions and 130 deletions.
310 changes: 180 additions & 130 deletions sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._
import scala.language.implicitConversions

import io.netty.buffer.ArrowBuf
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.BitVector
import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
import org.apache.arrow.vector._
import org.apache.arrow.vector.BaseValueVector.BaseMutator
import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch}
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema}
Expand All @@ -32,70 +33,17 @@ import org.apache.spark.sql.types._

object Arrow {

private case class TypeFuncs(getType: () => ArrowType,
fill: ArrowBuf => Unit,
write: (InternalRow, Int, ArrowBuf) => Unit)

private def getTypeFuncs(dataType: DataType): TypeFuncs = {
val err = s"Unsupported data type ${dataType.simpleString}"

private def sparkTypeToArrowType(dataType: DataType): ArrowType = {
dataType match {
case NullType =>
TypeFuncs(
() => ArrowType.Null.INSTANCE,
(buf: ArrowBuf) => (),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => ())
case BooleanType =>
TypeFuncs(
() => ArrowType.Bool.INSTANCE,
(buf: ArrowBuf) => buf.writeBoolean(false),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
buf.writeBoolean(row.getBoolean(ordinal)))
case ShortType =>
TypeFuncs(
() => new ArrowType.Int(8 * ShortType.defaultSize, true),
(buf: ArrowBuf) => buf.writeShort(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal)))
case IntegerType =>
TypeFuncs(
() => new ArrowType.Int(8 * IntegerType.defaultSize, true),
(buf: ArrowBuf) => buf.writeInt(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal)))
case LongType =>
TypeFuncs(
() => new ArrowType.Int(8 * LongType.defaultSize, true),
(buf: ArrowBuf) => buf.writeLong(0L),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal)))
case FloatType =>
TypeFuncs(
() => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE),
(buf: ArrowBuf) => buf.writeFloat(0f),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal)))
case DoubleType =>
TypeFuncs(
() => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE),
(buf: ArrowBuf) => buf.writeDouble(0d),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
buf.writeDouble(row.getDouble(ordinal)))
case ByteType =>
TypeFuncs(
() => new ArrowType.Int(8, false),
(buf: ArrowBuf) => buf.writeByte(0),
(row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal)))
case StringType =>
TypeFuncs(
() => ArrowType.Utf8.INSTANCE,
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
throw new UnsupportedOperationException(err))
case StructType(_) =>
TypeFuncs(
() => ArrowType.Struct.INSTANCE,
(buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO
(row: InternalRow, ordinal: Int, buf: ArrowBuf) =>
throw new UnsupportedOperationException(err))
case _ =>
throw new IllegalArgumentException(err)
case BooleanType => ArrowType.Bool.INSTANCE
case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true)
case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true)
case LongType => new ArrowType.Int(8 * LongType.defaultSize, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case ByteType => new ArrowType.Int(8, false)
case StringType => ArrowType.Utf8.INSTANCE
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
}

Expand All @@ -110,8 +58,8 @@ object Arrow {
internalRowToArrowBuf(rows, ordinal, field, allocator)
}

val buffers = bufAndField.flatMap(_._1).toList.asJava
val fieldNodes = bufAndField.flatMap(_._2).toList.asJava
val fieldNodes = bufAndField.flatMap(_._1).toList.asJava
val buffers = bufAndField.flatMap(_._2).toList.asJava

new ArrowRecordBatch(rows.length, fieldNodes, buffers)
}
Expand All @@ -123,67 +71,24 @@ object Arrow {
rows: Array[InternalRow],
ordinal: Int,
field: StructField,
allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = {
allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = {
val numOfRows = rows.length
val columnWriter = ColumnWriter(allocator, field.dataType)
columnWriter.init(numOfRows)
var index = 0

field.dataType match {
case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType =>
val validityVector = new BitVector("validity", allocator)
val validityMutator = validityVector.getMutator
validityVector.allocateNew(numOfRows)
validityMutator.setValueCount(numOfRows)

val buf = allocator.buffer(numOfRows * field.dataType.defaultSize)
val typeFunc = getTypeFuncs(field.dataType)
var nullCount = 0
var index = 0
while (index < rows.length) {
val row = rows(index)
if (row.isNullAt(ordinal)) {
nullCount += 1
validityMutator.set(index, 0)
typeFunc.fill(buf)
} else {
validityMutator.set(index, 1)
typeFunc.write(row, ordinal, buf)
}
index += 1
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validityVector.getBuffer, buf), Array(fieldNode))

case StringType =>
val validityVector = new BitVector("validity", allocator)
val validityMutator = validityVector.getMutator()
validityVector.allocateNew(numOfRows)
validityMutator.setValueCount(numOfRows)

val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize)
var bytesCount = 0
bufOffset.writeInt(bytesCount)
val bufValues = allocator.buffer(1024)
var nullCount = 0
rows.zipWithIndex.foreach { case (row, index) =>
if (row.isNullAt(ordinal)) {
nullCount += 1
validityMutator.set(index, 0)
bufOffset.writeInt(bytesCount)
} else {
validityMutator.set(index, 1)
val bytes = row.getUTF8String(ordinal).getBytes
bytesCount += bytes.length
bufOffset.writeInt(bytesCount)
bufValues.writeBytes(bytes)
}
}

val fieldNode = new ArrowFieldNode(numOfRows, nullCount)

(Array(validityVector.getBuffer, bufOffset, bufValues),
Array(fieldNode))
while(index < numOfRows) {
val row = rows(index)
if (row.isNullAt(ordinal)) {
columnWriter.writeNull()
} else {
columnWriter.write(row, ordinal)
}
index += 1
}

val (arrowFieldNodes, arrowBufs) = columnWriter.finish()
(arrowFieldNodes.toArray, arrowBufs.toArray)
}

private[sql] def schemaToArrowSchema(schema: StructType): Schema = {
Expand All @@ -195,13 +100,158 @@ object Arrow {
val name = sparkField.name
val dataType = sparkField.dataType
val nullable = sparkField.nullable
new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava)
}
}

object ColumnWriter {
def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = {
dataType match {
case StructType(fields) =>
val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava
new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields)
case _ =>
new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava)
case BooleanType => new BooleanColumnWriter(allocator)
case ShortType => new ShortColumnWriter(allocator)
case IntegerType => new IntegerColumnWriter(allocator)
case LongType => new LongColumnWriter(allocator)
case FloatType => new FloatColumnWriter(allocator)
case DoubleType => new DoubleColumnWriter(allocator)
case ByteType => new ByteColumnWriter(allocator)
case StringType => new UTF8StringColumnWriter(allocator)
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}")
}
}
}

private[sql] trait ColumnWriter {
def init(initialSize: Int): Unit
def writeNull(): Unit
def write(row: InternalRow, ordinal: Int): Unit
def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf])
}

/**
* Base class for flat arrow column writer, i.e., column without children.
*/
private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator)
extends ColumnWriter {
protected val valueVector: BaseDataValueVector
protected val valueMutator: BaseMutator

protected var count = 0
protected var nullCount = 0

protected def setNull(): Unit
protected def setValue(row: InternalRow, ordinal: Int): Unit
protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag

override def init(initialSize: Int): Unit = {
valueVector.allocateNew()
}

override def writeNull(): Unit = {
setNull()
nullCount += 1
count += 1
}

override def write(row: InternalRow, ordinal: Int): Unit = {
setValue(row, ordinal)
count += 1
}

override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = {
valueMutator.setValueCount(count)
val fieldNode = new ArrowFieldNode(count, nullCount)
(List(fieldNode), valueBuffers)
}
}

private[sql] class BooleanColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
private def bool2int(b: Boolean): Int = if (b) 1 else 0

override protected val valueVector: NullableBitVector
= new NullableBitVector("BooleanValue", allocator)
override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal)))
}

private[sql] class ShortColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableSmallIntVector
= new NullableSmallIntVector("ShortValue", allocator)
override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getShort(ordinal))
}

private[sql] class IntegerColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableIntVector
= new NullableIntVector("IntValue", allocator)
override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getInt(ordinal))
}

private[sql] class LongColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableBigIntVector
= new NullableBigIntVector("LongValue", allocator)
override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getLong(ordinal))
}

private[sql] class FloatColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableFloat4Vector
= new NullableFloat4Vector("FloatValue", allocator)
override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getFloat(ordinal))
}

private[sql] class DoubleColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableFloat8Vector
= new NullableFloat8Vector("DoubleValue", allocator)
override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getDouble(ordinal))
}

private[sql] class ByteColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableUInt1Vector
= new NullableUInt1Vector("ByteValue", allocator)
override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit
= valueMutator.setSafe(count, row.getByte(ordinal))
}

private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
extends PrimitiveColumnWriter(allocator) {
override protected val valueVector: NullableVarBinaryVector
= new NullableVarBinaryVector("UTF8StringValue", allocator)
override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator

override def setNull(): Unit = valueMutator.setNull(count)
override def setValue(row: InternalRow, ordinal: Int): Unit = {
val bytes = row.getUTF8String(ordinal).getBytes
valueMutator.setSafe(count, bytes, 0, bytes.length)
}
}

0 comments on commit bdba357

Please sign in to comment.