From b1b1475c7fe69a7ff0c871fd71e39c640983abe1 Mon Sep 17 00:00:00 2001 From: Elbin Pallimalil Date: Fri, 13 Dec 2024 20:50:45 +0530 Subject: [PATCH] Make ArrowPageBuilder singleton formatting changes Rename to block builder --- .../plugin/arrow/AbstractArrowMetadata.java | 84 +------ ...wPageUtils.java => ArrowBlockBuilder.java} | 206 +++++++++--------- .../facebook/plugin/arrow/ArrowConnector.java | 3 +- .../plugin/arrow/ArrowPageSource.java | 12 +- .../plugin/arrow/ArrowPageSourceProvider.java | 9 +- .../plugin/arrow/ArrowTableLayoutHandle.java | 3 +- ...lsTest.java => ArrowBlockBuilderTest.java} | 63 +++--- .../plugin/arrow/TestingArrowMetadata.java | 36 +-- .../plugin/arrow/TestingArrowModule.java | 1 + .../plugin/arrow/TestingArrowPageBuilder.java | 46 ++++ 10 files changed, 218 insertions(+), 245 deletions(-) rename presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/{ArrowPageUtils.java => ArrowBlockBuilder.java} (84%) rename presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/{ArrowPageUtilsTest.java => ArrowBlockBuilderTest.java} (90%) create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPageBuilder.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 2951c361eb6b7..6f45d52294b98 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -14,20 +14,7 @@ package com.facebook.plugin.arrow; import com.facebook.airlift.log.Logger; -import com.facebook.presto.common.type.BigintType; -import com.facebook.presto.common.type.BooleanType; -import com.facebook.presto.common.type.DateType; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.DoubleType; -import com.facebook.presto.common.type.IntegerType; -import com.facebook.presto.common.type.RealType; -import com.facebook.presto.common.type.SmallintType; -import com.facebook.presto.common.type.TimeType; -import com.facebook.presto.common.type.TimestampType; -import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarbinaryType; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ColumnMetadata; import com.facebook.presto.spi.ConnectorSession; @@ -44,7 +31,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -57,7 +43,6 @@ import java.util.Set; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; -import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; @@ -67,71 +52,13 @@ public abstract class AbstractArrowMetadata private static final Logger logger = Logger.get(AbstractArrowMetadata.class); private final ArrowFlightConfig config; private final ArrowFlightClientHandler clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; - public AbstractArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler) + public AbstractArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder) { this.config = requireNonNull(config, "config is null"); this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); - } - - private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) - { - switch (floatingPoint.getPrecision()) { - case SINGLE: - return RealType.REAL; - case DOUBLE: - return DoubleType.DOUBLE; - default: - throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected floating point precision " + floatingPoint.getPrecision()); - } - } - - private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) - { - switch (intType.getBitWidth()) { - case 64: - return BigintType.BIGINT; - case 32: - return IntegerType.INTEGER; - case 16: - return SmallintType.SMALLINT; - case 8: - return TinyintType.TINYINT; - default: - throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected bit width " + intType.getBitWidth()); - } - } - - protected Type getPrestoTypeFromArrowField(Field field) - { - switch (field.getType().getTypeID()) { - case Int: - ArrowType.Int intType = (ArrowType.Int) field.getType(); - return getPrestoTypeForArrowIntType(intType); - case Binary: - case LargeBinary: - case FixedSizeBinary: - return VarbinaryType.VARBINARY; - case Date: - return DateType.DATE; - case Timestamp: - return TimestampType.TIMESTAMP; - case Utf8: - case LargeUtf8: - return VarcharType.VARCHAR; - case FloatingPoint: - ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); - return getPrestoTypeForArrowFloatingPointType(floatingPoint); - case Decimal: - ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); - return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); - case Bool: - return BooleanType.BOOLEAN; - case Time: - return TimeType.TIME; - default: - throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); - } + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowPageBuilder is null"); } protected abstract FlightDescriptor getFlightDescriptor(Optional query, String schema, String table); @@ -265,4 +192,9 @@ public Map> listTableColumns(ConnectorSess } return columns.build(); } + + private Type getPrestoTypeFromArrowField(Field field) + { + return arrowBlockBuilder.getPrestoTypeFromArrowField(field); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java similarity index 84% rename from presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java rename to presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java index 6777c4380efad..50e8e2ace5523 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java @@ -63,8 +63,8 @@ import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.util.JsonStringArrayList; import java.math.BigDecimal; @@ -74,15 +74,12 @@ import java.util.List; import java.util.concurrent.TimeUnit; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR; import static java.util.Objects.requireNonNull; -public class ArrowPageUtils +public class ArrowBlockBuilder { - private ArrowPageUtils() - { - } - - public static Block buildBlockFromFieldVector(FieldVector vector, Type type, DictionaryProvider dictionaryProvider) + public Block buildBlockFromFieldVector(FieldVector vector, Type type, DictionaryProvider dictionaryProvider) { if (vector.getField().getDictionary() != null) { Dictionary dictionary = dictionaryProvider.lookup(vector.getField().getDictionary().getId()); @@ -93,13 +90,13 @@ public static Block buildBlockFromFieldVector(FieldVector vector, Type type, Dic } } - public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) + public Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) { // Validate inputs requireNonNull(fieldVector, "encoded vector is null"); requireNonNull(dictionaryVector, "dictionary vector is null"); - Type prestoType = getPrestoTypeFromArrowType(dictionaryVector.getField().getType()); + Type prestoType = getPrestoTypeFromArrowField(dictionaryVector.getField()); Block dictionaryblock = buildBlockFromValueVector(dictionaryVector, prestoType); @@ -107,7 +104,67 @@ public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, Fiel return getDictionaryBlock(fieldVector, dictionaryblock); } - private static DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + protected Type getPrestoTypeFromArrowField(Field field) + { + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + return getPrestoTypeForArrowIntType(intType); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return VarbinaryType.VARBINARY; + case Date: + return DateType.DATE; + case Timestamp: + return TimestampType.TIMESTAMP; + case Utf8: + case LargeUtf8: + return VarcharType.VARCHAR; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + return getPrestoTypeForArrowFloatingPointType(floatingPoint); + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); + case Bool: + return BooleanType.BOOLEAN; + case Time: + return TimeType.TIME; + default: + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); + } + } + + private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) + { + switch (floatingPoint.getPrecision()) { + case SINGLE: + return RealType.REAL; + case DOUBLE: + return DoubleType.DOUBLE; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected floating point precision " + floatingPoint.getPrecision()); + } + } + + private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return BigintType.BIGINT; + case 32: + return IntegerType.INTEGER; + case 16: + return SmallintType.SMALLINT; + case 8: + return TinyintType.TINYINT; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected bit width " + intType.getBitWidth()); + } + } + + private DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) { if (fieldVector instanceof IntVector) { // Get the Arrow indices vector @@ -142,62 +199,7 @@ else if (fieldVector instanceof TinyIntVector) { } } - private static Type getPrestoTypeFromArrowType(ArrowType arrowType) - { - if (arrowType instanceof ArrowType.Utf8) { - return VarcharType.VARCHAR; - } - else if (arrowType instanceof ArrowType.Int) { - ArrowType.Int intType = (ArrowType.Int) arrowType; - if (intType.getBitWidth() == 8 || intType.getBitWidth() == 16 || intType.getBitWidth() == 32) { - return IntegerType.INTEGER; - } - else if (intType.getBitWidth() == 64) { - return BigintType.BIGINT; - } - else { - throw new UnsupportedOperationException("Unsupported int bit width: " + intType.getBitWidth()); - } - } - else if (arrowType instanceof ArrowType.FloatingPoint) { - ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; - FloatingPointPrecision precision = fpType.getPrecision(); - - if (precision == FloatingPointPrecision.SINGLE) { // 32-bit float - return RealType.REAL; - } - else if (precision == FloatingPointPrecision.DOUBLE) { // 64-bit float - return DoubleType.DOUBLE; - } - else { - throw new UnsupportedOperationException("Unsupported FloatingPoint precision: " + precision); - } - } - else if (arrowType instanceof ArrowType.Bool) { - return BooleanType.BOOLEAN; - } - else if (arrowType instanceof ArrowType.Binary) { - return VarbinaryType.VARBINARY; - } - else if (arrowType instanceof ArrowType.Decimal) { - return DecimalType.createDecimalType(); - } - else if (arrowType instanceof ArrowType.Timestamp) { - return TimestampType.TIMESTAMP; - } - else if (arrowType instanceof ArrowType.Date) { - return DateType.DATE; - } - else if (arrowType instanceof ArrowType.Time) { - return TimeType.TIME; - } - else if (arrowType instanceof ArrowType.LargeUtf8) { - return VarcharType.VARCHAR; - } - throw new UnsupportedOperationException("Unsupported ArrowType: " + arrowType); - } - - private static Block buildBlockFromValueVector(ValueVector vector, Type type) + private Block buildBlockFromValueVector(ValueVector vector, Type type) { if (vector instanceof BitVector) { return buildBlockFromBitVector((BitVector) vector, type); @@ -275,7 +277,7 @@ else if (vector instanceof ListVector) { } } - public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + public Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) { if (!(type instanceof TimestampType)) { throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); @@ -294,7 +296,7 @@ public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vecto return builder.build(); } - public static Block buildBlockFromBitVector(BitVector vector, Type type) + public Block buildBlockFromBitVector(BitVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -308,7 +310,7 @@ public static Block buildBlockFromBitVector(BitVector vector, Type type) return builder.build(); } - public static Block buildBlockFromIntVector(IntVector vector, Type type) + public Block buildBlockFromIntVector(IntVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -322,7 +324,7 @@ public static Block buildBlockFromIntVector(IntVector vector, Type type) return builder.build(); } - public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + public Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -336,7 +338,7 @@ public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type typ return builder.build(); } - public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + public Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -350,7 +352,7 @@ public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) return builder.build(); } - public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + public Block buildBlockFromBigIntVector(BigIntVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -364,7 +366,7 @@ public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) return builder.build(); } - public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + public Block buildBlockFromDecimalVector(DecimalVector vector, Type type) { if (!(type instanceof DecimalType)) { throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); @@ -391,7 +393,7 @@ public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) return builder.build(); } - public static Block buildBlockFromNullVector(NullVector vector, Type type) + public Block buildBlockFromNullVector(NullVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -400,7 +402,7 @@ public static Block buildBlockFromNullVector(NullVector vector, Type type) return builder.build(); } - public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + public Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) { if (!(type instanceof TimestampType)) { throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); @@ -420,7 +422,7 @@ public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vect return builder.build(); } - public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + public Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) { if (!(type instanceof TimestampType)) { throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); @@ -439,7 +441,7 @@ public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vect return builder.build(); } - public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + public Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -453,7 +455,7 @@ public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) return builder.build(); } - public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + public Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -468,7 +470,7 @@ public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) return builder.build(); } - public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + public Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -483,7 +485,7 @@ public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type t return builder.build(); } - public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + public Block buildBlockFromVarCharVector(VarCharVector vector, Type type) { if (!(type instanceof VarcharType)) { throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); @@ -502,7 +504,7 @@ public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) return builder.build(); } - public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + public Block buildBlockFromDateDayVector(DateDayVector vector, Type type) { if (!(type instanceof DateType)) { throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); @@ -520,7 +522,7 @@ public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) return builder.build(); } - public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + public Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) { if (!(type instanceof DateType)) { throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); @@ -540,7 +542,7 @@ public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type t return builder.build(); } - public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + public Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) { if (!(type instanceof TimeType)) { throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); @@ -560,7 +562,7 @@ public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) return builder.build(); } - public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + public Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) { if (!(type instanceof TimeType)) { throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); @@ -579,7 +581,7 @@ public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type t return builder.build(); } - public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + public Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) { if (!(type instanceof TimeType)) { throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); @@ -598,7 +600,7 @@ public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type t return builder.build(); } - public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + public Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) { if (!(type instanceof TimestampType)) { throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); @@ -618,7 +620,7 @@ public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, return builder.build(); } - public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + public Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -633,7 +635,7 @@ public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Ty return builder.build(); } - public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + public Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) { BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); for (int i = 0; i < vector.getValueCount(); i++) { @@ -650,7 +652,7 @@ public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Ty return builder.build(); } - public static Block buildBlockFromListVector(ListVector vector, Type type) + public Block buildBlockFromListVector(ListVector vector, Type type) { if (!(type instanceof ArrayType)) { throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); @@ -684,7 +686,7 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) return arrayBuilder.build(); } - public static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + public void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { if (value == null) { builder.appendNull(); @@ -732,13 +734,13 @@ else if (type instanceof TimestampType) { } } - public static void writeVarcharType(Type type, BlockBuilder builder, Object value) + public void writeVarcharType(Type type, BlockBuilder builder, Object value) { Slice slice = Slices.utf8Slice(value.toString()); type.writeSlice(builder, slice); } - public static void writeSmallintType(Type type, BlockBuilder builder, Object value) + public void writeSmallintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).shortValue()); @@ -759,7 +761,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void writeTinyintType(Type type, BlockBuilder builder, Object value) + public void writeTinyintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).byteValue()); @@ -780,7 +782,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void writeBigintType(Type type, BlockBuilder builder, Object value) + public void writeBigintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Long) { type.writeLong(builder, (Long) value); @@ -804,7 +806,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void writeIntegerType(Type type, BlockBuilder builder, Object value) + public void writeIntegerType(Type type, BlockBuilder builder, Object value) { if (value instanceof Integer) { type.writeLong(builder, (Integer) value); @@ -825,7 +827,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void writeDoubleType(Type type, BlockBuilder builder, Object value) + public void writeDoubleType(Type type, BlockBuilder builder, Object value) { if (value instanceof Double) { type.writeDouble(builder, (Double) value); @@ -849,7 +851,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void writeBooleanType(Type type, BlockBuilder builder, Object value) + public void writeBooleanType(Type type, BlockBuilder builder, Object value) { if (value instanceof Boolean) { type.writeBoolean(builder, (Boolean) value); @@ -859,7 +861,7 @@ public static void writeBooleanType(Type type, BlockBuilder builder, Object valu } } - public static void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) + public void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) { if (value instanceof BigDecimal) { BigDecimal decimalValue = (BigDecimal) value; @@ -888,7 +890,7 @@ else if (value instanceof Long) { } } - public static void writeArrayType(ArrayType type, BlockBuilder builder, Object value) + public void writeArrayType(ArrayType type, BlockBuilder builder, Object value) { Type elementType = type.getElementType(); BlockBuilder arrayBuilder = builder.beginBlockEntry(); @@ -898,7 +900,7 @@ public static void writeArrayType(ArrayType type, BlockBuilder builder, Object v builder.closeEntry(); } - public static void writeRowType(RowType type, BlockBuilder builder, Object value) + public void writeRowType(RowType type, BlockBuilder builder, Object value) { List rowValues = (List) value; BlockBuilder rowBuilder = builder.beginBlockEntry(); @@ -910,7 +912,7 @@ public static void writeRowType(RowType type, BlockBuilder builder, Object value builder.closeEntry(); } - public static void writeDateType(Type type, BlockBuilder builder, Object value) + public void writeDateType(Type type, BlockBuilder builder, Object value) { if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { int daysSinceEpoch = (int) (value instanceof java.sql.Date @@ -923,7 +925,7 @@ public static void writeDateType(Type type, BlockBuilder builder, Object value) } } - public static void writeTimestampType(Type type, BlockBuilder builder, Object value) + public void writeTimestampType(Type type, BlockBuilder builder, Object value) { if (value instanceof java.sql.Timestamp) { long millis = ((java.sql.Timestamp) value).getTime(); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index 1028af2414308..af491e39b3c6a 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -37,7 +37,8 @@ public class ArrowConnector private final ArrowFlightClientHandler arrowFlightClientHandler; @Inject - public ArrowConnector(ConnectorMetadata metadata, + public ArrowConnector( + ConnectorMetadata metadata, ConnectorHandleResolver handleResolver, ConnectorSplitManager splitManager, ConnectorPageSourceProvider pageSourceProvider, diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index 03f817250e5af..ceabd36b5c35d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -37,17 +37,23 @@ public class ArrowPageSource private static final Logger logger = Logger.get(ArrowPageSource.class); private final ArrowSplit split; private final List columnHandles; + private final ArrowBlockBuilder arrowBlockBuilder; private boolean completed; private int currentPosition; private VectorSchemaRoot vectorSchemaRoot; private ArrowFlightClient flightClient; private FlightStream flightStream; - public ArrowPageSource(ArrowSplit split, List columnHandles, ArrowFlightClientHandler clientHandler, - ConnectorSession connectorSession) + public ArrowPageSource( + ArrowSplit split, + List columnHandles, + ArrowFlightClientHandler clientHandler, + ConnectorSession connectorSession, + ArrowBlockBuilder arrowBlockBuilder) { this.columnHandles = columnHandles; this.split = split; + this.arrowBlockBuilder = arrowBlockBuilder; getFlightStream(clientHandler, split.getTicket(), connectorSession); } @@ -111,7 +117,7 @@ public Page getNextPage() for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { FieldVector vector = vectorSchemaRoot.getVector(columnIndex); Type type = columnHandles.get(columnIndex).getColumnType(); - Block block = ArrowPageUtils.buildBlockFromFieldVector(vector, type, flightStream.getDictionaryProvider()); + Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStream.getDictionaryProvider()); blocks.add(block); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java index f3bb41c3e35d4..e2815bc563056 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -31,11 +31,14 @@ public class ArrowPageSourceProvider implements ConnectorPageSourceProvider { private static final Logger logger = Logger.get(ArrowPageSourceProvider.class); - private ArrowFlightClientHandler clientHandler; + private final ArrowFlightClientHandler clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; + @Inject - public ArrowPageSourceProvider(ArrowFlightClientHandler clientHandler) + public ArrowPageSourceProvider(ArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder) { this.clientHandler = clientHandler; + this.arrowBlockBuilder = arrowBlockBuilder; } @Override @@ -47,6 +50,6 @@ public ConnectorPageSource createPageSource(ConnectorTransactionHandle transacti } ArrowSplit arrowSplit = (ArrowSplit) split; logger.debug("Processing split with flight ticket"); - return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session); + return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder); } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java index 46e94a4e1143c..b997f81c13a0d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -32,7 +32,8 @@ public class ArrowTableLayoutHandle private final TupleDomain tupleDomain; @JsonCreator - public ArrowTableLayoutHandle(@JsonProperty("table") ArrowTableHandle table, + public ArrowTableLayoutHandle( + @JsonProperty("table") ArrowTableHandle table, @JsonProperty("columnHandles") List columnHandles, @JsonProperty("tupleDomain") TupleDomain domain) { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowBlockBuilderTest.java similarity index 90% rename from presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java rename to presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowBlockBuilderTest.java index 626e7815c13ad..be30980330f84 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowBlockBuilderTest.java @@ -13,6 +13,7 @@ */ package com.facebook.plugin.arrow; +import com.facebook.airlift.log.Logger; import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.block.DictionaryBlock; @@ -59,23 +60,25 @@ import java.util.List; import java.util.Optional; -import static com.facebook.plugin.arrow.ArrowPageUtils.buildBlockFromDictionaryVector; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; -public class ArrowPageUtilsTest +public class ArrowBlockBuilderTest { + private static final Logger logger = Logger.get(ArrowBlockBuilderTest.class); private static final int DICTIONARY_LENGTH = 10; private static final int VECTOR_LENGTH = 50; private BufferAllocator allocator; + private ArrowBlockBuilder arrowBlockBuilder; @BeforeClass public void setUp() { // Initialize the Arrow allocator allocator = new RootAllocator(Integer.MAX_VALUE); - System.out.println("Allocator initialized: " + allocator); + logger.debug("Allocator initialized: %s", allocator); + arrowBlockBuilder = new ArrowBlockBuilder(); } @Test @@ -92,7 +95,7 @@ public void testBuildBlockFromBitVector() bitVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + Block resultBlock = arrowBlockBuilder.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); // Now verify the result block assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions @@ -112,7 +115,7 @@ public void testBuildBlockFromTinyIntVector() tinyIntVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + Block resultBlock = arrowBlockBuilder.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); // Now verify the result block assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions @@ -132,7 +135,7 @@ public void testBuildBlockFromSmallIntVector() smallIntVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + Block resultBlock = arrowBlockBuilder.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); // Now verify the result block assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions @@ -152,7 +155,7 @@ public void testBuildBlockFromIntVector() intVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + Block resultBlock = arrowBlockBuilder.buildBlockFromIntVector(intVector, IntegerType.INTEGER); // Now verify the result block assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions @@ -176,7 +179,7 @@ public void testBuildBlockFromBigIntVector() bigIntVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + Block resultBlock = arrowBlockBuilder.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); // Now verify the result block assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L @@ -195,7 +198,7 @@ public void testBuildBlockFromDecimalVector() decimalVector.setValueCount(2); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + Block resultBlock = arrowBlockBuilder.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); // Now verify the result block assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions @@ -215,7 +218,7 @@ public void testBuildBlockFromTimeStampMicroVector() timestampMicroVector.setValueCount(3); // Build the block from the vector - Block resultBlock = ArrowPageUtils.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + Block resultBlock = arrowBlockBuilder.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); // Now verify the result block assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions @@ -253,7 +256,7 @@ public void testBuildBlockFromListVector() ArrayType arrayType = new ArrayType(IntegerType.INTEGER); // Call the method to test - Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); + Block block = arrowBlockBuilder.buildBlockFromListVector(listVector, arrayType); // Validate the result assertEquals(block.getPositionCount(), 4); // 4 lists in the block @@ -286,7 +289,7 @@ public void testProcessDictionaryVector() BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); // Process the dictionary vector - Block result = buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); + Block result = arrowBlockBuilder.buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); // Verify the result assertNotNull(result, "The BlockBuilder should not be null."); @@ -317,7 +320,7 @@ public void testBuildBlockFromDictionaryVector() indicesVector.set(3, 2); // Third index points to "cherry" indicesVector.setValueCount(4); // Call the method under test - Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); // Assertions to check the dictionary block's behavior assertNotNull(block); @@ -370,7 +373,7 @@ public void testBuildBlockFromDictionaryVectorBigInt() dictionaryVector.setValueCount(3); // Call the method under test - Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); // Assertions to check the dictionary block's behavior assertNotNull(block); @@ -420,7 +423,7 @@ public void testBuildBlockFromDictionaryVectorSmallInt() dictionaryVector.setValueCount(3); // Call the method under test - Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); // Assertions to check the dictionary block's behavior assertNotNull(block); @@ -470,7 +473,7 @@ public void testBuildBlockFromDictionaryVectorTinyInt() dictionaryVector.setValueCount(3); // Call the method under test - Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + Block block = arrowBlockBuilder.buildBlockFromDictionaryVector(indicesVector, dictionaryVector); // Assertions to check the dictionary block's behavior assertNotNull(block); @@ -504,7 +507,7 @@ public void testWriteVarcharType() BlockBuilder builder = varcharType.createBlockBuilder(null, 1); String value = "test_string"; - ArrowPageUtils.writeVarcharType(varcharType, builder, value); + arrowBlockBuilder.writeVarcharType(varcharType, builder, value); Block block = builder.build(); Slice result = varcharType.getSlice(block, 0); @@ -518,7 +521,7 @@ public void testWriteSmallintType() BlockBuilder builder = smallintType.createBlockBuilder(null, 1); short value = 42; - ArrowPageUtils.writeSmallintType(smallintType, builder, value); + arrowBlockBuilder.writeSmallintType(smallintType, builder, value); Block block = builder.build(); long result = smallintType.getLong(block, 0); @@ -532,7 +535,7 @@ public void testWriteTinyintType() BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); byte value = 7; - ArrowPageUtils.writeTinyintType(tinyintType, builder, value); + arrowBlockBuilder.writeTinyintType(tinyintType, builder, value); Block block = builder.build(); long result = tinyintType.getLong(block, 0); @@ -546,7 +549,7 @@ public void testWriteBigintType() BlockBuilder builder = bigintType.createBlockBuilder(null, 1); long value = 123456789L; - ArrowPageUtils.writeBigintType(bigintType, builder, value); + arrowBlockBuilder.writeBigintType(bigintType, builder, value); Block block = builder.build(); long result = bigintType.getLong(block, 0); @@ -560,7 +563,7 @@ public void testWriteIntegerType() BlockBuilder builder = integerType.createBlockBuilder(null, 1); int value = 42; - ArrowPageUtils.writeIntegerType(integerType, builder, value); + arrowBlockBuilder.writeIntegerType(integerType, builder, value); Block block = builder.build(); long result = integerType.getLong(block, 0); @@ -574,7 +577,7 @@ public void testWriteDoubleType() BlockBuilder builder = doubleType.createBlockBuilder(null, 1); double value = 42.42; - ArrowPageUtils.writeDoubleType(doubleType, builder, value); + arrowBlockBuilder.writeDoubleType(doubleType, builder, value); Block block = builder.build(); double result = doubleType.getDouble(block, 0); @@ -588,7 +591,7 @@ public void testWriteBooleanType() BlockBuilder builder = booleanType.createBlockBuilder(null, 1); boolean value = true; - ArrowPageUtils.writeBooleanType(booleanType, builder, value); + arrowBlockBuilder.writeBooleanType(booleanType, builder, value); Block block = builder.build(); boolean result = booleanType.getBoolean(block, 0); @@ -603,7 +606,7 @@ public void testWriteArrayType() BlockBuilder builder = arrayType.createBlockBuilder(null, 1); List values = Arrays.asList(1, 2, 3); - ArrowPageUtils.writeArrayType(arrayType, builder, values); + arrowBlockBuilder.writeArrayType(arrayType, builder, values); Block block = builder.build(); Block arrayBlock = arrayType.getObject(block, 0); @@ -622,7 +625,7 @@ public void testWriteRowType() BlockBuilder builder = rowType.createBlockBuilder(null, 1); List rowValues = Arrays.asList(42, "test"); - ArrowPageUtils.writeRowType(rowType, builder, rowValues); + arrowBlockBuilder.writeRowType(rowType, builder, rowValues); Block block = builder.build(); Block rowBlock = rowType.getObject(block, 0); @@ -637,7 +640,7 @@ public void testWriteDateType() BlockBuilder builder = dateType.createBlockBuilder(null, 1); LocalDate value = LocalDate.of(2020, 1, 1); - ArrowPageUtils.writeDateType(dateType, builder, value); + arrowBlockBuilder.writeDateType(dateType, builder, value); Block block = builder.build(); long result = dateType.getLong(block, 0); @@ -651,7 +654,7 @@ public void testWriteTimestampType() BlockBuilder builder = timestampType.createBlockBuilder(null, 1); long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC - ArrowPageUtils.writeTimestampType(timestampType, builder, value); + arrowBlockBuilder.writeTimestampType(timestampType, builder, value); Block block = builder.build(); long result = timestampType.getLong(block, 0); @@ -666,7 +669,7 @@ public void testWriteTimestampTypeWithSqlTimestamp() java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); long expectedMillis = timestamp.getTime(); - ArrowPageUtils.writeTimestampType(timestampType, builder, timestamp); + arrowBlockBuilder.writeTimestampType(timestampType, builder, timestamp); Block block = builder.build(); long result = timestampType.getLong(block, 0); @@ -680,7 +683,7 @@ public void testShortDecimalRetrieval() BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); BigDecimal decimalValue = new BigDecimal("12345.67"); - ArrowPageUtils.writeDecimalType(shortDecimalType, builder, decimalValue); + arrowBlockBuilder.writeDecimalType(shortDecimalType, builder, decimalValue); Block block = builder.build(); long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 @@ -695,7 +698,7 @@ public void testLongDecimalRetrieval() DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); - ArrowPageUtils.writeDecimalType(longDecimalType, builder, decimalValue); + arrowBlockBuilder.writeDecimalType(longDecimalType, builder, decimalValue); // Build the block after inserting the decimal value Block block = builder.build(); Slice unscaledSlice = longDecimalType.getSlice(block, 0); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index 3fad658bbbb58..5665553e0edd9 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -14,10 +14,6 @@ package com.facebook.plugin.arrow; import com.facebook.airlift.log.Logger; -import com.facebook.presto.common.type.CharType; -import com.facebook.presto.common.type.TimeType; -import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.SchemaTableName; @@ -28,7 +24,6 @@ import org.apache.arrow.flight.Action; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.Result; -import org.apache.arrow.vector.types.pojo.Field; import javax.inject.Inject; @@ -52,9 +47,14 @@ public class TestingArrowMetadata private final ArrowFlightConfig config; @Inject - public TestingArrowMetadata(ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testConfig, ArrowFlightConfig config) + public TestingArrowMetadata( + ArrowFlightClientHandler clientHandler, + NodeManager nodeManager, + TestingArrowFlightConfig testConfig, + ArrowFlightConfig config, + ArrowBlockBuilder arrowBlockBuilder) { - super(config, clientHandler); + super(config, clientHandler, arrowBlockBuilder); this.nodeManager = nodeManager; this.testConfig = testConfig; this.clientHandler = clientHandler; @@ -110,28 +110,6 @@ public List extractSchemaAndTableData(Optional schema, Connector } } - @Override - protected Type getPrestoTypeFromArrowField(Field field) - { - String columnLength = field.getMetadata().get("columnLength"); - int length = columnLength != null ? Integer.parseInt(columnLength) : 0; - - String nativeType = field.getMetadata().get("columnNativeType"); - - if ("CHAR".equals(nativeType) || "CHARACTER".equals(nativeType)) { - return CharType.createCharType(length); - } - else if ("VARCHAR".equals(nativeType)) { - return VarcharType.createVarcharType(length); - } - else if ("TIME".equals(nativeType)) { - return TimeType.TIME; - } - else { - return super.getPrestoTypeFromArrowField(field); - } - } - @Override protected String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName) { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java index cab5872a5507a..67657d78633cd 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowModule.java @@ -31,5 +31,6 @@ public void configure(Binder binder) binder.bind(ConnectorSplitManager.class).to(TestingArrowSplitManager.class).in(Scopes.SINGLETON); binder.bind(ArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); binder.bind(ConnectorMetadata.class).to(TestingArrowMetadata.class).in(Scopes.SINGLETON); + binder.bind(ArrowBlockBuilder.class).to(TestingArrowPageBuilder.class).in(Scopes.SINGLETON); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPageBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPageBuilder.java new file mode 100644 index 0000000000000..e7ab908fcc52b --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowPageBuilder.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import org.apache.arrow.vector.types.pojo.Field; + +public class TestingArrowPageBuilder + extends ArrowBlockBuilder +{ + @Override + protected Type getPrestoTypeFromArrowField(Field field) + { + String columnLength = field.getMetadata().get("columnLength"); + int length = columnLength != null ? Integer.parseInt(columnLength) : 0; + + String nativeType = field.getMetadata().get("columnNativeType"); + + if ("CHAR".equals(nativeType) || "CHARACTER".equals(nativeType)) { + return CharType.createCharType(length); + } + else if ("VARCHAR".equals(nativeType)) { + return VarcharType.createVarcharType(length); + } + else if ("TIME".equals(nativeType)) { + return TimeType.TIME; + } + else { + return super.getPrestoTypeFromArrowField(field); + } + } +}