diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml new file mode 100644 index 0000000000000..ee77c122536e1 --- /dev/null +++ b/.github/workflows/arrow-flight-tests.yml @@ -0,0 +1,82 @@ +name: arrow flight tests + +on: + pull_request: + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - '!presto-docs/**' + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + # Checkout the code only if there are changes in the relevant files + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + # Set up Java for the build environment + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + # Cache Maven dependencies to speed up the build + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + # Install dependencies for the target module + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + + # Run Maven tests for the target module + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index fd9871e6b2490..b2d1ff04a964e 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -16,18 +16,14 @@ 4.10.0 17.0.0 4.1.110.Final + 1.6.20 + 2.23.0 com.facebook.airlift bootstrap - - - ch.qos.logback - logback-core - - @@ -39,14 +35,7 @@ com.google.guava guava - - org.checkerframework - checker-qual - - - com.google.errorprone - error_prone_annotations - + com.google.j2objc j2objc-annotations @@ -99,6 +88,12 @@ org.apache.arrow arrow-memory-core ${arrow.version} + + + org.slf4j + slf4j-api + + @@ -115,37 +110,29 @@ org.apache.arrow arrow-jdbc ${arrow.version} - - - - io.netty - netty-codec-http2 - ${netty.version} - - - - io.netty - netty-handler-proxy - ${netty.version} - io.netty - netty-codec-http + org.slf4j + slf4j-api - - io.netty - netty-tcnative-boringssl-static - 2.0.65.Final - - org.apache.arrow arrow-vector ${arrow.version} + + org.slf4j + slf4j-api + + + + commons-codec + commons-codec + + com.fasterxml.jackson.datatype jackson-datatype-jsr310 @@ -188,13 +175,6 @@ test - - org.checkerframework - checker-qual - 3.4.1 - test - - com.facebook.presto presto-testng-services @@ -225,6 +205,11 @@ flight-core ${arrow.version} + + org.slf4j + slf4j-api + + com.google.j2objc j2objc-annotations @@ -249,41 +234,10 @@ h2 test - - - org.codehaus.mojo - animal-sniffer-annotations - 1.23 - - - - com.google.j2objc - j2objc-annotations - 1.3 - - - - com.google.errorprone - error_prone_annotations - 2.14.0 - - - - com.google.protobuf - protobuf-java - 3.25.5 - - - - commons-codec - commons-codec - 1.17.0 - - io.netty netty-transport-native-unix-common @@ -321,22 +275,36 @@ - org.objenesis - objenesis - 3.3 + io.netty + netty-handler-proxy + ${netty.version} + + + + io.netty + netty-codec-http + ${netty.version} org.jetbrains.kotlin kotlin-stdlib-common - 1.6.20 + ${kotlin.version} + + + + com.google.errorprone + error_prone_annotations + ${error_prone_annotations} - org.slf4j - slf4j-api - 2.0.13 + org.apache.arrow + arrow-algorithm + ${arrow.version} + compile + @@ -352,26 +320,10 @@ org.apache.maven.plugins maven-enforcer-plugin - - - - - com.google.errorprone:error_prone_annotations - - - - org.apache.maven.plugins maven-dependency-plugin - - - io.netty:netty-codec-http2 - io.netty:netty-handler-proxy - io.netty:netty-tcnative-boringssl-static - - org.basepom.maven diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 0507a76749977..c991d3c4c1f5d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -38,16 +38,20 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -70,76 +74,67 @@ public AbstractArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler this.clientHandler = requireNonNull(clientHandler); } - private ArrowColumnHandle createArrowColumnHandleForFloatingPointType(String columnName, ArrowType.FloatingPoint floatingPoint) + private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) { switch (floatingPoint.getPrecision()) { case SINGLE: - return new ArrowColumnHandle(columnName, RealType.REAL); + return RealType.REAL; case DOUBLE: - return new ArrowColumnHandle(columnName, DoubleType.DOUBLE); + return DoubleType.DOUBLE; default: throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPoint.getPrecision()); } } - private ArrowColumnHandle createArrowColumnHandleForIntType(String columnName, ArrowType.Int intType) + private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) { switch (intType.getBitWidth()) { case 64: - return new ArrowColumnHandle(columnName, BigintType.BIGINT); + return BigintType.BIGINT; case 32: - return new ArrowColumnHandle(columnName, IntegerType.INTEGER); + return IntegerType.INTEGER; case 16: - return new ArrowColumnHandle(columnName, SmallintType.SMALLINT); + return SmallintType.SMALLINT; case 8: - return new ArrowColumnHandle(columnName, TinyintType.TINYINT); + return TinyintType.TINYINT; default: throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth()); } } - private ColumnMetadata createIntColumnMetadata(String columnName, ArrowType.Int intType) + protected Type getPrestoTypeFromArrowField(Field field) { - switch (intType.getBitWidth()) { - case 64: - return new ColumnMetadata(columnName, BigintType.BIGINT); - case 32: - return new ColumnMetadata(columnName, IntegerType.INTEGER); - case 16: - return new ColumnMetadata(columnName, SmallintType.SMALLINT); - case 8: - return new ColumnMetadata(columnName, TinyintType.TINYINT); - default: - throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid bit width " + intType.getBitWidth()); - } - } - - private ColumnMetadata createFloatingPointColumnMetadata(String columnName, ArrowType.FloatingPoint floatingPointType) - { - switch (floatingPointType.getPrecision()) { - case SINGLE: - return new ColumnMetadata(columnName, RealType.REAL); - case DOUBLE: - return new ColumnMetadata(columnName, DoubleType.DOUBLE); + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + return getPrestoTypeForArrowIntType(intType); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return VarbinaryType.VARBINARY; + case Date: + return DateType.DATE; + case Timestamp: + return TimestampType.TIMESTAMP; + case Utf8: + case LargeUtf8: + return VarcharType.VARCHAR; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + return getPrestoTypeForArrowFloatingPointType(floatingPoint); + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); + case Bool: + return BooleanType.BOOLEAN; + case Time: + return TimeType.TIME; default: - throw new ArrowException(ARROW_FLIGHT_ERROR, "Invalid floating point precision " + floatingPointType.getPrecision()); + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); } } - /** - * Provides the field type, which can be overridden by concrete implementations - * with their own custom type. - * - * @return the field type - */ - protected Type overrideFieldType(Field field, Type type) - { - return type; - } - - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table); - - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema); + protected abstract FlightDescriptor getFlightDescriptor(Optional query, String schema, String table); protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); @@ -163,11 +158,11 @@ public List getColumnsList(String schema, String table, ConnectorSession try { String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), Optional.empty(), + FlightDescriptor flightDescriptor = getFlightDescriptor(Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, connectorSession); - List fields = flightInfo.getSchema().getFields(); + Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); + List fields = flightschema.map(Schema::getFields).orElse(Collections.emptyList()); return fields; } catch (Exception e) { @@ -178,7 +173,7 @@ public List getColumnsList(String schema, String table, ConnectorSession @Override public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { - Map column = new HashMap<>(); + Map columnHandles = new HashMap<>(); String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); String tableValue = ((ArrowTableHandle) tableHandle).getTable(); @@ -190,56 +185,22 @@ public Map getColumnHandles(ConnectorSession session, Conn String columnName = field.getName(); logger.debug("The value of the flight columnName is:- %s", columnName); - ArrowColumnHandle handle; - switch (field.getType().getTypeID()) { - case Int: - ArrowType.Int intType = (ArrowType.Int) field.getType(); - handle = createArrowColumnHandleForIntType(columnName, intType); - break; - case Binary: - case LargeBinary: - case FixedSizeBinary: - handle = new ArrowColumnHandle(columnName, VarbinaryType.VARBINARY); - break; - case Date: - handle = new ArrowColumnHandle(columnName, DateType.DATE); - break; - case Timestamp: - handle = new ArrowColumnHandle(columnName, TimestampType.TIMESTAMP); - break; - case Utf8: - case LargeUtf8: - handle = new ArrowColumnHandle(columnName, VarcharType.VARCHAR); - break; - case FloatingPoint: - ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); - handle = createArrowColumnHandleForFloatingPointType(columnName, floatingPoint); - break; - case Decimal: - ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); - handle = new ArrowColumnHandle(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale())); - break; - case Bool: - handle = new ArrowColumnHandle(columnName, BooleanType.BOOLEAN); - break; - case Time: - handle = new ArrowColumnHandle(columnName, TimeType.TIME); - break; - default: - throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); - } - Type type = overrideFieldType(field, handle.getColumnType()); - if (!type.equals(handle.getColumnType())) { - handle = new ArrowColumnHandle(columnName, type); - } - column.put(columnName, handle); + Type type = getPrestoTypeFromArrowField(field); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); } - return column; + return columnHandles; } @Override public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { + if (!(table instanceof ArrowTableHandle)) { + throw new PrestoException( + StandardErrorCode.INVALID_CAST_ARGUMENT, + "Invalid table handle: Expected an instance of ArrowTableHandle but received " + + table.getClass().getSimpleName() + ""); + } + ArrowTableHandle tableHandle = (ArrowTableHandle) table; List columns = new ArrayList<>(); @@ -266,53 +227,8 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect for (Field field : columnList) { String columnName = field.getName(); - ArrowType type = field.getType(); - - ColumnMetadata columnMetadata; - - switch (type.getTypeID()) { - case Int: - ArrowType.Int intType = (ArrowType.Int) type; - columnMetadata = createIntColumnMetadata(columnName, intType); - break; - case Binary: - case LargeBinary: - case FixedSizeBinary: - columnMetadata = new ColumnMetadata(columnName, VarbinaryType.VARBINARY); - break; - case Date: - columnMetadata = new ColumnMetadata(columnName, DateType.DATE); - break; - case Timestamp: - columnMetadata = new ColumnMetadata(columnName, TimestampType.TIMESTAMP); - break; - case Utf8: - case LargeUtf8: - columnMetadata = new ColumnMetadata(columnName, VarcharType.VARCHAR); - break; - case FloatingPoint: - ArrowType.FloatingPoint floatingPointType = (ArrowType.FloatingPoint) type; - columnMetadata = createFloatingPointColumnMetadata(columnName, floatingPointType); - break; - case Decimal: - ArrowType.Decimal decimalType = (ArrowType.Decimal) type; - columnMetadata = new ColumnMetadata(columnName, DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale())); - break; - case Time: - columnMetadata = new ColumnMetadata(columnName, TimeType.TIME); - break; - case Bool: - columnMetadata = new ColumnMetadata(columnName, BooleanType.BOOLEAN); - break; - default: - throw new UnsupportedOperationException("The data type " + type.getTypeID() + " is not supported."); - } - - Type fieldType = overrideFieldType(field, columnMetadata.getType()); - if (!fieldType.equals(columnMetadata.getType())) { - columnMetadata = new ColumnMetadata(columnName, fieldType); - } - meta.add(columnMetadata); + Type fieldType = getPrestoTypeFromArrowField(field); + meta.add(new ColumnMetadata(columnName, fieldType)); } return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java index 49d4fe91fc171..e92afe8b0a4b8 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import java.util.List; @@ -37,17 +38,17 @@ public AbstractArrowSplitManager(ArrowFlightClientHandler client) this.clientHandler = client; } - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) { ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), tableLayoutHandle); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, session); + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, session); List splits = flightInfo.getEndpoints() .stream() .map(info -> new ArrowSplit( diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index 9523bd210faa5..1028af2414308 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -34,16 +34,20 @@ public class ArrowConnector private final ConnectorPageSourceProvider pageSourceProvider; private final ConnectorHandleResolver handleResolver; + private final ArrowFlightClientHandler arrowFlightClientHandler; + @Inject public ArrowConnector(ConnectorMetadata metadata, - ConnectorHandleResolver handleResolver, - ConnectorSplitManager splitManager, - ConnectorPageSourceProvider pageSourceProvider) + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + ArrowFlightClientHandler arrowFlightClientHandler) { this.metadata = requireNonNull(metadata, "Metadata is null"); - this.handleResolver = requireNonNull(handleResolver, "Metadata is null"); + this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.splitManager = requireNonNull(splitManager, "SplitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); } public Optional getHandleResolver() @@ -74,4 +78,10 @@ public ConnectorPageSourceProvider getPageSourceProvider() { return pageSourceProvider; } + + @Override + public void shutdown() + { + arrowFlightClientHandler.closeRootallocator(); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java index f17317dd42b3a..e070b2c4624d6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -46,7 +46,7 @@ public class ArrowConnectorFactory public ArrowConnectorFactory(String name, Module module, ClassLoader classLoader) { checkArgument(!isNullOrEmpty(name), "name is null or empty"); - this.name = name; + this.name = requireNonNull(name, "name is null"); this.module = requireNonNull(module, "module is null"); this.classLoader = requireNonNull(classLoader, "classLoader is null"); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java index 1a9d964d7458d..3d12617a0839d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java @@ -14,7 +14,6 @@ package com.facebook.plugin.arrow; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.memory.RootAllocator; import java.io.IOException; import java.io.InputStream; @@ -27,13 +26,11 @@ public class ArrowFlightClient { private final FlightClient flightClient; private final Optional trustedCertificate; - private RootAllocator allocator; - public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate, RootAllocator allocator) + public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate) { this.flightClient = requireNonNull(flightClient, "flightClient cannot be null"); - this.trustedCertificate = trustedCertificate; - this.allocator = allocator; + this.trustedCertificate = requireNonNull(trustedCertificate, "trustedCertificate is null"); } public FlightClient getFlightClient() @@ -53,9 +50,5 @@ public void close() throws InterruptedException, IOException if (trustedCertificate.isPresent()) { trustedCertificate.get().close(); } - if (allocator != null) { - allocator.close(); - allocator = null; - } } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index ba75cbffc088c..a85edf42358bb 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -21,6 +21,7 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; import java.io.FileInputStream; import java.io.InputStream; @@ -33,6 +34,8 @@ public abstract class ArrowFlightClientHandler private static final Logger logger = Logger.get(ArrowFlightClientHandler.class); private final ArrowFlightConfig config; + private RootAllocator allocator; + public ArrowFlightClientHandler(ArrowFlightConfig config) { this.config = config; @@ -41,7 +44,6 @@ public ArrowFlightClientHandler(ArrowFlightConfig config) private ArrowFlightClient initializeClient(Optional uri) { try { - RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); Optional trustedCertificate = Optional.empty(); Location location; @@ -57,6 +59,10 @@ private ArrowFlightClient initializeClient(Optional uri) } } + if (null == allocator) { + initializeAllocator(); + } + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); if (config.getVerifyServer() != null && !config.getVerifyServer()) { flightClientBuilder.verifyServer(false); @@ -67,26 +73,22 @@ else if (config.getFlightServerSSLCertificate() != null) { } FlightClient flightClient = flightClientBuilder.build(); - return new ArrowFlightClient(flightClient, trustedCertificate, allocator); + return new ArrowFlightClient(flightClient, trustedCertificate); } catch (Exception ex) { throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex); } } - protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); - - /** - * Connector implementations can override this method to get a FlightDescriptor - * from command or path. - * @param flightRequest - * @return - */ - protected FlightDescriptor getFlightDescriptor(ArrowFlightRequest flightRequest) + private synchronized void initializeAllocator() { - return FlightDescriptor.command(flightRequest.getCommand()); + if (allocator == null) { + allocator = new RootAllocator(Long.MAX_VALUE); + } } + protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); + public ArrowFlightConfig getConfig() { return config; @@ -97,13 +99,12 @@ public ArrowFlightClient getClient(Optional uri) return initializeClient(uri); } - public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession connectorSession) + public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) { try (ArrowFlightClient client = getClient(Optional.empty())) { CredentialCallOption auth = this.getCallOptions(connectorSession); - FlightDescriptor descriptor = getFlightDescriptor(request); logger.debug("Fetching flight info"); - FlightInfo flightInfo = client.getFlightClient().getInfo(descriptor, auth); + FlightInfo flightInfo = client.getFlightClient().getInfo(flightDescriptor, auth); logger.debug("got flight info"); return flightInfo; } @@ -111,4 +112,16 @@ public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession con throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); } } + + public Optional getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + return getFlightInfo(flightDescriptor, connectorSession).getSchemaOptional(); + } + + public void closeRootallocator() + { + if (null != allocator) { + allocator.close(); + } + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java deleted file mode 100644 index 7e04e0a6066e3..0000000000000 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.plugin.arrow; - -public interface ArrowFlightRequest -{ - byte[] getCommand(); -} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index b83e791f2163e..ec51125a34848 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -16,54 +16,19 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.CharType; -import com.facebook.presto.common.type.DateType; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.TimeType; -import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; -import com.google.common.base.CharMatcher; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Ticket; -import org.apache.arrow.vector.BigIntVector; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.DateDayVector; -import org.apache.arrow.vector.DateMilliVector; -import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.Float8Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.NullVector; -import org.apache.arrow.vector.SmallIntVector; -import org.apache.arrow.vector.TimeMicroVector; -import org.apache.arrow.vector.TimeMilliVector; -import org.apache.arrow.vector.TimeSecVector; -import org.apache.arrow.vector.TimeStampMicroVector; -import org.apache.arrow.vector.TimeStampMilliTZVector; -import org.apache.arrow.vector.TimeStampMilliVector; -import org.apache.arrow.vector.TimeStampSecVector; -import org.apache.arrow.vector.TinyIntVector; -import org.apache.arrow.vector.VarBinaryVector; -import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; -import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.LocalTime; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.concurrent.TimeUnit; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; @@ -100,454 +65,6 @@ private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] tick } } - private Block buildBlockFromVector(FieldVector vector, Type type) - { - if (vector instanceof BitVector) { - return buildBlockFromBitVector((BitVector) vector, type); - } - else if (vector instanceof TinyIntVector) { - return buildBlockFromTinyIntVector((TinyIntVector) vector, type); - } - else if (vector instanceof IntVector) { - return buildBlockFromIntVector((IntVector) vector, type); - } - else if (vector instanceof SmallIntVector) { - return buildBlockFromSmallIntVector((SmallIntVector) vector, type); - } - else if (vector instanceof BigIntVector) { - return buildBlockFromBigIntVector((BigIntVector) vector, type); - } - else if (vector instanceof DecimalVector) { - return buildBlockFromDecimalVector((DecimalVector) vector, type); - } - else if (vector instanceof NullVector) { - return buildBlockFromNullVector((NullVector) vector, type); - } - else if (vector instanceof TimeStampMicroVector) { - return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliVector) { - return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); - } - else if (vector instanceof Float4Vector) { - return buildBlockFromFloat4Vector((Float4Vector) vector, type); - } - else if (vector instanceof Float8Vector) { - return buildBlockFromFloat8Vector((Float8Vector) vector, type); - } - else if (vector instanceof VarCharVector) { - if (type instanceof CharType) { - return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else if (type instanceof TimeType) { - return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else { - return buildBlockFromVarCharVector((VarCharVector) vector, type); - } - } - else if (vector instanceof VarBinaryVector) { - return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); - } - else if (vector instanceof DateDayVector) { - return buildBlockFromDateDayVector((DateDayVector) vector, type); - } - else if (vector instanceof DateMilliVector) { - return buildBlockFromDateMilliVector((DateMilliVector) vector, type); - } - else if (vector instanceof TimeMilliVector) { - return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); - } - else if (vector instanceof TimeSecVector) { - return buildBlockFromTimeSecVector((TimeSecVector) vector, type); - } - else if (vector instanceof TimeStampSecVector) { - return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); - } - else if (vector instanceof TimeMicroVector) { - return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliTZVector) { - return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); - } - throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); - } - - private Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromBitVector(BitVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeBoolean(builder, vector.get(i) == 1); - } - } - return builder.build(); - } - - private Block buildBlockFromIntVector(IntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromBigIntVector(BigIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDecimalVector(DecimalVector vector, Type type) - { - if (!(type instanceof DecimalType)) { - throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); - } - - DecimalType decimalType = (DecimalType) type; - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value - if (decimalType.isShort()) { - builder.writeLong(decimal.unscaledValue().longValue()); - } - else { - Slice slice = Decimals.encodeScaledValue(decimal); - decimalType.writeSlice(builder, slice, 0, slice.length()); - } - } - } - return builder.build(); - } - - private Block buildBlockFromNullVector(NullVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - builder.appendNull(); - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long micros = vector.get(i); - long millis = TimeUnit.MICROSECONDS.toMillis(micros); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeDouble(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int intBits = Float.floatToIntBits(vector.get(i)); - type.writeLong(builder, intBits); - } - } - return builder.build(); - } - - private Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - byte[] value = vector.get(i); - type.writeSlice(builder, Slices.wrappedBuffer(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromVarCharVector(VarCharVector vector, Type type) - { - if (!(type instanceof VarcharType)) { - throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateDayVector(DateDayVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - DateType dateType = (DateType) type; - long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); - dateType.writeLong(builder, days); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); - } - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long micro = TimeUnit.MICROSECONDS.toMillis(value); - type.writeLong(builder, micro); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); - } - } - return builder.build(); - } - - private Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String timeString = new String(vector.get(i), StandardCharsets.UTF_8); - LocalTime time = LocalTime.parse(timeString); - long millis = Duration.between(LocalTime.MIN, time).toMillis(); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - @Override public long getCompletedBytes() { @@ -605,7 +122,13 @@ public Page getNextPage() FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); Type type = columnHandles.get(columnIndex).getColumnType(); - Block block = buildBlockFromVector(vector, type); + boolean isDictionaryBlock = vector.getField().getDictionary() != null; + Dictionary dictionary = null; + if (isDictionaryBlock) { + dictionary = flightStream.getDictionaryProvider().lookup(vector.getField().getDictionary().getId()); + } + Block block = null != dictionary ? ArrowPageUtils.buildBlockFromVector(vector, type, dictionary.getVector(), isDictionaryBlock) : + ArrowPageUtils.buildBlockFromVector(vector, type, null, false); blocks.add(block); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java new file mode 100644 index 0000000000000..04b20d00c24a8 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -0,0 +1,968 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListReader; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.JsonStringArrayList; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static java.util.Objects.requireNonNull; + +public class ArrowPageUtils +{ + private ArrowPageUtils() + { + } + + public static Block buildBlockFromVector(FieldVector vector, Type type, FieldVector dictionary, boolean isDictionaryVector) + { + if (isDictionaryVector) { + return buildBlockFromDictionaryVector(vector, dictionary); + } + else if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + else if (vector instanceof ListVector) { + return buildBlockFromListVector((ListVector) vector, type); + } + + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); + } + + public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) + { + // Validate inputs + requireNonNull(fieldVector, "encoded vector is null"); + requireNonNull(dictionaryVector, "dictionary vector is null"); + + // Create a BlockBuilder for the decoded vector's data type + Type prestoType = getPrestoTypeFromArrowType(dictionaryVector.getField().getType()); + + Block dictionaryblock = null; + // Populate the block dynamically based on vector type + for (int i = 0; i < dictionaryVector.getValueCount(); i++) { + if (!dictionaryVector.isNull(i)) { + dictionaryblock = appendValueToBlock(dictionaryVector, prestoType); + } + } + + return getDictionaryBlock(fieldVector, dictionaryblock); + + // Create the Presto DictionaryBlock + } + + private static DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + { + if (fieldVector instanceof IntVector) { + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof SmallIntVector) { + // Get the SmallInt indices vector + SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; + int[] ids = new int[smallIntIndicesVector.getValueCount()]; + for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) { + ids[i] = smallIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof TinyIntVector) { + // Get the TinyInt indices vector + TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector; + int[] ids = new int[tinyIntIndicesVector.getValueCount()]; + for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) { + ids[i] = tinyIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else { + // Handle the case where the FieldVector is of an unsupported type + throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass()); + } + } + + private static Type getPrestoTypeFromArrowType(ArrowType arrowType) + { + if (arrowType instanceof ArrowType.Utf8) { + return VarcharType.VARCHAR; + } + else if (arrowType instanceof ArrowType.Int) { + ArrowType.Int intType = (ArrowType.Int) arrowType; + if (intType.getBitWidth() == 8 || intType.getBitWidth() == 16 || intType.getBitWidth() == 32) { + return IntegerType.INTEGER; + } + else if (intType.getBitWidth() == 64) { + return BigintType.BIGINT; + } + } + else if (arrowType instanceof ArrowType.FloatingPoint) { + ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; + FloatingPointPrecision precision = fpType.getPrecision(); + + if (precision == FloatingPointPrecision.SINGLE) { // 32-bit float + return RealType.REAL; + } + else if (precision == FloatingPointPrecision.DOUBLE) { // 64-bit float + return DoubleType.DOUBLE; + } + else { + throw new UnsupportedOperationException("Unsupported FloatingPoint precision: " + precision); + } + } + else if (arrowType instanceof ArrowType.Bool) { + return BooleanType.BOOLEAN; + } + else if (arrowType instanceof ArrowType.Binary) { + return VarbinaryType.VARBINARY; + } + else if (arrowType instanceof ArrowType.Decimal) { + return DecimalType.createDecimalType(); + } + throw new UnsupportedOperationException("Unsupported ArrowType: " + arrowType); + } + + private static Block appendValueToBlock(ValueVector vector, Type prestoType) + { + if (vector instanceof VarCharVector) { + return buildBlockFromVarCharVector((VarCharVector) vector, prestoType); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, prestoType); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, prestoType); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, prestoType); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, prestoType); + } + else if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, prestoType); + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, prestoType); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, prestoType); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, prestoType); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, prestoType); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, prestoType); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, prestoType); + } + else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + public static Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + public static Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromListVector(ListVector vector, Type type) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + arrayBuilder.appendNull(); + } + else { + BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); + UnionListReader reader = vector.getReader(); + reader.setPosition(i); + + while (reader.next()) { + Object value = reader.readObject(); + if (value == null) { + elementBuilder.appendNull(); + } + else { + appendValueToBuilder(elementType, elementBuilder, value); + } + } + arrayBuilder.closeEntry(); + } + } + return arrayBuilder.build(); + } + + public static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + { + if (value == null) { + builder.appendNull(); + return; + } + + if (type instanceof VarcharType) { + writeVarcharType(type, builder, value); + } + else if (type instanceof SmallintType) { + writeSmallintType(type, builder, value); + } + else if (type instanceof TinyintType) { + writeTinyintType(type, builder, value); + } + else if (type instanceof BigintType) { + writeBigintType(type, builder, value); + } + else if (type instanceof IntegerType) { + writeIntegerType(type, builder, value); + } + else if (type instanceof DoubleType) { + writeDoubleType(type, builder, value); + } + else if (type instanceof BooleanType) { + writeBooleanType(type, builder, value); + } + else if (type instanceof DecimalType) { + writeDecimalType((DecimalType) type, builder, value); + } + else if (type instanceof ArrayType) { + writeArrayType((ArrayType) type, builder, value); + } + else if (type instanceof RowType) { + writeRowType((RowType) type, builder, value); + } + else if (type instanceof DateType) { + writeDateType(type, builder, value); + } + else if (type instanceof TimestampType) { + writeTimestampType(type, builder, value); + } + else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + public static void writeVarcharType(Type type, BlockBuilder builder, Object value) + { + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } + + public static void writeSmallintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).shortValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + short shortValue = Short.parseShort(obj.toString()); + type.writeLong(builder, shortValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for SmallintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for SmallintType: " + value.getClass()); + } + } + + public static void writeTinyintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).byteValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + byte byteValue = Byte.parseByte(obj.toString()); + type.writeLong(builder, byteValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for TinyintType: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for TinyintType: " + value.getClass()); + } + } + + public static void writeBigintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } + else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } + + public static void writeIntegerType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } + + public static void writeDoubleType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } + else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } + else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } + + public static void writeBooleanType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } + else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } + + public static void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) + { + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (type.isShort()) { + // write ShortDecimalType + long unscaledValue = decimalValue.unscaledValue().longValue(); + type.writeLong(builder, unscaledValue); + } + else { + // write LongDecimalType + Slice slice = Decimals.encodeScaledValue(decimalValue); + type.writeSlice(builder, slice); + } + } + else if (value instanceof Long) { + // Direct handling for ShortDecimalType using long + if (type.isShort()) { + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Long value is not supported for LongDecimalType."); + } + } + else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + } + } + + public static void writeArrayType(ArrayType type, BlockBuilder builder, Object value) + { + Type elementType = type.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } + + public static void writeRowType(RowType type, BlockBuilder builder, Object value) + { + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = type.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } + + public static void writeDateType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { + int daysSinceEpoch = (int) (value instanceof java.sql.Date + ? ((java.sql.Date) value).toLocalDate().toEpochDay() + : ((java.time.LocalDate) value).toEpochDay()); + type.writeLong(builder, daysSinceEpoch); + } + else { + throw new IllegalArgumentException("Unsupported type for DateType: " + value.getClass()); + } + } + + public static void writeTimestampType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Timestamp) { + long millis = ((java.sql.Timestamp) value).getTime(); + type.writeLong(builder, millis); + } + else if (value instanceof java.time.Instant) { + long millis = ((java.time.Instant) value).toEpochMilli(); + type.writeLong(builder, millis); + } + else if (value instanceof Long) { // write long epoch milliseconds directly + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Unsupported type for TimestampType: " + value.getClass()); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java index e0f8c6586791d..cef04e4372a5a 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Objects; + public class ArrowTableHandle implements ConnectorTableHandle { @@ -49,4 +51,23 @@ public String toString() { return schema + ":" + table; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableHandle that = (ArrowTableHandle) o; + return Objects.equals(schema, that.schema) && Objects.equals(table, that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(schema, table); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java index 512246683562f..46e94a4e1143c 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +import java.util.Objects; import static java.util.Objects.requireNonNull; @@ -32,8 +33,8 @@ public class ArrowTableLayoutHandle @JsonCreator public ArrowTableLayoutHandle(@JsonProperty("table") ArrowTableHandle table, - @JsonProperty("columnHandles") List columnHandles, - @JsonProperty("tupleDomain") TupleDomain domain) + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) { this.tableHandle = requireNonNull(table, "table is null"); this.columnHandles = requireNonNull(columnHandles, "columns are null"); @@ -63,4 +64,23 @@ public String toString() { return "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o; + return Objects.equals(tableHandle, arrowTableLayoutHandle.tableHandle) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain); + } + + @Override + public int hashCode() + { + return Objects.hash(tableHandle, columnHandles, tupleDomain); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java index fedf60137e496..2795a33668f6f 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -25,7 +25,6 @@ public class ArrowFlightQueryRunner { - private static DistributedQueryRunner queryRunner; private static final Logger logger = Logger.get(ArrowFlightQueryRunner.class); private ArrowFlightQueryRunner() { @@ -34,10 +33,7 @@ private ArrowFlightQueryRunner() public static DistributedQueryRunner createQueryRunner() throws Exception { - if (queryRunner == null) { - queryRunner = createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); - } - return queryRunner; + return createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); } private static DistributedQueryRunner createQueryRunner(Map catalogProperties, Class factoryClass) throws Exception @@ -47,9 +43,7 @@ private static DistributedQueryRunner createQueryRunner(Map cata .setSchema("testdb") .build(); - if (queryRunner == null) { - queryRunner = DistributedQueryRunner.builder(session).build(); - } + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).build(); try { String connectorName = "arrow"; diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java new file mode 100644 index 0000000000000..c4ab656d41bb3 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; + +final class ArrowMetadataUtil +{ + private ArrowMetadataUtil() {} + + public static final JsonCodec COLUMN_CODEC; + public static final JsonCodec TABLE_CODEC; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer())); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); + TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); + } + + public static final class TestingTypeDeserializer + extends FromStringDeserializer + { + private final Map types = ImmutableMap.of( + StandardTypes.BIGINT, BIGINT, + StandardTypes.VARCHAR, VARCHAR); + + public TestingTypeDeserializer() + { + super(Type.class); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + Type type = types.get(value.toLowerCase(ENGLISH)); + checkArgument(type != null, "Unknown type %s", value); + return type; + } + } + + public static void assertJsonRoundTrip(JsonCodec codec, T object) + { + String json = codec.toJson(object); + T copy = codec.fromJson(json); + assertEquals(copy, object); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java new file mode 100644 index 0000000000000..432b82b69fb2d --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -0,0 +1,707 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import io.airlift.slice.Slice; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowPageUtils.buildBlockFromDictionaryVector; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class ArrowPageUtilsTest +{ + private static final int DICTIONARY_LENGTH = 10; + private static final int VECTOR_LENGTH = 50; + private BufferAllocator allocator; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + System.out.println("Allocator initialized: " + allocator); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + BitVector bitVector = new BitVector("bitVector", allocator); + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + IntVector intVector = new IntVector("intVector", allocator); + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator); + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2); // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator); + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } + + @Test + public void testBuildBlockFromListVector() + { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); + + // Validate the result + assertEquals(block.getPositionCount(), 4); // 4 lists in the block + } + } + + @Test + public void testProcessDictionaryVector() + { + // Create dictionary vector + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(DICTIONARY_LENGTH); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + dictionaryVector.setValueCount(DICTIONARY_LENGTH); + + // Create raw vector + VarCharVector rawVector = new VarCharVector("raw", allocator); + rawVector.allocateNew(VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int value = i % DICTIONARY_LENGTH; + rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8)); + } + rawVector.setValueCount(VECTOR_LENGTH); + + // Encode using dictionary + ArrowType.Int index = new ArrowType.Int(16,true); + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, index)); + BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); + + // Process the dictionary vector + Block result = buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); + + // Verify the result + assertNotNull(result, "The BlockBuilder should not be null."); + assertEquals(result.getPositionCount(), 50); + } + + @Test + public void testBuildBlockFromDictionaryVector() + { + IntVector indicesVector = new IntVector("indices", allocator); + indicesVector.allocateNew(3); // allocating space for 3 values + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Set up index values (this would reference the dictionary) + indicesVector.set(0, 0); // First index points to "apple" + indicesVector.set(1, 1); // Second index points to "banana" + indicesVector.set(2, 2); + indicesVector.set(3, 2); // Third index points to "cherry" + indicesVector.setValueCount(4); + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + else if (i == 3) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorBigInt() + { + BigIntVector indicesVector = new BigIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, 0L); + indicesVector.set(1, 1L); + indicesVector.set(2, 2L); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorSmallInt() + { + SmallIntVector indicesVector = new SmallIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (short) 0); + indicesVector.set(1, (short) 1); + indicesVector.set(2, (short) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorTinyInt() + { + TinyIntVector indicesVector = new TinyIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (byte) 0); + indicesVector.set(1, (byte) 1); + indicesVector.set(2, (byte) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testWriteVarcharType() + { + Type varcharType = VarcharType.createUnboundedVarcharType(); + BlockBuilder builder = varcharType.createBlockBuilder(null, 1); + + String value = "test_string"; + ArrowPageUtils.writeVarcharType(varcharType, builder, value); + + Block block = builder.build(); + Slice result = varcharType.getSlice(block, 0); + assertEquals(result.toStringUtf8(), value); + } + + @Test + public void testWriteSmallintType() + { + Type smallintType = SmallintType.SMALLINT; + BlockBuilder builder = smallintType.createBlockBuilder(null, 1); + + short value = 42; + ArrowPageUtils.writeSmallintType(smallintType, builder, value); + + Block block = builder.build(); + long result = smallintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTinyintType() + { + Type tinyintType = TinyintType.TINYINT; + BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); + + byte value = 7; + ArrowPageUtils.writeTinyintType(tinyintType, builder, value); + + Block block = builder.build(); + long result = tinyintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteBigintType() + { + Type bigintType = BigintType.BIGINT; + BlockBuilder builder = bigintType.createBlockBuilder(null, 1); + + long value = 123456789L; + ArrowPageUtils.writeBigintType(bigintType, builder, value); + + Block block = builder.build(); + long result = bigintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteIntegerType() + { + Type integerType = IntegerType.INTEGER; + BlockBuilder builder = integerType.createBlockBuilder(null, 1); + + int value = 42; + ArrowPageUtils.writeIntegerType(integerType, builder, value); + + Block block = builder.build(); + long result = integerType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteDoubleType() + { + Type doubleType = DoubleType.DOUBLE; + BlockBuilder builder = doubleType.createBlockBuilder(null, 1); + + double value = 42.42; + ArrowPageUtils.writeDoubleType(doubleType, builder, value); + + Block block = builder.build(); + double result = doubleType.getDouble(block, 0); + assertEquals(result, value, 0.001); + } + + @Test + public void testWriteBooleanType() + { + Type booleanType = BooleanType.BOOLEAN; + BlockBuilder builder = booleanType.createBlockBuilder(null, 1); + + boolean value = true; + ArrowPageUtils.writeBooleanType(booleanType, builder, value); + + Block block = builder.build(); + boolean result = booleanType.getBoolean(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteArrayType() + { + Type elementType = IntegerType.INTEGER; + ArrayType arrayType = new ArrayType(elementType); + BlockBuilder builder = arrayType.createBlockBuilder(null, 1); + + List values = Arrays.asList(1, 2, 3); + ArrowPageUtils.writeArrayType(arrayType, builder, values); + + Block block = builder.build(); + Block arrayBlock = arrayType.getObject(block, 0); + assertEquals(arrayBlock.getPositionCount(), values.size()); + for (int i = 0; i < values.size(); i++) { + assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue()); + } + } + + @Test + public void testWriteRowType() + { + RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); + RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); + RowType rowType = RowType.from(Arrays.asList(field1, field2)); + BlockBuilder builder = rowType.createBlockBuilder(null, 1); + + List rowValues = Arrays.asList(42, "test"); + ArrowPageUtils.writeRowType(rowType, builder, rowValues); + + Block block = builder.build(); + Block rowBlock = rowType.getObject(block, 0); + assertEquals(IntegerType.INTEGER.getLong(rowBlock, 0), 42); + assertEquals(VarcharType.createUnboundedVarcharType().getSlice(rowBlock, 1).toStringUtf8(), "test"); + } + + @Test + public void testWriteDateType() + { + Type dateType = DateType.DATE; + BlockBuilder builder = dateType.createBlockBuilder(null, 1); + + LocalDate value = LocalDate.of(2020, 1, 1); + ArrowPageUtils.writeDateType(dateType, builder, value); + + Block block = builder.build(); + long result = dateType.getLong(block, 0); + assertEquals(result, value.toEpochDay()); + } + + @Test + public void testWriteTimestampType() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC + ArrowPageUtils.writeTimestampType(timestampType, builder, value); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testWriteTimestampTypeWithSqlTimestamp() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); + long expectedMillis = timestamp.getTime(); + ArrowPageUtils.writeTimestampType(timestampType, builder, timestamp); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, expectedMillis); + } + + @Test + public void testShortDecimalRetrieval() + { + DecimalType shortDecimalType = DecimalType.createDecimalType(10, 2); // Precision: 10, Scale: 2 + BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); + + BigDecimal decimalValue = new BigDecimal("12345.67"); + ArrowPageUtils.writeDecimalType(shortDecimalType, builder, decimalValue); + + Block block = builder.build(); + long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 + BigDecimal result = BigDecimal.valueOf(unscaledValue).movePointLeft(shortDecimalType.getScale()); + assertEquals(result, decimalValue); + } + + @Test + public void testLongDecimalRetrieval() + { + // Create a DecimalType with precision 38 and scale 10 + DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); + BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); + BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); + ArrowPageUtils.writeDecimalType(longDecimalType, builder, decimalValue); + // Build the block after inserting the decimal value + Block block = builder.build(); + Slice unscaledSlice = longDecimalType.getSlice(block, 0); + BigInteger unscaledValue = Decimals.decodeUnscaledValue(unscaledSlice); + BigDecimal result = new BigDecimal(unscaledValue).movePointLeft(longDecimalType.getScale()); + // Assert the decoded result is equal to the original decimal value + assertEquals(result, decimalValue); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java new file mode 100644 index 0000000000000..1d9c490180abc --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.spi.ColumnMetadata; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestArrowColumnHandle +{ + @Test + public void testConstructorAndGetters() + { + // Given + String columnName = "testColumn"; + // When + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // Then + assertEquals(columnHandle.getColumnName(), columnName, "Column name should match the input"); + assertEquals(columnHandle.getColumnType(), IntegerType.INTEGER, "Column type should match the input"); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnName() + { + // Given + // When + new ArrowColumnHandle(null, IntegerType.INTEGER); // Should throw NullPointerException + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnType() + { + // Given + String columnName = "testColumn"; + + // When + new ArrowColumnHandle(columnName, null); // Should throw NullPointerException + } + + @Test + public void testGetColumnMetadata() + { + // Given + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // When + ColumnMetadata columnMetadata = columnHandle.getColumnMetadata(); + + // Then + assertNotNull(columnMetadata, "ColumnMetadata should not be null"); + assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); + } + + @Test + public void testToString() + { + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + String result = columnHandle.toString(); + String expected = columnName + ":" + IntegerType.INTEGER; + assertEquals(result, expected, "toString() should return the correct string representation"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java new file mode 100644 index 0000000000000..1bc50b08cbd37 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +import java.io.File; + +public class TestArrowFlightIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + private static final Logger logger = Logger.get(TestArrowFlightIntegrationSmokeTest.class); + private static RootAllocator allocator; + private static FlightServer server; + private static Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowServer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port " + server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java index 147f66718121a..21b28745cd1d3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java @@ -19,6 +19,7 @@ import com.facebook.presto.testing.MaterializedResult; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueries; +import com.facebook.presto.tests.DistributedQueryRunner; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; import org.apache.arrow.memory.RootAllocator; @@ -52,11 +53,13 @@ public class TestArrowFlightQueries private static RootAllocator allocator; private static FlightServer server; private static Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; @BeforeClass public void setup() throws Exception { + arrowFlightQueryRunner = getDistributedQueryRunner(); File certChainFile = new File("src/test/resources/server.crt"); File privateKeyFile = new File("src/test/resources/server.key"); @@ -83,6 +86,7 @@ public void close() { server.close(); allocator.close(); + arrowFlightQueryRunner.close(); } @Test diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java new file mode 100644 index 0000000000000..ea95e9fec01b0 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestArrowHandleResolver +{ + @Test + public void testGetTableHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableHandleClass(), + ArrowTableHandle.class, + "getTableHandleClass should return ArrowTableHandle class."); + } + @Test + public void testGetTableLayoutHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableLayoutHandleClass(), + ArrowTableLayoutHandle.class, + "getTableLayoutHandleClass should return ArrowTableLayoutHandle class."); + } + @Test + public void testGetColumnHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getColumnHandleClass(), + ArrowColumnHandle.class, + "getColumnHandleClass should return ArrowColumnHandle class."); + } + @Test + public void testGetSplitClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getSplitClass(), + ArrowSplit.class, + "getSplitClass should return ArrowSplit class."); + } + @Test + public void testGetTransactionHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTransactionHandleClass(), + ArrowTransactionHandle.class, + "getTransactionHandleClass should return ArrowTransactionHandle class."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java new file mode 100644 index 0000000000000..65da26254bd34 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestArrowSplit +{ + private ArrowSplit arrowSplit; + private String schemaName; + private String tableName; + private byte[] ticket; + private List locationUrls; + + @BeforeMethod + public void setUp() + { + schemaName = "testSchema"; + tableName = "testTable"; + ticket = new byte[] {1, 2, 3, 4}; + locationUrls = Arrays.asList("http://localhost:8080", "http://localhost:8081"); + + // Instantiate ArrowSplit with mock data + arrowSplit = new ArrowSplit(schemaName, tableName, ticket, locationUrls); + } + + @Test + public void testConstructorAndGetters() + { + // Test that the constructor correctly initializes fields + assertEquals(arrowSplit.getSchemaName(), schemaName, "Schema name should match."); + assertEquals(arrowSplit.getTableName(), tableName, "Table name should match."); + assertEquals(arrowSplit.getTicket(), ticket, "Ticket byte array should match."); + assertEquals(arrowSplit.getLocationUrls(), locationUrls, "Location URLs list should match."); + } + + @Test + public void testNodeSelectionStrategy() + { + // Test that the node selection strategy is NO_PREFERENCE + assertEquals(arrowSplit.getNodeSelectionStrategy(), NodeSelectionStrategy.NO_PREFERENCE, "Node selection strategy should be NO_PREFERENCE."); + } + + @Test + public void testGetPreferredNodes() + { + // Test that the preferred nodes list is empty + List preferredNodes = arrowSplit.getPreferredNodes(null); + assertNotNull(preferredNodes, "Preferred nodes list should not be null."); + assertTrue(preferredNodes.isEmpty(), "Preferred nodes list should be empty."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java new file mode 100644 index 0000000000000..2061fe5036534 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.testing.EquivalenceTester; +import org.testng.annotations.Test; + +import static com.facebook.plugin.arrow.ArrowMetadataUtil.TABLE_CODEC; +import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip; + +public class TestArrowTableHandle +{ + @Test + public void testJsonRoundTrip() + { + assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table")); + } + + @Test + public void testEquivalence() + { + EquivalenceTester.equivalenceTester() + .addEquivalentGroup( + new ArrowTableHandle("tm_engine", "employees")).check(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..0ff7301c7e0ff --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestArrowTableLayoutHandle +{ + @Test + public void testConstructorAndGetters() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + assertEquals(layoutHandle.getTableHandle(), tableHandle, "Table handle mismatch."); + assertEquals(layoutHandle.getColumnHandles(), columnHandles, "Column handles mismatch."); + assertEquals(layoutHandle.getTupleDomain(), tupleDomain, "Tuple domain mismatch."); + } + + @Test + public void testToString() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", BigintType.BIGINT)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + String expectedString = "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + assertEquals(layoutHandle.toString(), expectedString, "toString output mismatch."); + } + + @Test + public void testEqualsAndHashCode() + { + ArrowTableHandle tableHandle1 = new ArrowTableHandle("schema", "table"); + ArrowTableHandle tableHandle2 = new ArrowTableHandle("schema", "different_table"); + + List columnHandles1 = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + List columnHandles2 = Collections.singletonList( + new ArrowColumnHandle("column1", IntegerType.INTEGER)); + + TupleDomain tupleDomain1 = TupleDomain.all(); + TupleDomain tupleDomain2 = TupleDomain.none(); + + ArrowTableLayoutHandle layoutHandle1 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle2 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle3 = new ArrowTableLayoutHandle(tableHandle2, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle4 = new ArrowTableLayoutHandle(tableHandle1, columnHandles2, tupleDomain1); + ArrowTableLayoutHandle layoutHandle5 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain2); + + // Test equality + assertEquals(layoutHandle1, layoutHandle2, "Handles with same attributes should be equal."); + assertNotEquals(layoutHandle1, layoutHandle3, "Handles with different tableHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle4, "Handles with different columnHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle5, "Handles with different tupleDomains should not be equal."); + assertNotEquals(layoutHandle1, null, "Handle should not be equal to null."); + assertNotEquals(layoutHandle1, new Object(), "Handle should not be equal to an object of another class."); + + // Test hash codes + assertEquals(layoutHandle1.hashCode(), layoutHandle2.hashCode(), "Equal handles should have same hash code."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle3.hashCode(), "Handles with different tableHandles should have different hash codes."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle4.hashCode(), "Handles with different columnHandles should have different hash codes."); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "table is null") + public void testConstructorNullTableHandle() + { + new ArrowTableLayoutHandle(null, Collections.emptyList(), TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columns are null") + public void testConstructorNullColumnHandles() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), null, TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "domain is null") + public void testConstructorNullTupleDomain() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), Collections.emptyList(), null); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java index 2497a7c68ae3f..dd019a7689cec 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -21,7 +21,6 @@ import java.util.Optional; public class TestingArrowFlightRequest - implements ArrowFlightRequest { private final String schema; private final String table; @@ -74,7 +73,6 @@ public TestingRequestData build() return requestData; } - @Override public byte[] getCommand() { ObjectMapper objectMapper = new ObjectMapper(); @@ -90,25 +88,11 @@ public byte[] getCommand() private TestingConnectionProperties getConnectionProperties() { - TestingConnectionProperties properties = new TestingConnectionProperties(); - properties.database = testconfig.getDataSourceDatabase(); - properties.host = testconfig.getDataSourceHost(); - properties.port = testconfig.getDataSourcePort(); - properties.username = testconfig.getDataSourceUsername(); - properties.password = testconfig.getDataSourcePassword(); - return properties; + return new TestingConnectionProperties(testconfig.getDataSourceDatabase(), testconfig.getDataSourcePassword(), testconfig.getDataSourceHost(), testconfig.getDataSourceSSL(), testconfig.getDataSourceUsername()); } private TestingInteractionProperties createInteractionProperties() { - TestingInteractionProperties interactionProperties = new TestingInteractionProperties(); - if (getQuery().isPresent()) { - interactionProperties.setSelectStatement(getQuery().get()); - } - else { - interactionProperties.setSchema(getSchema()); - interactionProperties.setTable(getTable()); - } - return interactionProperties; + return getQuery().isPresent() ? new TestingInteractionProperties(getQuery().get(), getSchema(), getTable()) : new TestingInteractionProperties(null, getSchema(), getTable()); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index d7c101945fa34..3fad658bbbb58 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -26,6 +26,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.ImmutableList; import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.Result; import org.apache.arrow.vector.types.pojo.Field; @@ -46,26 +47,20 @@ public class TestingArrowMetadata private static final Logger logger = Logger.get(TestingArrowMetadata.class); private static final ObjectMapper objectMapper = new ObjectMapper(); private final NodeManager nodeManager; - private final TestingArrowFlightConfig testconfig; + private final TestingArrowFlightConfig testConfig; private final ArrowFlightClientHandler clientHandler; private final ArrowFlightConfig config; @Inject - public TestingArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testconfig) + public TestingArrowMetadata(ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testConfig, ArrowFlightConfig config) { super(config, clientHandler); this.nodeManager = nodeManager; - this.testconfig = testconfig; + this.testConfig = testConfig; this.clientHandler = clientHandler; this.config = config; } - @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) - { - return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); - } - @Override public List listSchemaNames(ConnectorSession session) { @@ -95,7 +90,7 @@ public List extractSchemaAndTableData(Optional schema, Connector { try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { List names = new ArrayList<>(); - ArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); + TestingArrowFlightRequest request = getArrowFlightRequest(schema.orElse(null)); ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); @@ -116,7 +111,7 @@ public List extractSchemaAndTableData(Optional schema, Connector } @Override - protected Type overrideFieldType(Field field, Type type) + protected Type getPrestoTypeFromArrowField(Field field) { String columnLength = field.getMetadata().get("columnLength"); int length = columnLength != null ? Integer.parseInt(columnLength) : 0; @@ -133,7 +128,7 @@ else if ("TIME".equals(nativeType)) { return TimeType.TIME; } else { - return type; + return super.getPrestoTypeFromArrowField(field); } } @@ -150,8 +145,14 @@ protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table) + protected FlightDescriptor getFlightDescriptor(Optional query, String schema, String table) + { + TestingArrowFlightRequest request = new TestingArrowFlightRequest(this.config, testConfig, schema, table, query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); + } + + private TestingArrowFlightRequest getArrowFlightRequest(String schema) { - return new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testConfig); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java new file mode 100644 index 0000000000000..aeed01e6ca473 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowQueryRunner.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestingArrowQueryRunner +{ + private static DistributedQueryRunner queryRunner; + private static final Logger logger = Logger.get(TestingArrowQueryRunner.class); + private TestingArrowQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner() throws Exception + { + if (queryRunner == null) { + queryRunner = createQueryRunner(ImmutableMap.of(), TestingArrowFactory.class); + } + return queryRunner; + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties, Class factoryClass) throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("testdb") + .build(); + + if (queryRunner == null) { + queryRunner = DistributedQueryRunner.builder(session).build(); + } + + try { + String connectorName = "arrow"; + queryRunner.installPlugin(new ArrowPlugin(connectorName, new TestingArrowModule())); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "127.0.0.1") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server.port", "9443") + .put("arrow-flight.server.verify", "false"); + + queryRunner.createCatalog(connectorName, connectorName, properties.build()); + + return queryRunner; + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java index b3124d0425913..141ab4a9d24cd 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java @@ -147,6 +147,7 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen @Override public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) { + throw new UnsupportedOperationException("This operation is not supported"); } @Override @@ -222,7 +223,7 @@ else if (selectStatement != null) { @Override public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { - return null; + throw new UnsupportedOperationException("This operation is not supported"); } @Override diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java index 5206c65717933..34694863bc277 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.NodeManager; import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; import javax.inject.Inject; @@ -28,7 +29,7 @@ public class TestingArrowSplitManager private final NodeManager nodeManager; @Inject - public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) + public TestingArrowSplitManager(ArrowFlightConfig config, ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) { super(client); this.testconfig = testconfig; @@ -36,13 +37,14 @@ public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFli } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) { ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); Optional query = Optional.of(new TestingArrowQueryBuilder().buildSql(tableHandle.getSchema(), tableHandle.getTable(), tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), tableLayoutHandle.getTupleDomain())); - return new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java index 7f265036eeac0..e158fdc67707e 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -13,12 +13,23 @@ */ package com.facebook.plugin.arrow; +import javax.annotation.concurrent.Immutable; + +@Immutable public class TestingConnectionProperties { - public String database; - public String password; - public Integer port; - public String host; - public Boolean ssl; - public String username; + private final String database; + private final String password; + private final String host; + private final Boolean ssl; + private final String username; + + public TestingConnectionProperties(String database, String password, String host, Boolean ssl, String username) + { + this.database = database; + this.password = password; + this.host = host; + this.ssl = ssl; + this.username = username; + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java index 81b4ee5f7e811..8fd06660a7515 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java @@ -15,15 +15,26 @@ import com.fasterxml.jackson.annotation.JsonProperty; -public class TestingInteractionProperties +public final class TestingInteractionProperties { @JsonProperty("select_statement") - private String selectStatement; + private final String selectStatement; @JsonProperty("schema_name") - private String schema; + private final String schema; + @JsonProperty("table_name") - private String table; + private final String table; + + // Constructor to initialize the fields + public TestingInteractionProperties(String selectStatement, String schema, String table) + { + this.selectStatement = selectStatement; + this.schema = schema; + this.table = table; + } + + // Getters (no setters as the fields are final and immutable) public String getSelectStatement() { return selectStatement; @@ -39,18 +50,5 @@ public String getTable() return table; } - public void setSchema(String schema) - { - this.schema = schema; - } - - public void setSelectStatement(String selectStatement) - { - this.selectStatement = selectStatement; - } - - public void setTable(String table) - { - this.table = table; - } + // No setters as the class is immutable } diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst index eb9a62798e9ef..9f5635b9af550 100644 --- a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -3,7 +3,6 @@ Arrow Flight Connector ====================== This connector allows querying multiple data sources that are supported by an Arrow Flight server. -Apache Arrow enhances performance and efficiency in data-intensive applications through its columnar memory layout, zero-copy reads, vectorized execution, cross-language interoperability, rich data type support, and optimization for modern hardware. These features collectively reduce overhead, improve data processing speeds, and facilitate seamless data exchange between different systems and languages. Getting Started with base-arrow-module: Essential Abstract Methods for Developers ---------------------------------------------------------------------------------