Skip to content

Commit

Permalink
[Kernel] Add utility to filter for columns in a schema (#4151)
Browse files Browse the repository at this point in the history
## Description
This utility method can be used to filter for columns with invariants or
a certain data type (timestamp_ntz, variant, etc.) to implicitly
identify the table features that should be enabled.

## How was this patch tested?
Unit tests.
  • Loading branch information
vkorukanti authored Feb 14, 2025
1 parent 0b818bf commit dfc50d6
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import io.delta.kernel.internal.actions.Metadata;
import io.delta.kernel.internal.actions.Protocol;
import io.delta.kernel.internal.util.SchemaUtils;
import io.delta.kernel.internal.util.Tuple2;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructType;
import java.util.*;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -260,15 +262,6 @@ private static boolean metadataRequiresWriterFeatureToBeEnabled(
}
}

private static void validateNoInvariants(StructType tableSchema) {
boolean hasInvariants =
tableSchema.fields().stream()
.anyMatch(field -> field.getMetadata().contains("delta.invariants"));
if (hasInvariants) {
throw columnInvariantsNotSupported();
}
}

private static boolean isWriterFeatureSupported(Protocol protocol, String featureName) {
List<String> writerFeatures = protocol.getWriterFeatures();
if (writerFeatures == null) {
Expand All @@ -277,4 +270,32 @@ private static boolean isWriterFeatureSupported(Protocol protocol, String featur
return writerFeatures.contains(featureName)
&& protocol.getMinWriterVersion() >= TABLE_FEATURES_MIN_WRITER_VERSION;
}

private static void validateNoInvariants(StructType tableSchema) {
if (hasInvariants(tableSchema)) {
throw DeltaErrors.columnInvariantsNotSupported();
}
}

static boolean hasInvariants(StructType tableSchema) {
return !SchemaUtils.filterRecursively(
tableSchema,
/* recurseIntoMapOrArrayElements = */ false, // constraints are not allowed in maps or
// arrays
/* stopOnFirstMatch */ true,
/* filter */ field -> field.getMetadata().contains("delta.invariants"))
.isEmpty();
}

/**
* Check if the table schema has a column of type. Caution: works only for the primitive types.
*/
static boolean hasTypeColumn(StructType tableSchema, DataType type) {
return !SchemaUtils.filterRecursively(
tableSchema,
/* recurseIntoMapOrArrayElements = */ true,
/* stopOnFirstMatch */ true,
/* filter */ field -> field.getDataType().equals(type))
.isEmpty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.delta.kernel.internal.DeltaErrors;
import io.delta.kernel.types.*;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -177,6 +178,30 @@ public static int findColIndex(StructType schema, String colName) {
return -1;
}

/**
* Finds `StructField`s that match a given check `f`. Returns the path to the column, and the
* field of all fields that match the check.
*
* @param schema The DataType to filter
* @param recurseIntoMapOrArrayElements This flag defines whether we should recurse into elements
* types of ArrayType and MapType.
* @param f The function to check each StructField
* @param stopOnFirstMatch If true, stop the search when the first match is found
* @return A List of pairs, each containing a List of Strings (the path) and a StructField. If
* {@code stopOnFirstMatch} is true, the list will contain at most one element.
*/
public static List<Tuple2<List<String>, StructField>> filterRecursively(
DataType schema,
boolean recurseIntoMapOrArrayElements,
boolean stopOnFirstMatch,
Function<StructField, Boolean> f) {
return recurseIntoComplexTypes(
schema, new ArrayList<>(), recurseIntoMapOrArrayElements, stopOnFirstMatch, f);
}

/////////////////////////////////////////////////////////////////////////////////////////////////
/// Private methods ///
/////////////////////////////////////////////////////////////////////////////////////////////////
/**
* Returns all column names in this schema as a flat list. For example, a schema like:
*
Expand All @@ -192,35 +217,85 @@ public static int findColIndex(StructType schema, String colName) {
* </pre>
*/
private static List<String> flattenNestedFieldNames(StructType schema) {
List<String> fieldNames = new ArrayList<>();
for (StructField field : schema.fields()) {
String escapedName = escapeDots(field.getName());
fieldNames.add(escapedName);
fieldNames.addAll(flattenNestedFieldNamesRecursive(escapedName, field.getDataType()));
}
return fieldNames;
List<Tuple2<List<String>, StructField>> columnPathToStructFields =
filterRecursively(
schema,
true /* recurseIntoMapOrArrayElements */,
false /* stopOnFirstMatch */,
sf -> true);

return columnPathToStructFields.stream()
.map(t -> t._1)
.map(SchemaUtils::concatWithDot)
.collect(Collectors.toList());
}

private static List<String> flattenNestedFieldNamesRecursive(String prefix, DataType type) {
List<String> fieldNames = new ArrayList<>();
private static List<Tuple2<List<String>, StructField>> recurseIntoComplexTypes(
DataType type,
List<String> columnPath,
boolean recurseIntoMapOrArrayElements,
boolean stopOnFirstMatch,
Function<StructField, Boolean> f) {
List<Tuple2<List<String>, StructField>> filtered = new ArrayList<>();

if (type instanceof StructType) {
for (StructField field : ((StructType) type).fields()) {
String escapedName = escapeDots(field.getName());
fieldNames.add(prefix + "." + escapedName);
fieldNames.addAll(
flattenNestedFieldNamesRecursive(prefix + "." + escapedName, field.getDataType()));
StructType s = (StructType) type;
for (StructField sf : s.fields()) {
List<String> newColumnPath = new ArrayList<>(columnPath);
newColumnPath.add(sf.getName());

if (f.apply(sf)) {
filtered.add(new Tuple2<>(newColumnPath, sf));
if (stopOnFirstMatch) {
return filtered;
}
}

filtered.addAll(
recurseIntoComplexTypes(
sf.getDataType(),
newColumnPath,
recurseIntoMapOrArrayElements,
stopOnFirstMatch,
f));

if (stopOnFirstMatch && !filtered.isEmpty()) {
return filtered;
}
}
} else if (type instanceof ArrayType) {
fieldNames.addAll(
flattenNestedFieldNamesRecursive(
prefix + ".element", ((ArrayType) type).getElementType()));
} else if (type instanceof MapType) {
MapType mapType = (MapType) type;
fieldNames.addAll(flattenNestedFieldNamesRecursive(prefix + ".key", mapType.getKeyType()));
fieldNames.addAll(
flattenNestedFieldNamesRecursive(prefix + ".value", mapType.getValueType()));
} else if (type instanceof ArrayType && recurseIntoMapOrArrayElements) {
ArrayType a = (ArrayType) type;
List<String> newColumnPath = new ArrayList<>(columnPath);
newColumnPath.add("element");
return recurseIntoComplexTypes(
a.getElementType(), newColumnPath, recurseIntoMapOrArrayElements, stopOnFirstMatch, f);
} else if (type instanceof MapType && recurseIntoMapOrArrayElements) {
MapType m = (MapType) type;
List<String> keyColumnPath = new ArrayList<>(columnPath);
keyColumnPath.add("key");
List<String> valueColumnPath = new ArrayList<>(columnPath);
valueColumnPath.add("value");
filtered.addAll(
recurseIntoComplexTypes(
m.getKeyType(), keyColumnPath, recurseIntoMapOrArrayElements, stopOnFirstMatch, f));
if (stopOnFirstMatch && !filtered.isEmpty()) {
return filtered;
}
filtered.addAll(
recurseIntoComplexTypes(
m.getValueType(),
valueColumnPath,
recurseIntoMapOrArrayElements,
stopOnFirstMatch,
f));
}
return fieldNames;

return filtered;
}

/** column name by concatenating the column path elements (think of nested) with dots */
private static String concatWithDot(List<String> columnPath) {
return columnPath.stream().map(SchemaUtils::escapeDots).collect(Collectors.joining("."));
}

private static String escapeDots(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
package io.delta.kernel.internal.util

import io.delta.kernel.exceptions.KernelException
import io.delta.kernel.internal.util.SchemaUtils.validateSchema
import io.delta.kernel.internal.util.SchemaUtils.{filterRecursively, validateSchema}
import io.delta.kernel.types.IntegerType.INTEGER
import io.delta.kernel.types.{ArrayType, MapType, StringType, StructType}
import io.delta.kernel.types.LongType.LONG
import io.delta.kernel.types.TimestampType.TIMESTAMP
import io.delta.kernel.types.{ArrayType, MapType, StringType, StructField, StructType}
import org.scalatest.funsuite.AnyFunSuite

import java.util.Locale
import scala.collection.JavaConverters._

class SchemaUtilsSuite extends AnyFunSuite {
private def expectFailure(shouldContain: String*)(f: => Unit): Unit = {
Expand Down Expand Up @@ -273,4 +276,97 @@ class SchemaUtilsSuite extends AnyFunSuite {
}
}
}

///////////////////////////////////////////////////////////////////////////
// filterRecursively
///////////////////////////////////////////////////////////////////////////
val testSchema = new StructType()
.add("a", INTEGER)
.add("b", INTEGER)
.add("c", LONG)
.add("s", new StructType()
.add("a", TIMESTAMP)
.add("e", INTEGER)
.add("f", LONG)
.add("g", new StructType()
.add("a", INTEGER)
.add("b", TIMESTAMP)
.add("c", LONG)
).add("h", new MapType(
new StructType().add("a", TIMESTAMP),
new StructType().add("b", INTEGER),
true)
).add("i", new ArrayType(
new StructType().add("b", TIMESTAMP),
true)
)
).add("d", new MapType(
new StructType().add("b", TIMESTAMP),
new StructType().add("a", INTEGER),
true)
).add("e", new ArrayType(
new StructType()
.add("f", TIMESTAMP)
.add("b", INTEGER),
true)
)
val flattenedTestSchema = {
SchemaUtils.filterRecursively(
testSchema,
/* visitListMapTypes = */ true,
/* stopOnFirstMatch = */ false,
(v1: StructField) => true
).asScala.map(f => f._1.asScala.mkString(".") -> f._2).toMap
}
Seq(
// Format: (testPrefix, visitListMapTypes, stopOnFirstMatch, filter, expectedColumns)
("Filter by name 'b', stop on first match",
true, true, (field: StructField) => field.getName == "b", Seq("b")),
("Filter by name 'b', visit all matches",
false, false, (field: StructField) => field.getName == "b",
Seq("b", "s.g.b")),
("Filter by name 'b', visit all matches including nested structures",
true, false, (field: StructField) => field.getName == "b",
Seq(
"b",
"s.g.b",
"s.h.value.b",
"s.i.element.b",
"d.key.b",
"e.element.b"
)),
("Filter by TIMESTAMP type, stop on first match",
false, true, (field: StructField) => field.getDataType == TIMESTAMP,
Seq("s.a")),
("Filter by TIMESTAMP type, visit all matches including nested structures",
true, false, (field: StructField) => field.getDataType == TIMESTAMP,
Seq(
"s.a",
"s.g.b",
"s.h.key.a",
"s.i.element.b",
"d.key.b",
"e.element.f"
)),
("Filter by TIMESTAMP type and name 'f', visit all matches", true, false,
(field: StructField) => field.getDataType == TIMESTAMP && field.getName == "f",
Seq("e.element.f")),
("Filter by non-existent field name 'z'",
true, false, (field: StructField) => field.getName == "z", Seq())
).foreach {
case (testDescription, visitListMapTypes, stopOnFirstMatch, filter, expectedColumns) =>
test(s"filterRecursively - $testDescription | " +
s"visitListMapTypes=$visitListMapTypes, stopOnFirstMatch=$stopOnFirstMatch") {

val results =
filterRecursively(testSchema, visitListMapTypes, stopOnFirstMatch,
(v1: StructField) => filter(v1))
// convert to map of column path concatenated with '.' and the StructField
.asScala.map(f => (f._1.asScala.mkString("."), f._2)).toMap

// Assert that the number of results matches the expected columns
assert(results.size === expectedColumns.size)
assert(results === flattenedTestSchema.filterKeys(expectedColumns.contains))
}
}
}

0 comments on commit dfc50d6

Please sign in to comment.