diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala index 30e79c2e3c6a4..7100a8f035156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Arrow.scala @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.BitVector +import org.apache.arrow.memory.{BaseAllocator, RootAllocator} +import org.apache.arrow.vector._ +import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.FloatingPointPrecision import org.apache.arrow.vector.types.pojo.{ArrowType, Field, Schema} @@ -32,70 +33,17 @@ import org.apache.spark.sql.types._ object Arrow { - private case class TypeFuncs(getType: () => ArrowType, - fill: ArrowBuf => Unit, - write: (InternalRow, Int, ArrowBuf) => Unit) - - private def getTypeFuncs(dataType: DataType): TypeFuncs = { - val err = s"Unsupported data type ${dataType.simpleString}" - + private def sparkTypeToArrowType(dataType: DataType): ArrowType = { dataType match { - case NullType => - TypeFuncs( - () => ArrowType.Null.INSTANCE, - (buf: ArrowBuf) => (), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => ()) - case BooleanType => - TypeFuncs( - () => ArrowType.Bool.INSTANCE, - (buf: ArrowBuf) => buf.writeBoolean(false), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - buf.writeBoolean(row.getBoolean(ordinal))) - case ShortType => - TypeFuncs( - () => new ArrowType.Int(8 * ShortType.defaultSize, true), - (buf: ArrowBuf) => buf.writeShort(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeShort(row.getShort(ordinal))) - case IntegerType => - TypeFuncs( - () => new ArrowType.Int(8 * IntegerType.defaultSize, true), - (buf: ArrowBuf) => buf.writeInt(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeInt(row.getInt(ordinal))) - case LongType => - TypeFuncs( - () => new ArrowType.Int(8 * LongType.defaultSize, true), - (buf: ArrowBuf) => buf.writeLong(0L), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeLong(row.getLong(ordinal))) - case FloatType => - TypeFuncs( - () => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), - (buf: ArrowBuf) => buf.writeFloat(0f), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeFloat(row.getFloat(ordinal))) - case DoubleType => - TypeFuncs( - () => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), - (buf: ArrowBuf) => buf.writeDouble(0d), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - buf.writeDouble(row.getDouble(ordinal))) - case ByteType => - TypeFuncs( - () => new ArrowType.Int(8, false), - (buf: ArrowBuf) => buf.writeByte(0), - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => buf.writeByte(row.getByte(ordinal))) - case StringType => - TypeFuncs( - () => ArrowType.Utf8.INSTANCE, - (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - throw new UnsupportedOperationException(err)) - case StructType(_) => - TypeFuncs( - () => ArrowType.Struct.INSTANCE, - (buf: ArrowBuf) => throw new UnsupportedOperationException(err), // TODO - (row: InternalRow, ordinal: Int, buf: ArrowBuf) => - throw new UnsupportedOperationException(err)) - case _ => - throw new IllegalArgumentException(err) + case BooleanType => ArrowType.Bool.INSTANCE + case ShortType => new ArrowType.Int(8 * ShortType.defaultSize, true) + case IntegerType => new ArrowType.Int(8 * IntegerType.defaultSize, true) + case LongType => new ArrowType.Int(8 * LongType.defaultSize, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case ByteType => new ArrowType.Int(8, false) + case StringType => ArrowType.Utf8.INSTANCE + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } @@ -110,8 +58,8 @@ object Arrow { internalRowToArrowBuf(rows, ordinal, field, allocator) } - val buffers = bufAndField.flatMap(_._1).toList.asJava - val fieldNodes = bufAndField.flatMap(_._2).toList.asJava + val fieldNodes = bufAndField.flatMap(_._1).toList.asJava + val buffers = bufAndField.flatMap(_._2).toList.asJava new ArrowRecordBatch(rows.length, fieldNodes, buffers) } @@ -123,67 +71,24 @@ object Arrow { rows: Array[InternalRow], ordinal: Int, field: StructField, - allocator: RootAllocator): (Array[ArrowBuf], Array[ArrowFieldNode]) = { + allocator: RootAllocator): (Array[ArrowFieldNode], Array[ArrowBuf]) = { val numOfRows = rows.length + val columnWriter = ColumnWriter(allocator, field.dataType) + columnWriter.init(numOfRows) + var index = 0 - field.dataType match { - case ShortType | IntegerType | LongType | DoubleType | FloatType | BooleanType | ByteType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - - val buf = allocator.buffer(numOfRows * field.dataType.defaultSize) - val typeFunc = getTypeFuncs(field.dataType) - var nullCount = 0 - var index = 0 - while (index < rows.length) { - val row = rows(index) - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - typeFunc.fill(buf) - } else { - validityMutator.set(index, 1) - typeFunc.write(row, ordinal, buf) - } - index += 1 - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, buf), Array(fieldNode)) - - case StringType => - val validityVector = new BitVector("validity", allocator) - val validityMutator = validityVector.getMutator() - validityVector.allocateNew(numOfRows) - validityMutator.setValueCount(numOfRows) - - val bufOffset = allocator.buffer((numOfRows + 1) * IntegerType.defaultSize) - var bytesCount = 0 - bufOffset.writeInt(bytesCount) - val bufValues = allocator.buffer(1024) - var nullCount = 0 - rows.zipWithIndex.foreach { case (row, index) => - if (row.isNullAt(ordinal)) { - nullCount += 1 - validityMutator.set(index, 0) - bufOffset.writeInt(bytesCount) - } else { - validityMutator.set(index, 1) - val bytes = row.getUTF8String(ordinal).getBytes - bytesCount += bytes.length - bufOffset.writeInt(bytesCount) - bufValues.writeBytes(bytes) - } - } - - val fieldNode = new ArrowFieldNode(numOfRows, nullCount) - - (Array(validityVector.getBuffer, bufOffset, bufValues), - Array(fieldNode)) + while(index < numOfRows) { + val row = rows(index) + if (row.isNullAt(ordinal)) { + columnWriter.writeNull() + } else { + columnWriter.write(row, ordinal) + } + index += 1 } + + val (arrowFieldNodes, arrowBufs) = columnWriter.finish() + (arrowFieldNodes.toArray, arrowBufs.toArray) } private[sql] def schemaToArrowSchema(schema: StructType): Schema = { @@ -195,13 +100,158 @@ object Arrow { val name = sparkField.name val dataType = sparkField.dataType val nullable = sparkField.nullable + new Field(name, nullable, sparkTypeToArrowType(dataType), List.empty[Field].asJava) + } +} +object ColumnWriter { + def apply(allocator: BaseAllocator, dataType: DataType): ColumnWriter = { dataType match { - case StructType(fields) => - val childrenFields = fields.map(sparkFieldToArrowField).toList.asJava - new Field(name, nullable, ArrowType.Struct.INSTANCE, childrenFields) - case _ => - new Field(name, nullable, getTypeFuncs(dataType).getType(), List.empty[Field].asJava) + case BooleanType => new BooleanColumnWriter(allocator) + case ShortType => new ShortColumnWriter(allocator) + case IntegerType => new IntegerColumnWriter(allocator) + case LongType => new LongColumnWriter(allocator) + case FloatType => new FloatColumnWriter(allocator) + case DoubleType => new DoubleColumnWriter(allocator) + case ByteType => new ByteColumnWriter(allocator) + case StringType => new UTF8StringColumnWriter(allocator) + case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dataType}") } } } + +private[sql] trait ColumnWriter { + def init(initialSize: Int): Unit + def writeNull(): Unit + def write(row: InternalRow, ordinal: Int): Unit + def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) +} + +/** + * Base class for flat arrow column writer, i.e., column without children. + */ +private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseAllocator) + extends ColumnWriter { + protected val valueVector: BaseDataValueVector + protected val valueMutator: BaseMutator + + protected var count = 0 + protected var nullCount = 0 + + protected def setNull(): Unit + protected def setValue(row: InternalRow, ordinal: Int): Unit + protected def valueBuffers(): Seq[ArrowBuf] = valueVector.getBuffers(true) // TODO: check the flag + + override def init(initialSize: Int): Unit = { + valueVector.allocateNew() + } + + override def writeNull(): Unit = { + setNull() + nullCount += 1 + count += 1 + } + + override def write(row: InternalRow, ordinal: Int): Unit = { + setValue(row, ordinal) + count += 1 + } + + override def finish(): (Seq[ArrowFieldNode], Seq[ArrowBuf]) = { + valueMutator.setValueCount(count) + val fieldNode = new ArrowFieldNode(count, nullCount) + (List(fieldNode), valueBuffers) + } +} + +private[sql] class BooleanColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + private def bool2int(b: Boolean): Int = if (b) 1 else 0 + + override protected val valueVector: NullableBitVector + = new NullableBitVector("BooleanValue", allocator) + override protected val valueMutator: NullableBitVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, bool2int(row.getBoolean(ordinal))) +} + +private[sql] class ShortColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableSmallIntVector + = new NullableSmallIntVector("ShortValue", allocator) + override protected val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getShort(ordinal)) +} + +private[sql] class IntegerColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableIntVector + = new NullableIntVector("IntValue", allocator) + override protected val valueMutator: NullableIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getInt(ordinal)) +} + +private[sql] class LongColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableBigIntVector + = new NullableBigIntVector("LongValue", allocator) + override protected val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getLong(ordinal)) +} + +private[sql] class FloatColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableFloat4Vector + = new NullableFloat4Vector("FloatValue", allocator) + override protected val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getFloat(ordinal)) +} + +private[sql] class DoubleColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableFloat8Vector + = new NullableFloat8Vector("DoubleValue", allocator) + override protected val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getDouble(ordinal)) +} + +private[sql] class ByteColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableUInt1Vector + = new NullableUInt1Vector("ByteValue", allocator) + override protected val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit + = valueMutator.setSafe(count, row.getByte(ordinal)) +} + +private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator) + extends PrimitiveColumnWriter(allocator) { + override protected val valueVector: NullableVarBinaryVector + = new NullableVarBinaryVector("UTF8StringValue", allocator) + override protected val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator + + override def setNull(): Unit = valueMutator.setNull(count) + override def setValue(row: InternalRow, ordinal: Int): Unit = { + val bytes = row.getUTF8String(ordinal).getBytes + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +}