From d6c9e0d270acd4f4f1745fef32d8d966cb644515 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 13:55:08 +0530 Subject: [PATCH 01/39] Arrow CI job --- .github/workflows/arrow-tests.yml | 79 +++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 .github/workflows/arrow-tests.yml diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml new file mode 100644 index 0000000000000..31d8f10508320 --- /dev/null +++ b/.github/workflows/arrow-tests.yml @@ -0,0 +1,79 @@ +name: arrow-tests + +on: + pull_request: + paths: + - 'presto-base-arrow-flight/**' # Trigger this workflow only when there are changes in the presto-base-arrow-flight module + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - 'presto-base-arrow-flight/**' + + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + - name: Populate maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl $(echo '${{ matrix.modules }}' | cut -d' ' -f1) + + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} From 2b758d8c431c64cb8753726e75ef4101a61bf7ba Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 14:19:02 +0530 Subject: [PATCH 02/39] added verison in property file --- presto-base-arrow-flight/pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index 47f83962cf255..ea4d3bef2dbee 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -16,6 +16,8 @@ 4.10.0 17.0.0 4.1.110.Final + 1.6.20 + 2.23.0 @@ -288,13 +290,13 @@ org.jetbrains.kotlin kotlin-stdlib-common - 1.6.20 + ${kotlin.version} com.google.errorprone error_prone_annotations - 2.23.0 + ${error_prone_annotations} From 84d4d89fd8d5ee51c559b4827054772f07f24c00 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 15:14:34 +0530 Subject: [PATCH 03/39] testing after removing hyphen --- .github/workflows/arrow-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index 31d8f10508320..c0ddb0049bcd3 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -1,4 +1,4 @@ -name: arrow-tests +name: arrow tests on: pull_request: From 153312e1657bfb55501582ec7b404881f74f321f Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 16:33:47 +0530 Subject: [PATCH 04/39] remove path to make sure CI is run during every build --- .github/workflows/arrow-tests.yml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index c0ddb0049bcd3..b695e808d9d13 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -2,8 +2,6 @@ name: arrow tests on: pull_request: - paths: - - 'presto-base-arrow-flight/**' # Trigger this workflow only when there are changes in the presto-base-arrow-flight module env: CONTINUOUS_INTEGRATION: true @@ -26,7 +24,7 @@ jobs: with: filters: | codechange: - - 'presto-base-arrow-flight/**' + - 'presto-base-arrow-flight/**' # Only trigger if changes are in `presto-base-arrow-flight` test: runs-on: ubuntu-latest @@ -43,17 +41,20 @@ jobs: cancel-in-progress: true steps: + # Checkout the code only if there are changes in the relevant files - uses: actions/checkout@v4 if: needs.changes.outputs.codechange == 'true' with: show-progress: false + # Set up Java for the build environment - uses: actions/setup-java@v2 if: needs.changes.outputs.codechange == 'true' with: distribution: 'temurin' java-version: 8 + # Cache Maven dependencies to speed up the build - name: Cache local Maven repository if: needs.changes.outputs.codechange == 'true' id: cache-maven @@ -64,16 +65,19 @@ jobs: restore-keys: | ${{ runner.os }}-maven-2- - - name: Populate maven cache + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + # Install dependencies for the target module - name: Maven Install if: needs.changes.outputs.codechange == 'true' run: | export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" - ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl $(echo '${{ matrix.modules }}' | cut -d' ' -f1) + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + # Run Maven tests for the target module - name: Maven Tests if: needs.changes.outputs.codechange == 'true' run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} From a5979135d56502b43e68ecaaed6e93fe0aa9e716 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 16:52:47 +0530 Subject: [PATCH 05/39] changed yaml to run CI on evry pull requests and every update except doc update --- .github/workflows/arrow-tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index b695e808d9d13..83ad7c940122b 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -24,8 +24,7 @@ jobs: with: filters: | codechange: - - 'presto-base-arrow-flight/**' # Only trigger if changes are in `presto-base-arrow-flight` - + - '!presto-docs/**' test: runs-on: ubuntu-latest needs: changes From d466a13cbc7d83a74ce5aa163014276e4a1df6fe Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Wed, 20 Nov 2024 12:15:54 +0530 Subject: [PATCH 06/39] review comment fixes --- .../plugin/arrow/ArrowConnectorFactory.java | 2 +- .../plugin/arrow/ArrowPageSource.java | 486 +---------------- .../facebook/plugin/arrow/ArrowPageUtils.java | 509 ++++++++++++++++++ .../plugin/arrow/ArrowPageUtilsTest.java | 198 +++++++ .../arrow/TestingArrowFlightRequest.java | 18 +- .../plugin/arrow/TestingArrowServer.java | 3 +- .../arrow/TestingConnectionProperties.java | 21 +- .../arrow/TestingInteractionProperties.java | 34 +- 8 files changed, 744 insertions(+), 527 deletions(-) create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java index f17317dd42b3a..e070b2c4624d6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -46,7 +46,7 @@ public class ArrowConnectorFactory public ArrowConnectorFactory(String name, Module module, ClassLoader classLoader) { checkArgument(!isNullOrEmpty(name), "name is null or empty"); - this.name = name; + this.name = requireNonNull(name, "name is null"); this.module = requireNonNull(module, "module is null"); this.classLoader = requireNonNull(classLoader, "classLoader is null"); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index b83e791f2163e..13ca93f4599a6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -16,54 +16,18 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.CharType; -import com.facebook.presto.common.type.DateType; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.TimeType; -import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; -import com.google.common.base.CharMatcher; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Ticket; -import org.apache.arrow.vector.BigIntVector; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.DateDayVector; -import org.apache.arrow.vector.DateMilliVector; -import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.Float8Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.NullVector; -import org.apache.arrow.vector.SmallIntVector; -import org.apache.arrow.vector.TimeMicroVector; -import org.apache.arrow.vector.TimeMilliVector; -import org.apache.arrow.vector.TimeSecVector; -import org.apache.arrow.vector.TimeStampMicroVector; -import org.apache.arrow.vector.TimeStampMilliTZVector; -import org.apache.arrow.vector.TimeStampMilliVector; -import org.apache.arrow.vector.TimeStampSecVector; -import org.apache.arrow.vector.TinyIntVector; -import org.apache.arrow.vector.VarBinaryVector; -import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.LocalTime; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.concurrent.TimeUnit; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; @@ -100,454 +64,6 @@ private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] tick } } - private Block buildBlockFromVector(FieldVector vector, Type type) - { - if (vector instanceof BitVector) { - return buildBlockFromBitVector((BitVector) vector, type); - } - else if (vector instanceof TinyIntVector) { - return buildBlockFromTinyIntVector((TinyIntVector) vector, type); - } - else if (vector instanceof IntVector) { - return buildBlockFromIntVector((IntVector) vector, type); - } - else if (vector instanceof SmallIntVector) { - return buildBlockFromSmallIntVector((SmallIntVector) vector, type); - } - else if (vector instanceof BigIntVector) { - return buildBlockFromBigIntVector((BigIntVector) vector, type); - } - else if (vector instanceof DecimalVector) { - return buildBlockFromDecimalVector((DecimalVector) vector, type); - } - else if (vector instanceof NullVector) { - return buildBlockFromNullVector((NullVector) vector, type); - } - else if (vector instanceof TimeStampMicroVector) { - return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliVector) { - return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); - } - else if (vector instanceof Float4Vector) { - return buildBlockFromFloat4Vector((Float4Vector) vector, type); - } - else if (vector instanceof Float8Vector) { - return buildBlockFromFloat8Vector((Float8Vector) vector, type); - } - else if (vector instanceof VarCharVector) { - if (type instanceof CharType) { - return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else if (type instanceof TimeType) { - return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else { - return buildBlockFromVarCharVector((VarCharVector) vector, type); - } - } - else if (vector instanceof VarBinaryVector) { - return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); - } - else if (vector instanceof DateDayVector) { - return buildBlockFromDateDayVector((DateDayVector) vector, type); - } - else if (vector instanceof DateMilliVector) { - return buildBlockFromDateMilliVector((DateMilliVector) vector, type); - } - else if (vector instanceof TimeMilliVector) { - return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); - } - else if (vector instanceof TimeSecVector) { - return buildBlockFromTimeSecVector((TimeSecVector) vector, type); - } - else if (vector instanceof TimeStampSecVector) { - return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); - } - else if (vector instanceof TimeMicroVector) { - return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliTZVector) { - return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); - } - throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); - } - - private Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromBitVector(BitVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeBoolean(builder, vector.get(i) == 1); - } - } - return builder.build(); - } - - private Block buildBlockFromIntVector(IntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromBigIntVector(BigIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDecimalVector(DecimalVector vector, Type type) - { - if (!(type instanceof DecimalType)) { - throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); - } - - DecimalType decimalType = (DecimalType) type; - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value - if (decimalType.isShort()) { - builder.writeLong(decimal.unscaledValue().longValue()); - } - else { - Slice slice = Decimals.encodeScaledValue(decimal); - decimalType.writeSlice(builder, slice, 0, slice.length()); - } - } - } - return builder.build(); - } - - private Block buildBlockFromNullVector(NullVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - builder.appendNull(); - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long micros = vector.get(i); - long millis = TimeUnit.MICROSECONDS.toMillis(micros); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeDouble(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int intBits = Float.floatToIntBits(vector.get(i)); - type.writeLong(builder, intBits); - } - } - return builder.build(); - } - - private Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - byte[] value = vector.get(i); - type.writeSlice(builder, Slices.wrappedBuffer(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromVarCharVector(VarCharVector vector, Type type) - { - if (!(type instanceof VarcharType)) { - throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateDayVector(DateDayVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - DateType dateType = (DateType) type; - long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); - dateType.writeLong(builder, days); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); - } - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long micro = TimeUnit.MICROSECONDS.toMillis(value); - type.writeLong(builder, micro); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); - } - } - return builder.build(); - } - - private Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String timeString = new String(vector.get(i), StandardCharsets.UTF_8); - LocalTime time = LocalTime.parse(timeString); - long millis = Duration.between(LocalTime.MIN, time).toMillis(); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - @Override public long getCompletedBytes() { @@ -605,7 +121,7 @@ public Page getNextPage() FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); Type type = columnHandles.get(columnIndex).getColumnType(); - Block block = buildBlockFromVector(vector, type); + Block block = ArrowPageUtils.buildBlockFromVector(vector, type); blocks.add(block); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java new file mode 100644 index 0000000000000..071a3a302c491 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -0,0 +1,509 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.concurrent.TimeUnit; + +public class ArrowPageUtils +{ + private ArrowPageUtils() + { + } + static Block buildBlockFromVector(FieldVector vector, Type type) + { + if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); + } + + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + public static Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + public static Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java new file mode 100644 index 0000000000000..666f5fc2641cd --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TinyIntVector; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class ArrowPageUtilsTest +{ + private BufferAllocator allocator; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + System.out.println("Allocator initialized: " + allocator); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + BitVector bitVector = new BitVector("bitVector", allocator); + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + IntVector intVector = new IntVector("intVector", allocator); + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator); + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2); // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator); + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java index 2497a7c68ae3f..979383fa1c854 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -90,25 +90,11 @@ public byte[] getCommand() private TestingConnectionProperties getConnectionProperties() { - TestingConnectionProperties properties = new TestingConnectionProperties(); - properties.database = testconfig.getDataSourceDatabase(); - properties.host = testconfig.getDataSourceHost(); - properties.port = testconfig.getDataSourcePort(); - properties.username = testconfig.getDataSourceUsername(); - properties.password = testconfig.getDataSourcePassword(); - return properties; + return new TestingConnectionProperties(testconfig.getDataSourceDatabase(), testconfig.getDataSourcePassword(), testconfig.getDataSourceHost(), testconfig.getDataSourceSSL(), testconfig.getDataSourceUsername()); } private TestingInteractionProperties createInteractionProperties() { - TestingInteractionProperties interactionProperties = new TestingInteractionProperties(); - if (getQuery().isPresent()) { - interactionProperties.setSelectStatement(getQuery().get()); - } - else { - interactionProperties.setSchema(getSchema()); - interactionProperties.setTable(getTable()); - } - return interactionProperties; + return getQuery().isPresent() ? new TestingInteractionProperties(getQuery().get(), getSchema(), getTable()) : new TestingInteractionProperties(null, getSchema(), getTable()); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java index b3124d0425913..141ab4a9d24cd 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java @@ -147,6 +147,7 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen @Override public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) { + throw new UnsupportedOperationException("This operation is not supported"); } @Override @@ -222,7 +223,7 @@ else if (selectStatement != null) { @Override public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { - return null; + throw new UnsupportedOperationException("This operation is not supported"); } @Override diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java index 7f265036eeac0..707283ba5bf55 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -13,12 +13,21 @@ */ package com.facebook.plugin.arrow; -public class TestingConnectionProperties +public final class TestingConnectionProperties { - public String database; - public String password; + public final String database; + public final String password; public Integer port; - public String host; - public Boolean ssl; - public String username; + public final String host; + public final Boolean ssl; + public final String username; + + public TestingConnectionProperties(String database, String password, String host, Boolean ssl, String username) + { + this.database = database; + this.password = password; + this.host = host; + this.ssl = ssl; + this.username = username; + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java index 81b4ee5f7e811..8fd06660a7515 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java @@ -15,15 +15,26 @@ import com.fasterxml.jackson.annotation.JsonProperty; -public class TestingInteractionProperties +public final class TestingInteractionProperties { @JsonProperty("select_statement") - private String selectStatement; + private final String selectStatement; @JsonProperty("schema_name") - private String schema; + private final String schema; + @JsonProperty("table_name") - private String table; + private final String table; + + // Constructor to initialize the fields + public TestingInteractionProperties(String selectStatement, String schema, String table) + { + this.selectStatement = selectStatement; + this.schema = schema; + this.table = table; + } + + // Getters (no setters as the fields are final and immutable) public String getSelectStatement() { return selectStatement; @@ -39,18 +50,5 @@ public String getTable() return table; } - public void setSchema(String schema) - { - this.schema = schema; - } - - public void setSelectStatement(String selectStatement) - { - this.selectStatement = selectStatement; - } - - public void setTable(String table) - { - this.table = table; - } + // No setters as the class is immutable } From ada5038bed5d8d5d6479807db585599d80e45127 Mon Sep 17 00:00:00 2001 From: Elbin Pallimalil Date: Thu, 21 Nov 2024 10:00:52 +0530 Subject: [PATCH 07/39] Use flight descriptor instead of ArrowFlightRequest --- .../plugin/arrow/AbstractArrowMetadata.java | 9 ++++----- .../arrow/AbstractArrowSplitManager.java | 7 ++++--- .../arrow/ArrowFlightClientHandler.java | 16 ++-------------- .../plugin/arrow/ArrowFlightRequest.java | 19 ------------------- .../arrow/TestingArrowFlightRequest.java | 2 -- .../plugin/arrow/TestingArrowMetadata.java | 19 ++++++++++--------- .../arrow/TestingArrowSplitManager.java | 6 ++++-- 7 files changed, 24 insertions(+), 54 deletions(-) delete mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 05c2cc604104f..19ab1cf4ae9e2 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -130,9 +131,7 @@ protected Type getPrestoTypeFromArrowField(Field field) } } - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table); - - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema); + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table); protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); @@ -156,10 +155,10 @@ public List getColumnsList(String schema, String table, ConnectorSession try { String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), Optional.empty(), + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, connectorSession); + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, connectorSession); List fields = flightInfo.getSchema().getFields(); return fields; } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java index 49d4fe91fc171..e92afe8b0a4b8 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import java.util.List; @@ -37,17 +38,17 @@ public AbstractArrowSplitManager(ArrowFlightClientHandler client) this.clientHandler = client; } - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) { ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), tableLayoutHandle); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, session); + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, session); List splits = flightInfo.getEndpoints() .stream() .map(info -> new ArrowSplit( diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index ba75cbffc088c..e38f87e05de8c 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -76,17 +76,6 @@ else if (config.getFlightServerSSLCertificate() != null) { protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); - /** - * Connector implementations can override this method to get a FlightDescriptor - * from command or path. - * @param flightRequest - * @return - */ - protected FlightDescriptor getFlightDescriptor(ArrowFlightRequest flightRequest) - { - return FlightDescriptor.command(flightRequest.getCommand()); - } - public ArrowFlightConfig getConfig() { return config; @@ -97,13 +86,12 @@ public ArrowFlightClient getClient(Optional uri) return initializeClient(uri); } - public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession connectorSession) + public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) { try (ArrowFlightClient client = getClient(Optional.empty())) { CredentialCallOption auth = this.getCallOptions(connectorSession); - FlightDescriptor descriptor = getFlightDescriptor(request); logger.debug("Fetching flight info"); - FlightInfo flightInfo = client.getFlightClient().getInfo(descriptor, auth); + FlightInfo flightInfo = client.getFlightClient().getInfo(flightDescriptor, auth); logger.debug("got flight info"); return flightInfo; } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java deleted file mode 100644 index 7e04e0a6066e3..0000000000000 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.plugin.arrow; - -public interface ArrowFlightRequest -{ - byte[] getCommand(); -} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java index 2497a7c68ae3f..70314bede4714 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -21,7 +21,6 @@ import java.util.Optional; public class TestingArrowFlightRequest - implements ArrowFlightRequest { private final String schema; private final String table; @@ -74,7 +73,6 @@ public TestingRequestData build() return requestData; } - @Override public byte[] getCommand() { ObjectMapper objectMapper = new ObjectMapper(); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index 3a649808186e0..0425981cded2a 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -26,6 +26,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.ImmutableList; import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.Result; import org.apache.arrow.vector.types.pojo.Field; @@ -60,12 +61,6 @@ public TestingArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler c this.config = config; } - @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) - { - return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); - } - @Override public List listSchemaNames(ConnectorSession session) { @@ -95,7 +90,7 @@ public List extractSchemaAndTableData(Optional schema, Connector { try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { List names = new ArrayList<>(); - ArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); + TestingArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); @@ -150,8 +145,14 @@ protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table) + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table) { - return new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); + } + + private TestingArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) + { + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java index 5206c65717933..876c8c67ae032 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.NodeManager; import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; import javax.inject.Inject; @@ -36,13 +37,14 @@ public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFli } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) { ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); Optional query = Optional.of(new TestingArrowQueryBuilder().buildSql(tableHandle.getSchema(), tableHandle.getTable(), tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), tableLayoutHandle.getTupleDomain())); - return new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); } } From d3b8063dc2e9aae27f9adecef86cc363cc582e7a Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 12:14:24 +0530 Subject: [PATCH 08/39] Arrow page utils changes --- .../facebook/plugin/arrow/ArrowPageUtils.java | 154 ++++++++++++++++++ .../plugin/arrow/ArrowPageUtilsTest.java | 42 +++++ 2 files changed, 196 insertions(+) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 071a3a302c491..916e7cadb43fa 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -15,12 +15,21 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.DateType; import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; @@ -47,11 +56,15 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListReader; +import org.apache.arrow.vector.util.JsonStringArrayList; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.LocalTime; +import java.util.List; import java.util.concurrent.TimeUnit; public class ArrowPageUtils @@ -506,4 +519,145 @@ public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Ty } return builder.build(); } + + public static Block buildBlockFromListVector(ListVector vector, Type type) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + arrayBuilder.appendNull(); + } else { + BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); + UnionListReader reader = vector.getReader(); + reader.setPosition(i); + + while (reader.next()) { + Object value = reader.readObject(); + if (value == null) { + elementBuilder.appendNull(); + } else { + appendValueToBuilder(elementType, elementBuilder, value); + } + } + arrayBuilder.closeEntry(); + } + } + return arrayBuilder.build(); + } + + private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { + if (value == null) { + builder.appendNull(); + return; + } + + if (type instanceof VarcharType) { + // Convert value to string and write as Varchar + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } else if (type instanceof BigintType) { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } else if (type instanceof IntegerType) { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } else if (type instanceof DoubleType) { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } else if (type instanceof BooleanType) { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (decimalType.isShort()) { + builder.writeLong(decimalValue.unscaledValue().longValue()); + } else { + Slice slice = Decimals.encodeScaledValue(decimalValue); + decimalType.writeSlice(builder, slice); + } + } else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + } + } else if (type instanceof ArrayType) { + // Handling array types (lists) + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } else if (type instanceof RowType) { + // Handling row types (structs) + RowType rowType = (RowType) type; + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = rowType.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + + } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 666f5fc2641cd..d41e95d372eae 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -14,6 +14,7 @@ package com.facebook.plugin.arrow; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.DecimalType; @@ -30,6 +31,9 @@ import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TimeStampMicroVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -195,4 +199,42 @@ public void testBuildBlockFromTimeStampMicroVector() assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) } + + @Test + public void testBuildBlockFromListVector() { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); + + // Validate the result + Assert.assertEquals(block.getPositionCount(), 4); // 4 lists in the block + + } + } + } From 72e2f07bddaa42976fc00c445a94d95e2d6f75fc Mon Sep 17 00:00:00 2001 From: Elbin Pallimalil Date: Thu, 21 Nov 2024 10:00:52 +0530 Subject: [PATCH 09/39] Use flight descriptor instead of ArrowFlightRequest --- .../plugin/arrow/AbstractArrowMetadata.java | 9 ++++----- .../arrow/AbstractArrowSplitManager.java | 7 ++++--- .../arrow/ArrowFlightClientHandler.java | 16 ++-------------- .../plugin/arrow/ArrowFlightRequest.java | 19 ------------------- .../arrow/TestingArrowFlightRequest.java | 2 -- .../plugin/arrow/TestingArrowMetadata.java | 19 ++++++++++--------- .../arrow/TestingArrowSplitManager.java | 6 ++++-- 7 files changed, 24 insertions(+), 54 deletions(-) delete mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 05c2cc604104f..19ab1cf4ae9e2 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -130,9 +131,7 @@ protected Type getPrestoTypeFromArrowField(Field field) } } - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table); - - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema); + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table); protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); @@ -156,10 +155,10 @@ public List getColumnsList(String schema, String table, ConnectorSession try { String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), Optional.empty(), + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, connectorSession); + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, connectorSession); List fields = flightInfo.getSchema().getFields(); return fields; } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java index 49d4fe91fc171..e92afe8b0a4b8 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.FixedSplitSource; import com.facebook.presto.spi.connector.ConnectorSplitManager; import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import java.util.List; @@ -37,17 +38,17 @@ public AbstractArrowSplitManager(ArrowFlightClientHandler client) this.clientHandler = client; } - protected abstract ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); + protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); @Override public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) { ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); - ArrowFlightRequest request = getArrowFlightRequest(clientHandler.getConfig(), + FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), tableLayoutHandle); - FlightInfo flightInfo = clientHandler.getFlightInfo(request, session); + FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, session); List splits = flightInfo.getEndpoints() .stream() .map(info -> new ArrowSplit( diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index ba75cbffc088c..e38f87e05de8c 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -76,17 +76,6 @@ else if (config.getFlightServerSSLCertificate() != null) { protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); - /** - * Connector implementations can override this method to get a FlightDescriptor - * from command or path. - * @param flightRequest - * @return - */ - protected FlightDescriptor getFlightDescriptor(ArrowFlightRequest flightRequest) - { - return FlightDescriptor.command(flightRequest.getCommand()); - } - public ArrowFlightConfig getConfig() { return config; @@ -97,13 +86,12 @@ public ArrowFlightClient getClient(Optional uri) return initializeClient(uri); } - public FlightInfo getFlightInfo(ArrowFlightRequest request, ConnectorSession connectorSession) + public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) { try (ArrowFlightClient client = getClient(Optional.empty())) { CredentialCallOption auth = this.getCallOptions(connectorSession); - FlightDescriptor descriptor = getFlightDescriptor(request); logger.debug("Fetching flight info"); - FlightInfo flightInfo = client.getFlightClient().getInfo(descriptor, auth); + FlightInfo flightInfo = client.getFlightClient().getInfo(flightDescriptor, auth); logger.debug("got flight info"); return flightInfo; } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java deleted file mode 100644 index 7e04e0a6066e3..0000000000000 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightRequest.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.plugin.arrow; - -public interface ArrowFlightRequest -{ - byte[] getCommand(); -} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java index 979383fa1c854..dd019a7689cec 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -21,7 +21,6 @@ import java.util.Optional; public class TestingArrowFlightRequest - implements ArrowFlightRequest { private final String schema; private final String table; @@ -74,7 +73,6 @@ public TestingRequestData build() return requestData; } - @Override public byte[] getCommand() { ObjectMapper objectMapper = new ObjectMapper(); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index 3a649808186e0..0425981cded2a 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -26,6 +26,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.collect.ImmutableList; import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.Result; import org.apache.arrow.vector.types.pojo.Field; @@ -60,12 +61,6 @@ public TestingArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler c this.config = config; } - @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) - { - return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); - } - @Override public List listSchemaNames(ConnectorSession session) { @@ -95,7 +90,7 @@ public List extractSchemaAndTableData(Optional schema, Connector { try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { List names = new ArrayList<>(); - ArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); + TestingArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); @@ -150,8 +145,14 @@ protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, Optional query, String schema, String table) + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table) { - return new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); + } + + private TestingArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) + { + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java index 5206c65717933..876c8c67ae032 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -15,6 +15,7 @@ import com.facebook.presto.spi.NodeManager; import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightDescriptor; import javax.inject.Inject; @@ -36,13 +37,14 @@ public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFli } @Override - protected ArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) + protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle) { ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); Optional query = Optional.of(new TestingArrowQueryBuilder().buildSql(tableHandle.getSchema(), tableHandle.getTable(), tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), tableLayoutHandle.getTupleDomain())); - return new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, tableHandle.getSchema(), tableHandle.getTable(), query, nodeManager.getWorkerNodes().size()); + return FlightDescriptor.command(request.getCommand()); } } From 780f0a32dd73609f9c2220572f09190a84798f2e Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 12:32:36 +0530 Subject: [PATCH 10/39] Arrow page utils changes - fixed checkstyle issues --- .../facebook/plugin/arrow/ArrowPageUtils.java | 85 ++++++++++++------- .../plugin/arrow/ArrowPageUtilsTest.java | 9 +- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 916e7cadb43fa..b367e21bce160 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -24,12 +24,9 @@ import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.DoubleType; import com.facebook.presto.common.type.IntegerType; -import com.facebook.presto.common.type.RealType; import com.facebook.presto.common.type.RowType; -import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; -import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; @@ -72,6 +69,7 @@ public class ArrowPageUtils private ArrowPageUtils() { } + static Block buildBlockFromVector(FieldVector vector, Type type) { if (vector instanceof BitVector) { @@ -142,6 +140,9 @@ else if (vector instanceof TimeMicroVector) { else if (vector instanceof TimeStampMilliTZVector) { return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); } + else if (vector instanceof ListVector) { + return buildBlockFromListVector((ListVector) vector, type); + } throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } @@ -533,7 +534,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) for (int i = 0; i < vector.getValueCount(); i++) { if (vector.isNull(i)) { arrayBuilder.appendNull(); - } else { + } + else { BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); UnionListReader reader = vector.getReader(); reader.setPosition(i); @@ -542,7 +544,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) Object value = reader.readObject(); if (value == null) { elementBuilder.appendNull(); - } else { + } + else { appendValueToBuilder(elementType, elementBuilder, value); } } @@ -552,7 +555,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) return arrayBuilder.build(); } - private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { + private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + { if (value == null) { builder.appendNull(); return; @@ -562,78 +566,98 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object // Convert value to string and write as Varchar Slice slice = Slices.utf8Slice(value.toString()); type.writeSlice(builder, slice); - } else if (type instanceof BigintType) { + } + else if (type instanceof BigintType) { if (value instanceof Long) { type.writeLong(builder, (Long) value); - } else if (value instanceof Integer) { + } + else if (value instanceof Integer) { type.writeLong(builder, ((Integer) value).longValue()); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { long longValue = Long.parseLong(obj.toString()); type.writeLong(builder, longValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); } - } else if (type instanceof IntegerType) { + } + else if (type instanceof IntegerType) { if (value instanceof Integer) { type.writeLong(builder, (Integer) value); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { int intValue = Integer.parseInt(obj.toString()); type.writeLong(builder, intValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); } - } else if (type instanceof DoubleType) { + } + else if (type instanceof DoubleType) { if (value instanceof Double) { type.writeDouble(builder, (Double) value); - } else if (value instanceof Float) { + } + else if (value instanceof Float) { type.writeDouble(builder, ((Float) value).doubleValue()); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { double doubleValue = Double.parseDouble(obj.toString()); type.writeDouble(builder, doubleValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); } - } else if (type instanceof BooleanType) { + } + else if (type instanceof BooleanType) { if (value instanceof Boolean) { type.writeBoolean(builder, (Boolean) value); - } else { + } + else { throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); } - } else if (type instanceof DecimalType) { + } + else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; if (value instanceof BigDecimal) { BigDecimal decimalValue = (BigDecimal) value; if (decimalType.isShort()) { builder.writeLong(decimalValue.unscaledValue().longValue()); - } else { + } + else { Slice slice = Decimals.encodeScaledValue(decimalValue); decimalType.writeSlice(builder, slice); } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); } - } else if (type instanceof ArrayType) { + } + else if (type instanceof ArrayType) { // Handling array types (lists) ArrayType arrayType = (ArrayType) type; Type elementType = arrayType.getElementType(); @@ -642,7 +666,8 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object appendValueToBuilder(elementType, arrayBuilder, element); } builder.closeEntry(); - } else if (type instanceof RowType) { + } + else if (type instanceof RowType) { // Handling row types (structs) RowType rowType = (RowType) type; List rowValues = (List) value; @@ -653,11 +678,9 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); } builder.closeEntry(); - } else { + } + else { throw new IllegalArgumentException("Unsupported type: " + type); } } - - - } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index d41e95d372eae..9fbc1b4255508 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -33,7 +33,6 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -201,11 +200,11 @@ public void testBuildBlockFromTimeStampMicroVector() } @Test - public void testBuildBlockFromListVector() { + public void testBuildBlockFromListVector() + { // Create a root allocator for Arrow vectors try (BufferAllocator allocator = new RootAllocator(); ListVector listVector = ListVector.empty("listVector", allocator)) { - // Allocate the vector and get the writer listVector.allocateNew(); UnionListWriter listWriter = listVector.getWriter(); @@ -232,9 +231,7 @@ public void testBuildBlockFromListVector() { Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); // Validate the result - Assert.assertEquals(block.getPositionCount(), 4); // 4 lists in the block - + assertEquals(block.getPositionCount(), 4); // 4 lists in the block } } - } From a2cb5c5e41f52b0e947ebda1c8133cc811bbb417 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 13:26:55 +0530 Subject: [PATCH 11/39] Review comment fixes - Root allocator and typos --- .../plugin/arrow/AbstractArrowMetadata.java | 13 ++++++------ .../facebook/plugin/arrow/ArrowConnector.java | 17 +++++++++++++--- .../plugin/arrow/ArrowFlightClient.java | 6 +----- .../arrow/ArrowFlightClientHandler.java | 20 +++++++++++++++++-- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 19ab1cf4ae9e2..8cb86394f5bac 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -44,11 +44,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -158,8 +159,8 @@ public List getColumnsList(String schema, String table, ConnectorSession FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); - FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, connectorSession); - List fields = flightInfo.getSchema().getFields(); + Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); + List fields = flightschema.map(schema1 -> schema1.getFields()).orElse(Collections.emptyList()); return fields; } catch (Exception e) { @@ -170,7 +171,7 @@ public List getColumnsList(String schema, String table, ConnectorSession @Override public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { - Map columns = new HashMap<>(); + Map columnHandles = new HashMap<>(); String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); String tableValue = ((ArrowTableHandle) tableHandle).getTable(); @@ -183,9 +184,9 @@ public Map getColumnHandles(ConnectorSession session, Conn logger.debug("The value of the flight columnName is:- %s", columnName); Type type = getPrestoTypeFromArrowField(field); - columns.put(columnName, new ArrowColumnHandle(columnName, type)); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); } - return columns; + return columnHandles; } @Override diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index 9523bd210faa5..cf25b90b9d670 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -34,16 +34,21 @@ public class ArrowConnector private final ConnectorPageSourceProvider pageSourceProvider; private final ConnectorHandleResolver handleResolver; + private final ArrowFlightClientHandler arrowFlightClientHandler; + @Inject public ArrowConnector(ConnectorMetadata metadata, - ConnectorHandleResolver handleResolver, - ConnectorSplitManager splitManager, - ConnectorPageSourceProvider pageSourceProvider) + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + + ArrowFlightClientHandler arrowFlightClientHandler) { this.metadata = requireNonNull(metadata, "Metadata is null"); this.handleResolver = requireNonNull(handleResolver, "Metadata is null"); this.splitManager = requireNonNull(splitManager, "SplitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); } public Optional getHandleResolver() @@ -74,4 +79,10 @@ public ConnectorPageSourceProvider getPageSourceProvider() { return pageSourceProvider; } + + @Override + public void shutdown() + { + arrowFlightClientHandler.closeRootallocator(); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java index fa3fabf7b7fb2..3d12617a0839d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java @@ -14,7 +14,6 @@ package com.facebook.plugin.arrow; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.memory.RootAllocator; import java.io.IOException; import java.io.InputStream; @@ -27,13 +26,11 @@ public class ArrowFlightClient { private final FlightClient flightClient; private final Optional trustedCertificate; - private RootAllocator allocator; - public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate, RootAllocator allocator) + public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate) { this.flightClient = requireNonNull(flightClient, "flightClient cannot be null"); this.trustedCertificate = requireNonNull(trustedCertificate, "trustedCertificate is null"); - this.allocator = requireNonNull(allocator, "allocator is null"); } public FlightClient getFlightClient() @@ -53,6 +50,5 @@ public void close() throws InterruptedException, IOException if (trustedCertificate.isPresent()) { trustedCertificate.get().close(); } - allocator.close(); } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index e38f87e05de8c..51704748d0677 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -21,6 +21,7 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; import java.io.FileInputStream; import java.io.InputStream; @@ -33,6 +34,8 @@ public abstract class ArrowFlightClientHandler private static final Logger logger = Logger.get(ArrowFlightClientHandler.class); private final ArrowFlightConfig config; + private RootAllocator allocator; + public ArrowFlightClientHandler(ArrowFlightConfig config) { this.config = config; @@ -41,7 +44,6 @@ public ArrowFlightClientHandler(ArrowFlightConfig config) private ArrowFlightClient initializeClient(Optional uri) { try { - RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); Optional trustedCertificate = Optional.empty(); Location location; @@ -57,6 +59,10 @@ private ArrowFlightClient initializeClient(Optional uri) } } + if (null == allocator) { + allocator = new RootAllocator(Long.MAX_VALUE); + } + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); if (config.getVerifyServer() != null && !config.getVerifyServer()) { flightClientBuilder.verifyServer(false); @@ -67,7 +73,7 @@ else if (config.getFlightServerSSLCertificate() != null) { } FlightClient flightClient = flightClientBuilder.build(); - return new ArrowFlightClient(flightClient, trustedCertificate, allocator); + return new ArrowFlightClient(flightClient, trustedCertificate); } catch (Exception ex) { throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex); @@ -99,4 +105,14 @@ public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSess throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); } } + + public Optional getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + return getFlightInfo(flightDescriptor, connectorSession).getSchemaOptional(); + } + + public void closeRootallocator() + { + allocator.close(); + } } From bc25a2e8b823a65d3434408b3b44d0f3e3cedc4b Mon Sep 17 00:00:00 2001 From: Elbin Pallimalil Date: Fri, 22 Nov 2024 16:12:05 +0530 Subject: [PATCH 12/39] Remove config getter from flight client handler --- .../com/facebook/plugin/arrow/AbstractArrowMetadata.java | 2 +- .../facebook/plugin/arrow/AbstractArrowSplitManager.java | 7 ++++--- .../facebook/plugin/arrow/ArrowFlightClientHandler.java | 5 ----- .../facebook/plugin/arrow/TestingArrowSplitManager.java | 4 ++-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 19ab1cf4ae9e2..a98a627f4c3bb 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -155,7 +155,7 @@ public List getColumnsList(String schema, String table, ConnectorSession try { String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); - FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), Optional.empty(), + FlightDescriptor flightDescriptor = getFlightDescriptor(config, Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, connectorSession); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java index e92afe8b0a4b8..b07adcf6315ce 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowSplitManager.java @@ -32,10 +32,12 @@ public abstract class AbstractArrowSplitManager { private static final Logger logger = Logger.get(AbstractArrowSplitManager.class); private final ArrowFlightClientHandler clientHandler; + private final ArrowFlightConfig config; - public AbstractArrowSplitManager(ArrowFlightClientHandler client) + public AbstractArrowSplitManager(ArrowFlightClientHandler client, ArrowFlightConfig config) { this.clientHandler = client; + this.config = config; } protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, ArrowTableLayoutHandle tableLayoutHandle); @@ -45,8 +47,7 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand { ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; ArrowTableHandle tableHandle = tableLayoutHandle.getTableHandle(); - FlightDescriptor flightDescriptor = getFlightDescriptor(clientHandler.getConfig(), - tableLayoutHandle); + FlightDescriptor flightDescriptor = getFlightDescriptor(config, tableLayoutHandle); FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, session); List splits = flightInfo.getEndpoints() diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index e38f87e05de8c..7db878774065f 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -76,11 +76,6 @@ else if (config.getFlightServerSSLCertificate() != null) { protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); - public ArrowFlightConfig getConfig() - { - return config; - } - public ArrowFlightClient getClient(Optional uri) { return initializeClient(uri); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java index 876c8c67ae032..ec8f7cd2e9da1 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -29,9 +29,9 @@ public class TestingArrowSplitManager private final NodeManager nodeManager; @Inject - public TestingArrowSplitManager(ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) + public TestingArrowSplitManager(ArrowFlightConfig config, ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) { - super(client); + super(client, config); this.testconfig = testconfig; this.nodeManager = nodeManager; } From 0b065771618c45483a23b12ca885b3d17449ca88 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 13:55:08 +0530 Subject: [PATCH 13/39] Arrow CI job --- .github/workflows/arrow-tests.yml | 79 +++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 .github/workflows/arrow-tests.yml diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml new file mode 100644 index 0000000000000..31d8f10508320 --- /dev/null +++ b/.github/workflows/arrow-tests.yml @@ -0,0 +1,79 @@ +name: arrow-tests + +on: + pull_request: + paths: + - 'presto-base-arrow-flight/**' # Trigger this workflow only when there are changes in the presto-base-arrow-flight module + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - 'presto-base-arrow-flight/**' + + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + - name: Populate maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl $(echo '${{ matrix.modules }}' | cut -d' ' -f1) + + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} From bcc101b04b70d478ba4d100549b235982d402128 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 14:19:02 +0530 Subject: [PATCH 14/39] added verison in property file --- presto-base-arrow-flight/pom.xml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index 47f83962cf255..ea4d3bef2dbee 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -16,6 +16,8 @@ 4.10.0 17.0.0 4.1.110.Final + 1.6.20 + 2.23.0 @@ -288,13 +290,13 @@ org.jetbrains.kotlin kotlin-stdlib-common - 1.6.20 + ${kotlin.version} com.google.errorprone error_prone_annotations - 2.23.0 + ${error_prone_annotations} From 9615f4c7f6590d72cd5b96773f22e4f14f7b1cfb Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 15:14:34 +0530 Subject: [PATCH 15/39] testing after removing hyphen --- .github/workflows/arrow-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index 31d8f10508320..c0ddb0049bcd3 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -1,4 +1,4 @@ -name: arrow-tests +name: arrow tests on: pull_request: From e8250f0da066a4443b69851bb63d3d620e9afde2 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 16:33:47 +0530 Subject: [PATCH 16/39] remove path to make sure CI is run during every build --- .github/workflows/arrow-tests.yml | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index c0ddb0049bcd3..b695e808d9d13 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -2,8 +2,6 @@ name: arrow tests on: pull_request: - paths: - - 'presto-base-arrow-flight/**' # Trigger this workflow only when there are changes in the presto-base-arrow-flight module env: CONTINUOUS_INTEGRATION: true @@ -26,7 +24,7 @@ jobs: with: filters: | codechange: - - 'presto-base-arrow-flight/**' + - 'presto-base-arrow-flight/**' # Only trigger if changes are in `presto-base-arrow-flight` test: runs-on: ubuntu-latest @@ -43,17 +41,20 @@ jobs: cancel-in-progress: true steps: + # Checkout the code only if there are changes in the relevant files - uses: actions/checkout@v4 if: needs.changes.outputs.codechange == 'true' with: show-progress: false + # Set up Java for the build environment - uses: actions/setup-java@v2 if: needs.changes.outputs.codechange == 'true' with: distribution: 'temurin' java-version: 8 + # Cache Maven dependencies to speed up the build - name: Cache local Maven repository if: needs.changes.outputs.codechange == 'true' id: cache-maven @@ -64,16 +65,19 @@ jobs: restore-keys: | ${{ runner.os }}-maven-2- - - name: Populate maven cache + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + # Install dependencies for the target module - name: Maven Install if: needs.changes.outputs.codechange == 'true' run: | export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" - ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl $(echo '${{ matrix.modules }}' | cut -d' ' -f1) + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + # Run Maven tests for the target module - name: Maven Tests if: needs.changes.outputs.codechange == 'true' run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} From 064b047dbc0e015348f82d60532b46ffae0c021b Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 18 Nov 2024 16:52:47 +0530 Subject: [PATCH 17/39] changed yaml to run CI on evry pull requests and every update except doc update --- .github/workflows/arrow-tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml index b695e808d9d13..83ad7c940122b 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-tests.yml @@ -24,8 +24,7 @@ jobs: with: filters: | codechange: - - 'presto-base-arrow-flight/**' # Only trigger if changes are in `presto-base-arrow-flight` - + - '!presto-docs/**' test: runs-on: ubuntu-latest needs: changes From 99d6a8fe86fa875db3deab6af7209a6ea90a2025 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Wed, 20 Nov 2024 12:15:54 +0530 Subject: [PATCH 18/39] review comment fixes --- .../plugin/arrow/ArrowConnectorFactory.java | 2 +- .../plugin/arrow/ArrowPageSource.java | 486 +---------------- .../facebook/plugin/arrow/ArrowPageUtils.java | 509 ++++++++++++++++++ .../plugin/arrow/ArrowPageUtilsTest.java | 198 +++++++ .../arrow/TestingArrowFlightRequest.java | 18 +- .../plugin/arrow/TestingArrowServer.java | 3 +- .../arrow/TestingConnectionProperties.java | 21 +- .../arrow/TestingInteractionProperties.java | 34 +- 8 files changed, 744 insertions(+), 527 deletions(-) create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java index f17317dd42b3a..e070b2c4624d6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -46,7 +46,7 @@ public class ArrowConnectorFactory public ArrowConnectorFactory(String name, Module module, ClassLoader classLoader) { checkArgument(!isNullOrEmpty(name), "name is null or empty"); - this.name = name; + this.name = requireNonNull(name, "name is null"); this.module = requireNonNull(module, "module is null"); this.classLoader = requireNonNull(classLoader, "classLoader is null"); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index b83e791f2163e..13ca93f4599a6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -16,54 +16,18 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.common.Page; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockBuilder; -import com.facebook.presto.common.type.CharType; -import com.facebook.presto.common.type.DateType; -import com.facebook.presto.common.type.DecimalType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.TimeType; -import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarcharType; import com.facebook.presto.spi.ConnectorPageSource; import com.facebook.presto.spi.ConnectorSession; -import com.google.common.base.CharMatcher; -import io.airlift.slice.Slice; -import io.airlift.slice.Slices; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Ticket; -import org.apache.arrow.vector.BigIntVector; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.DateDayVector; -import org.apache.arrow.vector.DateMilliVector; -import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.Float4Vector; -import org.apache.arrow.vector.Float8Vector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.NullVector; -import org.apache.arrow.vector.SmallIntVector; -import org.apache.arrow.vector.TimeMicroVector; -import org.apache.arrow.vector.TimeMilliVector; -import org.apache.arrow.vector.TimeSecVector; -import org.apache.arrow.vector.TimeStampMicroVector; -import org.apache.arrow.vector.TimeStampMilliTZVector; -import org.apache.arrow.vector.TimeStampMilliVector; -import org.apache.arrow.vector.TimeStampSecVector; -import org.apache.arrow.vector.TinyIntVector; -import org.apache.arrow.vector.VarBinaryVector; -import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.LocalTime; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.concurrent.TimeUnit; import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_ERROR; @@ -100,454 +64,6 @@ private void getFlightStream(ArrowFlightClientHandler clientHandler, byte[] tick } } - private Block buildBlockFromVector(FieldVector vector, Type type) - { - if (vector instanceof BitVector) { - return buildBlockFromBitVector((BitVector) vector, type); - } - else if (vector instanceof TinyIntVector) { - return buildBlockFromTinyIntVector((TinyIntVector) vector, type); - } - else if (vector instanceof IntVector) { - return buildBlockFromIntVector((IntVector) vector, type); - } - else if (vector instanceof SmallIntVector) { - return buildBlockFromSmallIntVector((SmallIntVector) vector, type); - } - else if (vector instanceof BigIntVector) { - return buildBlockFromBigIntVector((BigIntVector) vector, type); - } - else if (vector instanceof DecimalVector) { - return buildBlockFromDecimalVector((DecimalVector) vector, type); - } - else if (vector instanceof NullVector) { - return buildBlockFromNullVector((NullVector) vector, type); - } - else if (vector instanceof TimeStampMicroVector) { - return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliVector) { - return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); - } - else if (vector instanceof Float4Vector) { - return buildBlockFromFloat4Vector((Float4Vector) vector, type); - } - else if (vector instanceof Float8Vector) { - return buildBlockFromFloat8Vector((Float8Vector) vector, type); - } - else if (vector instanceof VarCharVector) { - if (type instanceof CharType) { - return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else if (type instanceof TimeType) { - return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); - } - else { - return buildBlockFromVarCharVector((VarCharVector) vector, type); - } - } - else if (vector instanceof VarBinaryVector) { - return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); - } - else if (vector instanceof DateDayVector) { - return buildBlockFromDateDayVector((DateDayVector) vector, type); - } - else if (vector instanceof DateMilliVector) { - return buildBlockFromDateMilliVector((DateMilliVector) vector, type); - } - else if (vector instanceof TimeMilliVector) { - return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); - } - else if (vector instanceof TimeSecVector) { - return buildBlockFromTimeSecVector((TimeSecVector) vector, type); - } - else if (vector instanceof TimeStampSecVector) { - return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); - } - else if (vector instanceof TimeMicroVector) { - return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); - } - else if (vector instanceof TimeStampMilliTZVector) { - return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); - } - throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); - } - - private Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromBitVector(BitVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeBoolean(builder, vector.get(i) == 1); - } - } - return builder.build(); - } - - private Block buildBlockFromIntVector(IntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromBigIntVector(BigIntVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDecimalVector(DecimalVector vector, Type type) - { - if (!(type instanceof DecimalType)) { - throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); - } - - DecimalType decimalType = (DecimalType) type; - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value - if (decimalType.isShort()) { - builder.writeLong(decimal.unscaledValue().longValue()); - } - else { - Slice slice = Decimals.encodeScaledValue(decimal); - decimalType.writeSlice(builder, slice, 0, slice.length()); - } - } - } - return builder.build(); - } - - private Block buildBlockFromNullVector(NullVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - builder.appendNull(); - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long micros = vector.get(i); - long millis = TimeUnit.MICROSECONDS.toMillis(micros); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeDouble(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int intBits = Float.floatToIntBits(vector.get(i)); - type.writeLong(builder, intBits); - } - } - return builder.build(); - } - - private Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - byte[] value = vector.get(i); - type.writeSlice(builder, Slices.wrappedBuffer(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromVarCharVector(VarCharVector vector, Type type) - { - if (!(type instanceof VarcharType)) { - throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(value)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateDayVector(DateDayVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - type.writeLong(builder, vector.get(i)); - } - } - return builder.build(); - } - - private Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) - { - if (!(type instanceof DateType)) { - throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - DateType dateType = (DateType) type; - long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); - dateType.writeLong(builder, days); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - int value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long millis = vector.get(i); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) - { - if (!(type instanceof TimeType)) { - throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); - } - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long micro = TimeUnit.MICROSECONDS.toMillis(value); - type.writeLong(builder, micro); - } - } - return builder.build(); - } - - private Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) - { - if (!(type instanceof TimestampType)) { - throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); - } - - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - long value = vector.get(i); - long millis = TimeUnit.SECONDS.toMillis(value); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - - private Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String value = new String(vector.get(i), StandardCharsets.UTF_8); - type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); - } - } - return builder.build(); - } - - private Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) - { - BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); - for (int i = 0; i < vector.getValueCount(); i++) { - if (vector.isNull(i)) { - builder.appendNull(); - } - else { - String timeString = new String(vector.get(i), StandardCharsets.UTF_8); - LocalTime time = LocalTime.parse(timeString); - long millis = Duration.between(LocalTime.MIN, time).toMillis(); - type.writeLong(builder, millis); - } - } - return builder.build(); - } - @Override public long getCompletedBytes() { @@ -605,7 +121,7 @@ public Page getNextPage() FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); Type type = columnHandles.get(columnIndex).getColumnType(); - Block block = buildBlockFromVector(vector, type); + Block block = ArrowPageUtils.buildBlockFromVector(vector, type); blocks.add(block); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java new file mode 100644 index 0000000000000..071a3a302c491 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -0,0 +1,509 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.concurrent.TimeUnit; + +public class ArrowPageUtils +{ + private ArrowPageUtils() + { + } + static Block buildBlockFromVector(FieldVector vector, Type type) + { + if (vector instanceof BitVector) { + return buildBlockFromBitVector((BitVector) vector, type); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, type); + } + else if (vector instanceof IntVector) { + return buildBlockFromIntVector((IntVector) vector, type); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, type); + } + else if (vector instanceof BigIntVector) { + return buildBlockFromBigIntVector((BigIntVector) vector, type); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, type); + } + else if (vector instanceof NullVector) { + return buildBlockFromNullVector((NullVector) vector, type); + } + else if (vector instanceof TimeStampMicroVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliVector) { + return buildBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type); + } + else if (vector instanceof Float4Vector) { + return buildBlockFromFloat4Vector((Float4Vector) vector, type); + } + else if (vector instanceof Float8Vector) { + return buildBlockFromFloat8Vector((Float8Vector) vector, type); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + return buildCharTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else if (type instanceof TimeType) { + return buildTimeTypeBlockFromVarcharVector((VarCharVector) vector, type); + } + else { + return buildBlockFromVarCharVector((VarCharVector) vector, type); + } + } + else if (vector instanceof VarBinaryVector) { + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, type); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, type); + } + else if (vector instanceof DateMilliVector) { + return buildBlockFromDateMilliVector((DateMilliVector) vector, type); + } + else if (vector instanceof TimeMilliVector) { + return buildBlockFromTimeMilliVector((TimeMilliVector) vector, type); + } + else if (vector instanceof TimeSecVector) { + return buildBlockFromTimeSecVector((TimeSecVector) vector, type); + } + else if (vector instanceof TimeStampSecVector) { + return buildBlockFromTimeStampSecVector((TimeStampSecVector) vector, type); + } + else if (vector instanceof TimeMicroVector) { + return buildBlockFromTimeMicroVector((TimeMicroVector) vector, type); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); + } + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); + } + + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromBitVector(BitVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + return builder.build(); + } + + public static Block buildBlockFromIntVector(IntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromSmallIntVector(SmallIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromTinyIntVector(TinyIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromBigIntVector(BigIntVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDecimalVector(DecimalVector vector, Type type) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + return builder.build(); + } + + public static Block buildBlockFromNullVector(NullVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + builder.appendNull(); + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat8Vector(Float8Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromFloat4Vector(Float4Vector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarBinaryVector(VarBinaryVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromVarCharVector(VarCharVector vector, Type type) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(value)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateDayVector(DateDayVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + return builder.build(); + } + + public static Block buildBlockFromDateMilliVector(DateMilliVector vector, Type type) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeSecVector(TimeSecVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMilliVector(TimeMilliVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type type) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + return builder.build(); + } + + public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + return builder.build(); + } + + public static Block buildCharTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + return builder.build(); + } + + public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type) + { + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + return builder.build(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java new file mode 100644 index 0000000000000..666f5fc2641cd --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TinyIntVector; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.math.BigDecimal; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class ArrowPageUtilsTest +{ + private BufferAllocator allocator; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + System.out.println("Allocator initialized: " + allocator); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + BitVector bitVector = new BitVector("bitVector", allocator); + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBitVector(bitVector, BooleanType.BOOLEAN); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTinyIntVector(tinyIntVector, TinyintType.TINYINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromSmallIntVector(smallIntVector, SmallintType.SMALLINT); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + IntVector intVector = new IntVector("intVector", allocator); + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromIntVector(intVector, IntegerType.INTEGER); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator); + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromBigIntVector(bigIntVector, BigintType.BIGINT); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2); // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromDecimalVector(decimalVector, DecimalType.createDecimalType(10, 2)); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator); + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = ArrowPageUtils.buildBlockFromTimeStampMicroVector(timestampMicroVector, TimestampType.TIMESTAMP); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java index 70314bede4714..dd019a7689cec 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowFlightRequest.java @@ -88,25 +88,11 @@ public byte[] getCommand() private TestingConnectionProperties getConnectionProperties() { - TestingConnectionProperties properties = new TestingConnectionProperties(); - properties.database = testconfig.getDataSourceDatabase(); - properties.host = testconfig.getDataSourceHost(); - properties.port = testconfig.getDataSourcePort(); - properties.username = testconfig.getDataSourceUsername(); - properties.password = testconfig.getDataSourcePassword(); - return properties; + return new TestingConnectionProperties(testconfig.getDataSourceDatabase(), testconfig.getDataSourcePassword(), testconfig.getDataSourceHost(), testconfig.getDataSourceSSL(), testconfig.getDataSourceUsername()); } private TestingInteractionProperties createInteractionProperties() { - TestingInteractionProperties interactionProperties = new TestingInteractionProperties(); - if (getQuery().isPresent()) { - interactionProperties.setSelectStatement(getQuery().get()); - } - else { - interactionProperties.setSchema(getSchema()); - interactionProperties.setTable(getTable()); - } - return interactionProperties; + return getQuery().isPresent() ? new TestingInteractionProperties(getQuery().get(), getSchema(), getTable()) : new TestingInteractionProperties(null, getSchema(), getTable()); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java index b3124d0425913..141ab4a9d24cd 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowServer.java @@ -147,6 +147,7 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen @Override public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) { + throw new UnsupportedOperationException("This operation is not supported"); } @Override @@ -222,7 +223,7 @@ else if (selectStatement != null) { @Override public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { - return null; + throw new UnsupportedOperationException("This operation is not supported"); } @Override diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java index 7f265036eeac0..707283ba5bf55 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -13,12 +13,21 @@ */ package com.facebook.plugin.arrow; -public class TestingConnectionProperties +public final class TestingConnectionProperties { - public String database; - public String password; + public final String database; + public final String password; public Integer port; - public String host; - public Boolean ssl; - public String username; + public final String host; + public final Boolean ssl; + public final String username; + + public TestingConnectionProperties(String database, String password, String host, Boolean ssl, String username) + { + this.database = database; + this.password = password; + this.host = host; + this.ssl = ssl; + this.username = username; + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java index 81b4ee5f7e811..8fd06660a7515 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingInteractionProperties.java @@ -15,15 +15,26 @@ import com.fasterxml.jackson.annotation.JsonProperty; -public class TestingInteractionProperties +public final class TestingInteractionProperties { @JsonProperty("select_statement") - private String selectStatement; + private final String selectStatement; @JsonProperty("schema_name") - private String schema; + private final String schema; + @JsonProperty("table_name") - private String table; + private final String table; + + // Constructor to initialize the fields + public TestingInteractionProperties(String selectStatement, String schema, String table) + { + this.selectStatement = selectStatement; + this.schema = schema; + this.table = table; + } + + // Getters (no setters as the fields are final and immutable) public String getSelectStatement() { return selectStatement; @@ -39,18 +50,5 @@ public String getTable() return table; } - public void setSchema(String schema) - { - this.schema = schema; - } - - public void setSelectStatement(String selectStatement) - { - this.selectStatement = selectStatement; - } - - public void setTable(String table) - { - this.table = table; - } + // No setters as the class is immutable } From 45f069a74980aa6a108dcc98d22a435351609905 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 12:14:24 +0530 Subject: [PATCH 19/39] Arrow page utils changes --- .../facebook/plugin/arrow/ArrowPageUtils.java | 154 ++++++++++++++++++ .../plugin/arrow/ArrowPageUtilsTest.java | 42 +++++ 2 files changed, 196 insertions(+) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 071a3a302c491..916e7cadb43fa 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -15,12 +15,21 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.CharType; import com.facebook.presto.common.type.DateType; import com.facebook.presto.common.type.DecimalType; import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; @@ -47,11 +56,15 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListReader; +import org.apache.arrow.vector.util.JsonStringArrayList; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.LocalTime; +import java.util.List; import java.util.concurrent.TimeUnit; public class ArrowPageUtils @@ -506,4 +519,145 @@ public static Block buildTimeTypeBlockFromVarcharVector(VarCharVector vector, Ty } return builder.build(); } + + public static Block buildBlockFromListVector(ListVector vector, Type type) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = type.createBlockBuilder(null, vector.getValueCount()); + + for (int i = 0; i < vector.getValueCount(); i++) { + if (vector.isNull(i)) { + arrayBuilder.appendNull(); + } else { + BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); + UnionListReader reader = vector.getReader(); + reader.setPosition(i); + + while (reader.next()) { + Object value = reader.readObject(); + if (value == null) { + elementBuilder.appendNull(); + } else { + appendValueToBuilder(elementType, elementBuilder, value); + } + } + arrayBuilder.closeEntry(); + } + } + return arrayBuilder.build(); + } + + private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { + if (value == null) { + builder.appendNull(); + return; + } + + if (type instanceof VarcharType) { + // Convert value to string and write as Varchar + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } else if (type instanceof BigintType) { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } else if (type instanceof IntegerType) { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } else if (type instanceof DoubleType) { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } else if (value instanceof JsonStringArrayList) { + JsonStringArrayList list = (JsonStringArrayList) value; + for (Object obj : list) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } + } + } else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } else if (type instanceof BooleanType) { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (decimalType.isShort()) { + builder.writeLong(decimalValue.unscaledValue().longValue()); + } else { + Slice slice = Decimals.encodeScaledValue(decimalValue); + decimalType.writeSlice(builder, slice); + } + } else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + } + } else if (type instanceof ArrayType) { + // Handling array types (lists) + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } else if (type instanceof RowType) { + // Handling row types (structs) + RowType rowType = (RowType) type; + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = rowType.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + + } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 666f5fc2641cd..d41e95d372eae 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -14,6 +14,7 @@ package com.facebook.plugin.arrow; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; import com.facebook.presto.common.type.DecimalType; @@ -30,6 +31,9 @@ import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TimeStampMicroVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -195,4 +199,42 @@ public void testBuildBlockFromTimeStampMicroVector() assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) } + + @Test + public void testBuildBlockFromListVector() { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); + + // Validate the result + Assert.assertEquals(block.getPositionCount(), 4); // 4 lists in the block + + } + } + } From ef7b73ab2e65ebc792d4c3dce5d925baaa212fa3 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 12:32:36 +0530 Subject: [PATCH 20/39] Arrow page utils changes - fixed checkstyle issues --- .../facebook/plugin/arrow/ArrowPageUtils.java | 85 ++++++++++++------- .../plugin/arrow/ArrowPageUtilsTest.java | 9 +- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 916e7cadb43fa..b367e21bce160 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -24,12 +24,9 @@ import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.DoubleType; import com.facebook.presto.common.type.IntegerType; -import com.facebook.presto.common.type.RealType; import com.facebook.presto.common.type.RowType; -import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; -import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; @@ -72,6 +69,7 @@ public class ArrowPageUtils private ArrowPageUtils() { } + static Block buildBlockFromVector(FieldVector vector, Type type) { if (vector instanceof BitVector) { @@ -142,6 +140,9 @@ else if (vector instanceof TimeMicroVector) { else if (vector instanceof TimeStampMilliTZVector) { return buildBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type); } + else if (vector instanceof ListVector) { + return buildBlockFromListVector((ListVector) vector, type); + } throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } @@ -533,7 +534,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) for (int i = 0; i < vector.getValueCount(); i++) { if (vector.isNull(i)) { arrayBuilder.appendNull(); - } else { + } + else { BlockBuilder elementBuilder = arrayBuilder.beginBlockEntry(); UnionListReader reader = vector.getReader(); reader.setPosition(i); @@ -542,7 +544,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) Object value = reader.readObject(); if (value == null) { elementBuilder.appendNull(); - } else { + } + else { appendValueToBuilder(elementType, elementBuilder, value); } } @@ -552,7 +555,8 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) return arrayBuilder.build(); } - private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { + private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + { if (value == null) { builder.appendNull(); return; @@ -562,78 +566,98 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object // Convert value to string and write as Varchar Slice slice = Slices.utf8Slice(value.toString()); type.writeSlice(builder, slice); - } else if (type instanceof BigintType) { + } + else if (type instanceof BigintType) { if (value instanceof Long) { type.writeLong(builder, (Long) value); - } else if (value instanceof Integer) { + } + else if (value instanceof Integer) { type.writeLong(builder, ((Integer) value).longValue()); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { long longValue = Long.parseLong(obj.toString()); type.writeLong(builder, longValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); } - } else if (type instanceof IntegerType) { + } + else if (type instanceof IntegerType) { if (value instanceof Integer) { type.writeLong(builder, (Integer) value); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { int intValue = Integer.parseInt(obj.toString()); type.writeLong(builder, intValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); } - } else if (type instanceof DoubleType) { + } + else if (type instanceof DoubleType) { if (value instanceof Double) { type.writeDouble(builder, (Double) value); - } else if (value instanceof Float) { + } + else if (value instanceof Float) { type.writeDouble(builder, ((Float) value).doubleValue()); - } else if (value instanceof JsonStringArrayList) { + } + else if (value instanceof JsonStringArrayList) { JsonStringArrayList list = (JsonStringArrayList) value; for (Object obj : list) { try { double doubleValue = Double.parseDouble(obj.toString()); type.writeDouble(builder, doubleValue); - } catch (NumberFormatException e) { + } + catch (NumberFormatException e) { throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); } - } else if (type instanceof BooleanType) { + } + else if (type instanceof BooleanType) { if (value instanceof Boolean) { type.writeBoolean(builder, (Boolean) value); - } else { + } + else { throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); } - } else if (type instanceof DecimalType) { + } + else if (type instanceof DecimalType) { DecimalType decimalType = (DecimalType) type; if (value instanceof BigDecimal) { BigDecimal decimalValue = (BigDecimal) value; if (decimalType.isShort()) { builder.writeLong(decimalValue.unscaledValue().longValue()); - } else { + } + else { Slice slice = Decimals.encodeScaledValue(decimalValue); decimalType.writeSlice(builder, slice); } - } else { + } + else { throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); } - } else if (type instanceof ArrayType) { + } + else if (type instanceof ArrayType) { // Handling array types (lists) ArrayType arrayType = (ArrayType) type; Type elementType = arrayType.getElementType(); @@ -642,7 +666,8 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object appendValueToBuilder(elementType, arrayBuilder, element); } builder.closeEntry(); - } else if (type instanceof RowType) { + } + else if (type instanceof RowType) { // Handling row types (structs) RowType rowType = (RowType) type; List rowValues = (List) value; @@ -653,11 +678,9 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); } builder.closeEntry(); - } else { + } + else { throw new IllegalArgumentException("Unsupported type: " + type); } } - - - } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index d41e95d372eae..9fbc1b4255508 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -33,7 +33,6 @@ import org.apache.arrow.vector.TinyIntVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -201,11 +200,11 @@ public void testBuildBlockFromTimeStampMicroVector() } @Test - public void testBuildBlockFromListVector() { + public void testBuildBlockFromListVector() + { // Create a root allocator for Arrow vectors try (BufferAllocator allocator = new RootAllocator(); ListVector listVector = ListVector.empty("listVector", allocator)) { - // Allocate the vector and get the writer listVector.allocateNew(); UnionListWriter listWriter = listVector.getWriter(); @@ -232,9 +231,7 @@ public void testBuildBlockFromListVector() { Block block = ArrowPageUtils.buildBlockFromListVector(listVector, arrayType); // Validate the result - Assert.assertEquals(block.getPositionCount(), 4); // 4 lists in the block - + assertEquals(block.getPositionCount(), 4); // 4 lists in the block } } - } From 498631957c5866ef121f3002710c169e75cd79d7 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 13:26:55 +0530 Subject: [PATCH 21/39] Review comment fixes - Root allocator and typos --- .../plugin/arrow/AbstractArrowMetadata.java | 13 ++++++------ .../facebook/plugin/arrow/ArrowConnector.java | 17 +++++++++++++--- .../plugin/arrow/ArrowFlightClient.java | 6 +----- .../arrow/ArrowFlightClientHandler.java | 20 +++++++++++++++++-- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index a98a627f4c3bb..5d9a0763f34e6 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -44,11 +44,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -158,8 +159,8 @@ public List getColumnsList(String schema, String table, ConnectorSession FlightDescriptor flightDescriptor = getFlightDescriptor(config, Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); - FlightInfo flightInfo = clientHandler.getFlightInfo(flightDescriptor, connectorSession); - List fields = flightInfo.getSchema().getFields(); + Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); + List fields = flightschema.map(schema1 -> schema1.getFields()).orElse(Collections.emptyList()); return fields; } catch (Exception e) { @@ -170,7 +171,7 @@ public List getColumnsList(String schema, String table, ConnectorSession @Override public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) { - Map columns = new HashMap<>(); + Map columnHandles = new HashMap<>(); String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); String tableValue = ((ArrowTableHandle) tableHandle).getTable(); @@ -183,9 +184,9 @@ public Map getColumnHandles(ConnectorSession session, Conn logger.debug("The value of the flight columnName is:- %s", columnName); Type type = getPrestoTypeFromArrowField(field); - columns.put(columnName, new ArrowColumnHandle(columnName, type)); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); } - return columns; + return columnHandles; } @Override diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index 9523bd210faa5..cf25b90b9d670 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -34,16 +34,21 @@ public class ArrowConnector private final ConnectorPageSourceProvider pageSourceProvider; private final ConnectorHandleResolver handleResolver; + private final ArrowFlightClientHandler arrowFlightClientHandler; + @Inject public ArrowConnector(ConnectorMetadata metadata, - ConnectorHandleResolver handleResolver, - ConnectorSplitManager splitManager, - ConnectorPageSourceProvider pageSourceProvider) + ConnectorHandleResolver handleResolver, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + + ArrowFlightClientHandler arrowFlightClientHandler) { this.metadata = requireNonNull(metadata, "Metadata is null"); this.handleResolver = requireNonNull(handleResolver, "Metadata is null"); this.splitManager = requireNonNull(splitManager, "SplitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); + this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); } public Optional getHandleResolver() @@ -74,4 +79,10 @@ public ConnectorPageSourceProvider getPageSourceProvider() { return pageSourceProvider; } + + @Override + public void shutdown() + { + arrowFlightClientHandler.closeRootallocator(); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java index fa3fabf7b7fb2..3d12617a0839d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClient.java @@ -14,7 +14,6 @@ package com.facebook.plugin.arrow; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.memory.RootAllocator; import java.io.IOException; import java.io.InputStream; @@ -27,13 +26,11 @@ public class ArrowFlightClient { private final FlightClient flightClient; private final Optional trustedCertificate; - private RootAllocator allocator; - public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate, RootAllocator allocator) + public ArrowFlightClient(FlightClient flightClient, Optional trustedCertificate) { this.flightClient = requireNonNull(flightClient, "flightClient cannot be null"); this.trustedCertificate = requireNonNull(trustedCertificate, "trustedCertificate is null"); - this.allocator = requireNonNull(allocator, "allocator is null"); } public FlightClient getFlightClient() @@ -53,6 +50,5 @@ public void close() throws InterruptedException, IOException if (trustedCertificate.isPresent()) { trustedCertificate.get().close(); } - allocator.close(); } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index 7db878774065f..42993ff8ba4b1 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -21,6 +21,7 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; import java.io.FileInputStream; import java.io.InputStream; @@ -33,6 +34,8 @@ public abstract class ArrowFlightClientHandler private static final Logger logger = Logger.get(ArrowFlightClientHandler.class); private final ArrowFlightConfig config; + private RootAllocator allocator; + public ArrowFlightClientHandler(ArrowFlightConfig config) { this.config = config; @@ -41,7 +44,6 @@ public ArrowFlightClientHandler(ArrowFlightConfig config) private ArrowFlightClient initializeClient(Optional uri) { try { - RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); Optional trustedCertificate = Optional.empty(); Location location; @@ -57,6 +59,10 @@ private ArrowFlightClient initializeClient(Optional uri) } } + if (null == allocator) { + allocator = new RootAllocator(Long.MAX_VALUE); + } + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); if (config.getVerifyServer() != null && !config.getVerifyServer()) { flightClientBuilder.verifyServer(false); @@ -67,7 +73,7 @@ else if (config.getFlightServerSSLCertificate() != null) { } FlightClient flightClient = flightClientBuilder.build(); - return new ArrowFlightClient(flightClient, trustedCertificate, allocator); + return new ArrowFlightClient(flightClient, trustedCertificate); } catch (Exception ex) { throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight client could not be obtained." + ex.getMessage(), ex); @@ -94,4 +100,14 @@ public FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSess throw new ArrowException(ARROW_FLIGHT_ERROR, "The flight information could not be obtained from the flight server." + e.getMessage(), e); } } + + public Optional getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + return getFlightInfo(flightDescriptor, connectorSession).getSchemaOptional(); + } + + public void closeRootallocator() + { + allocator.close(); + } } From 82b64d50477fff0fd4e52ef1842b82013293afad Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 17:57:30 +0530 Subject: [PATCH 22/39] Review comment fixes --- .../{arrow-tests.yml => arrow-flight-tests.yml} | 2 +- .../plugin/arrow/AbstractArrowMetadata.java | 2 +- .../facebook/plugin/arrow/ArrowConnector.java | 1 - .../facebook/plugin/arrow/ArrowPageUtils.java | 2 +- .../arrow/TestingConnectionProperties.java | 17 ++++++++++------- 5 files changed, 13 insertions(+), 11 deletions(-) rename .github/workflows/{arrow-tests.yml => arrow-flight-tests.yml} (99%) diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-flight-tests.yml similarity index 99% rename from .github/workflows/arrow-tests.yml rename to .github/workflows/arrow-flight-tests.yml index 83ad7c940122b..ee77c122536e1 100644 --- a/.github/workflows/arrow-tests.yml +++ b/.github/workflows/arrow-flight-tests.yml @@ -1,4 +1,4 @@ -name: arrow tests +name: arrow flight tests on: pull_request: diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 5d9a0763f34e6..8928bac9676f0 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -160,7 +160,7 @@ public List getColumnsList(String schema, String table, ConnectorSession dataSourceSpecificSchemaName, dataSourceSpecificTableName); Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); - List fields = flightschema.map(schema1 -> schema1.getFields()).orElse(Collections.emptyList()); + List fields = flightschema.map(Schema :: getFields).orElse(Collections.emptyList()); return fields; } catch (Exception e) { diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index cf25b90b9d670..d6221625df615 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -41,7 +41,6 @@ public ArrowConnector(ConnectorMetadata metadata, ConnectorHandleResolver handleResolver, ConnectorSplitManager splitManager, ConnectorPageSourceProvider pageSourceProvider, - ArrowFlightClientHandler arrowFlightClientHandler) { this.metadata = requireNonNull(metadata, "Metadata is null"); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index b367e21bce160..8b85d48fdde2a 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -70,7 +70,7 @@ private ArrowPageUtils() { } - static Block buildBlockFromVector(FieldVector vector, Type type) + public static Block buildBlockFromVector(FieldVector vector, Type type) { if (vector instanceof BitVector) { return buildBlockFromBitVector((BitVector) vector, type); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java index 707283ba5bf55..0cccfece3866e 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -13,14 +13,17 @@ */ package com.facebook.plugin.arrow; -public final class TestingConnectionProperties +import javax.annotation.concurrent.Immutable; + +@Immutable +public class TestingConnectionProperties { - public final String database; - public final String password; - public Integer port; - public final String host; - public final Boolean ssl; - public final String username; + private final String database; + private final String password; + private Integer port; + private final String host; + private final Boolean ssl; + private final String username; public TestingConnectionProperties(String database, String password, String host, Boolean ssl, String username) { From d9db76ef5f12395a5c8fb44df8229202193d449b Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 22 Nov 2024 18:29:46 +0530 Subject: [PATCH 23/39] Review comment fixes --- .../com/facebook/plugin/arrow/TestingArrowSplitManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java index ec8f7cd2e9da1..34694863bc277 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowSplitManager.java @@ -31,7 +31,7 @@ public class TestingArrowSplitManager @Inject public TestingArrowSplitManager(ArrowFlightConfig config, ArrowFlightClientHandler client, TestingArrowFlightConfig testconfig, NodeManager nodeManager) { - super(client, config); + super(client); this.testconfig = testconfig; this.nodeManager = nodeManager; } From 5e99355a0bbf81ed7fb87f120d9f5a6353144050 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Tue, 26 Nov 2024 14:00:40 +0530 Subject: [PATCH 24/39] Review comment fixes - Changed config --- .../com/facebook/plugin/arrow/AbstractArrowMetadata.java | 6 +++--- .../com/facebook/plugin/arrow/TestingArrowMetadata.java | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 8928bac9676f0..611fb38481a62 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -132,7 +132,7 @@ protected Type getPrestoTypeFromArrowField(Field field) } } - protected abstract FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table); + protected abstract FlightDescriptor getFlightDescriptor(Optional query, String schema, String table); protected abstract String getDataSourceSpecificSchemaName(ArrowFlightConfig config, String schemaName); @@ -156,11 +156,11 @@ public List getColumnsList(String schema, String table, ConnectorSession try { String dataSourceSpecificSchemaName = getDataSourceSpecificSchemaName(config, schema); String dataSourceSpecificTableName = getDataSourceSpecificTableName(config, table); - FlightDescriptor flightDescriptor = getFlightDescriptor(config, Optional.empty(), + FlightDescriptor flightDescriptor = getFlightDescriptor(Optional.empty(), dataSourceSpecificSchemaName, dataSourceSpecificTableName); Optional flightschema = clientHandler.getSchema(flightDescriptor, connectorSession); - List fields = flightschema.map(Schema :: getFields).orElse(Collections.emptyList()); + List fields = flightschema.map(Schema::getFields).orElse(Collections.emptyList()); return fields; } catch (Exception e) { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index 0425981cded2a..a1cd0a4dffb76 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -145,9 +145,9 @@ protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String } @Override - protected FlightDescriptor getFlightDescriptor(ArrowFlightConfig config, Optional query, String schema, String table) + protected FlightDescriptor getFlightDescriptor(Optional query, String schema, String table) { - TestingArrowFlightRequest request = new TestingArrowFlightRequest(config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(this.config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); return FlightDescriptor.command(request.getCommand()); } From a4ba35c6008c0e2ed0409ec717d27ed0d9c4b438 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Tue, 26 Nov 2024 14:44:19 +0530 Subject: [PATCH 25/39] Added support for small int tiny int date and timestamp --- .../facebook/plugin/arrow/ArrowPageUtils.java | 305 ++++++++++++------ 1 file changed, 209 insertions(+), 96 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 8b85d48fdde2a..8750028bfe060 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -25,8 +25,10 @@ import com.facebook.presto.common.type.DoubleType; import com.facebook.presto.common.type.IntegerType; import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; @@ -563,124 +565,235 @@ private static void appendValueToBuilder(Type type, BlockBuilder builder, Object } if (type instanceof VarcharType) { - // Convert value to string and write as Varchar - Slice slice = Slices.utf8Slice(value.toString()); - type.writeSlice(builder, slice); + handleVarcharType(type, builder, value); + } + else if (type instanceof SmallintType) { + handleSmallintType(type, builder, value); + } + else if (type instanceof TinyintType) { + handleTinyintType(type, builder, value); } else if (type instanceof BigintType) { - if (value instanceof Long) { - type.writeLong(builder, (Long) value); - } - else if (value instanceof Integer) { - type.writeLong(builder, ((Integer) value).longValue()); - } - else if (value instanceof JsonStringArrayList) { - JsonStringArrayList list = (JsonStringArrayList) value; - for (Object obj : list) { - try { - long longValue = Long.parseLong(obj.toString()); - type.writeLong(builder, longValue); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); - } - } - } - else { - throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); - } + handleBigintType(type, builder, value); } else if (type instanceof IntegerType) { - if (value instanceof Integer) { - type.writeLong(builder, (Integer) value); - } - else if (value instanceof JsonStringArrayList) { - JsonStringArrayList list = (JsonStringArrayList) value; - for (Object obj : list) { - try { - int intValue = Integer.parseInt(obj.toString()); - type.writeLong(builder, intValue); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); - } - } - } - else { - throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); - } + handleIntegerType(type, builder, value); } else if (type instanceof DoubleType) { - if (value instanceof Double) { - type.writeDouble(builder, (Double) value); - } - else if (value instanceof Float) { - type.writeDouble(builder, ((Float) value).doubleValue()); - } - else if (value instanceof JsonStringArrayList) { - JsonStringArrayList list = (JsonStringArrayList) value; - for (Object obj : list) { - try { - double doubleValue = Double.parseDouble(obj.toString()); - type.writeDouble(builder, doubleValue); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); - } + handleDoubleType(type, builder, value); + } + else if (type instanceof BooleanType) { + handleBooleanType(type, builder, value); + } + else if (type instanceof DecimalType) { + handleDecimalType((DecimalType) type, builder, value); + } + else if (type instanceof ArrayType) { + handleArrayType((ArrayType) type, builder, value); + } + else if (type instanceof RowType) { + handleRowType((RowType) type, builder, value); + } + else if (type instanceof DateType) { + handleDateType(type, builder, value); + } + else if (type instanceof TimestampType) { + handleTimestampType(type, builder, value); + } + else { + throw new IllegalArgumentException("Unsupported type: " + type); + } + } + + private static void handleVarcharType(Type type, BlockBuilder builder, Object value) + { + Slice slice = Slices.utf8Slice(value.toString()); + type.writeSlice(builder, slice); + } + + private static void handleSmallintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).shortValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + short shortValue = Short.parseShort(obj.toString()); + type.writeLong(builder, shortValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for SmallintType: " + obj, e); } } - else { - throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + else { + throw new IllegalArgumentException("Unsupported type for SmallintType: " + value.getClass()); + } + } + + private static void handleTinyintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Number) { + type.writeLong(builder, ((Number) value).byteValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + byte byteValue = Byte.parseByte(obj.toString()); + type.writeLong(builder, byteValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList for TinyintType: " + obj, e); + } } } - else if (type instanceof BooleanType) { - if (value instanceof Boolean) { - type.writeBoolean(builder, (Boolean) value); + else { + throw new IllegalArgumentException("Unsupported type for TinyintType: " + value.getClass()); + } + } + + private static void handleBigintType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Long) { + type.writeLong(builder, (Long) value); + } + else if (value instanceof Integer) { + type.writeLong(builder, ((Integer) value).longValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + long longValue = Long.parseLong(obj.toString()); + type.writeLong(builder, longValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } } - else { - throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + else { + throw new IllegalArgumentException("Unsupported type for BigintType: " + value.getClass()); + } + } + + private static void handleIntegerType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Integer) { + type.writeLong(builder, (Integer) value); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + int intValue = Integer.parseInt(obj.toString()); + type.writeLong(builder, intValue); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); + } } } - else if (type instanceof DecimalType) { - DecimalType decimalType = (DecimalType) type; - if (value instanceof BigDecimal) { - BigDecimal decimalValue = (BigDecimal) value; - if (decimalType.isShort()) { - builder.writeLong(decimalValue.unscaledValue().longValue()); + else { + throw new IllegalArgumentException("Unsupported type for IntegerType: " + value.getClass()); + } + } + + private static void handleDoubleType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Double) { + type.writeDouble(builder, (Double) value); + } + else if (value instanceof Float) { + type.writeDouble(builder, ((Float) value).doubleValue()); + } + else if (value instanceof JsonStringArrayList) { + for (Object obj : (JsonStringArrayList) value) { + try { + double doubleValue = Double.parseDouble(obj.toString()); + type.writeDouble(builder, doubleValue); } - else { - Slice slice = Decimals.encodeScaledValue(decimalValue); - decimalType.writeSlice(builder, slice); + catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid number format in JsonStringArrayList: " + obj, e); } } + } + else { + throw new IllegalArgumentException("Unsupported type for DoubleType: " + value.getClass()); + } + } + + private static void handleBooleanType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof Boolean) { + type.writeBoolean(builder, (Boolean) value); + } + else { + throw new IllegalArgumentException("Unsupported type for BooleanType: " + value.getClass()); + } + } + + private static void handleDecimalType(DecimalType type, BlockBuilder builder, Object value) + { + if (value instanceof BigDecimal) { + BigDecimal decimalValue = (BigDecimal) value; + if (type.isShort()) { + builder.writeLong(decimalValue.unscaledValue().longValue()); + } else { - throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); + Slice slice = Decimals.encodeScaledValue(decimalValue); + type.writeSlice(builder, slice); } } - else if (type instanceof ArrayType) { - // Handling array types (lists) - ArrayType arrayType = (ArrayType) type; - Type elementType = arrayType.getElementType(); - BlockBuilder arrayBuilder = builder.beginBlockEntry(); - for (Object element : (Iterable) value) { - appendValueToBuilder(elementType, arrayBuilder, element); - } - builder.closeEntry(); + else { + throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); } - else if (type instanceof RowType) { - // Handling row types (structs) - RowType rowType = (RowType) type; - List rowValues = (List) value; - BlockBuilder rowBuilder = builder.beginBlockEntry(); - List fields = rowType.getFields(); - for (int i = 0; i < fields.size(); i++) { - Type fieldType = fields.get(i).getType(); - appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); - } - builder.closeEntry(); + } + + private static void handleArrayType(ArrayType type, BlockBuilder builder, Object value) + { + Type elementType = type.getElementType(); + BlockBuilder arrayBuilder = builder.beginBlockEntry(); + for (Object element : (Iterable) value) { + appendValueToBuilder(elementType, arrayBuilder, element); + } + builder.closeEntry(); + } + + private static void handleRowType(RowType type, BlockBuilder builder, Object value) + { + List rowValues = (List) value; + BlockBuilder rowBuilder = builder.beginBlockEntry(); + List fields = type.getFields(); + for (int i = 0; i < fields.size(); i++) { + Type fieldType = fields.get(i).getType(); + appendValueToBuilder(fieldType, rowBuilder, rowValues.get(i)); + } + builder.closeEntry(); + } + + private static void handleDateType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { + int daysSinceEpoch = (int) (value instanceof java.sql.Date + ? ((java.sql.Date) value).toLocalDate().toEpochDay() + : ((java.time.LocalDate) value).toEpochDay()); + type.writeLong(builder, daysSinceEpoch); } else { - throw new IllegalArgumentException("Unsupported type: " + type); + throw new IllegalArgumentException("Unsupported type for DateType: " + value.getClass()); + } + } + + private static void handleTimestampType(Type type, BlockBuilder builder, Object value) + { + if (value instanceof java.sql.Timestamp || value instanceof java.time.Instant) { + long millis = value instanceof java.sql.Timestamp + ? ((java.sql.Timestamp) value).getTime() + : ((java.time.Instant) value).toEpochMilli(); + type.writeLong(builder, millis); + } + else { + throw new IllegalArgumentException("Unsupported type for TimestampType: " + value.getClass()); } } } From 7170a7f576e4991d3de6d5ccb83e2d51dd20c9ed Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Wed, 27 Nov 2024 16:36:53 +0530 Subject: [PATCH 26/39] Review comment fixes - Dictionary encoding and other tests --- presto-base-arrow-flight/pom.xml | 9 +- .../facebook/plugin/arrow/ArrowPageUtils.java | 32 +++++++ .../plugin/arrow/ArrowTableHandle.java | 21 +++++ .../plugin/arrow/ArrowTableLayoutHandle.java | 24 ++++- .../plugin/arrow/ArrowPageUtilsTest.java | 43 +++++++++ .../facebook/plugin/arrow/TestArrowSplit.java | 88 ++++++++++++++++++ .../plugin/arrow/TestArrowTableHandle.java | 92 +++++++++++++++++++ 7 files changed, 306 insertions(+), 3 deletions(-) create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index ea4d3bef2dbee..43f50779cf111 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -234,7 +234,6 @@ h2 test - @@ -298,6 +297,14 @@ error_prone_annotations ${error_prone_annotations} + + + org.apache.arrow + arrow-algorithm + 18.0.0 + compile + + diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 8750028bfe060..51725d95a1f6d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -145,9 +145,41 @@ else if (vector instanceof TimeStampMilliTZVector) { else if (vector instanceof ListVector) { return buildBlockFromListVector((ListVector) vector, type); } + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } + public static Block buildBlockFromEncodedVector(IntVector encodedVector, VarCharVector dictionary) + { + // Ensure the dictionary vector is valid + if (dictionary == null || encodedVector == null) { + throw new IllegalArgumentException("Both encodedVector and dictionary must be non-null."); + } + + // Create a BlockBuilder for VARCHAR + BlockBuilder builder = VarcharType.VARCHAR.createBlockBuilder(null, encodedVector.getValueCount()); + + // Iterate through the encoded vector and retrieve values from the dictionary + for (int i = 0; i < encodedVector.getValueCount(); i++) { + if (encodedVector.isNull(i)) { + builder.appendNull(); // Append null if the index is null + } + else { + int dictionaryIndex = encodedVector.get(i); + if (dictionary.isNull(dictionaryIndex)) { + builder.appendNull(); // Append null if the dictionary value is null + } + else { + byte[] valueBytes = dictionary.get(dictionaryIndex); + String value = new String(valueBytes, StandardCharsets.UTF_8); + VarcharType.VARCHAR.writeSlice(builder, Slices.utf8Slice(value)); // Append the dictionary value + } + } + } + + return builder.build(); + } + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) { if (!(type instanceof TimestampType)) { diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java index e0f8c6586791d..cef04e4372a5a 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.Objects; + public class ArrowTableHandle implements ConnectorTableHandle { @@ -49,4 +51,23 @@ public String toString() { return schema + ":" + table; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableHandle that = (ArrowTableHandle) o; + return Objects.equals(schema, that.schema) && Objects.equals(table, that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(schema, table); + } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java index 512246683562f..46e94a4e1143c 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; +import java.util.Objects; import static java.util.Objects.requireNonNull; @@ -32,8 +33,8 @@ public class ArrowTableLayoutHandle @JsonCreator public ArrowTableLayoutHandle(@JsonProperty("table") ArrowTableHandle table, - @JsonProperty("columnHandles") List columnHandles, - @JsonProperty("tupleDomain") TupleDomain domain) + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) { this.tableHandle = requireNonNull(table, "table is null"); this.columnHandles = requireNonNull(columnHandles, "columns are null"); @@ -63,4 +64,23 @@ public String toString() { return "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o; + return Objects.equals(tableHandle, arrowTableLayoutHandle.tableHandle) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain); + } + + @Override + public int hashCode() + { + return Objects.hash(tableHandle, columnHandles, tupleDomain); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 9fbc1b4255508..84900918e0ae3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -31,18 +31,26 @@ import org.apache.arrow.vector.SmallIntVector; import org.apache.arrow.vector.TimeStampMicroVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; public class ArrowPageUtilsTest { + private static final int DICTIONARY_LENGTH = 10; + private static final int VECTOR_LENGTH = 50; private BufferAllocator allocator; @BeforeClass @@ -234,4 +242,39 @@ public void testBuildBlockFromListVector() assertEquals(block.getPositionCount(), 4); // 4 lists in the block } } + + @Test + public void testProcessDictionaryVector() + { + // Create dictionary vector + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(DICTIONARY_LENGTH); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + dictionaryVector.setValueCount(DICTIONARY_LENGTH); + + // Create raw vector + VarCharVector rawVector = new VarCharVector("raw", allocator); + rawVector.allocateNew(VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int value = i % DICTIONARY_LENGTH; + rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8)); + } + rawVector.setValueCount(VECTOR_LENGTH); + + // Encode using dictionary + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(rawVector, dictionary); + + // Decode back to original + VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dictionary); + + // Process the dictionary vector + Block result = ArrowPageUtils.buildBlockFromEncodedVector(encodedVector, dictionaryVector); + + // Verify the result + assertNotNull(result, "The BlockBuilder should not be null."); + assertEquals(result.getPositionCount(), 50); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java new file mode 100644 index 0000000000000..e6dceda59ef85 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@Test(singleThreaded = true) +public class TestArrowSplit +{ + private ArrowSplit arrowSplit; + private String schemaName; + private String tableName; + private byte[] ticket; + private List locationUrls; + + @BeforeMethod + public void setUp() + { + schemaName = "testSchema"; + tableName = "testTable"; + ticket = new byte[] {1, 2, 3, 4}; + locationUrls = Arrays.asList("http://localhost:8080", "http://localhost:8081"); + + // Instantiate ArrowSplit with mock data + arrowSplit = new ArrowSplit(schemaName, tableName, ticket, locationUrls); + } + + @Test + public void testConstructorAndGetters() + { + // Test that the constructor correctly initializes fields + assertEquals(arrowSplit.getSchemaName(), schemaName, "Schema name should match."); + assertEquals(arrowSplit.getTableName(), tableName, "Table name should match."); + assertEquals(arrowSplit.getTicket(), ticket, "Ticket byte array should match."); + assertEquals(arrowSplit.getLocationUrls(), locationUrls, "Location URLs list should match."); + } + + @Test + public void testNodeSelectionStrategy() + { + // Test that the node selection strategy is NO_PREFERENCE + assertEquals(arrowSplit.getNodeSelectionStrategy(), NodeSelectionStrategy.NO_PREFERENCE, "Node selection strategy should be NO_PREFERENCE."); + } + + @Test + public void testGetPreferredNodes() + { + // Test that the preferred nodes list is empty + List preferredNodes = arrowSplit.getPreferredNodes(null); + assertNotNull(preferredNodes, "Preferred nodes list should not be null."); + assertTrue(preferredNodes.isEmpty(), "Preferred nodes list should be empty."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java new file mode 100644 index 0000000000000..c6b66751c3438 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.testng.annotations.Test; + +import static com.facebook.airlift.testing.Assertions.assertNotEquals; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestArrowTableHandle +{ + @Test + public void testConstructorAndGetters() + { + String schema = "test_schema"; + String table = "test_table"; + + // Create an instance of ArrowTableHandle + ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); + + // Verify that the schema and table are correctly set + assertEquals(tableHandle.getSchema(), schema, "Schema should match the input value."); + assertEquals(tableHandle.getTable(), table, "Table should match the input value."); + } + + @Test + public void testToString() + { + String schema = "test_schema"; + String table = "test_table"; + + // Create an instance of ArrowTableHandle + ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); + + // Verify the toString() output + String expectedToString = "test_schema:test_table"; + assertEquals(tableHandle.toString(), expectedToString, "toString() should return the correct string representation."); + } + + @Test + public void testEqualityAndHashCode() + { + String schema = "schema"; + String table = "table"; + + String schema2 = "schema2"; + String table2 = "table2"; + + // Create multiple instances + ArrowTableHandle handle1 = new ArrowTableHandle(schema, table); + ArrowTableHandle handle2 = new ArrowTableHandle(schema, table); + ArrowTableHandle handle3 = new ArrowTableHandle(schema2, table2); + + // Verify equality + assertEquals(handle1, handle2, "Handles with the same schema and table should be equal."); + assertNotEquals(handle1, handle3, "Handles with different schema or table should not be equal."); + + // Verify hashCode + assertEquals(handle1.hashCode(), handle2.hashCode(), "Equal handles should have the same hashCode."); + assertNotEquals(handle1.hashCode(), handle3.hashCode(), "Different handles should have different hashCodes."); + } + + @Test + public void testNullValues() + { + String schema = null; + String table = null; + + // Create an instance of ArrowTableHandle with null values + ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); + + // Verify that the schema and table are null + assertNull(tableHandle.getSchema(), "Schema should be null."); + assertNull(tableHandle.getTable(), "Table should be null."); + + // Verify the toString() output + String expectedToString = "null:null"; + assertEquals(tableHandle.toString(), expectedToString, "toString() should handle null values correctly."); + } +} From 2a1071e7004cc692ecdf03991607022b31309c0d Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Thu, 28 Nov 2024 15:41:20 +0530 Subject: [PATCH 27/39] Review comment fixes - Added Tests as per comments --- .../facebook/plugin/arrow/ArrowPageUtils.java | 54 +++-- .../plugin/arrow/ArrowPageUtilsTest.java | 221 ++++++++++++++++++ .../plugin/arrow/TestArrowColumnHandle.java | 83 +++++++ .../plugin/arrow/TestArrowHandleResolver.java | 67 ++++++ .../facebook/plugin/arrow/TestArrowSplit.java | 14 -- .../plugin/arrow/TestArrowTableHandle.java | 13 ++ .../arrow/TestArrowTableLayoutHandle.java | 116 +++++++++ 7 files changed, 535 insertions(+), 33 deletions(-) create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 51725d95a1f6d..0dfa6c557c6d2 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -502,7 +502,6 @@ public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type t } return builder.build(); } - public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) { if (!(type instanceof TimestampType)) { @@ -589,7 +588,7 @@ public static Block buildBlockFromListVector(ListVector vector, Type type) return arrayBuilder.build(); } - private static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) + public static void appendValueToBuilder(Type type, BlockBuilder builder, Object value) { if (value == null) { builder.appendNull(); @@ -637,13 +636,13 @@ else if (type instanceof TimestampType) { } } - private static void handleVarcharType(Type type, BlockBuilder builder, Object value) + public static void handleVarcharType(Type type, BlockBuilder builder, Object value) { Slice slice = Slices.utf8Slice(value.toString()); type.writeSlice(builder, slice); } - private static void handleSmallintType(Type type, BlockBuilder builder, Object value) + public static void handleSmallintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).shortValue()); @@ -664,7 +663,7 @@ else if (value instanceof JsonStringArrayList) { } } - private static void handleTinyintType(Type type, BlockBuilder builder, Object value) + public static void handleTinyintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).byteValue()); @@ -685,7 +684,7 @@ else if (value instanceof JsonStringArrayList) { } } - private static void handleBigintType(Type type, BlockBuilder builder, Object value) + public static void handleBigintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Long) { type.writeLong(builder, (Long) value); @@ -709,7 +708,7 @@ else if (value instanceof JsonStringArrayList) { } } - private static void handleIntegerType(Type type, BlockBuilder builder, Object value) + public static void handleIntegerType(Type type, BlockBuilder builder, Object value) { if (value instanceof Integer) { type.writeLong(builder, (Integer) value); @@ -730,7 +729,7 @@ else if (value instanceof JsonStringArrayList) { } } - private static void handleDoubleType(Type type, BlockBuilder builder, Object value) + public static void handleDoubleType(Type type, BlockBuilder builder, Object value) { if (value instanceof Double) { type.writeDouble(builder, (Double) value); @@ -754,7 +753,7 @@ else if (value instanceof JsonStringArrayList) { } } - private static void handleBooleanType(Type type, BlockBuilder builder, Object value) + public static void handleBooleanType(Type type, BlockBuilder builder, Object value) { if (value instanceof Boolean) { type.writeBoolean(builder, (Boolean) value); @@ -764,24 +763,36 @@ private static void handleBooleanType(Type type, BlockBuilder builder, Object va } } - private static void handleDecimalType(DecimalType type, BlockBuilder builder, Object value) + public static void handleDecimalType(DecimalType type, BlockBuilder builder, Object value) { if (value instanceof BigDecimal) { BigDecimal decimalValue = (BigDecimal) value; if (type.isShort()) { - builder.writeLong(decimalValue.unscaledValue().longValue()); + // Handle ShortDecimalType + long unscaledValue = decimalValue.unscaledValue().longValue(); + type.writeLong(builder, unscaledValue); } else { + // Handle LongDecimalType Slice slice = Decimals.encodeScaledValue(decimalValue); type.writeSlice(builder, slice); } } + else if (value instanceof Long) { + // Direct handling for ShortDecimalType using long + if (type.isShort()) { + type.writeLong(builder, (Long) value); + } + else { + throw new IllegalArgumentException("Long value is not supported for LongDecimalType."); + } + } else { throw new IllegalArgumentException("Unsupported type for DecimalType: " + value.getClass()); } } - private static void handleArrayType(ArrayType type, BlockBuilder builder, Object value) + public static void handleArrayType(ArrayType type, BlockBuilder builder, Object value) { Type elementType = type.getElementType(); BlockBuilder arrayBuilder = builder.beginBlockEntry(); @@ -791,7 +802,7 @@ private static void handleArrayType(ArrayType type, BlockBuilder builder, Object builder.closeEntry(); } - private static void handleRowType(RowType type, BlockBuilder builder, Object value) + public static void handleRowType(RowType type, BlockBuilder builder, Object value) { List rowValues = (List) value; BlockBuilder rowBuilder = builder.beginBlockEntry(); @@ -803,7 +814,7 @@ private static void handleRowType(RowType type, BlockBuilder builder, Object val builder.closeEntry(); } - private static void handleDateType(Type type, BlockBuilder builder, Object value) + public static void handleDateType(Type type, BlockBuilder builder, Object value) { if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { int daysSinceEpoch = (int) (value instanceof java.sql.Date @@ -816,14 +827,19 @@ private static void handleDateType(Type type, BlockBuilder builder, Object value } } - private static void handleTimestampType(Type type, BlockBuilder builder, Object value) + public static void handleTimestampType(Type type, BlockBuilder builder, Object value) { - if (value instanceof java.sql.Timestamp || value instanceof java.time.Instant) { - long millis = value instanceof java.sql.Timestamp - ? ((java.sql.Timestamp) value).getTime() - : ((java.time.Instant) value).toEpochMilli(); + if (value instanceof java.sql.Timestamp) { + long millis = ((java.sql.Timestamp) value).getTime(); + type.writeLong(builder, millis); + } + else if (value instanceof java.time.Instant) { + long millis = ((java.time.Instant) value).toEpochMilli(); type.writeLong(builder, millis); } + else if (value instanceof Long) { // Handle long epoch milliseconds directly + type.writeLong(builder, (Long) value); + } else { throw new IllegalArgumentException("Unsupported type for TimestampType: " + value.getClass()); } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 84900918e0ae3..356f5a7ad3428 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -14,14 +14,22 @@ package com.facebook.plugin.arrow; import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import io.airlift.slice.Slice; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; @@ -41,7 +49,12 @@ import org.testng.annotations.Test; import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -277,4 +290,212 @@ public void testProcessDictionaryVector() assertNotNull(result, "The BlockBuilder should not be null."); assertEquals(result.getPositionCount(), 50); } + + @Test + public void testHandleVarcharType() + { + Type varcharType = VarcharType.createUnboundedVarcharType(); + BlockBuilder builder = varcharType.createBlockBuilder(null, 1); + + String value = "test_string"; + ArrowPageUtils.handleVarcharType(varcharType, builder, value); + + Block block = builder.build(); + Slice result = varcharType.getSlice(block, 0); + assertEquals(result.toStringUtf8(), value); + } + + @Test + public void testHandleSmallintType() + { + Type smallintType = SmallintType.SMALLINT; + BlockBuilder builder = smallintType.createBlockBuilder(null, 1); + + short value = 42; + ArrowPageUtils.handleSmallintType(smallintType, builder, value); + + Block block = builder.build(); + long result = smallintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleTinyintType() + { + Type tinyintType = TinyintType.TINYINT; + BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); + + byte value = 7; + ArrowPageUtils.handleTinyintType(tinyintType, builder, value); + + Block block = builder.build(); + long result = tinyintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleBigintType() + { + Type bigintType = BigintType.BIGINT; + BlockBuilder builder = bigintType.createBlockBuilder(null, 1); + + long value = 123456789L; + ArrowPageUtils.handleBigintType(bigintType, builder, value); + + Block block = builder.build(); + long result = bigintType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleIntegerType() + { + Type integerType = IntegerType.INTEGER; + BlockBuilder builder = integerType.createBlockBuilder(null, 1); + + int value = 42; + ArrowPageUtils.handleIntegerType(integerType, builder, value); + + Block block = builder.build(); + long result = integerType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleDoubleType() + { + Type doubleType = DoubleType.DOUBLE; + BlockBuilder builder = doubleType.createBlockBuilder(null, 1); + + double value = 42.42; + ArrowPageUtils.handleDoubleType(doubleType, builder, value); + + Block block = builder.build(); + double result = doubleType.getDouble(block, 0); + assertEquals(result, value, 0.001); + } + + @Test + public void testHandleBooleanType() + { + Type booleanType = BooleanType.BOOLEAN; + BlockBuilder builder = booleanType.createBlockBuilder(null, 1); + + boolean value = true; + ArrowPageUtils.handleBooleanType(booleanType, builder, value); + + Block block = builder.build(); + boolean result = booleanType.getBoolean(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleArrayType() + { + Type elementType = IntegerType.INTEGER; + ArrayType arrayType = new ArrayType(elementType); + BlockBuilder builder = arrayType.createBlockBuilder(null, 1); + + List values = Arrays.asList(1, 2, 3); + ArrowPageUtils.handleArrayType(arrayType, builder, values); + + Block block = builder.build(); + Block arrayBlock = arrayType.getObject(block, 0); + assertEquals(arrayBlock.getPositionCount(), values.size()); + for (int i = 0; i < values.size(); i++) { + assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue()); + } + } + + @Test + public void testHandleRowType() + { + RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); + RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); + RowType rowType = RowType.from(Arrays.asList(field1, field2)); + BlockBuilder builder = rowType.createBlockBuilder(null, 1); + + List rowValues = Arrays.asList(42, "test"); + ArrowPageUtils.handleRowType(rowType, builder, rowValues); + + Block block = builder.build(); + Block rowBlock = rowType.getObject(block, 0); + assertEquals(IntegerType.INTEGER.getLong(rowBlock, 0), 42); + assertEquals(VarcharType.createUnboundedVarcharType().getSlice(rowBlock, 1).toStringUtf8(), "test"); + } + + @Test + public void testHandleDateType() + { + Type dateType = DateType.DATE; + BlockBuilder builder = dateType.createBlockBuilder(null, 1); + + LocalDate value = LocalDate.of(2020, 1, 1); + ArrowPageUtils.handleDateType(dateType, builder, value); + + Block block = builder.build(); + long result = dateType.getLong(block, 0); + assertEquals(result, value.toEpochDay()); + } + + @Test + public void testHandleTimestampType() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC + ArrowPageUtils.handleTimestampType(timestampType, builder, value); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, value); + } + + @Test + public void testHandleTimestampTypeWithSqlTimestamp() + { + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); + long expectedMillis = timestamp.getTime(); + ArrowPageUtils.handleTimestampType(timestampType, builder, timestamp); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, expectedMillis); + } + + @Test + public void testShortDecimalRetrieval() + { + DecimalType shortDecimalType = DecimalType.createDecimalType(10, 2); // Precision: 10, Scale: 2 + BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); + + BigDecimal decimalValue = new BigDecimal("12345.67"); + ArrowPageUtils.handleDecimalType(shortDecimalType, builder, decimalValue); + + Block block = builder.build(); + long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 + BigDecimal result = BigDecimal.valueOf(unscaledValue).movePointLeft(shortDecimalType.getScale()); + assertEquals(result, decimalValue); + } + + @Test + public void testLongDecimalRetrieval() + { + // Create a DecimalType with precision 38 and scale 10 + DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); + BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); + BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); + ArrowPageUtils.handleDecimalType(longDecimalType, builder, decimalValue); + // Build the block after inserting the decimal value + Block block = builder.build(); + Slice unscaledSlice = longDecimalType.getSlice(block, 0); + BigInteger unscaledValue = Decimals.decodeUnscaledValue(unscaledSlice); + BigDecimal result = new BigDecimal(unscaledValue).movePointLeft(longDecimalType.getScale()); + // Assert the decoded result is equal to the original decimal value + assertEquals(result, decimalValue); + } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java new file mode 100644 index 0000000000000..1d9c490180abc --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.spi.ColumnMetadata; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestArrowColumnHandle +{ + @Test + public void testConstructorAndGetters() + { + // Given + String columnName = "testColumn"; + // When + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // Then + assertEquals(columnHandle.getColumnName(), columnName, "Column name should match the input"); + assertEquals(columnHandle.getColumnType(), IntegerType.INTEGER, "Column type should match the input"); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnName() + { + // Given + // When + new ArrowColumnHandle(null, IntegerType.INTEGER); // Should throw NullPointerException + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnType() + { + // Given + String columnName = "testColumn"; + + // When + new ArrowColumnHandle(columnName, null); // Should throw NullPointerException + } + + @Test + public void testGetColumnMetadata() + { + // Given + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // When + ColumnMetadata columnMetadata = columnHandle.getColumnMetadata(); + + // Then + assertNotNull(columnMetadata, "ColumnMetadata should not be null"); + assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); + } + + @Test + public void testToString() + { + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + String result = columnHandle.toString(); + String expected = columnName + ":" + IntegerType.INTEGER; + assertEquals(result, expected, "toString() should return the correct string representation"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java new file mode 100644 index 0000000000000..ea95e9fec01b0 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestArrowHandleResolver +{ + @Test + public void testGetTableHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableHandleClass(), + ArrowTableHandle.class, + "getTableHandleClass should return ArrowTableHandle class."); + } + @Test + public void testGetTableLayoutHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableLayoutHandleClass(), + ArrowTableLayoutHandle.class, + "getTableLayoutHandleClass should return ArrowTableLayoutHandle class."); + } + @Test + public void testGetColumnHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getColumnHandleClass(), + ArrowColumnHandle.class, + "getColumnHandleClass should return ArrowColumnHandle class."); + } + @Test + public void testGetSplitClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getSplitClass(), + ArrowSplit.class, + "getSplitClass should return ArrowSplit class."); + } + @Test + public void testGetTransactionHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTransactionHandleClass(), + ArrowTransactionHandle.class, + "getTransactionHandleClass should return ArrowTransactionHandle class."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java index e6dceda59ef85..65da26254bd34 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java @@ -25,20 +25,6 @@ import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - @Test(singleThreaded = true) public class TestArrowSplit { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java index c6b66751c3438..dd1696eacabd6 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -19,6 +19,19 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ public class TestArrowTableHandle { @Test diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..0ff7301c7e0ff --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestArrowTableLayoutHandle +{ + @Test + public void testConstructorAndGetters() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + assertEquals(layoutHandle.getTableHandle(), tableHandle, "Table handle mismatch."); + assertEquals(layoutHandle.getColumnHandles(), columnHandles, "Column handles mismatch."); + assertEquals(layoutHandle.getTupleDomain(), tupleDomain, "Tuple domain mismatch."); + } + + @Test + public void testToString() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", BigintType.BIGINT)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + String expectedString = "tableHandle:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + assertEquals(layoutHandle.toString(), expectedString, "toString output mismatch."); + } + + @Test + public void testEqualsAndHashCode() + { + ArrowTableHandle tableHandle1 = new ArrowTableHandle("schema", "table"); + ArrowTableHandle tableHandle2 = new ArrowTableHandle("schema", "different_table"); + + List columnHandles1 = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + List columnHandles2 = Collections.singletonList( + new ArrowColumnHandle("column1", IntegerType.INTEGER)); + + TupleDomain tupleDomain1 = TupleDomain.all(); + TupleDomain tupleDomain2 = TupleDomain.none(); + + ArrowTableLayoutHandle layoutHandle1 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle2 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle3 = new ArrowTableLayoutHandle(tableHandle2, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle4 = new ArrowTableLayoutHandle(tableHandle1, columnHandles2, tupleDomain1); + ArrowTableLayoutHandle layoutHandle5 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain2); + + // Test equality + assertEquals(layoutHandle1, layoutHandle2, "Handles with same attributes should be equal."); + assertNotEquals(layoutHandle1, layoutHandle3, "Handles with different tableHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle4, "Handles with different columnHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle5, "Handles with different tupleDomains should not be equal."); + assertNotEquals(layoutHandle1, null, "Handle should not be equal to null."); + assertNotEquals(layoutHandle1, new Object(), "Handle should not be equal to an object of another class."); + + // Test hash codes + assertEquals(layoutHandle1.hashCode(), layoutHandle2.hashCode(), "Equal handles should have same hash code."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle3.hashCode(), "Handles with different tableHandles should have different hash codes."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle4.hashCode(), "Handles with different columnHandles should have different hash codes."); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "table is null") + public void testConstructorNullTableHandle() + { + new ArrowTableLayoutHandle(null, Collections.emptyList(), TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columns are null") + public void testConstructorNullColumnHandles() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), null, TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "domain is null") + public void testConstructorNullTupleDomain() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), Collections.emptyList(), null); + } +} From c05c392a61e30745307985e99356a079fcecf520 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Thu, 28 Nov 2024 17:34:30 +0530 Subject: [PATCH 28/39] Removed license header unwanted place --- .../facebook/plugin/arrow/TestArrowTableHandle.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java index dd1696eacabd6..c6b66751c3438 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -19,19 +19,6 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNull; -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ public class TestArrowTableHandle { @Test From 266032212603830c25fb4f9cab2936ed19bbc216 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 10:37:57 +0530 Subject: [PATCH 29/39] Removed duplicate CI job --- .github/workflows/arrow-tests.yml | 82 ------------------------------- 1 file changed, 82 deletions(-) delete mode 100644 .github/workflows/arrow-tests.yml diff --git a/.github/workflows/arrow-tests.yml b/.github/workflows/arrow-tests.yml deleted file mode 100644 index 83ad7c940122b..0000000000000 --- a/.github/workflows/arrow-tests.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: arrow tests - -on: - pull_request: - -env: - CONTINUOUS_INTEGRATION: true - MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" - MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" - MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" - MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" - RETRY: .github/bin/retry - -jobs: - changes: - runs-on: ubuntu-latest - permissions: - pull-requests: read - outputs: - codechange: ${{ steps.filter.outputs.codechange }} - steps: - - uses: dorny/paths-filter@v2 - id: filter - with: - filters: | - codechange: - - '!presto-docs/**' - test: - runs-on: ubuntu-latest - needs: changes - strategy: - fail-fast: false - matrix: - modules: - - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module - - timeout-minutes: 80 - concurrency: - group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} - cancel-in-progress: true - - steps: - # Checkout the code only if there are changes in the relevant files - - uses: actions/checkout@v4 - if: needs.changes.outputs.codechange == 'true' - with: - show-progress: false - - # Set up Java for the build environment - - uses: actions/setup-java@v2 - if: needs.changes.outputs.codechange == 'true' - with: - distribution: 'temurin' - java-version: 8 - - # Cache Maven dependencies to speed up the build - - name: Cache local Maven repository - if: needs.changes.outputs.codechange == 'true' - id: cache-maven - uses: actions/cache@v2 - with: - path: ~/.m2/repository - key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} - restore-keys: | - ${{ runner.os }}-maven-2- - - # Resolve Maven dependencies (if cache is not found) - - name: Populate Maven cache - if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' - run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs - - # Install dependencies for the target module - - name: Maven Install - if: needs.changes.outputs.codechange == 'true' - run: | - export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" - ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} - - # Run Maven tests for the target module - - name: Maven Tests - if: needs.changes.outputs.codechange == 'true' - run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} From 03f1bcdb0ee737b47b620c9593ced545224d151c Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 12:46:58 +0530 Subject: [PATCH 30/39] Review comment fixes --- .../plugin/arrow/AbstractArrowMetadata.java | 9 +++++++++ .../com/facebook/plugin/arrow/ArrowConnector.java | 2 +- .../plugin/arrow/ArrowFlightClientHandler.java | 9 ++++++++- .../plugin/arrow/TestingArrowMetadata.java | 14 +++++++------- .../plugin/arrow/TestingConnectionProperties.java | 1 - 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index 611fb38481a62..f086ce3c6578e 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -38,8 +38,10 @@ import com.facebook.presto.spi.ConnectorTableMetadata; import com.facebook.presto.spi.Constraint; import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.SchemaTableName; import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.connector.ConnectorMetadata; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -192,6 +194,13 @@ public Map getColumnHandles(ConnectorSession session, Conn @Override public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) { + if (!(table instanceof ArrowTableHandle)) { + throw new PrestoException( + StandardErrorCode.INVALID_CAST_ARGUMENT, + "Invalid table handle: Expected an instance of ArrowTableHandle but received " + + table.getClass().getSimpleName() + ". Please check the table type being used."); + } + ArrowTableHandle tableHandle = (ArrowTableHandle) table; List columns = new ArrayList<>(); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java index d6221625df615..1028af2414308 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -44,7 +44,7 @@ public ArrowConnector(ConnectorMetadata metadata, ArrowFlightClientHandler arrowFlightClientHandler) { this.metadata = requireNonNull(metadata, "Metadata is null"); - this.handleResolver = requireNonNull(handleResolver, "Metadata is null"); + this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.splitManager = requireNonNull(splitManager, "SplitManager is null"); this.pageSourceProvider = requireNonNull(pageSourceProvider, "PageSinkProvider is null"); this.arrowFlightClientHandler = requireNonNull(arrowFlightClientHandler, "arrow flight handler is null"); diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index 51704748d0677..5c3290588ceb4 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -60,7 +60,7 @@ private ArrowFlightClient initializeClient(Optional uri) } if (null == allocator) { - allocator = new RootAllocator(Long.MAX_VALUE); + initializeAllocator(); } FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); @@ -80,6 +80,13 @@ else if (config.getFlightServerSSLCertificate() != null) { } } + private synchronized void initializeAllocator() + { + if (allocator == null) { + allocator = new RootAllocator(Long.MAX_VALUE); + } + } + protected abstract CredentialCallOption getCallOptions(ConnectorSession connectorSession); public ArrowFlightConfig getConfig() diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java index a1cd0a4dffb76..3fad658bbbb58 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingArrowMetadata.java @@ -47,16 +47,16 @@ public class TestingArrowMetadata private static final Logger logger = Logger.get(TestingArrowMetadata.class); private static final ObjectMapper objectMapper = new ObjectMapper(); private final NodeManager nodeManager; - private final TestingArrowFlightConfig testconfig; + private final TestingArrowFlightConfig testConfig; private final ArrowFlightClientHandler clientHandler; private final ArrowFlightConfig config; @Inject - public TestingArrowMetadata(ArrowFlightConfig config, ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testconfig) + public TestingArrowMetadata(ArrowFlightClientHandler clientHandler, NodeManager nodeManager, TestingArrowFlightConfig testConfig, ArrowFlightConfig config) { super(config, clientHandler); this.nodeManager = nodeManager; - this.testconfig = testconfig; + this.testConfig = testConfig; this.clientHandler = clientHandler; this.config = config; } @@ -90,7 +90,7 @@ public List extractSchemaAndTableData(Optional schema, Connector { try (ArrowFlightClient client = clientHandler.getClient(Optional.empty())) { List names = new ArrayList<>(); - TestingArrowFlightRequest request = getArrowFlightRequest(config, schema.orElse(null)); + TestingArrowFlightRequest request = getArrowFlightRequest(schema.orElse(null)); ObjectNode rootNode = (ObjectNode) objectMapper.readTree(request.getCommand()); String modifiedQueryJson = objectMapper.writeValueAsString(rootNode); @@ -147,12 +147,12 @@ protected String getDataSourceSpecificTableName(ArrowFlightConfig config, String @Override protected FlightDescriptor getFlightDescriptor(Optional query, String schema, String table) { - TestingArrowFlightRequest request = new TestingArrowFlightRequest(this.config, testconfig, schema, table, query, nodeManager.getWorkerNodes().size()); + TestingArrowFlightRequest request = new TestingArrowFlightRequest(this.config, testConfig, schema, table, query, nodeManager.getWorkerNodes().size()); return FlightDescriptor.command(request.getCommand()); } - private TestingArrowFlightRequest getArrowFlightRequest(ArrowFlightConfig config, String schema) + private TestingArrowFlightRequest getArrowFlightRequest(String schema) { - return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testconfig); + return new TestingArrowFlightRequest(config, schema, nodeManager.getWorkerNodes().size(), testConfig); } } diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java index 0cccfece3866e..e158fdc67707e 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestingConnectionProperties.java @@ -20,7 +20,6 @@ public class TestingConnectionProperties { private final String database; private final String password; - private Integer port; private final String host; private final Boolean ssl; private final String username; From fa04dde6a59acbee7b296111d3a8036f4423fbac Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 14:51:05 +0530 Subject: [PATCH 31/39] Added more testcases --- .../plugin/arrow/ArrowOutputTableHandle.java | 147 ++++++++++++++++++ .../plugin/arrow/ArrowMetadataUtil.java | 77 +++++++++ .../plugin/arrow/ArrowPageUtilsTest.java | 3 - .../plugin/arrow/TestArrowTableHandle.java | 73 ++------- 4 files changed, 233 insertions(+), 67 deletions(-) create mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java create mode 100644 presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java new file mode 100644 index 0000000000000..128db301408f7 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorInsertTableHandle; +import com.facebook.presto.spi.ConnectorOutputTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ArrowOutputTableHandle + implements ConnectorOutputTableHandle, ConnectorInsertTableHandle +{ + private final String connectorId; + private final String catalogName; + private final String schemaName; + private final String tableName; + private final List columnNames; + private final List columnTypes; + private final String temporaryTableName; + + @JsonCreator + public ArrowOutputTableHandle( + @JsonProperty("connectorId") String connectorId, + @JsonProperty("catalogName") @Nullable String catalogName, + @JsonProperty("schemaName") @Nullable String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("columnNames") List columnNames, + @JsonProperty("columnTypes") List columnTypes, + @JsonProperty("temporaryTableName") String temporaryTableName) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.catalogName = catalogName; + this.schemaName = schemaName; + this.tableName = requireNonNull(tableName, "tableName is null"); + this.temporaryTableName = requireNonNull(temporaryTableName, "temporaryTableName is null"); + + requireNonNull(columnNames, "columnNames is null"); + requireNonNull(columnTypes, "columnTypes is null"); + checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes sizes don't match"); + this.columnNames = ImmutableList.copyOf(columnNames); + this.columnTypes = ImmutableList.copyOf(columnTypes); + } + + @JsonProperty + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + @Nullable + public String getCatalogName() + { + return catalogName; + } + + @JsonProperty + @Nullable + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public List getColumnNames() + { + return columnNames; + } + + @JsonProperty + public List getColumnTypes() + { + return columnTypes; + } + + @JsonProperty + public String getTemporaryTableName() + { + return temporaryTableName; + } + + @Override + public String toString() + { + return format("jdbc:%s.%s.%s", catalogName, schemaName, tableName); + } + + @Override + public int hashCode() + { + return Objects.hash( + connectorId, + catalogName, + schemaName, + tableName, + columnNames, + columnTypes, + temporaryTableName); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + ArrowOutputTableHandle other = (ArrowOutputTableHandle) obj; + return Objects.equals(this.connectorId, other.connectorId) && + Objects.equals(this.catalogName, other.catalogName) && + Objects.equals(this.schemaName, other.schemaName) && + Objects.equals(this.tableName, other.tableName) && + Objects.equals(this.columnNames, other.columnNames) && + Objects.equals(this.columnTypes, other.columnTypes) && + Objects.equals(this.temporaryTableName, other.temporaryTableName); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java new file mode 100644 index 0000000000000..1366c123d9c93 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; + +final class ArrowMetadataUtil +{ + private ArrowMetadataUtil() {} + + public static final JsonCodec COLUMN_CODEC; + public static final JsonCodec TABLE_CODEC; + public static final JsonCodec OUTPUT_TABLE_CODEC; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer())); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); + TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); + OUTPUT_TABLE_CODEC = codecFactory.jsonCodec(ArrowOutputTableHandle.class); + } + + public static final class TestingTypeDeserializer + extends FromStringDeserializer + { + private final Map types = ImmutableMap.of( + StandardTypes.BIGINT, BIGINT, + StandardTypes.VARCHAR, VARCHAR); + + public TestingTypeDeserializer() + { + super(Type.class); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + Type type = types.get(value.toLowerCase(ENGLISH)); + checkArgument(type != null, "Unknown type %s", value); + return type; + } + } + + public static void assertJsonRoundTrip(JsonCodec codec, T object) + { + String json = codec.toJson(object); + T copy = codec.fromJson(json); + assertEquals(copy, object); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 356f5a7ad3428..128205fa7ce48 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -280,9 +280,6 @@ public void testProcessDictionaryVector() Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); IntVector encodedVector = (IntVector) DictionaryEncoder.encode(rawVector, dictionary); - // Decode back to original - VarCharVector decodedVector = (VarCharVector) DictionaryEncoder.decode(encodedVector, dictionary); - // Process the dictionary vector Block result = ArrowPageUtils.buildBlockFromEncodedVector(encodedVector, dictionaryVector); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java index c6b66751c3438..2061fe5036534 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -13,80 +13,25 @@ */ package com.facebook.plugin.arrow; +import com.facebook.airlift.testing.EquivalenceTester; import org.testng.annotations.Test; -import static com.facebook.airlift.testing.Assertions.assertNotEquals; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNull; +import static com.facebook.plugin.arrow.ArrowMetadataUtil.TABLE_CODEC; +import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip; public class TestArrowTableHandle { @Test - public void testConstructorAndGetters() + public void testJsonRoundTrip() { - String schema = "test_schema"; - String table = "test_table"; - - // Create an instance of ArrowTableHandle - ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); - - // Verify that the schema and table are correctly set - assertEquals(tableHandle.getSchema(), schema, "Schema should match the input value."); - assertEquals(tableHandle.getTable(), table, "Table should match the input value."); - } - - @Test - public void testToString() - { - String schema = "test_schema"; - String table = "test_table"; - - // Create an instance of ArrowTableHandle - ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); - - // Verify the toString() output - String expectedToString = "test_schema:test_table"; - assertEquals(tableHandle.toString(), expectedToString, "toString() should return the correct string representation."); + assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table")); } @Test - public void testEqualityAndHashCode() + public void testEquivalence() { - String schema = "schema"; - String table = "table"; - - String schema2 = "schema2"; - String table2 = "table2"; - - // Create multiple instances - ArrowTableHandle handle1 = new ArrowTableHandle(schema, table); - ArrowTableHandle handle2 = new ArrowTableHandle(schema, table); - ArrowTableHandle handle3 = new ArrowTableHandle(schema2, table2); - - // Verify equality - assertEquals(handle1, handle2, "Handles with the same schema and table should be equal."); - assertNotEquals(handle1, handle3, "Handles with different schema or table should not be equal."); - - // Verify hashCode - assertEquals(handle1.hashCode(), handle2.hashCode(), "Equal handles should have the same hashCode."); - assertNotEquals(handle1.hashCode(), handle3.hashCode(), "Different handles should have different hashCodes."); - } - - @Test - public void testNullValues() - { - String schema = null; - String table = null; - - // Create an instance of ArrowTableHandle with null values - ArrowTableHandle tableHandle = new ArrowTableHandle(schema, table); - - // Verify that the schema and table are null - assertNull(tableHandle.getSchema(), "Schema should be null."); - assertNull(tableHandle.getTable(), "Table should be null."); - - // Verify the toString() output - String expectedToString = "null:null"; - assertEquals(tableHandle.toString(), expectedToString, "toString() should handle null values correctly."); + EquivalenceTester.equivalenceTester() + .addEquivalentGroup( + new ArrowTableHandle("tm_engine", "employees")).check(); } } From a32e96792b89972a54f8515b5be70e3fc16d497e Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 17:38:21 +0530 Subject: [PATCH 32/39] Fixed review comments and added support for other datatypes --- .../facebook/plugin/arrow/ArrowPageUtils.java | 122 +++++++++++++++--- 1 file changed, 104 insertions(+), 18 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 0dfa6c557c6d2..6ebd70e6138cc 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -24,12 +24,14 @@ import com.facebook.presto.common.type.Decimals; import com.facebook.presto.common.type.DoubleType; import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; import com.facebook.presto.common.type.RowType; import com.facebook.presto.common.type.SmallintType; import com.facebook.presto.common.type.TimeType; import com.facebook.presto.common.type.TimestampType; import com.facebook.presto.common.type.TinyintType; import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.common.type.VarcharType; import com.google.common.base.CharMatcher; import io.airlift.slice.Slice; @@ -53,10 +55,16 @@ import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.TimeStampSecVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListReader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.JsonStringArrayList; import java.math.BigDecimal; @@ -149,37 +157,114 @@ else if (vector instanceof ListVector) { throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } - public static Block buildBlockFromEncodedVector(IntVector encodedVector, VarCharVector dictionary) + public static Block buildBlockFromEncodedVector(FieldVector encodedVector, FieldVector dictionary) { - // Ensure the dictionary vector is valid - if (dictionary == null || encodedVector == null) { + // Validate inputs + if (encodedVector == null || dictionary == null) { throw new IllegalArgumentException("Both encodedVector and dictionary must be non-null."); } - // Create a BlockBuilder for VARCHAR - BlockBuilder builder = VarcharType.VARCHAR.createBlockBuilder(null, encodedVector.getValueCount()); + // Create a Dictionary object + Dictionary arrowDictionary = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); - // Iterate through the encoded vector and retrieve values from the dictionary - for (int i = 0; i < encodedVector.getValueCount(); i++) { - if (encodedVector.isNull(i)) { - builder.appendNull(); // Append null if the index is null + // Decode the encoded vector using the dictionary + ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, arrowDictionary); + + // Create a BlockBuilder for the decoded vector's data type + Type prestoType = getPrestoTypeFromArrowType(decodedVector.getField().getType()); + BlockBuilder builder = prestoType.createBlockBuilder(null, decodedVector.getValueCount()); + + // Populate the block dynamically based on vector type + for (int i = 0; i < decodedVector.getValueCount(); i++) { + if (decodedVector.isNull(i)) { + builder.appendNull(); // Append null for null values } else { - int dictionaryIndex = encodedVector.get(i); - if (dictionary.isNull(dictionaryIndex)) { - builder.appendNull(); // Append null if the dictionary value is null - } - else { - byte[] valueBytes = dictionary.get(dictionaryIndex); - String value = new String(valueBytes, StandardCharsets.UTF_8); - VarcharType.VARCHAR.writeSlice(builder, Slices.utf8Slice(value)); // Append the dictionary value - } + // Handle based on vector type + appendValueToBlock(decodedVector, i, prestoType, builder); } } return builder.build(); } + private static Type getPrestoTypeFromArrowType(ArrowType arrowType) + { + if (arrowType instanceof ArrowType.Utf8) { + return VarcharType.VARCHAR; + } + else if (arrowType instanceof ArrowType.Int) { + ArrowType.Int intType = (ArrowType.Int) arrowType; + if (intType.getBitWidth() == 8 || intType.getBitWidth() == 16 || intType.getBitWidth() == 32) { + return IntegerType.INTEGER; + } + else if (intType.getBitWidth() == 64) { + return BigintType.BIGINT; + } + } + else if (arrowType instanceof ArrowType.FloatingPoint) { + ArrowType.FloatingPoint fpType = (ArrowType.FloatingPoint) arrowType; + FloatingPointPrecision precision = fpType.getPrecision(); + + if (precision == FloatingPointPrecision.SINGLE) { // 32-bit float + return RealType.REAL; + } + else if (precision == FloatingPointPrecision.DOUBLE) { // 64-bit float + return DoubleType.DOUBLE; + } + else { + throw new UnsupportedOperationException("Unsupported FloatingPoint precision: " + precision); + } + } + else if (arrowType instanceof ArrowType.Bool) { + return BooleanType.BOOLEAN; + } + else if (arrowType instanceof ArrowType.Binary) { + return VarbinaryType.VARBINARY; + } + else if (arrowType instanceof ArrowType.Decimal) { + return DecimalType.createDecimalType(); + } + throw new UnsupportedOperationException("Unsupported ArrowType: " + arrowType); + } + + private static void appendValueToBlock(ValueVector vector, int index, Type prestoType, BlockBuilder builder) + { + if (vector instanceof VarCharVector) { + VarCharVector varCharVector = (VarCharVector) vector; + byte[] valueBytes = varCharVector.get(index); + prestoType.writeSlice(builder, Slices.utf8Slice(new String(valueBytes, StandardCharsets.UTF_8))); + } + else if (vector instanceof IntVector) { + IntVector intVector = (IntVector) vector; + prestoType.writeLong(builder, intVector.get(index)); + } + else if (vector instanceof BigIntVector) { + BigIntVector bigIntVector = (BigIntVector) vector; + prestoType.writeLong(builder, bigIntVector.get(index)); + } + else if (vector instanceof Float4Vector) { + Float4Vector floatVector = (Float4Vector) vector; + prestoType.writeLong(builder, Float.floatToRawIntBits(floatVector.get(index))); + } + else if (vector instanceof Float8Vector) { + Float8Vector doubleVector = (Float8Vector) vector; + prestoType.writeDouble(builder, doubleVector.get(index)); + } + else if (vector instanceof BitVector) { + BitVector bitVector = (BitVector) vector; + prestoType.writeBoolean(builder, bitVector.get(index) == 1); + } + else if (vector instanceof VarBinaryVector) { + VarBinaryVector binaryVector = (VarBinaryVector) vector; + byte[] valueBytes = binaryVector.get(index); + prestoType.writeSlice(builder, Slices.wrappedBuffer(valueBytes)); + } + else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + public static Block buildBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type) { if (!(type instanceof TimestampType)) { @@ -502,6 +587,7 @@ public static Block buildBlockFromTimeMicroVector(TimeMicroVector vector, Type t } return builder.build(); } + public static Block buildBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type) { if (!(type instanceof TimestampType)) { From d42cf784522b98cafb26bda853ca0827419401c0 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 18:07:48 +0530 Subject: [PATCH 33/39] Minor fixes on method argument --- .../java/com/facebook/plugin/arrow/ArrowPageUtils.java | 8 ++------ .../com/facebook/plugin/arrow/ArrowPageUtilsTest.java | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 6ebd70e6138cc..95cea5c91fd96 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -64,7 +64,6 @@ import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.util.JsonStringArrayList; import java.math.BigDecimal; @@ -157,18 +156,15 @@ else if (vector instanceof ListVector) { throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } - public static Block buildBlockFromEncodedVector(FieldVector encodedVector, FieldVector dictionary) + public static Block buildBlockFromEncodedVector(FieldVector encodedVector, Dictionary dictionary) { // Validate inputs if (encodedVector == null || dictionary == null) { throw new IllegalArgumentException("Both encodedVector and dictionary must be non-null."); } - // Create a Dictionary object - Dictionary arrowDictionary = new Dictionary(dictionary, new DictionaryEncoding(1L, false, null)); - // Decode the encoded vector using the dictionary - ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, arrowDictionary); + ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, dictionary); // Create a BlockBuilder for the decoded vector's data type Type prestoType = getPrestoTypeFromArrowType(decodedVector.getField().getType()); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 128205fa7ce48..44fab56e2a95a 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -281,7 +281,7 @@ public void testProcessDictionaryVector() IntVector encodedVector = (IntVector) DictionaryEncoder.encode(rawVector, dictionary); // Process the dictionary vector - Block result = ArrowPageUtils.buildBlockFromEncodedVector(encodedVector, dictionaryVector); + Block result = ArrowPageUtils.buildBlockFromEncodedVector(encodedVector, dictionary); // Verify the result assertNotNull(result, "The BlockBuilder should not be null."); From a27652a11ab661ec817465bb55de4a20d959616a Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 29 Nov 2024 20:17:37 +0530 Subject: [PATCH 34/39] Fixed review comments - cosmetic changes and unneccessary class removal --- presto-base-arrow-flight/pom.xml | 2 +- .../plugin/arrow/AbstractArrowMetadata.java | 2 +- .../arrow/ArrowFlightClientHandler.java | 4 +- .../plugin/arrow/ArrowOutputTableHandle.java | 147 ------------------ .../facebook/plugin/arrow/ArrowPageUtils.java | 63 ++++---- .../plugin/arrow/ArrowMetadataUtil.java | 2 - .../plugin/arrow/ArrowPageUtilsTest.java | 52 +++---- 7 files changed, 63 insertions(+), 209 deletions(-) delete mode 100644 presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java diff --git a/presto-base-arrow-flight/pom.xml b/presto-base-arrow-flight/pom.xml index 43f50779cf111..b2d1ff04a964e 100644 --- a/presto-base-arrow-flight/pom.xml +++ b/presto-base-arrow-flight/pom.xml @@ -301,7 +301,7 @@ org.apache.arrow arrow-algorithm - 18.0.0 + ${arrow.version} compile diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java index f086ce3c6578e..c991d3c4c1f5d 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/AbstractArrowMetadata.java @@ -198,7 +198,7 @@ public List getTableLayouts(ConnectorSession session throw new PrestoException( StandardErrorCode.INVALID_CAST_ARGUMENT, "Invalid table handle: Expected an instance of ArrowTableHandle but received " - + table.getClass().getSimpleName() + ". Please check the table type being used."); + + table.getClass().getSimpleName() + ""); } ArrowTableHandle tableHandle = (ArrowTableHandle) table; diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java index 5c3290588ceb4..a85edf42358bb 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightClientHandler.java @@ -120,6 +120,8 @@ public Optional getSchema(FlightDescriptor flightDescriptor, ConnectorSe public void closeRootallocator() { - allocator.close(); + if (null != allocator) { + allocator.close(); + } } } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java deleted file mode 100644 index 128db301408f7..0000000000000 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowOutputTableHandle.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.plugin.arrow; - -import com.facebook.presto.common.type.Type; -import com.facebook.presto.spi.ConnectorInsertTableHandle; -import com.facebook.presto.spi.ConnectorOutputTableHandle; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.collect.ImmutableList; - -import javax.annotation.Nullable; - -import java.util.List; -import java.util.Objects; - -import static com.google.common.base.Preconditions.checkArgument; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -public class ArrowOutputTableHandle - implements ConnectorOutputTableHandle, ConnectorInsertTableHandle -{ - private final String connectorId; - private final String catalogName; - private final String schemaName; - private final String tableName; - private final List columnNames; - private final List columnTypes; - private final String temporaryTableName; - - @JsonCreator - public ArrowOutputTableHandle( - @JsonProperty("connectorId") String connectorId, - @JsonProperty("catalogName") @Nullable String catalogName, - @JsonProperty("schemaName") @Nullable String schemaName, - @JsonProperty("tableName") String tableName, - @JsonProperty("columnNames") List columnNames, - @JsonProperty("columnTypes") List columnTypes, - @JsonProperty("temporaryTableName") String temporaryTableName) - { - this.connectorId = requireNonNull(connectorId, "connectorId is null"); - this.catalogName = catalogName; - this.schemaName = schemaName; - this.tableName = requireNonNull(tableName, "tableName is null"); - this.temporaryTableName = requireNonNull(temporaryTableName, "temporaryTableName is null"); - - requireNonNull(columnNames, "columnNames is null"); - requireNonNull(columnTypes, "columnTypes is null"); - checkArgument(columnNames.size() == columnTypes.size(), "columnNames and columnTypes sizes don't match"); - this.columnNames = ImmutableList.copyOf(columnNames); - this.columnTypes = ImmutableList.copyOf(columnTypes); - } - - @JsonProperty - public String getConnectorId() - { - return connectorId; - } - - @JsonProperty - @Nullable - public String getCatalogName() - { - return catalogName; - } - - @JsonProperty - @Nullable - public String getSchemaName() - { - return schemaName; - } - - @JsonProperty - public String getTableName() - { - return tableName; - } - - @JsonProperty - public List getColumnNames() - { - return columnNames; - } - - @JsonProperty - public List getColumnTypes() - { - return columnTypes; - } - - @JsonProperty - public String getTemporaryTableName() - { - return temporaryTableName; - } - - @Override - public String toString() - { - return format("jdbc:%s.%s.%s", catalogName, schemaName, tableName); - } - - @Override - public int hashCode() - { - return Objects.hash( - connectorId, - catalogName, - schemaName, - tableName, - columnNames, - columnTypes, - temporaryTableName); - } - - @Override - public boolean equals(Object obj) - { - if (this == obj) { - return true; - } - if (obj == null || getClass() != obj.getClass()) { - return false; - } - ArrowOutputTableHandle other = (ArrowOutputTableHandle) obj; - return Objects.equals(this.connectorId, other.connectorId) && - Objects.equals(this.catalogName, other.catalogName) && - Objects.equals(this.schemaName, other.schemaName) && - Objects.equals(this.tableName, other.tableName) && - Objects.equals(this.columnNames, other.columnNames) && - Objects.equals(this.columnTypes, other.columnTypes) && - Objects.equals(this.temporaryTableName, other.temporaryTableName); - } -} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 95cea5c91fd96..ca78ed11b5447 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -73,6 +73,8 @@ import java.util.List; import java.util.concurrent.TimeUnit; +import static java.util.Objects.requireNonNull; + public class ArrowPageUtils { private ArrowPageUtils() @@ -159,9 +161,8 @@ else if (vector instanceof ListVector) { public static Block buildBlockFromEncodedVector(FieldVector encodedVector, Dictionary dictionary) { // Validate inputs - if (encodedVector == null || dictionary == null) { - throw new IllegalArgumentException("Both encodedVector and dictionary must be non-null."); - } + requireNonNull(encodedVector, "encoded vector is null"); + requireNonNull(dictionary, "dictionary is null"); // Decode the encoded vector using the dictionary ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, dictionary); @@ -176,7 +177,7 @@ public static Block buildBlockFromEncodedVector(FieldVector encodedVector, Dicti builder.appendNull(); // Append null for null values } else { - // Handle based on vector type + // write based on vector type appendValueToBlock(decodedVector, i, prestoType, builder); } } @@ -678,53 +679,53 @@ public static void appendValueToBuilder(Type type, BlockBuilder builder, Object } if (type instanceof VarcharType) { - handleVarcharType(type, builder, value); + writeVarcharType(type, builder, value); } else if (type instanceof SmallintType) { - handleSmallintType(type, builder, value); + writeSmallintType(type, builder, value); } else if (type instanceof TinyintType) { - handleTinyintType(type, builder, value); + writeTinyintType(type, builder, value); } else if (type instanceof BigintType) { - handleBigintType(type, builder, value); + writeBigintType(type, builder, value); } else if (type instanceof IntegerType) { - handleIntegerType(type, builder, value); + writeIntegerType(type, builder, value); } else if (type instanceof DoubleType) { - handleDoubleType(type, builder, value); + writeDoubleType(type, builder, value); } else if (type instanceof BooleanType) { - handleBooleanType(type, builder, value); + writeBooleanType(type, builder, value); } else if (type instanceof DecimalType) { - handleDecimalType((DecimalType) type, builder, value); + writeDecimalType((DecimalType) type, builder, value); } else if (type instanceof ArrayType) { - handleArrayType((ArrayType) type, builder, value); + writeArrayType((ArrayType) type, builder, value); } else if (type instanceof RowType) { - handleRowType((RowType) type, builder, value); + writeRowType((RowType) type, builder, value); } else if (type instanceof DateType) { - handleDateType(type, builder, value); + writeDateType(type, builder, value); } else if (type instanceof TimestampType) { - handleTimestampType(type, builder, value); + writeTimestampType(type, builder, value); } else { throw new IllegalArgumentException("Unsupported type: " + type); } } - public static void handleVarcharType(Type type, BlockBuilder builder, Object value) + public static void writeVarcharType(Type type, BlockBuilder builder, Object value) { Slice slice = Slices.utf8Slice(value.toString()); type.writeSlice(builder, slice); } - public static void handleSmallintType(Type type, BlockBuilder builder, Object value) + public static void writeSmallintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).shortValue()); @@ -745,7 +746,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void handleTinyintType(Type type, BlockBuilder builder, Object value) + public static void writeTinyintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Number) { type.writeLong(builder, ((Number) value).byteValue()); @@ -766,7 +767,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void handleBigintType(Type type, BlockBuilder builder, Object value) + public static void writeBigintType(Type type, BlockBuilder builder, Object value) { if (value instanceof Long) { type.writeLong(builder, (Long) value); @@ -790,7 +791,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void handleIntegerType(Type type, BlockBuilder builder, Object value) + public static void writeIntegerType(Type type, BlockBuilder builder, Object value) { if (value instanceof Integer) { type.writeLong(builder, (Integer) value); @@ -811,7 +812,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void handleDoubleType(Type type, BlockBuilder builder, Object value) + public static void writeDoubleType(Type type, BlockBuilder builder, Object value) { if (value instanceof Double) { type.writeDouble(builder, (Double) value); @@ -835,7 +836,7 @@ else if (value instanceof JsonStringArrayList) { } } - public static void handleBooleanType(Type type, BlockBuilder builder, Object value) + public static void writeBooleanType(Type type, BlockBuilder builder, Object value) { if (value instanceof Boolean) { type.writeBoolean(builder, (Boolean) value); @@ -845,17 +846,17 @@ public static void handleBooleanType(Type type, BlockBuilder builder, Object val } } - public static void handleDecimalType(DecimalType type, BlockBuilder builder, Object value) + public static void writeDecimalType(DecimalType type, BlockBuilder builder, Object value) { if (value instanceof BigDecimal) { BigDecimal decimalValue = (BigDecimal) value; if (type.isShort()) { - // Handle ShortDecimalType + // write ShortDecimalType long unscaledValue = decimalValue.unscaledValue().longValue(); type.writeLong(builder, unscaledValue); } else { - // Handle LongDecimalType + // write LongDecimalType Slice slice = Decimals.encodeScaledValue(decimalValue); type.writeSlice(builder, slice); } @@ -874,7 +875,7 @@ else if (value instanceof Long) { } } - public static void handleArrayType(ArrayType type, BlockBuilder builder, Object value) + public static void writeArrayType(ArrayType type, BlockBuilder builder, Object value) { Type elementType = type.getElementType(); BlockBuilder arrayBuilder = builder.beginBlockEntry(); @@ -884,7 +885,7 @@ public static void handleArrayType(ArrayType type, BlockBuilder builder, Object builder.closeEntry(); } - public static void handleRowType(RowType type, BlockBuilder builder, Object value) + public static void writeRowType(RowType type, BlockBuilder builder, Object value) { List rowValues = (List) value; BlockBuilder rowBuilder = builder.beginBlockEntry(); @@ -896,7 +897,7 @@ public static void handleRowType(RowType type, BlockBuilder builder, Object valu builder.closeEntry(); } - public static void handleDateType(Type type, BlockBuilder builder, Object value) + public static void writeDateType(Type type, BlockBuilder builder, Object value) { if (value instanceof java.sql.Date || value instanceof java.time.LocalDate) { int daysSinceEpoch = (int) (value instanceof java.sql.Date @@ -909,7 +910,7 @@ public static void handleDateType(Type type, BlockBuilder builder, Object value) } } - public static void handleTimestampType(Type type, BlockBuilder builder, Object value) + public static void writeTimestampType(Type type, BlockBuilder builder, Object value) { if (value instanceof java.sql.Timestamp) { long millis = ((java.sql.Timestamp) value).getTime(); @@ -919,7 +920,7 @@ else if (value instanceof java.time.Instant) { long millis = ((java.time.Instant) value).toEpochMilli(); type.writeLong(builder, millis); } - else if (value instanceof Long) { // Handle long epoch milliseconds directly + else if (value instanceof Long) { // write long epoch milliseconds directly type.writeLong(builder, (Long) value); } else { diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java index 1366c123d9c93..c4ab656d41bb3 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java @@ -36,7 +36,6 @@ private ArrowMetadataUtil() {} public static final JsonCodec COLUMN_CODEC; public static final JsonCodec TABLE_CODEC; - public static final JsonCodec OUTPUT_TABLE_CODEC; static { JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); @@ -44,7 +43,6 @@ private ArrowMetadataUtil() {} JsonCodecFactory codecFactory = new JsonCodecFactory(provider); COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); - OUTPUT_TABLE_CODEC = codecFactory.jsonCodec(ArrowOutputTableHandle.class); } public static final class TestingTypeDeserializer diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 44fab56e2a95a..1523f64b475bc 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -289,13 +289,13 @@ public void testProcessDictionaryVector() } @Test - public void testHandleVarcharType() + public void testWriteVarcharType() { Type varcharType = VarcharType.createUnboundedVarcharType(); BlockBuilder builder = varcharType.createBlockBuilder(null, 1); String value = "test_string"; - ArrowPageUtils.handleVarcharType(varcharType, builder, value); + ArrowPageUtils.writeVarcharType(varcharType, builder, value); Block block = builder.build(); Slice result = varcharType.getSlice(block, 0); @@ -303,13 +303,13 @@ public void testHandleVarcharType() } @Test - public void testHandleSmallintType() + public void testWriteSmallintType() { Type smallintType = SmallintType.SMALLINT; BlockBuilder builder = smallintType.createBlockBuilder(null, 1); short value = 42; - ArrowPageUtils.handleSmallintType(smallintType, builder, value); + ArrowPageUtils.writeSmallintType(smallintType, builder, value); Block block = builder.build(); long result = smallintType.getLong(block, 0); @@ -317,13 +317,13 @@ public void testHandleSmallintType() } @Test - public void testHandleTinyintType() + public void testWriteTinyintType() { Type tinyintType = TinyintType.TINYINT; BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); byte value = 7; - ArrowPageUtils.handleTinyintType(tinyintType, builder, value); + ArrowPageUtils.writeTinyintType(tinyintType, builder, value); Block block = builder.build(); long result = tinyintType.getLong(block, 0); @@ -331,13 +331,13 @@ public void testHandleTinyintType() } @Test - public void testHandleBigintType() + public void testWriteBigintType() { Type bigintType = BigintType.BIGINT; BlockBuilder builder = bigintType.createBlockBuilder(null, 1); long value = 123456789L; - ArrowPageUtils.handleBigintType(bigintType, builder, value); + ArrowPageUtils.writeBigintType(bigintType, builder, value); Block block = builder.build(); long result = bigintType.getLong(block, 0); @@ -345,13 +345,13 @@ public void testHandleBigintType() } @Test - public void testHandleIntegerType() + public void testWriteIntegerType() { Type integerType = IntegerType.INTEGER; BlockBuilder builder = integerType.createBlockBuilder(null, 1); int value = 42; - ArrowPageUtils.handleIntegerType(integerType, builder, value); + ArrowPageUtils.writeIntegerType(integerType, builder, value); Block block = builder.build(); long result = integerType.getLong(block, 0); @@ -359,13 +359,13 @@ public void testHandleIntegerType() } @Test - public void testHandleDoubleType() + public void testWriteDoubleType() { Type doubleType = DoubleType.DOUBLE; BlockBuilder builder = doubleType.createBlockBuilder(null, 1); double value = 42.42; - ArrowPageUtils.handleDoubleType(doubleType, builder, value); + ArrowPageUtils.writeDoubleType(doubleType, builder, value); Block block = builder.build(); double result = doubleType.getDouble(block, 0); @@ -373,13 +373,13 @@ public void testHandleDoubleType() } @Test - public void testHandleBooleanType() + public void testWriteBooleanType() { Type booleanType = BooleanType.BOOLEAN; BlockBuilder builder = booleanType.createBlockBuilder(null, 1); boolean value = true; - ArrowPageUtils.handleBooleanType(booleanType, builder, value); + ArrowPageUtils.writeBooleanType(booleanType, builder, value); Block block = builder.build(); boolean result = booleanType.getBoolean(block, 0); @@ -387,14 +387,14 @@ public void testHandleBooleanType() } @Test - public void testHandleArrayType() + public void testWriteArrayType() { Type elementType = IntegerType.INTEGER; ArrayType arrayType = new ArrayType(elementType); BlockBuilder builder = arrayType.createBlockBuilder(null, 1); List values = Arrays.asList(1, 2, 3); - ArrowPageUtils.handleArrayType(arrayType, builder, values); + ArrowPageUtils.writeArrayType(arrayType, builder, values); Block block = builder.build(); Block arrayBlock = arrayType.getObject(block, 0); @@ -405,7 +405,7 @@ public void testHandleArrayType() } @Test - public void testHandleRowType() + public void testWriteRowType() { RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); @@ -413,7 +413,7 @@ public void testHandleRowType() BlockBuilder builder = rowType.createBlockBuilder(null, 1); List rowValues = Arrays.asList(42, "test"); - ArrowPageUtils.handleRowType(rowType, builder, rowValues); + ArrowPageUtils.writeRowType(rowType, builder, rowValues); Block block = builder.build(); Block rowBlock = rowType.getObject(block, 0); @@ -422,13 +422,13 @@ public void testHandleRowType() } @Test - public void testHandleDateType() + public void testWriteDateType() { Type dateType = DateType.DATE; BlockBuilder builder = dateType.createBlockBuilder(null, 1); LocalDate value = LocalDate.of(2020, 1, 1); - ArrowPageUtils.handleDateType(dateType, builder, value); + ArrowPageUtils.writeDateType(dateType, builder, value); Block block = builder.build(); long result = dateType.getLong(block, 0); @@ -436,13 +436,13 @@ public void testHandleDateType() } @Test - public void testHandleTimestampType() + public void testWriteTimestampType() { Type timestampType = TimestampType.TIMESTAMP; BlockBuilder builder = timestampType.createBlockBuilder(null, 1); long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC - ArrowPageUtils.handleTimestampType(timestampType, builder, value); + ArrowPageUtils.writeTimestampType(timestampType, builder, value); Block block = builder.build(); long result = timestampType.getLong(block, 0); @@ -450,14 +450,14 @@ public void testHandleTimestampType() } @Test - public void testHandleTimestampTypeWithSqlTimestamp() + public void testWriteTimestampTypeWithSqlTimestamp() { Type timestampType = TimestampType.TIMESTAMP; BlockBuilder builder = timestampType.createBlockBuilder(null, 1); java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); long expectedMillis = timestamp.getTime(); - ArrowPageUtils.handleTimestampType(timestampType, builder, timestamp); + ArrowPageUtils.writeTimestampType(timestampType, builder, timestamp); Block block = builder.build(); long result = timestampType.getLong(block, 0); @@ -471,7 +471,7 @@ public void testShortDecimalRetrieval() BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); BigDecimal decimalValue = new BigDecimal("12345.67"); - ArrowPageUtils.handleDecimalType(shortDecimalType, builder, decimalValue); + ArrowPageUtils.writeDecimalType(shortDecimalType, builder, decimalValue); Block block = builder.build(); long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 @@ -486,7 +486,7 @@ public void testLongDecimalRetrieval() DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); - ArrowPageUtils.handleDecimalType(longDecimalType, builder, decimalValue); + ArrowPageUtils.writeDecimalType(longDecimalType, builder, decimalValue); // Build the block after inserting the decimal value Block block = builder.build(); Slice unscaledSlice = longDecimalType.getSlice(block, 0); From 4fb97d7fa0cd3ef30effb0ef994338eac9b4ad2f Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 6 Dec 2024 13:21:50 +0530 Subject: [PATCH 35/39] Arrow - DictionaryEncoding usecase --- .../plugin/arrow/ArrowPageSource.java | 9 +- .../facebook/plugin/arrow/ArrowPageUtils.java | 83 ++++++++++--------- .../plugin/arrow/ArrowPageUtilsTest.java | 54 +++++++++++- 3 files changed, 107 insertions(+), 39 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java index 13ca93f4599a6..ec51125a34848 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -24,6 +24,7 @@ import org.apache.arrow.flight.Ticket; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; import java.util.ArrayList; import java.util.List; @@ -121,7 +122,13 @@ public Page getNextPage() FieldVector vector = vectorSchemaRoot.get().getVector(columnIndex); Type type = columnHandles.get(columnIndex).getColumnType(); - Block block = ArrowPageUtils.buildBlockFromVector(vector, type); + boolean isDictionaryBlock = vector.getField().getDictionary() != null; + Dictionary dictionary = null; + if (isDictionaryBlock) { + dictionary = flightStream.getDictionaryProvider().lookup(vector.getField().getDictionary().getId()); + } + Block block = null != dictionary ? ArrowPageUtils.buildBlockFromVector(vector, type, dictionary.getVector(), isDictionaryBlock) : + ArrowPageUtils.buildBlockFromVector(vector, type, null, false); blocks.add(block); } diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index ca78ed11b5447..93286f38db789 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; @@ -60,8 +61,6 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListReader; -import org.apache.arrow.vector.dictionary.Dictionary; -import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.JsonStringArrayList; @@ -81,9 +80,12 @@ private ArrowPageUtils() { } - public static Block buildBlockFromVector(FieldVector vector, Type type) + public static Block buildBlockFromVector(FieldVector vector, Type type, FieldVector dictionary, boolean isDictionaryVector) { - if (vector instanceof BitVector) { + if (isDictionaryVector) { + return buildBlockFromDictionaryVector(vector, dictionary); + } + else if (vector instanceof BitVector) { return buildBlockFromBitVector((BitVector) vector, type); } else if (vector instanceof TinyIntVector) { @@ -158,31 +160,32 @@ else if (vector instanceof ListVector) { throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getSimpleName()); } - public static Block buildBlockFromEncodedVector(FieldVector encodedVector, Dictionary dictionary) + public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, FieldVector dictionaryVector) { // Validate inputs - requireNonNull(encodedVector, "encoded vector is null"); - requireNonNull(dictionary, "dictionary is null"); - - // Decode the encoded vector using the dictionary - ValueVector decodedVector = DictionaryEncoder.decode(encodedVector, dictionary); + requireNonNull(fieldVector, "encoded vector is null"); + requireNonNull(dictionaryVector, "dictionary vector is null"); // Create a BlockBuilder for the decoded vector's data type - Type prestoType = getPrestoTypeFromArrowType(decodedVector.getField().getType()); - BlockBuilder builder = prestoType.createBlockBuilder(null, decodedVector.getValueCount()); + Type prestoType = getPrestoTypeFromArrowType(dictionaryVector.getField().getType()); + Block dictionaryblock = null; // Populate the block dynamically based on vector type - for (int i = 0; i < decodedVector.getValueCount(); i++) { - if (decodedVector.isNull(i)) { - builder.appendNull(); // Append null for null values - } - else { - // write based on vector type - appendValueToBlock(decodedVector, i, prestoType, builder); + for (int i = 0; i < dictionaryVector.getValueCount(); i++) { + if (!dictionaryVector.isNull(i)) { + dictionaryblock = appendValueToBlock(dictionaryVector, prestoType); } } - return builder.build(); + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + + // Create the Presto DictionaryBlock + return new DictionaryBlock(ids.length, dictionaryblock, ids); } private static Type getPrestoTypeFromArrowType(ArrowType arrowType) @@ -225,37 +228,43 @@ else if (arrowType instanceof ArrowType.Decimal) { throw new UnsupportedOperationException("Unsupported ArrowType: " + arrowType); } - private static void appendValueToBlock(ValueVector vector, int index, Type prestoType, BlockBuilder builder) + private static Block appendValueToBlock(ValueVector vector, Type prestoType) { if (vector instanceof VarCharVector) { - VarCharVector varCharVector = (VarCharVector) vector; - byte[] valueBytes = varCharVector.get(index); - prestoType.writeSlice(builder, Slices.utf8Slice(new String(valueBytes, StandardCharsets.UTF_8))); + return buildBlockFromVarCharVector((VarCharVector) vector, prestoType); } else if (vector instanceof IntVector) { - IntVector intVector = (IntVector) vector; - prestoType.writeLong(builder, intVector.get(index)); + return buildBlockFromIntVector((IntVector) vector, prestoType); } else if (vector instanceof BigIntVector) { - BigIntVector bigIntVector = (BigIntVector) vector; - prestoType.writeLong(builder, bigIntVector.get(index)); + return buildBlockFromBigIntVector((BigIntVector) vector, prestoType); } else if (vector instanceof Float4Vector) { - Float4Vector floatVector = (Float4Vector) vector; - prestoType.writeLong(builder, Float.floatToRawIntBits(floatVector.get(index))); + return buildBlockFromFloat4Vector((Float4Vector) vector, prestoType); } else if (vector instanceof Float8Vector) { - Float8Vector doubleVector = (Float8Vector) vector; - prestoType.writeDouble(builder, doubleVector.get(index)); + return buildBlockFromFloat8Vector((Float8Vector) vector, prestoType); } else if (vector instanceof BitVector) { - BitVector bitVector = (BitVector) vector; - prestoType.writeBoolean(builder, bitVector.get(index) == 1); + return buildBlockFromBitVector((BitVector) vector, prestoType); } else if (vector instanceof VarBinaryVector) { - VarBinaryVector binaryVector = (VarBinaryVector) vector; - byte[] valueBytes = binaryVector.get(index); - prestoType.writeSlice(builder, Slices.wrappedBuffer(valueBytes)); + return buildBlockFromVarBinaryVector((VarBinaryVector) vector, prestoType); + } + else if (vector instanceof DecimalVector) { + return buildBlockFromDecimalVector((DecimalVector) vector, prestoType); + } + else if (vector instanceof TinyIntVector) { + return buildBlockFromTinyIntVector((TinyIntVector) vector, prestoType); + } + else if (vector instanceof SmallIntVector) { + return buildBlockFromSmallIntVector((SmallIntVector) vector, prestoType); + } + else if (vector instanceof DateDayVector) { + return buildBlockFromDateDayVector((DateDayVector) vector, prestoType); + } + else if (vector instanceof TimeStampMilliTZVector) { + return buildBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, prestoType); } else { throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 1523f64b475bc..26bb1dcfd8c02 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; import com.facebook.presto.common.type.ArrayType; import com.facebook.presto.common.type.BigintType; import com.facebook.presto.common.type.BooleanType; @@ -56,6 +57,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.plugin.arrow.ArrowPageUtils.buildBlockFromDictionaryVector; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; @@ -281,13 +283,63 @@ public void testProcessDictionaryVector() IntVector encodedVector = (IntVector) DictionaryEncoder.encode(rawVector, dictionary); // Process the dictionary vector - Block result = ArrowPageUtils.buildBlockFromEncodedVector(encodedVector, dictionary); + Block result = buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); // Verify the result assertNotNull(result, "The BlockBuilder should not be null."); assertEquals(result.getPositionCount(), 50); } + @Test + public void testBuildBlockFromDictionaryVector() + { + IntVector indicesVector = new IntVector("indices", allocator); + indicesVector.allocateNew(3); // allocating space for 3 values + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Set up index values (this would reference the dictionary) + indicesVector.set(0, 0); // First index points to "apple" + indicesVector.set(1, 1); // Second index points to "banana" + indicesVector.set(2, 2); // Third index points to "cherry" + indicesVector.setValueCount(3); + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + @Test public void testWriteVarcharType() { From 0138aa77904b5c2bda2a994908333e7de044455e Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 6 Dec 2024 13:38:33 +0530 Subject: [PATCH 36/39] Removed description which looks like generated --- presto-docs/src/main/sphinx/connector/base-arrow-flight.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst index eb9a62798e9ef..9f5635b9af550 100644 --- a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -3,7 +3,6 @@ Arrow Flight Connector ====================== This connector allows querying multiple data sources that are supported by an Arrow Flight server. -Apache Arrow enhances performance and efficiency in data-intensive applications through its columnar memory layout, zero-copy reads, vectorized execution, cross-language interoperability, rich data type support, and optimization for modern hardware. These features collectively reduce overhead, improve data processing speeds, and facilitate seamless data exchange between different systems and languages. Getting Started with base-arrow-module: Essential Abstract Methods for Developers --------------------------------------------------------------------------------- From b13eeb3b9eb45db31a1b7e3253019bb6877d18c4 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Fri, 6 Dec 2024 18:24:53 +0530 Subject: [PATCH 37/39] Added support for dictionary encoding --- .../facebook/plugin/arrow/ArrowPageUtils.java | 52 +++++- .../plugin/arrow/ArrowPageUtilsTest.java | 156 +++++++++++++++++- 2 files changed, 200 insertions(+), 8 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 93286f38db789..3bef784fe1a1a 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -177,15 +177,53 @@ public static Block buildBlockFromDictionaryVector(FieldVector fieldVector, Fiel } } - // Get the Arrow indices vector - IntVector indicesVector = (IntVector) fieldVector; - int[] ids = new int[indicesVector.getValueCount()]; - for (int i = 0; i < indicesVector.getValueCount(); i++) { - ids[i] = indicesVector.get(i); - } + return getDictionaryBlock(fieldVector, dictionaryblock); // Create the Presto DictionaryBlock - return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + + private static DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + { + if (fieldVector instanceof IntVector) { + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof BigIntVector) { + // Get the BigInt indices vector + BigIntVector bigIntIndicesVector = (BigIntVector) fieldVector; + int[] ids = new int[bigIntIndicesVector.getValueCount()]; + for (int i = 0; i < bigIntIndicesVector.getValueCount(); i++) { + ids[i] = (int) bigIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof SmallIntVector) { + // Get the SmallInt indices vector + SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; + int[] ids = new int[smallIntIndicesVector.getValueCount()]; + for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) { + ids[i] = smallIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof TinyIntVector) { + // Get the TinyInt indices vector + TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector; + int[] ids = new int[tinyIntIndicesVector.getValueCount()]; + for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) { + ids[i] = tinyIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else { + // Handle the case where the FieldVector is of an unsupported type + throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass()); + } } private static Type getPrestoTypeFromArrowType(ArrowType arrowType) diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 26bb1dcfd8c02..a4db38e906f23 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -310,8 +310,162 @@ public void testBuildBlockFromDictionaryVector() // Set up index values (this would reference the dictionary) indicesVector.set(0, 0); // First index points to "apple" indicesVector.set(1, 1); // Second index points to "banana" - indicesVector.set(2, 2); // Third index points to "cherry" + indicesVector.set(2, 2); + indicesVector.set(3, 2); // Third index points to "cherry" + indicesVector.setValueCount(4); + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + else if (i == 3) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorBigInt() + { + BigIntVector indicesVector = new BigIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, 0L); + indicesVector.set(1, 1L); + indicesVector.set(2, 2L); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorSmallInt() + { + SmallIntVector indicesVector = new SmallIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (short) 0); + indicesVector.set(1, (short) 1); + indicesVector.set(2, (short) 2); + indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + // Call the method under test + Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorTinyInt() + { + TinyIntVector indicesVector = new TinyIntVector("indices", allocator); + + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (byte) 0); + indicesVector.set(1, (byte) 1); + indicesVector.set(2, (byte) 2); indicesVector.setValueCount(3); + + // Initialize a dummy dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + // Call the method under test Block block = buildBlockFromDictionaryVector(indicesVector, dictionaryVector); From e2463e7c3346fd6e200da956ad5bcd1b4c214582 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 9 Dec 2024 10:53:21 +0530 Subject: [PATCH 38/39] Fixed review comments --- .../java/com/facebook/plugin/arrow/ArrowPageUtils.java | 9 --------- .../com/facebook/plugin/arrow/ArrowPageUtilsTest.java | 3 ++- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java index 3bef784fe1a1a..04b20d00c24a8 100644 --- a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageUtils.java @@ -193,15 +193,6 @@ private static DictionaryBlock getDictionaryBlock(FieldVector fieldVector, Block } return new DictionaryBlock(ids.length, dictionaryblock, ids); } - else if (fieldVector instanceof BigIntVector) { - // Get the BigInt indices vector - BigIntVector bigIntIndicesVector = (BigIntVector) fieldVector; - int[] ids = new int[bigIntIndicesVector.getValueCount()]; - for (int i = 0; i < bigIntIndicesVector.getValueCount(); i++) { - ids[i] = (int) bigIntIndicesVector.get(i); - } - return new DictionaryBlock(ids.length, dictionaryblock, ids); - } else if (fieldVector instanceof SmallIntVector) { // Get the SmallInt indices vector SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index a4db38e906f23..998f1deb0a5e5 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -33,6 +33,7 @@ import io.airlift.slice.Slice; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BaseIntVector; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.DecimalVector; @@ -280,7 +281,7 @@ public void testProcessDictionaryVector() // Encode using dictionary Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); - IntVector encodedVector = (IntVector) DictionaryEncoder.encode(rawVector, dictionary); + BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); // Process the dictionary vector Block result = buildBlockFromDictionaryVector(encodedVector, dictionary.getVector()); From d391906127480309f6d51d81bebe178f5800e001 Mon Sep 17 00:00:00 2001 From: lithinpurushothaman Date: Mon, 9 Dec 2024 13:58:15 +0530 Subject: [PATCH 39/39] Fixed review comments -added index type --- .../java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java index 998f1deb0a5e5..432b82b69fb2d 100644 --- a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowPageUtilsTest.java @@ -46,6 +46,7 @@ import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -280,7 +281,8 @@ public void testProcessDictionaryVector() rawVector.setValueCount(VECTOR_LENGTH); // Encode using dictionary - Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + ArrowType.Int index = new ArrowType.Int(16,true); + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, index)); BaseIntVector encodedVector = (BaseIntVector) DictionaryEncoder.encode(rawVector, dictionary); // Process the dictionary vector