From 6a7bfefb88c5eab1d1c13b6f3f79a89a17776a70 Mon Sep 17 00:00:00 2001 From: sfc-gh-mknister <66097847+sfc-gh-mknister@users.noreply.github.com> Date: Tue, 14 Mar 2023 08:40:34 -0700 Subject: [PATCH] Add arrow tests (#1291) * SNOW-748294 --- .../client/jdbc/ArrowResultChunk.java | 2 +- .../SnowflakeResultSetSerializableV1.java | 2 +- .../client/core/SFArrowResultSetIT.java | 434 ++++++++++++++++-- .../client/core/SFTrustManagerTest.java | 2 +- 4 files changed, 407 insertions(+), 33 deletions(-) diff --git a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java index aba0d37d2..979f824a2 100644 --- a/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java +++ b/src/main/java/net/snowflake/client/jdbc/ArrowResultChunk.java @@ -443,7 +443,7 @@ public int getCurrentRowInRecordBatch() { * merge arrow result chunk with more than one batches into one record batch (Only used for the * first chunk when client side sorting is required) */ - private void mergeBatchesIntoOne() throws SnowflakeSQLException { + public void mergeBatchesIntoOne() throws SnowflakeSQLException { try { List first = batchOfVectors.get(0); for (int i = 1; i < batchOfVectors.size(); i++) { diff --git a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java index d444becb2..b21f893a0 100644 --- a/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java +++ b/src/main/java/net/snowflake/client/jdbc/SnowflakeResultSetSerializableV1.java @@ -251,7 +251,7 @@ public void setChunkFileCount(int chunkFileCount) { this.chunkFileCount = chunkFileCount; } - public void setFristChunkStringData(String firstChunkStringData) { + public void setFirstChunkStringData(String firstChunkStringData) { this.firstChunkStringData = firstChunkStringData; } diff --git a/src/test/java/net/snowflake/client/core/SFArrowResultSetIT.java b/src/test/java/net/snowflake/client/core/SFArrowResultSetIT.java index d3f93b4c4..9aae0bcf6 100644 --- a/src/test/java/net/snowflake/client/core/SFArrowResultSetIT.java +++ b/src/test/java/net/snowflake/client/core/SFArrowResultSetIT.java @@ -3,41 +3,37 @@ */ package net.snowflake.client.core; +import static net.snowflake.client.AbstractDriverIT.getConnection; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; +import java.time.Instant; +import java.util.*; import net.snowflake.client.category.TestCategoryArrow; -import net.snowflake.client.jdbc.ArrowResultChunk; -import net.snowflake.client.jdbc.ErrorCode; -import net.snowflake.client.jdbc.SnowflakeResultChunk; -import net.snowflake.client.jdbc.SnowflakeResultSetSerializableV1; -import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.*; import net.snowflake.client.jdbc.telemetry.NoOpTelemetryClient; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.ipc.ArrowWriter; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -48,7 +44,7 @@ public class SFArrowResultSetIT { private Random random = new Random(); /** allocator for arrow */ - private BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + protected BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); /** temporary folder to store result files */ @Rule public TemporaryFolder resultFolder = new TemporaryFolder(); @@ -75,7 +71,7 @@ public void testNoOfflineData() throws Throwable { SnowflakeResultSetSerializableV1 resultSetSerializable = new SnowflakeResultSetSerializableV1(); resultSetSerializable.setRootAllocator(new RootAllocator(Long.MAX_VALUE)); - resultSetSerializable.setFristChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); + resultSetSerializable.setFirstChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); resultSetSerializable.setFirstChunkByteData(dataBytes); resultSetSerializable.setChunkFileCount(0); @@ -96,7 +92,7 @@ public void testNoOfflineData() throws Throwable { @Test public void testEmptyResultSet() throws Throwable { SnowflakeResultSetSerializableV1 resultSetSerializable = new SnowflakeResultSetSerializableV1(); - resultSetSerializable.setFristChunkStringData( + resultSetSerializable.setFirstChunkStringData( Base64.getEncoder().encodeToString("".getBytes(StandardCharsets.UTF_8))); resultSetSerializable.setChunkFileCount(0); @@ -106,7 +102,7 @@ public void testEmptyResultSet() throws Throwable { assertThat(resultSet.isLast(), is(false)); assertThat(resultSet.isAfterLast(), is(true)); - resultSetSerializable.setFristChunkStringData(null); + resultSetSerializable.setFirstChunkStringData(null); resultSet = new SFArrowResultSet(resultSetSerializable, new NoOpTelemetryClient(), false); assertThat(resultSet.next(), is(false)); @@ -198,7 +194,7 @@ public void testFirstResponseAndOfflineData() throws Throwable { is.read(dataBytes, 0, dataSize); SnowflakeResultSetSerializableV1 resultSetSerializable = new SnowflakeResultSetSerializableV1(); - resultSetSerializable.setFristChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); + resultSetSerializable.setFirstChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); resultSetSerializable.setFirstChunkByteData(dataBytes); resultSetSerializable.setChunkFileCount(chunkCount); resultSetSerializable.setRootAllocator(new RootAllocator(Long.MAX_VALUE)); @@ -265,21 +261,77 @@ public DownloaderMetrics terminate() { } } - private Object[][] generateData(Schema schema, int rowCount) { + Object[][] generateData(Schema schema, int rowCount) { Object[][] data = new Object[schema.getFields().size()][rowCount]; for (int i = 0; i < schema.getFields().size(); i++) { Types.MinorType type = Types.getMinorTypeForArrowType(schema.getFields().get(i).getType()); switch (type) { + case BIT: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = random.nextBoolean(); + } + break; + } case INT: { for (int j = 0; j < rowCount; j++) { - data[i][j] = random.nextInt(); + data[i][j] = 0; + } + break; + } + case DATEDAY: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = Date.from(Instant.now()); + } + break; + } + case BIGINT: + case DECIMAL: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = 154639183700000l; + } + break; + } + case FLOAT8: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = random.nextDouble(); + } + break; + } + case TINYINT: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = (byte) random.nextInt(1 << 8); + } + break; + } + case SMALLINT: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = (short) random.nextInt(1 << 16); + } + break; + } + case VARBINARY: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = RandomStringUtils.random(20).getBytes(); + } + break; + } + case VARCHAR: + { + for (int j = 0; j < rowCount; j++) { + data[i][j] = RandomStringUtils.random(20); } break; } - // add other data types as needed later } } @@ -287,8 +339,8 @@ private Object[][] generateData(Schema schema, int rowCount) { return data; } - private File createArrowFile( - String fileName, Schema schema, Object[][] data, int rowsPerRecordBatch) throws IOException { + File createArrowFile(String fileName, Schema schema, Object[][] data, int rowsPerRecordBatch) + throws IOException { File file = resultFolder.newFile(fileName); VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); @@ -305,9 +357,40 @@ private File createArrowFile( FieldVector vector = root.getFieldVectors().get(j); switch (vector.getMinorType()) { + case BIT: + writeBitToField(vector, data[j], i, rowsToAppend); + break; case INT: writeIntToField(vector, data[j], i, rowsToAppend); break; + case TINYINT: + writeTinyIntToField(vector, data[j], i, rowsToAppend); + break; + case SMALLINT: + writeSmallIntToField(vector, data[j], i, rowsToAppend); + break; + case DATEDAY: + writeDateToField(vector, data[j], i, rowsToAppend); + break; + case BIGINT: + writeLongToField(vector, data[j], i, rowsToAppend); + break; + case FLOAT8: + writeDoubleToField(vector, data[j], i, rowsToAppend); + break; + case VARBINARY: + writeBytesToField(vector, data[j], i, rowsToAppend); + break; + case VARCHAR: + writeTextToField(vector, data[j], i, rowsToAppend); + break; + case DECIMAL: + writeDecimalToField(vector, data[j], i, rowsToAppend); + break; + case STRUCT: + writeTimestampStructToField(vector, data[j], data[j + 1], i, rowsToAppend); + j++; + break; } } @@ -319,15 +402,306 @@ private File createArrowFile( return file; } + private void writeLongToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + BigIntVector vector = (BigIntVector) fieldVector; + + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 0; i < rowsToAppend; i++) { + vector.setSafe(i, 1, (long) data[startIndex + i]); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeBitToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + BitVector vector = (BitVector) fieldVector; + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + int val = (Boolean) data[startIndex + i] == true ? 1 : 0; + vector.setSafe(i, 1, val); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeDateToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + DateDayVector datedayVector = (DateDayVector) fieldVector; + datedayVector.setInitialCapacity(rowsToAppend); + datedayVector.allocateNew(); + datedayVector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + datedayVector.setSafe(i, 1, (int) ((Date) data[startIndex + i]).getTime() / 1000); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeDecimalToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + DecimalVector datedayVector = (DecimalVector) fieldVector; + datedayVector.setInitialCapacity(rowsToAppend); + datedayVector.allocateNew(); + datedayVector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + datedayVector.setSafe(i, (long) data[startIndex + i]); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeDoubleToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + Float8Vector vector = (Float8Vector) fieldVector; + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + vector.setSafe(i, 1, (double) data[startIndex + i]); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + private void writeIntToField( FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { IntVector intVector = (IntVector) fieldVector; intVector.setInitialCapacity(rowsToAppend); intVector.allocateNew(); - for (int i = 0; i < rowsToAppend; i++) { + intVector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { intVector.setSafe(i, 1, (int) data[startIndex + i]); } + fieldVector.setValueCount(rowsToAppend); + } + + private void writeSmallIntToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + SmallIntVector intVector = (SmallIntVector) fieldVector; + intVector.setInitialCapacity(rowsToAppend); + intVector.allocateNew(); + intVector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + intVector.setSafe(i, 1, (short) data[startIndex + i]); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeTinyIntToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + TinyIntVector vector = (TinyIntVector) fieldVector; + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + vector.setSafe(i, 1, (byte) data[startIndex + i]); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeBytesToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + VarBinaryVector vector = (VarBinaryVector) fieldVector; + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + vector.setSafe(i, (byte[]) data[startIndex + i], 0, ((byte[]) data[startIndex + i]).length); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeTextToField( + FieldVector fieldVector, Object[] data, int startIndex, int rowsToAppend) { + VarCharVector intVector = (VarCharVector) fieldVector; + intVector.setInitialCapacity(rowsToAppend); + intVector.allocateNew(); + intVector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + intVector.setSafe(i, new Text((String) data[startIndex + i])); + } + // how many are set + fieldVector.setValueCount(rowsToAppend); + } + + private void writeTimestampStructToField( + FieldVector fieldVector, Object[] data, Object[] data2, int startIndex, int rowsToAppend) { + StructVector vector = (StructVector) fieldVector; + vector.setInitialCapacity(rowsToAppend); + vector.allocateNew(); + vector.setNull(0); + for (int i = 1; i < rowsToAppend; i++) { + List childVectors = vector.getChildrenFromFields(); + BigIntVector v1 = (BigIntVector) childVectors.get(0); + v1.setSafe(i, 1, (long) data[startIndex + i]); + + IntVector v2 = (IntVector) childVectors.get(1); + v2.setSafe(i, 1, (int) data2[startIndex + i]); + } // how many are set fieldVector.setValueCount(rowsToAppend); } + + /** Test that first chunk containing struct vectors (used for timestamps) can be sorted */ + @Test + public void testSortedResultChunkWithStructVectors() throws Throwable { + Connection con = getConnection(); + Statement statement = con.createStatement(); + statement.execute("create or replace table teststructtimestamp (t1 timestamp_ltz)"); + ResultSet rs = statement.executeQuery("select * from teststructtimestamp"); + List resultSetSerializables = + ((SnowflakeResultSet) rs).getResultSetSerializables(100 * 1024 * 1024); + SnowflakeResultSetSerializableV1 resultSetSerializable = + (SnowflakeResultSetSerializableV1) resultSetSerializables.get(0); + + Map customFieldMeta = new HashMap<>(); + customFieldMeta.put("logicalType", "TIMESTAMP_LTZ"); + customFieldMeta.put("scale", "38"); + // test normal date + FieldType fieldType = + new FieldType(true, Types.MinorType.BIGINT.getType(), null, customFieldMeta); + FieldType fieldType2 = + new FieldType(true, Types.MinorType.INT.getType(), null, customFieldMeta); + + StructVector structVector = StructVector.empty("testListVector", allocator); + List fieldList = new LinkedList(); + Field bigIntField = new Field("epoch", fieldType, null); + + Field intField = new Field("fraction", fieldType2, null); + + fieldList.add(bigIntField); + fieldList.add(intField); + + FieldType structFieldType = + new FieldType(true, Types.MinorType.STRUCT.getType(), null, customFieldMeta); + Field structField = new Field("timestamp", structFieldType, fieldList); + + structVector.initializeChildrenFromFields(fieldList); + + List fieldListMajor = new LinkedList(); + fieldListMajor.add(structField); + Schema dataSchema = new Schema(fieldList); + Object[][] data = generateData(dataSchema, 1000); + + Schema schema = new Schema(fieldListMajor); + + File file = createArrowFile("testTimestamp", schema, data, 10); + + int dataSize = (int) file.length(); + byte[] dataBytes = new byte[dataSize]; + + InputStream is = new FileInputStream(file); + is.read(dataBytes, 0, dataSize); + + resultSetSerializable.setRootAllocator(new RootAllocator(Long.MAX_VALUE)); + resultSetSerializable.setFirstChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); + resultSetSerializable.setFirstChunkByteData(dataBytes); + resultSetSerializable.setChunkFileCount(0); + + SFArrowResultSet resultSet = + new SFArrowResultSet(resultSetSerializable, new NoOpTelemetryClient(), true); + + for (int i = 0; i < 1000; i++) { + resultSet.next(); + } + // We inserted a null row at the beginning so when sorted, the last row should be null + assertEquals(null, resultSet.getObject(1)); + assertFalse(resultSet.next()); + statement.execute("drop table teststructtimestamp;"); + con.close(); + } + + /** Test that the first chunk can be sorted */ + @Test + public void testSortedResultChunk() throws Throwable { + Connection con = getConnection(); + Statement statement = con.createStatement(); + statement.execute( + "create or replace table alltypes (i1 int, d1 date, b1 bigint, f1 float, s1 smallint, t1 tinyint, b2 binary, t2 text, b3 boolean, d2 decimal)"); + ResultSet rs = statement.executeQuery("select * from alltypes"); + List resultSetSerializables = + ((SnowflakeResultSet) rs).getResultSetSerializables(100 * 1024 * 1024); + SnowflakeResultSetSerializableV1 resultSetSerializable = + (SnowflakeResultSetSerializableV1) resultSetSerializables.get(0); + + List fieldList = new ArrayList<>(); + Map customFieldMeta = new HashMap<>(); + customFieldMeta.put("logicalType", "FIXED"); + customFieldMeta.put("scale", "0"); + FieldType type = new FieldType(false, Types.MinorType.INT.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "DATE"); + type = new FieldType(false, Types.MinorType.DATEDAY.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "FIXED"); + type = new FieldType(false, Types.MinorType.BIGINT.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "REAL"); + type = new FieldType(false, Types.MinorType.FLOAT8.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "FIXED"); + type = new FieldType(false, Types.MinorType.SMALLINT.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "FIXED"); + type = new FieldType(false, Types.MinorType.TINYINT.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "BINARY"); + type = new FieldType(false, Types.MinorType.VARBINARY.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "TEXT"); + type = new FieldType(false, Types.MinorType.VARCHAR.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "BOOLEAN"); + type = new FieldType(false, Types.MinorType.BIT.getType(), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + customFieldMeta.put("logicalType", "REAL"); + type = new FieldType(false, new ArrowType.Decimal(38, 16, 128), null, customFieldMeta); + fieldList.add(new Field("", type, null)); + + Schema schema = new Schema(fieldList); + + Object[][] data = generateData(schema, 1000); + File file = createArrowFile("testVectorTypes", schema, data, 10); + + int dataSize = (int) file.length(); + byte[] dataBytes = new byte[dataSize]; + + InputStream is = new FileInputStream(file); + is.read(dataBytes, 0, dataSize); + + resultSetSerializable.setRootAllocator(new RootAllocator(Long.MAX_VALUE)); + resultSetSerializable.setFirstChunkStringData(Base64.getEncoder().encodeToString(dataBytes)); + resultSetSerializable.setFirstChunkByteData(dataBytes); + resultSetSerializable.setChunkFileCount(0); + + SFArrowResultSet resultSet = + new SFArrowResultSet(resultSetSerializable, new NoOpTelemetryClient(), true); + + for (int i = 0; i < 1000; i++) { + resultSet.next(); + } + // We inserted a null row at the beginning so when sorted, the last row should be null + assertEquals(null, resultSet.getObject(1)); + assertFalse(resultSet.next()); + statement.execute("drop table alltypes;"); + con.close(); + } } diff --git a/src/test/java/net/snowflake/client/core/SFTrustManagerTest.java b/src/test/java/net/snowflake/client/core/SFTrustManagerTest.java index fa61f35c8..b29668557 100644 --- a/src/test/java/net/snowflake/client/core/SFTrustManagerTest.java +++ b/src/test/java/net/snowflake/client/core/SFTrustManagerTest.java @@ -119,7 +119,7 @@ public void testSnowflakeResultSetSerializable_getResultSet() throws Exception { // Create an empty result set serializable object SnowflakeResultSetSerializableV1 resultSetSerializable = new SnowflakeResultSetSerializableV1(); - resultSetSerializable.setFristChunkStringData( + resultSetSerializable.setFirstChunkStringData( Base64.getEncoder().encodeToString("".getBytes(StandardCharsets.UTF_8))); resultSetSerializable.setChunkFileCount(0); resultSetSerializable.getParameters().put(CLIENT_MEMORY_LIMIT, 10);