Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
huan233usc committed Feb 14, 2025
1 parent 2d4dcf7 commit 3f535f7
Showing 1 changed file with 95 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,6 @@ public static <K, V> Map<K, V> toJavaMap(MapValue mapValue) {
return values;
}

/**
* Creates an {@link ArrayValue} from list of object.
*
* @param values list of object
* @return an {@link ArrayValue} with the given values of type {@link StringType}
*/
public static ArrayValue buildArrayValue(List<?> values, DataType dataType) {
if (values == null) {
return null;
}
return new ArrayValue() {
@Override
public int getSize() {
return values.size();
}

@Override
public ColumnVector getElements() {
return buildColumnVector(values, dataType);
}
};
}

/**
* Creates a {@link MapValue} from map of string keys and string values. The type {@code
Expand Down Expand Up @@ -127,12 +105,32 @@ public ColumnVector getValues() {
};
}

/**
* Creates an {@link ArrayValue} from list of objects.
*/
public static ArrayValue buildArrayValue(List<?> values, DataType dataType) {
if (values == null) {
return null;
}
return new ArrayValue() {
@Override
public int getSize() {
return values.size();
}

@Override
public ColumnVector getElements() {
return buildColumnVector(values, dataType);
}
};
}

/**
* Utility method to create a {@link ColumnVector} for given list of object, the object should be
* primitive type or an Row instance.
*
* @param values list of strings
* @return a {@link ColumnVector} with the given values of type {@link StringType}
* @return a {@link ColumnVector} with the given values type.
*/
public static ColumnVector buildColumnVector(List<?> values, DataType dataType) {
return new ColumnVector() {
Expand All @@ -153,169 +151,147 @@ public void close() {

@Override
public boolean isNullAt(int rowId) {
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
validateRowId(rowId);
return values.get(rowId) == null;
}

@Override
public boolean getBoolean(int rowId) {
checkArgument(BooleanType.BOOLEAN.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Boolean);
return (Boolean) value;
return (Boolean) getValidatedValue(rowId, Boolean.class);
}

@Override
public byte getByte(int rowId) {
checkArgument(ByteType.BYTE.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Byte);
return (Byte) value;
return (Byte) getValidatedValue(rowId, Byte.class);
}

@Override
public short getShort(int rowId) {
checkArgument(ShortType.SHORT.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Short);
return (Short) value;
return (Short) getValidatedValue(rowId, Short.class);
}

@Override
public int getInt(int rowId) {
checkArgument(IntegerType.INTEGER.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Integer);
return (Integer) value;
return (Integer) getValidatedValue(rowId, Integer.class);
}

@Override
public long getLong(int rowId) {
checkArgument(LongType.LONG.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Long);
return (Long) value;
return (Long) getValidatedValue(rowId, Long.class);
}

@Override
public float getFloat(int rowId) {
checkArgument(FloatType.FLOAT.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Float);
return (Float) value;
return (Float) getValidatedValue(rowId, Float.class);
}

@Override
public double getDouble(int rowId) {
checkArgument(DoubleType.DOUBLE.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof Double);
return (Double) value;
return (Double) getValidatedValue(rowId, Double.class);
}

@Override
public BigDecimal getDecimal(int rowId) {
checkArgument(dataType instanceof DecimalType);
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof BigDecimal);
return (BigDecimal) value;
return (BigDecimal) getValidatedValue(rowId, BigDecimal.class);
}

@Override
public String getString(int rowId) {
checkArgument(StringType.STRING.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof String);
return (String) value;
return (String) getValidatedValue(rowId, String.class);
}

@Override
public byte[] getBinary(int rowId) {
checkArgument(BinaryType.BINARY.equals(dataType));
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof byte[]);
return (byte[]) value;
return (byte[]) getValidatedValue(rowId, byte[].class);
}

@Override
public ArrayValue getArray(int rowId) {
checkArgument(dataType instanceof ArrayValue);
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof ArrayValue);
return (ArrayValue) value;
checkArgument(dataType instanceof ArrayType);
return (ArrayValue) getValidatedValue(rowId, ArrayValue.class);
}

@Override
public MapValue getMap(int rowId) {
checkArgument(dataType instanceof MapType);
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
Object value = values.get(rowId);
checkArgument(value instanceof MapType);
return (MapValue) value;
return (MapValue) getValidatedValue(rowId, MapValue.class);
}

@Override
public ColumnVector getChild(int ordinal) {
checkArgument(dataType instanceof StructType);
checkArgument(ordinal < ((StructType) dataType).length());

DataType childDatatype = ((StructType) dataType).at(ordinal).getDataType();
return buildColumnVector(
values.stream()
.map(
e -> {
checkArgument(e instanceof Row);
Row row = (Row) e;
if (row.isNullAt(ordinal)) {
return null;
}
if (childDatatype instanceof BooleanType) {
return row.getBoolean(ordinal);
} else if (childDatatype instanceof ByteType) {
return row.getByte(ordinal);
} else if (childDatatype instanceof ShortType) {
return row.getShort(ordinal);
} else if (childDatatype instanceof IntegerType
|| childDatatype instanceof DateType) {
// DateType data is stored internally as the number of days since 1970-01-01
return row.getInt(ordinal);
} else if (childDatatype instanceof LongType
|| childDatatype instanceof TimestampType) {
// TimestampType data is stored internally as the number of microseconds
// since the unix epoch
return row.getLong(ordinal);
} else if (childDatatype instanceof FloatType) {
return row.getFloat(ordinal);
} else if (childDatatype instanceof DoubleType) {
return row.getDouble(ordinal);
} else if (childDatatype instanceof StringType) {
return row.getString(ordinal);
} else if (childDatatype instanceof BinaryType) {
return row.getBinary(ordinal);
} else if (childDatatype instanceof StructType) {
return row.getStruct(ordinal);
} else if (childDatatype instanceof DecimalType) {
return row.getDecimal(ordinal);
} else if (childDatatype instanceof ArrayType) {
return row.getArray(ordinal);
} else if (dataType instanceof MapType) {
return row.getMap(ordinal);
} else {
throw new UnsupportedOperationException("unsupported data type");
}
})
.collect(Collectors.toList()),
childDatatype);
List<?> childValues = extractChildValues(ordinal, childDatatype);

return buildColumnVector(childValues, childDatatype);
}

private void validateRowId(int rowId) {
checkArgument(rowId >= 0 && rowId < values.size(), "Invalid rowId: %s", rowId);
}

private Object getValidatedValue(int rowId, Class<?> expectedType) {
validateRowId(rowId);
Object value = values.get(rowId);
checkArgument(expectedType.isInstance(value),
"Value must be of type %s", expectedType.getSimpleName());
return value;
}


private List<?> extractChildValues(int ordinal, DataType childDatatype) {
return values.stream()
.map(e -> extractChildValue(e, ordinal, childDatatype))
.collect(Collectors.toList());
}

private Object extractChildValue(Object element, int ordinal, DataType childDatatype) {
checkArgument(element instanceof Row);
Row row = (Row) element;

if (row.isNullAt(ordinal)) {
return null;
}

return extractTypedValue(row, ordinal, childDatatype);
}

private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) {
// Primitive Types
if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal);
if (childDatatype instanceof ByteType) return row.getByte(ordinal);
if (childDatatype instanceof ShortType) return row.getShort(ordinal);
if (childDatatype instanceof IntegerType ||
childDatatype instanceof DateType) return row.getInt(ordinal);
if (childDatatype instanceof LongType ||
childDatatype instanceof TimestampType) return row.getLong(ordinal);
if (childDatatype instanceof FloatType) return row.getFloat(ordinal);
if (childDatatype instanceof DoubleType) return row.getDouble(ordinal);

// Complex Types
if (childDatatype instanceof StringType) return row.getString(ordinal);
if (childDatatype instanceof BinaryType) return row.getBinary(ordinal);
if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal);

// Nested Types
if (childDatatype instanceof StructType) return row.getStruct(ordinal);
if (childDatatype instanceof ArrayType) return row.getArray(ordinal);
if (childDatatype instanceof MapType) return row.getMap(ordinal);

throw new UnsupportedOperationException(
String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName()));
}
};
}
Expand Down

0 comments on commit 3f535f7

Please sign in to comment.