diff --git a/build.sbt b/build.sbt index ba01c7791aa..160f212972b 100644 --- a/build.sbt +++ b/build.sbt @@ -223,7 +223,7 @@ lazy val kernelApi = (project in file("kernel/kernel-api")) "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5" % "test", "org.scalatest" %% "scalatest" % scalaTestVersion % "test", - "junit" % "junit" % "4.11" % "test", + "junit" % "junit" % "4.13" % "test", "com.novocode" % "junit-interface" % "0.11" % "test" ), @@ -255,7 +255,7 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) "org.apache.parquet" % "parquet-hadoop" % "1.12.3", "org.scalatest" %% "scalatest" % scalaTestVersion % "test", - "junit" % "junit" % "4.11" % "test", + "junit" % "junit" % "4.13" % "test", "commons-io" % "commons-io" % "2.8.0" % "test", "com.novocode" % "junit-interface" % "0.11" % "test" ), diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ArrayValue.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ArrayValue.java new file mode 100644 index 00000000000..62ce0dd11f1 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ArrayValue.java @@ -0,0 +1,32 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.data; + +/** + * Abstraction to represent a single array value in a {@link ColumnVector}. + */ +public interface ArrayValue { + /** + * The number of elements in the array + */ + int getSize(); + + /** + * A {@link ColumnVector} containing the array elements with exactly + * {@link ArrayValue#getSize()} elements. + */ + ColumnVector getElements(); +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java index 9adf6d3c959..5d4693071f6 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java @@ -17,8 +17,6 @@ package io.delta.kernel.data; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; import io.delta.kernel.annotation.Evolving; import io.delta.kernel.types.DataType; @@ -162,15 +160,10 @@ default BigDecimal getDecimal(int rowId) { } /** - * Return the map type value located at {@code rowId}. The return value is undefined and can be - * anything, if the slot for {@code rowId} is null. - * - * @param rowId - * @param Return map key type - * @param Return map value type - * @return + * Return the map value located at {@code rowId}. Returns null if the slot for {@code rowId} + * is null */ - default Map getMap(int rowId) { + default MapValue getMap(int rowId) { throw new UnsupportedOperationException("Invalid value request for data type"); } @@ -186,14 +179,10 @@ default Row getStruct(int rowId) { } /** - * Return the array value located at {@code rowId}. The return value is undefined and can be - * anything, if the slot for {@code rowId} is null. - * - * @param rowId - * @param Array element type - * @return + * Return the array value located at {@code rowId}. Returns null if the slot for {@code rowId} + * is null */ - default List getArray(int rowId) { + default ArrayValue getArray(int rowId) { throw new UnsupportedOperationException("Invalid value request for data type"); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/MapValue.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/MapValue.java new file mode 100644 index 00000000000..39897e93814 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/MapValue.java @@ -0,0 +1,40 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.data; + +/** + * Abstraction to represent a single map value in a {@link ColumnVector}. + */ +public interface MapValue { + /** + * The number of elements in the map + */ + int getSize(); + + /** + * A {@link ColumnVector} containing the keys. There are exactly {@link MapValue#getSize()} keys + * in the vector, and each key maps one-to-one to the value at the same index in + * {@link MapValue#getValues()}. + */ + ColumnVector getKeys(); + + /** + * A {@link ColumnVector} containing the values. There are exactly {@link MapValue#getSize()} + * values in the vector, and maps one-to-one to the keys in {@link MapValue#getKeys()} + */ + ColumnVector getValues(); + +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java index 3229542190f..adcacbc0f4c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java @@ -17,8 +17,6 @@ package io.delta.kernel.data; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; import io.delta.kernel.annotation.Evolving; import io.delta.kernel.types.StructType; @@ -112,11 +110,11 @@ public interface Row { * Return array value of the column located at the given ordinal. * Throws error if the column at given ordinal is not of array type, */ - List getArray(int ordinal); + ArrayValue getArray(int ordinal); /** * Return map value of the column located at the given ordinal. * Throws error if the column at given ordinal is not of map type, */ - Map getMap(int ordinal); + MapValue getMap(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java index 5797e599724..8f4388936ab 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/ScalarExpression.java @@ -33,7 +33,8 @@ *
  • Name: element_at *
      *
    • Semantic: element_at(map, key). Return the value of given key - * from the map type input. Ex: `element_at(map(1, 'a', 2, 'b'), 2)` returns 'b'
    • + * from the map type input. Returns null if the given key is not in + * the map Ex: `element_at(map(1, 'a', 2, 'b'), 2)` returns 'b' *
    • Since version: 3.0.0
    • *
    *
  • diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java index b7ef2dca092..7a3db5bf41d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/InternalScanFileUtils.java @@ -28,6 +28,7 @@ import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; +import io.delta.kernel.utils.VectorUtils; import io.delta.kernel.internal.actions.AddFile; import io.delta.kernel.internal.actions.DeletionVectorDescriptor; @@ -111,7 +112,7 @@ public static FileStatus getAddFileStatus(Row scanFileInfo) { */ public static Map getPartitionValues(Row scanFileInfo) { Row addFile = getAddFileEntry(scanFileInfo); - return addFile.getMap(ADD_FILE_PARTITION_VALUES_ORDINAL); + return VectorUtils.toJavaMap(addFile.getMap(ADD_FILE_PARTITION_VALUES_ORDINAL)); } /** diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java index 8f55d8be31b..64d1cff6133 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanBuilderImpl.java @@ -28,6 +28,7 @@ import io.delta.kernel.types.TimestampType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Tuple2; +import io.delta.kernel.utils.VectorUtils; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; @@ -86,7 +87,8 @@ public Scan build() { // TODO: support timestamp type partition columns // Timestamp partition columns have complicated semantics related to timezones so block this // for now - List partitionCols = protocolAndMetadata.get()._2.getPartitionColumns(); + List partitionCols = VectorUtils.toJavaList( + protocolAndMetadata.get()._2.getPartitionColumns()); for (String colName : partitionCols) { if (readSchema.indexOf(colName) >= 0 && readSchema.get(colName).getDataType() instanceof TimestampType) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 254ea971aec..db95c514679 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -99,7 +99,7 @@ public Row getScanState(TableClient tableClient) { readSchema, snapshotSchema, protocolAndMetadata.get()._2.getConfiguration() - .getOrDefault("delta.columnMapping.mode", "none") + .getOrDefault("delta.columnMapping.mode", "none") ) ), dataPath.toUri().toString()); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Metadata.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Metadata.java index ff7291a1034..e1227804ba8 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Metadata.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Metadata.java @@ -16,20 +16,23 @@ package io.delta.kernel.internal.actions; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Optional; import static java.util.Objects.requireNonNull; import io.delta.kernel.client.TableClient; +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.ArrayType; import io.delta.kernel.types.LongType; import io.delta.kernel.types.MapType; import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructType; +import io.delta.kernel.utils.VectorUtils; import static io.delta.kernel.utils.Utils.requireNonNull; +import io.delta.kernel.internal.lang.Lazy; import io.delta.kernel.internal.types.TableSchemaSerDe; public class Metadata { @@ -66,7 +69,7 @@ public static Metadata fromRow(Row row, TableClient tableClient) { .add("createdTime", LongType.INSTANCE, true /* contains null */) .add("configuration", new MapType(StringType.INSTANCE, StringType.INSTANCE, false), - false /* contains null */); + false /* nullable */); private final String id; private final Optional name; @@ -74,9 +77,10 @@ public static Metadata fromRow(Row row, TableClient tableClient) { private final Format format; private final String schemaString; private final StructType schema; - private final List partitionColumns; + private final ArrayValue partitionColumns; private final Optional createdTime; - private final Map configuration; + private final MapValue configurationMapValue; + private final Lazy> configuration; public Metadata( String id, @@ -85,19 +89,19 @@ public Metadata( Format format, String schemaString, StructType schema, - List partitionColumns, + ArrayValue partitionColumns, Optional createdTime, - Map configuration) { + MapValue configurationMapValue) { this.id = requireNonNull(id, "id is null"); this.name = name; this.description = requireNonNull(description, "description is null"); this.format = requireNonNull(format, "format is null"); this.schemaString = requireNonNull(schemaString, "schemaString is null"); this.schema = schema; - this.partitionColumns = - partitionColumns == null ? Collections.emptyList() : partitionColumns; + this.partitionColumns = requireNonNull(partitionColumns, "partitionColumns is null"); this.createdTime = createdTime; - this.configuration = configuration == null ? Collections.emptyMap() : configuration; + this.configurationMapValue = requireNonNull(configurationMapValue, "configuration is null"); + this.configuration = new Lazy<>(() -> VectorUtils.toJavaMap(configurationMapValue)); } public String getSchemaString() { @@ -108,7 +112,7 @@ public StructType getSchema() { return schema; } - public List getPartitionColumns() { + public ArrayValue getPartitionColumns() { return partitionColumns; } @@ -132,7 +136,11 @@ public Optional getCreatedTime() { return createdTime; } + public MapValue getConfigurationMapValue() { + return configurationMapValue; + } + public Map getConfiguration() { - return configuration; + return Collections.unmodifiableMap(configuration.get()); } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java index 0d7f546ef10..efd289831d8 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/Protocol.java @@ -23,6 +23,7 @@ import io.delta.kernel.types.IntegerType; import io.delta.kernel.types.StringType; import io.delta.kernel.types.StructType; +import io.delta.kernel.utils.VectorUtils; public class Protocol { public static Protocol fromRow(Row row) { @@ -32,8 +33,10 @@ public static Protocol fromRow(Row row) { return new Protocol( row.getInt(0), row.getInt(1), - row.isNullAt(2) ? Collections.emptyList() : row.getArray(2), - row.isNullAt(3) ? Collections.emptyList() : row.getArray(3)); + row.isNullAt(2) ? Collections.emptyList() : + VectorUtils.toJavaList(row.getArray(2)), + row.isNullAt(3) ? Collections.emptyList() : + VectorUtils.toJavaList(row.getArray(3))); } public static final StructType READ_SCHEMA = new StructType() diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ColumnarBatchRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ColumnarBatchRow.java index 3e17bd47eeb..ffe7cc507aa 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ColumnarBatchRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ColumnarBatchRow.java @@ -16,13 +16,9 @@ package io.delta.kernel.internal.data; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; import java.util.Objects; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.Row; +import io.delta.kernel.data.*; import io.delta.kernel.types.StructType; /** @@ -104,12 +100,12 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { + public ArrayValue getArray(int ordinal) { return columnVector(ordinal).getArray(rowId); } @Override - public Map getMap(int ordinal) { + public MapValue getMap(int ordinal) { return columnVector(ordinal).getMap(rowId); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index bfd5ba1215d..01a12bb84de 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -17,10 +17,11 @@ package io.delta.kernel.internal.data; import java.math.BigDecimal; -import java.util.List; import java.util.Map; import static java.util.Objects.requireNonNull; +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; @@ -31,6 +32,13 @@ public class GenericRow implements Row { private final StructType schema; private final Map ordinalToValue; + + /** + * @param schema the schema of the row + * @param ordinalToValue a mapping of column ordinal to objects; for each column the object + * must be of the return type corresponding to the data type's getter + * method in the Row interface + */ public GenericRow(StructType schema, Map ordinalToValue) { this.schema = requireNonNull(schema, "schema is null"); this.ordinalToValue = requireNonNull(ordinalToValue, "ordinalToValue is null"); @@ -113,17 +121,17 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { + public ArrayValue getArray(int ordinal) { // TODO: not sufficient check, also need to check the element type throwIfUnsafeAccess(ordinal, ArrayType.class, "array"); - return (List) getValue(ordinal); + return (ArrayValue) getValue(ordinal); } @Override - public Map getMap(int ordinal) { + public MapValue getMap(int ordinal) { // TODO: not sufficient check, also need to check the element types throwIfUnsafeAccess(ordinal, MapType.class, "map"); - return (Map) getValue(ordinal); + return (MapValue) getValue(ordinal); } private Object getValue(int ordinal) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java index cad67f6f6d3..0c350d06edf 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java @@ -25,6 +25,7 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; +import io.delta.kernel.utils.VectorUtils; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; @@ -55,7 +56,7 @@ public static ScanStateRow of( String readSchemaPhysicalJson, String tablePath) { HashMap valueMap = new HashMap<>(); - valueMap.put(COL_NAME_TO_ORDINAL.get("configuration"), metadata.getConfiguration()); + valueMap.put(COL_NAME_TO_ORDINAL.get("configuration"), metadata.getConfigurationMapValue()); valueMap.put(COL_NAME_TO_ORDINAL.get("logicalSchemaString"), readSchemaLogicalJson); valueMap.put(COL_NAME_TO_ORDINAL.get("physicalSchemaString"), readSchemaPhysicalJson); valueMap.put(COL_NAME_TO_ORDINAL.get("partitionColumns"), metadata.getPartitionColumns()); @@ -105,7 +106,8 @@ public static StructType getPhysicalSchema(TableClient tableClient, Row scanStat * @return List of partition column names according to the scan state. */ public static List getPartitionColumns(Row scanState) { - return scanState.getArray(COL_NAME_TO_ORDINAL.get("partitionColumns")); + return VectorUtils.toJavaList( + scanState.getArray(COL_NAME_TO_ORDINAL.get("partitionColumns"))); } /** @@ -113,10 +115,9 @@ public static List getPartitionColumns(Row scanState) { * {@link Scan#getScanState(TableClient)}. */ public static String getColumnMappingMode(Row scanState) { - Map configuration = - scanState.getMap(COL_NAME_TO_ORDINAL.get("configuration")); - String cmMode = configuration.get("delta.columnMapping.mode"); - return cmMode == null ? "none" : cmMode; + Map configuration = VectorUtils.toJavaMap( + scanState.getMap(COL_NAME_TO_ORDINAL.get("configuration"))); + return configuration.getOrDefault("delta.columnMapping.mode", "none"); } /** diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java index f4f43afa6ad..b50ef646535 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/replay/LogReplay.java @@ -226,7 +226,8 @@ private void validateSupportedTable(Protocol protocol, Metadata metadata) { private void verifySupportedColumnMappingMode(Metadata metadata) { // Check if the mode is name. Id mode is not yet supported - String cmMode = metadata.getConfiguration().get("delta.columnMapping.mode"); + String cmMode = metadata.getConfiguration() + .getOrDefault("delta.columnMapping.mode", "none"); if (!"none".equalsIgnoreCase(cmMode) && !"name".equalsIgnoreCase(cmMode)) { throw new UnsupportedOperationException( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/TableSchemaSerDe.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/TableSchemaSerDe.java index b119235be5f..d583b9c5f98 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/TableSchemaSerDe.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/types/TableSchemaSerDe.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.internal.types; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -39,6 +40,7 @@ import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Utils; +import io.delta.kernel.utils.VectorUtils; /** * Utility class to serialize and deserialize the table schema which is of type {@link StructType}. @@ -77,7 +79,7 @@ public static StructType fromJson(JsonHandler jsonHandler, String serializedStru private static StructType parseStructType(JsonHandler jsonHandler, String serializedStructType) { Function evalMethod = (row) -> { - final List fields = row.getArray(0); + final List fields = VectorUtils.toJavaList(row.getArray(0)); return new StructType( fields.stream() .map(field -> parseStructField(jsonHandler, field)) @@ -95,8 +97,8 @@ private static StructField parseStructField(JsonHandler jsonHandler, Row row) { String serializedDataType = row.getString(1); DataType type = parseDataType(jsonHandler, serializedDataType); boolean nullable = row.getBoolean(2); - Map metadata = row.getMap(3); - + Map metadata = row.isNullAt(3) ? Collections.emptyMap() : + VectorUtils.toJavaMap(row.getMap(3)); return new StructField(name, type, nullable, metadata); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/VectorUtils.java new file mode 100644 index 00000000000..8a1d812d9b2 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/VectorUtils.java @@ -0,0 +1,113 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.utils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.types.*; + +public final class VectorUtils { + + private VectorUtils() {} + + /** + * Converts an {@link ArrayValue} to a Java list. Any nested complex types are also converted + * to their Java type. + */ + public static List toJavaList(ArrayValue arrayValue) { + final ColumnVector elementVector = arrayValue.getElements(); + final DataType dataType = elementVector.getDataType(); + + List elements = new ArrayList<>(); + for (int i = 0; i < arrayValue.getSize(); i++) { + elements.add((T) getValueAsObject(elementVector, dataType, i)); + } + return elements; + } + + /** + * Converts a {@link MapValue} to a Java map. Any nested complex types are also converted + * to their Java type. + * + * Please note not all key types override hashCode/equals. Be careful when using with keys of: + * - Struct type at any nesting level (i.e. ArrayType(StructType) does not) + * - Binary type + */ + public static Map toJavaMap(MapValue mapValue) { + final ColumnVector keyVector = mapValue.getKeys(); + final DataType keyDataType = keyVector.getDataType(); + final ColumnVector valueVector = mapValue.getValues(); + final DataType valueDataType = valueVector.getDataType(); + + Map values = new HashMap<>(); + + for (int i = 0; i < mapValue.getSize(); i++) { + Object key = getValueAsObject(keyVector, keyDataType, i); + Object value = getValueAsObject(valueVector, valueDataType, i); + values.put((K) key, (V) value); + } + return values; + } + + /** + * Gets the value at {@code rowId} from the column vector. The type of the Object returned + * depends on the data type of the column vector. For complex types array and map, returns + * the value as Java list or Java map. + */ + private static Object getValueAsObject( + ColumnVector columnVector, DataType dataType, int rowId) { + if (columnVector.isNullAt(rowId)) { + return null; + } else if (dataType instanceof BooleanType) { + return columnVector.getBoolean(rowId); + } else if (dataType instanceof ByteType) { + return columnVector.getByte(rowId); + } else if (dataType instanceof ShortType) { + return columnVector.getShort(rowId); + } else if (dataType instanceof IntegerType || dataType instanceof DateType) { + // DateType data is stored internally as the number of days since 1970-01-01 + return columnVector.getInt(rowId); + } else if (dataType instanceof LongType || dataType instanceof TimestampType) { + // TimestampType data is stored internally as the number of microseconds since the unix + // epoch + return columnVector.getLong(rowId); + } else if (dataType instanceof FloatType) { + return columnVector.getFloat(rowId); + } else if (dataType instanceof DoubleType) { + return columnVector.getDouble(rowId); + } else if (dataType instanceof StringType) { + return columnVector.getString(rowId); + } else if (dataType instanceof BinaryType) { + return columnVector.getBinary(rowId); + } else if (dataType instanceof StructType) { + return columnVector.getStruct(rowId); + } else if (dataType instanceof DecimalType) { + return columnVector.getDecimal(rowId); + } else if (dataType instanceof ArrayType) { + return toJavaList(columnVector.getArray(rowId)); + } else if (dataType instanceof MapType) { + return toJavaMap(columnVector.getMap(rowId)); + } else { + throw new UnsupportedOperationException("unsupported data type"); + } + } +} diff --git a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java index fa023a8cfe5..2987d8d8ed8 100644 --- a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java +++ b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java @@ -17,7 +17,6 @@ import java.math.BigDecimal; import java.util.ArrayList; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -29,10 +28,7 @@ import io.delta.kernel.client.FileReadContext; import io.delta.kernel.client.JsonHandler; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.FileDataReadResult; -import io.delta.kernel.data.Row; +import io.delta.kernel.data.*; import io.delta.kernel.expressions.Predicate; import io.delta.kernel.types.ArrayType; import io.delta.kernel.types.BooleanType; @@ -160,23 +156,51 @@ private static Object decodeElement(JsonNode jsonValue, DataType dataType) { final Object parsedElement = decodeElement(element, arrayType.getElementType()); output.add(parsedElement); } - return output; + return new ArrayValue() { + @Override + public int getSize() { + return output.size(); + } + + @Override + public ColumnVector getElements() { + return new TestColumnVector(arrayType.getElementType(), output); + } + }; } if (dataType instanceof MapType) { throwIfTypeMismatch("map", jsonValue.isObject(), jsonValue); final MapType mapType = (MapType) dataType; + final List keys = new ArrayList<>(); + final List values = new ArrayList<>(); final Iterator> iter = jsonValue.fields(); - final Map output = new HashMap<>(); while (iter.hasNext()) { Map.Entry entry = iter.next(); String keyParsed = entry.getKey(); Object valueParsed = decodeElement(entry.getValue(), mapType.getValueType()); - output.put(keyParsed, valueParsed); + keys.add(keyParsed); + values.add(valueParsed); } - return output; + return new MapValue() { + + @Override + public int getSize() { + return keys.size(); + } + + @Override + public ColumnVector getKeys() { + return new TestColumnVector(mapType.getKeyType(), keys); + } + + @Override + public ColumnVector getValues() { + return new TestColumnVector(mapType.getValueType(), values); + } + }; } throw new UnsupportedOperationException( @@ -279,13 +303,109 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { - return (List) parsedValues[ordinal]; + public ArrayValue getArray(int ordinal) { + return (ArrayValue) parsedValues[ordinal]; + } + + @Override + public MapValue getMap(int ordinal) { + return (MapValue) parsedValues[ordinal]; + } + } + + private static class TestColumnVector implements ColumnVector { + + private final DataType dataType; + private final List values; + + TestColumnVector(DataType dataType, List values) { + this.dataType = dataType; + this.values = values; + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return values.size(); + } + + @Override + public void close() { + + } + + @Override + public boolean isNullAt(int rowId) { + return values.get(rowId) == null; + } + + @Override + public boolean getBoolean(int rowId) { + return (boolean) values.get(rowId); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public int getInt(int rowId) { + return (int) values.get(rowId); + } + + @Override + public long getLong(int rowId) { + return (long) values.get(rowId); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public String getString(int rowId) { + return (String) values.get(rowId); + } + + @Override + public BigDecimal getDecimal(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException("not yet implemented - test only"); + } + + @Override + public Row getStruct(int rowId) { + return (Row) values.get(rowId); + } + + @Override + public ArrayValue getArray(int rowId) { + return (ArrayValue) values.get(rowId); } @Override - public Map getMap(int ordinal) { - return (Map) parsedValues[ordinal]; + public MapValue getMap(int rowId) { + return (MapValue) values.get(rowId); } } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java index d89be7bf36f..02544eddbc2 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java @@ -16,19 +16,20 @@ package io.delta.kernel.defaults.internal.data; import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; +import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector; + public class DefaultJsonRow implements Row { private final Object[] parsedValues; private final StructType readSchema; @@ -110,13 +111,13 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { - return (List) parsedValues[ordinal]; + public ArrayValue getArray(int ordinal) { + return (ArrayValue) parsedValues[ordinal]; } @Override - public Map getMap(int ordinal) { - return (Map) parsedValues[ordinal]; + public MapValue getMap(int ordinal) { + return (MapValue) parsedValues[ordinal]; } private static void throwIfTypeMismatch(String expType, boolean hasExpType, JsonNode jsonNode) { @@ -174,18 +175,27 @@ private static Object decodeElement(JsonNode jsonValue, DataType dataType) { throwIfTypeMismatch("array", jsonValue.isArray(), jsonValue); final ArrayType arrayType = ((ArrayType) dataType); final ArrayNode jsonArray = (ArrayNode) jsonValue; - final List output = new ArrayList<>(); - - for (Iterator it = jsonArray.elements(); it.hasNext(); ) { - final JsonNode element = it.next(); + final Object[] elements = new Object[jsonArray.size()]; + for (int i = 0; i < jsonArray.size(); i++) { + final JsonNode element = jsonArray.get(i); final Object parsedElement = decodeElement(element, arrayType.getElementType()); if (parsedElement == null && !arrayType.containsNull()) { throw new RuntimeException("Array type expects no nulls as elements, but " + - "received `null` as array element"); + "received `null` as array element"); } - output.add(parsedElement); + elements[i] = parsedElement; } - return output; + return new ArrayValue() { + @Override + public int getSize() { + return elements.length; + } + + @Override + public ColumnVector getElements() { + return new DefaultGenericVector(arrayType.getElementType(), elements); + } + }; } if (dataType instanceof MapType) { @@ -193,10 +203,11 @@ private static Object decodeElement(JsonNode jsonValue, DataType dataType) { final MapType mapType = (MapType) dataType; if (!(mapType.getKeyType() instanceof StringType)) { throw new RuntimeException("MapType with a key type of `String` is supported, " + - "received a key type: " + mapType.getKeyType()); + "received a key type: " + mapType.getKeyType()); } + List keys = new ArrayList<>(jsonValue.size()); + List values = new ArrayList<>(jsonValue.size()); final Iterator> iter = jsonValue.fields(); - final Map output = new HashMap<>(); while (iter.hasNext()) { Map.Entry entry = iter.next(); @@ -204,12 +215,27 @@ private static Object decodeElement(JsonNode jsonValue, DataType dataType) { Object valueParsed = decodeElement(entry.getValue(), mapType.getValueType()); if (valueParsed == null && !mapType.isValueContainsNull()) { throw new RuntimeException("Map type expects no nulls in values, but " + - "received `null` as value"); + "received `null` as value"); } - output.put(keyParsed, valueParsed); + keys.add(keyParsed); + values.add(valueParsed); } + return new MapValue() { + @Override + public int getSize() { + return jsonValue.size(); + } + + @Override + public ColumnVector getKeys() { + return new DefaultGenericVector(mapType.getKeyType(), keys.toArray()); + } - return output; + @Override + public ColumnVector getValues() { + return new DefaultGenericVector(mapType.getValueType(), values.toArray()); + } + }; } throw new UnsupportedOperationException( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java index 9da2bc724cf..9f6e54f788f 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java @@ -18,12 +18,9 @@ import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Optional; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.Row; +import io.delta.kernel.data.*; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; @@ -210,7 +207,7 @@ public byte[] getBinary(int rowId) { } @Override - public Map getMap(int rowId) { + public MapValue getMap(int rowId) { assertValidRowId(rowId); return rows.get(rowId).getMap(columnOrdinal); } @@ -222,7 +219,7 @@ public Row getStruct(int rowId) { } @Override - public List getArray(int rowId) { + public ArrayValue getArray(int rowId) { assertValidRowId(rowId); return rows.get(rowId).getArray(columnOrdinal); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java index 0ddc6aae674..024b33bcf60 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java @@ -16,12 +16,12 @@ package io.delta.kernel.defaults.internal.data.vector; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; import java.util.Optional; import static java.util.Objects.requireNonNull; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.DataType; @@ -130,7 +130,7 @@ public BigDecimal getDecimal(int rowId) { } @Override - public Map getMap(int rowId) { + public MapValue getMap(int rowId) { throw unsupportedDataAccessException("map"); } @@ -140,7 +140,7 @@ public Row getStruct(int rowId) { } @Override - public List getArray(int rowId) { + public ArrayValue getArray(int rowId) { throw unsupportedDataAccessException("array"); } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java index 41fab3b75cd..88cbe2fca29 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java @@ -15,11 +15,10 @@ */ package io.delta.kernel.defaults.internal.data.vector; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import static java.util.Objects.requireNonNull; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.types.DataType; @@ -64,19 +63,29 @@ public DefaultArrayVector( * @return */ @Override - public List getArray(int rowId) { + public ArrayValue getArray(int rowId) { + checkValidRowId(rowId); if (isNullAt(rowId)) { return null; } - checkValidRowId(rowId); + // use the offsets array to find the starting and ending index in the underlying vector + // for this rowId int start = offsets[rowId]; int end = offsets[rowId + 1]; + return new ArrayValue() { - List values = new ArrayList<>(); - for (int entry = start; entry < end; entry++) { - Object key = VectorUtils.getValueAsObject(elementVector, entry); - values.add((T) key); - } - return values; + // create a view over the elements for this rowId + private final ColumnVector elements = new DefaultViewVector(elementVector, start, end); + + @Override + public int getSize() { + return elements.getSize(); + } + + @Override + public ColumnVector getElements() { + return elements; + } + }; } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java index cb3152cf8a8..b96005bbf80 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java @@ -16,10 +16,10 @@ package io.delta.kernel.defaults.internal.data.vector; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.DataType; @@ -107,8 +107,8 @@ public BigDecimal getDecimal(int rowId) { } @Override - public Map getMap(int rowId) { - return (Map) value; + public MapValue getMap(int rowId) { + return (MapValue) value; } @Override @@ -117,7 +117,7 @@ public Row getStruct(int rowId) { } @Override - public List getArray(int rowId) { - return (List) value; + public ArrayValue getArray(int rowId) { + return (ArrayValue) value; } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java new file mode 100644 index 00000000000..0253b82c93c --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -0,0 +1,147 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.math.BigDecimal; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.*; + +/** + * Generic column vector implementation to expose an array of objects as a column vector. + */ +public class DefaultGenericVector implements ColumnVector { + + private final DataType dataType; + private final Object[] values; + + public DefaultGenericVector(DataType dataType, Object[] values) { + this.dataType = dataType; + this.values = values; + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return values.length; + } + + @Override + public void close() { + + } + + @Override + public boolean isNullAt(int rowId) { + return values[rowId] == null; + } + + @Override + public boolean getBoolean(int rowId) { + throwIfUnsafeAccess(BooleanType.class, "boolean"); + return (boolean) values[rowId]; + } + + @Override + public byte getByte(int rowId) { + throwIfUnsafeAccess(ByteType.class, "byte"); + return (byte) values[rowId]; + } + + @Override + public short getShort(int rowId) { + throwIfUnsafeAccess(ShortType.class, "short"); + return (short) values[rowId]; + } + + @Override + public int getInt(int rowId) { + throwIfUnsafeAccess(IntegerType.class, "integer"); + return (int) values[rowId]; + } + + @Override + public long getLong(int rowId) { + throwIfUnsafeAccess(LongType.class, "long"); + return (long) values[rowId]; + } + + @Override + public float getFloat(int rowId) { + throwIfUnsafeAccess(FloatType.class, "float"); + return (float) values[rowId]; + } + + @Override + public double getDouble(int rowId) { + throwIfUnsafeAccess(DoubleType.class, "double"); + return (double) values[rowId]; + } + + @Override + public String getString(int rowId) { + throwIfUnsafeAccess(StringType.class, "string"); + return (String) values[rowId]; + } + + @Override + public BigDecimal getDecimal(int rowId) { + throwIfUnsafeAccess(DecimalType.class, "decimal"); + return (BigDecimal) values[rowId]; + } + + @Override + public byte[] getBinary(int rowId) { + throwIfUnsafeAccess(BinaryType.class, "binary"); + return (byte[]) values[rowId]; + } + + @Override + public Row getStruct(int rowId) { + throwIfUnsafeAccess(StructType.class, "struct"); + return (Row) values[rowId]; + } + + @Override + public ArrayValue getArray(int rowId) { + // TODO: not sufficient check, also need to check the element type + throwIfUnsafeAccess(ArrayType.class, "array"); + return (ArrayValue) values[rowId]; + } + + @Override + public MapValue getMap(int rowId) { + // TODO: not sufficient check, also need to check the element types + throwIfUnsafeAccess(MapType.class, "map"); + return (MapValue) values[rowId]; + } + + private void throwIfUnsafeAccess( Class expDataType, String accessType) { + if (!expDataType.isAssignableFrom(dataType.getClass())) { + String msg = String.format( + "Trying to access a `%s` value from vector of type `%s`", + accessType, + dataType); + throw new UnsupportedOperationException(msg); + } + }} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java index fa801280f64..320619f11ea 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java @@ -15,12 +15,11 @@ */ package io.delta.kernel.defaults.internal.data.vector; -import java.util.HashMap; -import java.util.Map; import java.util.Optional; import static java.util.Objects.requireNonNull; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; @@ -68,20 +67,35 @@ public DefaultMapVector( * @return */ @Override - public Map getMap(int rowId) { + public MapValue getMap(int rowId) { + checkValidRowId(rowId); if (isNullAt(rowId)) { return null; } - checkValidRowId(rowId); + // use the offsets array to find the starting and ending index in the underlying vectors + // for this rowId int start = offsets[rowId]; int end = offsets[rowId + 1]; + return new MapValue() { - Map values = new HashMap<>(); - for (int entry = start; entry < end; entry++) { - Object key = VectorUtils.getValueAsObject(keyVector, entry); - Object value = VectorUtils.getValueAsObject(valueVector, entry); - values.put((K) key, (V) value); - } - return values; + // create a view over the keys and values for this rowId + private final ColumnVector keys = new DefaultViewVector(keyVector, start, end); + private final ColumnVector values = new DefaultViewVector(valueVector, start, end); + + @Override + public int getSize() { + return keys.getSize(); + } + + @Override + public ColumnVector getKeys() { + return keys; + } + + @Override + public ColumnVector getValues() { + return values; + } + }; } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultStructVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultStructVector.java index 334962a77b2..ce4158e4042 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultStructVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultStructVector.java @@ -16,12 +16,12 @@ package io.delta.kernel.defaults.internal.data.vector; import java.math.BigDecimal; -import java.util.List; -import java.util.Map; import java.util.Optional; import static java.util.Objects.requireNonNull; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructType; @@ -157,12 +157,12 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { + public ArrayValue getArray(int ordinal) { return structVector.memberVectors[ordinal].getArray(rowId); } @Override - public Map getMap(int ordinal) { + public MapValue getMap(int ordinal) { return structVector.memberVectors[ordinal].getMap(rowId); } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java new file mode 100644 index 00000000000..5b061f4c879 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -0,0 +1,155 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.math.BigDecimal; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.DataType; +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; + +/** + * Provides a restricted view on an underlying column vector. + */ +public class DefaultViewVector implements ColumnVector { + + private final ColumnVector underlyingVector; + private final int offset; + private final int size; + + /** + * @param underlyingVector the underlying column vector to read + * @param start the row index of the underlyingVector where we want this vector to start + * @param end the row index of the underlyingVector where we want this vector to end + * (exclusive) + */ + public DefaultViewVector(ColumnVector underlyingVector, int start, int end) { + this.underlyingVector = underlyingVector; + this.offset = start; + this.size = end - start; + } + + @Override + public DataType getDataType() { + return underlyingVector.getDataType(); + } + + @Override + public int getSize() { + return size; + } + + @Override + public void close() { + // Don't close the underlying vector as it may still be used + } + + @Override + public boolean isNullAt(int rowId) { + checkValidRowId(rowId); + return underlyingVector.isNullAt(offset + rowId); + } + + @Override + public boolean getBoolean(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getBoolean(offset + rowId); + } + + @Override + public byte getByte(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getByte(offset + rowId); + } + + @Override + public short getShort(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getShort(offset + rowId); + } + + @Override + public int getInt(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getInt(offset + rowId); + } + + @Override + public long getLong(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getLong(offset + rowId); + } + + @Override + public float getFloat(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getFloat(offset + rowId); + } + + @Override + public double getDouble(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getDouble(offset + rowId); + } + + @Override + public byte[] getBinary(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getBinary(offset + rowId); + } + + @Override + public String getString(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getString(offset + rowId); + } + + @Override + public BigDecimal getDecimal(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getDecimal(offset + rowId); + } + + @Override + public MapValue getMap(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getMap(offset + rowId); + } + + @Override + public Row getStruct(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getStruct(offset + rowId); + } + + @Override + public ArrayValue getArray(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getArray(offset + rowId); + } + + private void checkValidRowId(int rowId) { + checkArgument(rowId >= 0 && rowId < size, + String.format( + "Invalid rowId=%s for size=%s", + rowId, + size + )); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/VectorUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/VectorUtils.java deleted file mode 100644 index 7948c2cd933..00000000000 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/VectorUtils.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.defaults.internal.data.vector; - -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.types.*; - -/** - * Utility methods for {@link io.delta.kernel.data.ColumnVector} implementations. - */ -public class VectorUtils { - private VectorUtils() {} - - /** - * Get the value at given {@code rowId} from the column vector. The type of the value object - * depends on the data type of the {@code vector}. - * - * @param vector - * @param rowId - * @return - */ - public static Object getValueAsObject(ColumnVector vector, int rowId) { - // TODO: may be it is better to just provide a `getObject` on the `ColumnVector` to - // avoid the nested if-else statements. - final DataType dataType = vector.getDataType(); - - if (vector.isNullAt(rowId)) { - return null; - } - - if (dataType instanceof BooleanType) { - return vector.getBoolean(rowId); - } else if (dataType instanceof ByteType) { - return vector.getByte(rowId); - } else if (dataType instanceof ShortType) { - return vector.getShort(rowId); - } else if (dataType instanceof IntegerType || dataType instanceof DateType) { - return vector.getInt(rowId); - } else if (dataType instanceof LongType || dataType instanceof TimestampType) { - return vector.getLong(rowId); - } else if (dataType instanceof FloatType) { - return vector.getFloat(rowId); - } else if (dataType instanceof DoubleType) { - return vector.getDouble(rowId); - } else if (dataType instanceof StringType) { - return vector.getString(rowId); - } else if (dataType instanceof BinaryType) { - return vector.getBinary(rowId); - } else if (dataType instanceof StructType) { - return vector.getStruct(rowId); - } else if (dataType instanceof MapType) { - return vector.getMap(rowId); - } else if (dataType instanceof ArrayType) { - return vector.getArray(rowId); - } else if (dataType instanceof DecimalType) { - return vector.getDecimal(rowId); - } - - throw new UnsupportedOperationException(dataType + " is not supported yet"); - } -} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java index 1028e9cb918..329a9c18f5c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ElementAtEvaluator.java @@ -19,6 +19,7 @@ import static java.lang.String.format; import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.ScalarExpression; import io.delta.kernel.types.DataType; @@ -73,7 +74,7 @@ static ColumnVector eval(ColumnVector map, ColumnVector lookupKey) { // The general pattern is call `isNullAt(rowId)` followed by `getString`. // So the cache of one value is enough. private int lastLookupRowId = -1; - private Object lastLookupValue = null; + private String lastLookupValue = null; @Override public DataType getDataType() { @@ -101,20 +102,37 @@ public boolean isNullAt(int rowId) { @Override public String getString(int rowId) { lookupValue(rowId); - return lastLookupValue == null ? null : (String) lastLookupValue; + return lastLookupValue == null ? null : lastLookupValue; } private Object lookupValue(int rowId) { if (rowId == lastLookupRowId) { return lastLookupValue; } - // TODO: this needs to be updated after the new way of accessing the complex - // types is merged. lastLookupRowId = rowId; String keyValue = lookupKey.getString(rowId); - lastLookupValue = map.getMap(rowId).get(keyValue); + lastLookupValue = findValueForKey(map.getMap(rowId), keyValue); return lastLookupValue; } + + /** + * Given a {@link MapValue} and string {@code key} find the corresponding value. + * Returns null if the key is not in the map. + * @param mapValue String->String map to search + * @param key the key to look up the value for; may be null + */ + private String findValueForKey(MapValue mapValue, String key) { + ColumnVector keyVector = mapValue.getKeys(); + for (int i = 0; i < mapValue.getSize(); i++) { + if ((keyVector.isNullAt(i) && key == null) || + (!keyVector.isNullAt(i) && keyVector.getString(i).equals(key))) { + return mapValue.getValues().isNullAt(i) ? null : + mapValue.getValues().getString(i); + } + } + // If the key is not in the map return null + return null; + } }; } diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java index 7de2932e921..03e1d8f53c3 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.defaults.client; +import java.io.IOException; import java.util.*; import org.apache.hadoop.conf.Configuration; @@ -30,8 +31,9 @@ import io.delta.kernel.fs.FileStatus; import io.delta.kernel.types.*; import io.delta.kernel.utils.CloseableIterator; -import io.delta.kernel.utils.Utils; +import io.delta.kernel.utils.VectorUtils; import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; +import static io.delta.kernel.utils.Utils.singletonColumnVector; import io.delta.kernel.internal.InternalScanFileUtils; @@ -126,7 +128,7 @@ public void parseJsonContent() .add("dataChange", BooleanType.INSTANCE); ColumnarBatch batch = - JSON_HANDLER.parseJson(Utils.singletonColumnVector(input), readSchema); + JSON_HANDLER.parseJson(singletonColumnVector(input), readSchema); assertEquals(1, batch.getSize()); try (CloseableIterator rows = batch.getRows()) { @@ -142,12 +144,68 @@ public void parseJsonContent() put("p2", "str"); } }; - assertEquals(expPartitionValues, row.getMap(1)); + Map actualPartitionValues = VectorUtils.toJavaMap(row.getMap(1)); + assertEquals(expPartitionValues, actualPartitionValues); assertEquals(348L, row.getLong(2)); assertEquals(true, row.getBoolean(3)); } } + @Test + public void parseNestedComplexTypes() throws IOException { + String json = "{" + + " \"array\": [0, 1, null]," + + " \"nested_array\": [[\"a\", \"b\"], [\"c\"], []]," + + " \"map\": {\"a\": true, \"b\": false},\n" + + " \"nested_map\": {\"a\": {\"one\": [], \"two\": [1, 2, 3]}, \"b\": {}}\n" + + "}"; + StructType schema = new StructType() + .add("array", new ArrayType(IntegerType.INSTANCE, true)) + .add("nested_array", new ArrayType(new ArrayType(StringType.INSTANCE, true), true)) + .add("map", new MapType(StringType.INSTANCE, BooleanType.INSTANCE, true)) + .add("nested_map", + new MapType( + StringType.INSTANCE, + new MapType( + StringType.INSTANCE, + new ArrayType(IntegerType.INSTANCE, true), + true + ), + true + )); + ColumnarBatch batch = JSON_HANDLER.parseJson(singletonColumnVector(json), schema); + + try (CloseableIterator rows = batch.getRows()) { + Row result = rows.next(); + List exp0 = Arrays.asList(0, 1, null); + assertEquals(exp0, VectorUtils.toJavaList(result.getArray(0))); + List> exp1 = Arrays.asList(Arrays.asList("a", "b"), Arrays.asList("c"), + Collections.emptyList()); + assertEquals(exp1, VectorUtils.toJavaList(result.getArray(1))); + Map exp2 = new HashMap() { + { + put("a", true); + put("b", false); + } + }; + assertEquals(exp2, VectorUtils.toJavaMap(result.getMap(2))); + Map> nestedMap = new HashMap>() { + { + put("one", Collections.emptyList()); + put("two", Arrays.asList(1, 2, 3)); + } + }; + Map>> exp3 = + new HashMap>>() { + { + put("a", nestedMap); + put("b", Collections.emptyMap()); + } + }; + assertEquals(exp3, VectorUtils.toJavaMap(result.getMap(3))); + } + } + private static CloseableIterator testFiles() throws Exception { String listFrom = DefaultKernelTestUtils.getTestResourceFilePath("json-files/1.json"); diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java index c1aed19dff3..9585ec05cc8 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/BaseIntegration.java @@ -36,7 +36,6 @@ import io.delta.kernel.defaults.client.DefaultTableClient; import io.delta.kernel.defaults.utils.DefaultKernelTestUtils; -import io.delta.kernel.defaults.internal.data.vector.VectorUtils; /** * Base class containing utility method to write integration tests that read data from @@ -160,8 +159,8 @@ protected boolean compareRows( ColumnVector expDataVector = expDataBatch.getColumnVector(fieldId); ColumnVector actDataVector = actDataBatch.getColumnVector(fieldId); - Object expObject = VectorUtils.getValueAsObject(expDataVector, expRowId); - Object actObject = VectorUtils.getValueAsObject(actDataVector, actRowId); + Object expObject = DefaultKernelTestUtils.getValueAsObject(expDataVector, expRowId); + Object actObject = DefaultKernelTestUtils.getValueAsObject(actDataVector, actRowId); boolean matched = compareObjects(fieldDataType, expObject, actObject); if (!matched) { return false; diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java index 2ff278430d8..b4a1d07c42e 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java @@ -22,7 +22,9 @@ import java.util.Map; import java.util.stream.IntStream; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.StructType; @@ -155,13 +157,15 @@ public Row getStruct(int ordinal) { } @Override - public List getArray(int ordinal) { - return (List) values.get(ordinal); + public ArrayValue getArray(int ordinal) { + throw new UnsupportedOperationException( + "array type unsupported for TestColumnBatchBuilder; use scala test utilities"); } @Override - public Map getMap(int ordinal) { - return (Map) values.get(ordinal); + public MapValue getMap(int ordinal) { + throw new UnsupportedOperationException( + "map type unsupported for TestColumnBatchBuilder; use scala test utilities"); } } } diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java index 5fd42a6756c..13f73113421 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/TestDeltaTableReads.java @@ -16,9 +16,6 @@ package io.delta.kernel.defaults.integration; import java.math.BigDecimal; -import java.sql.Date; -import java.util.Arrays; -import java.util.HashMap; import java.util.List; import org.junit.Rule; @@ -30,11 +27,9 @@ import io.delta.kernel.client.TableClient; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.types.*; -import static io.delta.kernel.internal.util.InternalUtils.daysSinceEpoch; import io.delta.kernel.defaults.client.DefaultTableClient; import io.delta.kernel.defaults.integration.DataBuilderUtils.TestColumnBatchBuilder; -import static io.delta.kernel.defaults.integration.DataBuilderUtils.row; import static io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getTestResourceFilePath; /** @@ -81,220 +76,6 @@ public void tablePrimitives() compareEqualUnorderd(expData, actualData); } - @Test - public void partitionedTable() - throws Exception { - String tablePath = goldenTablePath("data-reader-partition-values"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = removeUnsupportedType(snapshot.getSchema(tableClient)); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema); - - for (int i = 0; i < 2; i++) { - builder = builder.addRow( - i, - (long) i, - (byte) i, - (short) i, - i % 2 == 0, - (float) i, - (double) i, - String.valueOf(i), - "null", - daysSinceEpoch(Date.valueOf("2021-09-08")), - new BigDecimal(i), - Arrays.asList( - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), i), - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), i), - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), i) - ), - row( - structTypeOf(readSchema, "as_nested_struct"), - String.valueOf(i), - String.valueOf(i), - row( - structTypeOf( - structTypeOf(readSchema, "as_nested_struct"), - "ac" - ), - i, - (long) i - ) - ), - String.valueOf(i) - ); - } - - builder = builder.addRow( - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - Arrays.asList( - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), 2), - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), 2), - row(arrayElemStructTypeOf(readSchema, "as_list_of_records"), 2) - ), - row( - structTypeOf(readSchema, "as_nested_struct"), - "2", - "2", - row( - structTypeOf( - structTypeOf(readSchema, "as_nested_struct"), - "ac" - ), - 2, - 2L - ) - ), - "2" - ); - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - - @Test - public void tableWithComplexArrayTypes() - throws Exception { - String tablePath = goldenTablePath("data-reader-array-complex-objects"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = removeUnsupportedType(snapshot.getSchema(tableClient)); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema); - - for (int i = 0; i < 10; i++) { - final int index = i; - builder.addRow( - i, - Arrays.asList( - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i) - ), - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i) - ) - ), - Arrays.asList( - Arrays.asList( - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i)), - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i)) - ), - Arrays.asList( - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i)), - Arrays.asList( - Arrays.asList(i, i, i), - Arrays.asList(i, i, i))) - ), - Arrays.asList( - new HashMap() { - { - put(String.valueOf(index), (long) index); - } - }, - new HashMap() { - { - put(String.valueOf(index), (long) index); - } - } - ), - Arrays.asList( - row(arrayElemStructTypeOf(readSchema, "list_of_records"), i), - row(arrayElemStructTypeOf(readSchema, "list_of_records"), i), - row(arrayElemStructTypeOf(readSchema, "list_of_records"), i) - ) - ); - } - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - - @Test - public void tableWithComplexMapTypes() - throws Exception { - String tablePath = goldenTablePath("data-reader-map"); - Snapshot snapshot = snapshot(tablePath); - StructType readSchema = new StructType() - .add("i", IntegerType.INSTANCE) - .add("a", new MapType(IntegerType.INSTANCE, IntegerType.INSTANCE, true)) - .add("b", new MapType(LongType.INSTANCE, ByteType.INSTANCE, true)) - .add("c", new MapType(ShortType.INSTANCE, BooleanType.INSTANCE, true)) - .add("d", new MapType(FloatType.INSTANCE, DoubleType.INSTANCE, true)) - .add("f", new MapType( - IntegerType.INSTANCE, - new ArrayType(new StructType().add("val", IntegerType.INSTANCE), true), - true) - ); - - List actualData = readSnapshot(readSchema, snapshot); - - TestColumnBatchBuilder builder = DataBuilderUtils.builder(readSchema); - - for (int i = 0; i < 10; i++) { - final int index = i; - builder.addRow( - i, - new HashMap() { - { - put(index, index); - } - }, - new HashMap() { - { - put((long) index, (byte) index); - } - }, - new HashMap() { - { - put((short) index, index % 2 == 0); - } - }, - new HashMap() { - { - put((float) index, (double) index); - } - }, - new HashMap() { - { - StructType elemType = new StructType().add("val", IntegerType.INSTANCE); - put( - index, - Arrays.asList( - row(elemType, index), - row(elemType, index), - row(elemType, index) - ) - ); - } - } - ); - } - - ColumnarBatch expData = builder.build(); - compareEqualUnorderd(expData, actualData); - } - @Test public void tableWithCheckpoint() throws Exception { diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java index 92e02c0df4c..2875c69d365 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java @@ -30,14 +30,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.ColumnarBatch; -import io.delta.kernel.data.Row; +import io.delta.kernel.data.*; import io.delta.kernel.types.*; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Tuple2; +import io.delta.kernel.utils.VectorUtils; import io.delta.kernel.defaults.utils.DefaultKernelTestUtils; import io.delta.kernel.defaults.internal.DefaultKernelUtils; @@ -343,12 +343,14 @@ private static void verifyRowFromAllTypesFile( boolean expIsNull = rowId % 25 == 0; if (expIsNull) { assertTrue(vector.isNullAt(batchWithIdx._2)); + assertNull(vector.getArray(batchWithIdx._2)); } else if (rowId % 29 == 0) { - assertEquals(Collections.emptyList(), vector.getArray(batchWithIdx._2)); + checkArrayValue(vector.getArray(batchWithIdx._2), IntegerType.INSTANCE, + Collections.emptyList()); } else { List expArray = Arrays.asList(rowId, null, rowId + 1); - List actArray = vector.getArray(batchWithIdx._2); - assertEquals(expArray, actArray); + checkArrayValue(vector.getArray(batchWithIdx._2), IntegerType.INSTANCE, + expArray); } break; } @@ -357,41 +359,50 @@ private static void verifyRowFromAllTypesFile( break; case "array_of_structs": { assertFalse(vector.isNullAt(batchWithIdx._2)); - List actArray = vector.getArray(batchWithIdx._2); - assertTrue(actArray.size() == 2); - Row item0 = actArray.get(0); + ArrayValue arrayValue = vector.getArray(batchWithIdx._2); + ColumnVector elementVector = arrayValue.getElements(); + assertEquals(2, arrayValue.getSize()); + assertEquals(2, elementVector.getSize()); + assertTrue(elementVector.getDataType() instanceof StructType); + Row item0 = elementVector.getStruct(0); assertEquals(rowId, item0.getLong(0)); - assertNull(actArray.get(1)); + assertTrue(elementVector.isNullAt(1)); break; } case "map_of_prims": { boolean expIsNull = rowId % 28 == 0; if (expIsNull) { assertTrue(vector.isNullAt(batchWithIdx._2)); + assertNull(vector.getMap(batchWithIdx._2)); } else if (rowId % 30 == 0) { - assertEquals(Collections.emptyMap(), vector.getMap(batchWithIdx._2)); + checkMapValue( + vector.getMap(batchWithIdx._2), + IntegerType.INSTANCE, + LongType.INSTANCE, + Collections.emptyMap() + ); } else { - Map actValue = vector.getMap(batchWithIdx._2); - assertTrue(actValue.size() == 2); - - // entry 0: key = rowId - Integer key0 = rowId; - Long actValue0 = actValue.get(key0); - Long expValue0 = (rowId % 29 == 0) ? null : (rowId + 2L); - assertEquals(expValue0, actValue0); - - // entry 1 - Integer key1 = (rowId % 27 != 0) ? (rowId + 2) : (rowId + 3); - Long actValue1 = actValue.get(key1); - Long expValue1 = rowId + 9L; - assertEquals(expValue1, actValue1); + Map expValue = new HashMap() { + { + put(rowId, (rowId % 29 == 0) ? null : (rowId + 2L)); + put((rowId % 27 != 0) ? (rowId + 2) : (rowId + 3), rowId + 9L); + + } + }; + checkMapValue( + vector.getMap(batchWithIdx._2), + IntegerType.INSTANCE, + LongType.INSTANCE, + expValue + ); } break; } case "map_of_rows": { // Map(i + 1 -> (if (i % 10 == 0) Row((i*20).longValue()) else null)) assertFalse(vector.isNullAt(batchWithIdx._2)); - Map actValue = vector.getMap(batchWithIdx._2); + MapValue mapValue = vector.getMap(batchWithIdx._2); + Map actValue = VectorUtils.toJavaMap(mapValue); // entry 0: key = rowId Integer key0 = rowId + 1; @@ -492,7 +503,8 @@ private static void validateArrayOfArraysColumn( expArray = Collections.emptyList(); break; } - assertEquals(expArray, vector.getArray(batchRowId)); + DataType expDataType = new ArrayType(IntegerType.INSTANCE, true); + checkArrayValue(vector.getArray(batchRowId), expDataType, expArray); } private static void validateMapOfArraysColumn( @@ -525,7 +537,12 @@ private static void validateMapOfArraysColumn( } }; } - assertEquals(expMap, vector.getMap(batchRowId)); + checkMapValue( + vector.getMap(batchRowId), + LongType.INSTANCE, + new ArrayType(IntegerType.INSTANCE, true), + expMap + ); } private static Tuple2 getBatchForRowId( @@ -540,4 +557,39 @@ private static Tuple2 getBatchForRowId( throw new IllegalArgumentException("row id is not found: " + rowId); } + + private static void checkArrayValue( + ArrayValue arrayValue, DataType expDataType, List expList) { + int size = expList.size(); + ColumnVector elementVector = arrayValue.getElements(); + // Check the size is as expected and arrayValue.getSize == elementVector.getSize + assertEquals(size, arrayValue.getSize()); + assertEquals(size, elementVector.getSize()); + // Check the element vector has the correct data type + assertEquals(elementVector.getDataType(), expDataType); + // Check the elements are correct + assertEquals(expList, VectorUtils.toJavaList(arrayValue)); + assertThrows(IllegalArgumentException.class, + () -> DefaultKernelTestUtils.getValueAsObject(elementVector, size + 1)); + } + + private static void checkMapValue( + MapValue mapValue, DataType keyDataType, DataType valueDataType, Map expMap) { + int size = expMap.size(); + ColumnVector keyVector = mapValue.getKeys(); + ColumnVector valueVector = mapValue.getValues(); + // Check the size mapValue.getSize == keyVector.getSize == valueVector.getSize + assertEquals(size, mapValue.getSize()); + assertEquals(size, keyVector.getSize()); + assertEquals(size, valueVector.getSize()); + // Check the key and value vector has the correct data type + assertEquals(keyVector.getDataType(), keyDataType); + assertEquals(valueVector.getDataType(), valueDataType); + // Check the elements are correct + assertEquals(expMap, VectorUtils.toJavaMap(mapValue)); + assertThrows(IllegalArgumentException.class, + () -> DefaultKernelTestUtils.getValueAsObject(keyVector, size + 1)); + assertThrows(IllegalArgumentException.class, + () -> DefaultKernelTestUtils.getValueAsObject(valueVector, size + 1)); + } } diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/utils/DefaultKernelTestUtils.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/utils/DefaultKernelTestUtils.java index a42a9ae1527..03ef0bd2371 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/utils/DefaultKernelTestUtils.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/utils/DefaultKernelTestUtils.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.defaults.utils; +import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; @@ -23,14 +24,12 @@ private DefaultKernelTestUtils() {} /** * Returns a URI encoded path of the resource. - * - * @param resourcePath - * @return */ public static String getTestResourceFilePath(String resourcePath) { return DefaultKernelTestUtils.class.getClassLoader().getResource(resourcePath).getFile(); } + // This will no longer be needed once all tests have been moved to Scala public static Object getValueAsObject(Row row, int columnOrdinal) { // TODO: may be it is better to just provide a `getObject` on the `Row` to // avoid the nested if-else statements. @@ -60,10 +59,46 @@ public static Object getValueAsObject(Row row, int columnOrdinal) { return row.getBinary(columnOrdinal); } else if (dataType instanceof StructType) { return row.getStruct(columnOrdinal); - } else if (dataType instanceof MapType) { - return row.getMap(columnOrdinal); - } else if (dataType instanceof ArrayType) { - return row.getArray(columnOrdinal); + } + + throw new UnsupportedOperationException(dataType + " is not supported yet"); + } + + /** + * Get the value at given {@code rowId} from the column vector. The type of the value object + * depends on the data type of the {@code vector}. + */ + public static Object getValueAsObject(ColumnVector vector, int rowId) { + // TODO: may be it is better to just provide a `getObject` on the `ColumnVector` to + // avoid the nested if-else statements. + final DataType dataType = vector.getDataType(); + + if (vector.isNullAt(rowId)) { + return null; + } + + if (dataType instanceof BooleanType) { + return vector.getBoolean(rowId); + } else if (dataType instanceof ByteType) { + return vector.getByte(rowId); + } else if (dataType instanceof ShortType) { + return vector.getShort(rowId); + } else if (dataType instanceof IntegerType || dataType instanceof DateType) { + return vector.getInt(rowId); + } else if (dataType instanceof LongType || dataType instanceof TimestampType) { + return vector.getLong(rowId); + } else if (dataType instanceof FloatType) { + return vector.getFloat(rowId); + } else if (dataType instanceof DoubleType) { + return vector.getDouble(rowId); + } else if (dataType instanceof StringType) { + return vector.getString(rowId); + } else if (dataType instanceof BinaryType) { + return vector.getBinary(rowId); + } else if (dataType instanceof StructType) { + return vector.getStruct(rowId); + } else if (dataType instanceof DecimalType) { + return vector.getDecimal(rowId); } throw new UnsupportedOperationException(dataType + " is not supported yet"); diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index 3d483351009..273627cfc04 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -17,13 +17,18 @@ package io.delta.kernel.defaults import java.io.File import java.math.BigDecimal +import java.sql.Date + +import scala.collection.JavaConverters._ import org.scalatest.funsuite.AnyFunSuite import io.delta.golden.GoldenTableUtils.goldenTablePath +import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils + import io.delta.kernel.{Table, TableNotFoundException} import io.delta.kernel.defaults.internal.DefaultKernelUtils import io.delta.kernel.defaults.utils.{TestRow, TestUtils} -import org.apache.hadoop.shaded.org.apache.commons.io.FileUtils +import io.delta.kernel.internal.util.InternalUtils.daysSinceEpoch class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { @@ -34,7 +39,7 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { // TODO: for now we do not support timestamp partition columns, make sure it's blocked test("cannot read partition column of timestamp type") { val path = goldenTablePath("kernel-timestamp-TIMESTAMP_MICROS") - val snapshot = latestSnapshot(path); + val snapshot = latestSnapshot(path) val e = intercept[UnsupportedOperationException] { readSnapshot(snapshot) // request entire schema @@ -180,4 +185,102 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { s"Table at path `file:${target.getCanonicalPath}` is not found")) } } + + test("read partitioned table") { + val path = "file:" + goldenTablePath("data-reader-partition-values") + + // for now we don't support timestamp type partition columns so remove from read columns + val readCols = Table.forPath(defaultTableClient, path).getLatestSnapshot(defaultTableClient) + .getSchema(defaultTableClient) + .withoutField("as_timestamp") + .fields() + .asScala + .map(_.getName) + + val expectedAnswer = Seq(0, 1).map { i => + TestRow( + i, + i.toLong, + i.toByte, + i.toShort, + i % 2 == 0, + i.toFloat, + i.toDouble, + i.toString, + "null", + daysSinceEpoch(Date.valueOf("2021-09-08")), + new BigDecimal(i), + Seq(TestRow(i), TestRow(i), TestRow(i)), + TestRow(i.toString, i.toString, TestRow(i, i.toLong)), + i.toString + ) + } ++ (TestRow( + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + Seq(TestRow(2), TestRow(2), TestRow(2)), + TestRow("2", "2", TestRow(2, 2L)), + "2" + ) :: Nil) + + checkTable( + path = path, + expectedAnswer = expectedAnswer, + readCols = readCols + ) + } + + test("table with complex array types") { + val path = "file:" + goldenTablePath("data-reader-array-complex-objects") + + val expectedAnswer = (0 until 10).map { i => + TestRow( + i, + Seq(Seq(Seq(i, i, i), Seq(i, i, i)), Seq(Seq(i, i, i), Seq(i, i, i))), + Seq( + Seq(Seq(Seq(i, i, i), Seq(i, i, i)), Seq(Seq(i, i, i), Seq(i, i, i))), + Seq(Seq(Seq(i, i, i), Seq(i, i, i)), Seq(Seq(i, i, i), Seq(i, i, i))) + ), + Seq( + Map[String, Long](i.toString -> i.toLong), + Map[String, Long](i.toString -> i.toLong) + ), + Seq(TestRow(i), TestRow(i), TestRow(i)) + ) + } + + checkTable( + path = path, + expectedAnswer = expectedAnswer + ) + } + + test("table with complex map types") { + val path = "file:" + goldenTablePath("data-reader-map") + + val expectedAnswer = (0 until 10).map { i => + TestRow( + i, + Map(i -> i), + Map(i.toLong -> i.toByte), + Map(i.toShort -> (i % 2 == 0)), + Map(i.toFloat -> i.toDouble), + Map(i.toString -> new BigDecimal(i)), + Map(i -> Seq(TestRow(i), TestRow(i), TestRow(i))) + ) + } + + checkTable( + path = path, + expectedAnswer = expectedAnswer + ) + } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala index b4766a029c7..3fd8ca03d6a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluatorSuite.scala @@ -21,10 +21,11 @@ import java.sql.Date import java.util import java.util.Optional -import io.delta.kernel.data.{ColumnarBatch, ColumnVector} +import org.scalatest.funsuite.AnyFunSuite +import io.delta.kernel.data.{ColumnarBatch, ColumnVector, MapValue} import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultMapVector, DefaultStructVector} -import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject +import io.delta.kernel.defaults.internal.data.vector.{DefaultIntVector, DefaultStructVector} +import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.defaults.utils.TestUtils import io.delta.kernel.expressions._ import io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE @@ -32,7 +33,6 @@ import io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE import io.delta.kernel.expressions.Literal._ import io.delta.kernel.internal.util.InternalUtils import io.delta.kernel.types._ -import org.scalatest.funsuite.AnyFunSuite class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBase { test("evaluate expression: literal") { @@ -397,25 +397,16 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa test("evaluate expression: element_at") { import scala.collection.JavaConverters._ val nullStr = null.asInstanceOf[String] - val testMapValues = Seq( - Map("k0" -> "v00", "k1" -> "v01", "k3" -> nullStr, nullStr -> "v04").asJava, - Map("k0" -> "v10", "k1" -> nullStr, "k3" -> "v13", nullStr -> "v14").asJava, - Map("k0" -> nullStr, "k1" -> "v21", "k3" -> "v23", nullStr -> "v24").asJava, + val testMapValues: Seq[Map[AnyRef, AnyRef]] = Seq( + Map("k0" -> "v00", "k1" -> "v01", "k3" -> nullStr, nullStr -> "v04"), + Map("k0" -> "v10", "k1" -> nullStr, "k3" -> "v13", nullStr -> "v14"), + Map("k0" -> nullStr, "k1" -> "v21", "k3" -> "v23", nullStr -> "v24"), null ) - val testMapVector = new ColumnVector { - override def getDataType: DataType = - new MapType(StringType.INSTANCE, StringType.INSTANCE, true /* valueContainsNull */) + val testMapVector = buildMapVector( + testMapValues, + new MapType(StringType.INSTANCE, StringType.INSTANCE, true)) - override def getSize: Int = testMapValues.size - - override def close(): Unit = {} - - override def isNullAt(rowId: Int): Boolean = testMapValues(rowId) == null - - override def getMap[K, V](rowId: Int): util.Map[K, V] = - testMapValues(rowId).asInstanceOf[util.Map[K, V]] - } val inputBatch = new DefaultColumnarBatch( testMapVector.getSize, new StructType().add("partitionValues", testMapVector.getDataType), @@ -424,7 +415,7 @@ class DefaultExpressionEvaluatorSuite extends AnyFunSuite with ExpressionSuiteBa Seq("k0", "k1", "k2", null).foreach { lookupKey => val expOutput = testMapValues.map(map => { if (map == null) null - else map.get(lookupKey) + else map.getOrElse(lookupKey, null) }) val lookupKeyExpr = if (lookupKey == null) { diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala index 7aad68e5a76..1d9bb9b6e3c 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ImplicitCastExpressionSuite.scala @@ -15,13 +15,14 @@ */ package io.delta.kernel.defaults.internal.expressions +import org.scalatest.funsuite.AnyFunSuite + import io.delta.kernel.data.ColumnVector -import io.delta.kernel.defaults.internal.data.vector.VectorUtils.getValueAsObject import io.delta.kernel.defaults.internal.expressions.ImplicitCastExpression.canCastTo +import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.defaults.utils.TestUtils import io.delta.kernel.expressions.Column import io.delta.kernel.types._ -import org.scalatest.funsuite.AnyFunSuite class ImplicitCastExpressionSuite extends AnyFunSuite with TestUtils { private val allowedCasts: Set[(DataType, DataType)] = Set( diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 8c81f4ea3ff..680f2f60777 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -17,7 +17,7 @@ package io.delta.kernel.defaults.utils import scala.collection.JavaConverters._ -import io.delta.kernel.data.Row +import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} import io.delta.kernel.types._ /** @@ -34,15 +34,13 @@ import io.delta.kernel.types._ * - TimestampType --> long (number of microseconds since the unix epoch) * - DecimalType --> java.math.BigDecimal * - BinaryType --> Array[Byte] + * - ArrayType --> Seq[Any] + * - MapType --> Map[Any, Any] + * - StructType --> TestRow * - * TODO: complex types - * - StructType? - * - ArrayType? - * - MapType? + * For complex types array and map, the inner elements types should align with this mapping. */ class TestRow(val values: Array[Any]) { - // TODO: we could make this extend Row and create a way to generate Seq(Any) from Rows but it - // would complicate a lot of the code for not much benefit def length: Int = values.length @@ -99,16 +97,55 @@ object TestRow { case _: StringType => row.getString(i) case _: BinaryType => row.getBinary(i) case _: DecimalType => row.getDecimal(i) - - // TODO complex types - // case _: StructType => row.getStruct(i) - // case _: MapType => row.getMap(i) - // case _: ArrayType => row.getArray(i) + case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) + case _: MapType => mapValueToScalaMap(row.getMap(i)) + case _: StructType => TestRow(row.getStruct(i)) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) } + /** + * Retrieves the value at `rowId` in the column vector as it's corresponding scala type. + * See the [[TestRow]] docs for details. + */ + private def getAsTestObject(vector: ColumnVector, rowId: Int): Any = { + vector.getDataType match { + case _ if vector.isNullAt(rowId) => null + case _: BooleanType => vector.getBoolean(rowId) + case _: ByteType => vector.getByte(rowId) + case _: IntegerType => vector.getInt(rowId) + case _: LongType => vector.getLong(rowId) + case _: ShortType => vector.getShort(rowId) + case _: DateType => vector.getInt(rowId) + case _: TimestampType => vector.getLong(rowId) + case _: FloatType => vector.getFloat(rowId) + case _: DoubleType => vector.getDouble(rowId) + case _: StringType => vector.getString(rowId) + case _: BinaryType => vector.getBinary(rowId) + case _: DecimalType => vector.getDecimal(rowId) + case _: ArrayType => arrayValueToScalaSeq(vector.getArray(rowId)) + case _: MapType => mapValueToScalaMap(vector.getMap(rowId)) + case _: StructType => TestRow(vector.getStruct(rowId)) + case _ => throw new UnsupportedOperationException("unrecognized data type") + } + } + + private def arrayValueToScalaSeq(arrayValue: ArrayValue): Seq[Any] = { + val elemVector = arrayValue.getElements + (0 until arrayValue.getSize).map { i => + getAsTestObject(elemVector, i) + } + } + + private def mapValueToScalaMap(mapValue: MapValue): Map[Any, Any] = { + val keyVector = mapValue.getKeys() + val valueVector = mapValue.getValues() + (0 until mapValue.getSize).map { i => + getAsTestObject(keyVector, i) -> getAsTestObject(valueVector, i) + }.toMap + } + /** * Construct a [[TestRow]] from the given seq of values. See the docs for [[TestRow]] for * the scala type corresponding to each Kernel data type. diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 20fd7a7ee77..0633091e68a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -24,8 +24,9 @@ import scala.collection.mutable.ArrayBuffer import io.delta.kernel.{Scan, Snapshot, Table} import io.delta.kernel.client.TableClient -import io.delta.kernel.data.Row +import io.delta.kernel.data.{ColumnVector, MapValue, Row} import io.delta.kernel.defaults.client.DefaultTableClient +import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector import io.delta.kernel.types._ import io.delta.kernel.utils.CloseableIterator import org.apache.hadoop.conf.Configuration @@ -300,4 +301,34 @@ trait TestUtils extends Assertions { FileUtils.deleteDirectory(tempDir) } } + + /** + * Builds a MapType ColumnVector from a sequence of maps. + */ + def buildMapVector(mapValues: Seq[Map[AnyRef, AnyRef]], dataType: MapType): ColumnVector = { + val keyType = dataType.getKeyType + val valueType = dataType.getValueType + + def getMapValue(map: Map[AnyRef, AnyRef]): MapValue = { + if (map == null) { + null + } else { + val (keys, values) = map.unzip + new MapValue() { + override def getSize: Int = map.size + + override def getKeys = new DefaultGenericVector( + keyType, keys.toArray) + + override def getValues = new DefaultGenericVector( + valueType, values.toArray) + } + } + } + + new DefaultGenericVector( + dataType, + mapValues.map(getMapValue).toArray + ) + } }