diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 57392cb0585..3889de84900 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -72,28 +72,6 @@ public static Map toJavaMap(MapValue mapValue) { return values; } - /** - * Creates an {@link ArrayValue} from list of object. - * - * @param values list of object - * @return an {@link ArrayValue} with the given values of type {@link StringType} - */ - public static ArrayValue buildArrayValue(List values, DataType dataType) { - if (values == null) { - return null; - } - return new ArrayValue() { - @Override - public int getSize() { - return values.size(); - } - - @Override - public ColumnVector getElements() { - return buildColumnVector(values, dataType); - } - }; - } /** * Creates a {@link MapValue} from map of string keys and string values. The type {@code @@ -127,12 +105,32 @@ public ColumnVector getValues() { }; } + /** + * Creates an {@link ArrayValue} from list of objects. + */ + public static ArrayValue buildArrayValue(List values, DataType dataType) { + if (values == null) { + return null; + } + return new ArrayValue() { + @Override + public int getSize() { + return values.size(); + } + + @Override + public ColumnVector getElements() { + return buildColumnVector(values, dataType); + } + }; + } + /** * Utility method to create a {@link ColumnVector} for given list of object, the object should be * primitive type or an Row instance. * * @param values list of strings - * @return a {@link ColumnVector} with the given values of type {@link StringType} + * @return a {@link ColumnVector} with the given values type. */ public static ColumnVector buildColumnVector(List values, DataType dataType) { return new ColumnVector() { @@ -153,169 +151,147 @@ public void close() { @Override public boolean isNullAt(int rowId) { - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + validateRowId(rowId); return values.get(rowId) == null; } @Override public boolean getBoolean(int rowId) { checkArgument(BooleanType.BOOLEAN.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Boolean); - return (Boolean) value; + return (Boolean) getValidatedValue(rowId, Boolean.class); } @Override public byte getByte(int rowId) { checkArgument(ByteType.BYTE.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Byte); - return (Byte) value; + return (Byte) getValidatedValue(rowId, Byte.class); } @Override public short getShort(int rowId) { checkArgument(ShortType.SHORT.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Short); - return (Short) value; + return (Short) getValidatedValue(rowId, Short.class); } @Override public int getInt(int rowId) { checkArgument(IntegerType.INTEGER.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Integer); - return (Integer) value; + return (Integer) getValidatedValue(rowId, Integer.class); } @Override public long getLong(int rowId) { checkArgument(LongType.LONG.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Long); - return (Long) value; + return (Long) getValidatedValue(rowId, Long.class); } @Override public float getFloat(int rowId) { checkArgument(FloatType.FLOAT.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Float); - return (Float) value; + return (Float) getValidatedValue(rowId, Float.class); } @Override public double getDouble(int rowId) { checkArgument(DoubleType.DOUBLE.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof Double); - return (Double) value; + return (Double) getValidatedValue(rowId, Double.class); } @Override public BigDecimal getDecimal(int rowId) { checkArgument(dataType instanceof DecimalType); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof BigDecimal); - return (BigDecimal) value; + return (BigDecimal) getValidatedValue(rowId, BigDecimal.class); } @Override public String getString(int rowId) { checkArgument(StringType.STRING.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof String); - return (String) value; + return (String) getValidatedValue(rowId, String.class); } @Override public byte[] getBinary(int rowId) { checkArgument(BinaryType.BINARY.equals(dataType)); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof byte[]); - return (byte[]) value; + return (byte[]) getValidatedValue(rowId, byte[].class); } @Override public ArrayValue getArray(int rowId) { - checkArgument(dataType instanceof ArrayValue); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof ArrayValue); - return (ArrayValue) value; + checkArgument(dataType instanceof ArrayType); + return (ArrayValue) getValidatedValue(rowId, ArrayValue.class); } @Override public MapValue getMap(int rowId) { checkArgument(dataType instanceof MapType); - checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); - Object value = values.get(rowId); - checkArgument(value instanceof MapType); - return (MapValue) value; + return (MapValue) getValidatedValue(rowId, MapValue.class); } @Override public ColumnVector getChild(int ordinal) { checkArgument(dataType instanceof StructType); checkArgument(ordinal < ((StructType) dataType).length()); + DataType childDatatype = ((StructType) dataType).at(ordinal).getDataType(); - return buildColumnVector( - values.stream() - .map( - e -> { - checkArgument(e instanceof Row); - Row row = (Row) e; - if (row.isNullAt(ordinal)) { - return null; - } - if (childDatatype instanceof BooleanType) { - return row.getBoolean(ordinal); - } else if (childDatatype instanceof ByteType) { - return row.getByte(ordinal); - } else if (childDatatype instanceof ShortType) { - return row.getShort(ordinal); - } else if (childDatatype instanceof IntegerType - || childDatatype instanceof DateType) { - // DateType data is stored internally as the number of days since 1970-01-01 - return row.getInt(ordinal); - } else if (childDatatype instanceof LongType - || childDatatype instanceof TimestampType) { - // TimestampType data is stored internally as the number of microseconds - // since the unix epoch - return row.getLong(ordinal); - } else if (childDatatype instanceof FloatType) { - return row.getFloat(ordinal); - } else if (childDatatype instanceof DoubleType) { - return row.getDouble(ordinal); - } else if (childDatatype instanceof StringType) { - return row.getString(ordinal); - } else if (childDatatype instanceof BinaryType) { - return row.getBinary(ordinal); - } else if (childDatatype instanceof StructType) { - return row.getStruct(ordinal); - } else if (childDatatype instanceof DecimalType) { - return row.getDecimal(ordinal); - } else if (childDatatype instanceof ArrayType) { - return row.getArray(ordinal); - } else if (dataType instanceof MapType) { - return row.getMap(ordinal); - } else { - throw new UnsupportedOperationException("unsupported data type"); - } - }) - .collect(Collectors.toList()), - childDatatype); + List childValues = extractChildValues(ordinal, childDatatype); + + return buildColumnVector(childValues, childDatatype); + } + + private void validateRowId(int rowId) { + checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId); + } + + private Object getValidatedValue(int rowId, Class expectedType) { + validateRowId(rowId); + Object value = values.get(rowId); + checkArgument(expectedType.isInstance(value), + "Value must be of type %s", expectedType.getSimpleName()); + return value; + } + + + private List extractChildValues(int ordinal, DataType childDatatype) { + return values.stream() + .map(e -> extractChildValue(e, ordinal, childDatatype)) + .collect(Collectors.toList()); + } + + private Object extractChildValue(Object element, int ordinal, DataType childDatatype) { + checkArgument(element instanceof Row); + Row row = (Row) element; + + if (row.isNullAt(ordinal)) { + return null; + } + + return extractTypedValue(row, ordinal, childDatatype); + } + + private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) { + // Primitive Types + if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal); + if (childDatatype instanceof ByteType) return row.getByte(ordinal); + if (childDatatype instanceof ShortType) return row.getShort(ordinal); + if (childDatatype instanceof IntegerType || + childDatatype instanceof DateType) return row.getInt(ordinal); + if (childDatatype instanceof LongType || + childDatatype instanceof TimestampType) return row.getLong(ordinal); + if (childDatatype instanceof FloatType) return row.getFloat(ordinal); + if (childDatatype instanceof DoubleType) return row.getDouble(ordinal); + + // Complex Types + if (childDatatype instanceof StringType) return row.getString(ordinal); + if (childDatatype instanceof BinaryType) return row.getBinary(ordinal); + if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal); + + // Nested Types + if (childDatatype instanceof StructType) return row.getStruct(ordinal); + if (childDatatype instanceof ArrayType) return row.getArray(ordinal); + if (childDatatype instanceof MapType) return row.getMap(ordinal); + + throw new UnsupportedOperationException( + String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); } }; }