Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 3.3: Support storage-partitioned joins #6371

Merged
merged 3 commits into from
Dec 24, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api/src/main/java/org/apache/iceberg/types/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ public List<NestedField> fields() {
return lazyFieldList();
}

public boolean containsField(int id) {
return lazyFieldsById().containsKey(id);
}

public NestedField field(String name) {
return lazyFieldsByName().get(name);
}
Expand Down
27 changes: 21 additions & 6 deletions core/src/main/java/org/apache/iceberg/Partitioning.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ public Void alwaysNull(int fieldId, String sourceName, int sourceId) {
/**
* Builds a grouping key type considering all provided specs.
*
* @param specs one or many specs
* @return the constructed grouping key type
*/
public static StructType groupingKeyType(Collection<PartitionSpec> specs) {
return groupingKeyType(null, specs);
}

/**
* Builds a grouping key type considering the provided schema and specs.
*
* <p>A grouping key defines how data is split between files and consists of partition fields with
* non-void transforms that are present in each provided spec. Iceberg guarantees that records
* with different values for the grouping key are disjoint and are stored in separate files.
Expand All @@ -215,11 +225,15 @@ public Void alwaysNull(int fieldId, String sourceName, int sourceId) {
* that have the same field ID but use a void transform under the hood. Such fields cannot be part
* of the grouping key as void transforms always return null.
*
* <p>If the provided schema is not null, this method will only take into account partition fields
* on top of columns present in the schema. Otherwise, all partition fields will be considered.
*
* @param schema a schema specifying a set of source columns to consider (null to consider all)
* @param specs one or many specs
* @return the constructed grouping key type
*/
public static StructType groupingKeyType(Collection<PartitionSpec> specs) {
return buildPartitionProjectionType("grouping key", specs, commonActiveFieldIds(specs));
public static StructType groupingKeyType(Schema schema, Collection<PartitionSpec> specs) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to take into account the schema as we may project only particular columns.

return buildPartitionProjectionType("grouping key", specs, commonActiveFieldIds(schema, specs));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we limit the specs passed to this method to just the ones that are used by manifests that are scanned during planning? (I may answer this myself later, but I want to write the question down)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is determined based on scan tasks that match our filter in SparkPartitioningAwareScan.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a table has multiple specs but we scan only tasks that belong to one spec, we should take into account only the one that is being queried.

}

/**
Expand Down Expand Up @@ -341,15 +355,15 @@ private static Set<Integer> allFieldIds(Collection<PartitionSpec> specs) {
}

// collects IDs of partition fields with non-void transforms that are present in each spec
private static Set<Integer> commonActiveFieldIds(Collection<PartitionSpec> specs) {
private static Set<Integer> commonActiveFieldIds(Schema schema, Collection<PartitionSpec> specs) {
Set<Integer> commonActiveFieldIds = Sets.newHashSet();

int specIndex = 0;
for (PartitionSpec spec : specs) {
if (specIndex == 0) {
commonActiveFieldIds.addAll(activeFieldIds(spec));
commonActiveFieldIds.addAll(activeFieldIds(schema, spec));
} else {
commonActiveFieldIds.retainAll(activeFieldIds(spec));
commonActiveFieldIds.retainAll(activeFieldIds(schema, spec));
}

specIndex++;
Expand All @@ -358,8 +372,9 @@ private static Set<Integer> commonActiveFieldIds(Collection<PartitionSpec> specs
return commonActiveFieldIds;
}

private static List<Integer> activeFieldIds(PartitionSpec spec) {
private static List<Integer> activeFieldIds(Schema schema, PartitionSpec spec) {
return spec.fields().stream()
.filter(field -> schema == null || schema.findField(field.sourceId()) != null)
.filter(field -> !isVoidTransform(field))
.map(PartitionField::fieldId)
.collect(Collectors.toList());
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/java/org/apache/iceberg/TestPartitioning.java
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,19 @@ public void testGroupingKeyTypeWithEvolvedUnpartitionedSpec() {
Assert.assertEquals("Types must match", expectedType, actualType);
}

@Test
public void testGroupingKeyTypeWithProjectedSchema() {
TestTables.TestTable table =
TestTables.create(tableDir, "test", SCHEMA, BY_CATEGORY_DATA_SPEC, V1_FORMAT_VERSION);

Schema projectedSchema = table.schema().select("id", "data");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so is the idea to use the pushed projection to limit the fields used? I think that's a great idea!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Otherwise, we will break Spark because we can report non-projected columns.


StructType expectedType =
StructType.of(NestedField.optional(1001, "data", Types.StringType.get()));
StructType actualType = Partitioning.groupingKeyType(projectedSchema, table.specs().values());
Assert.assertEquals("Types must match", expectedType, actualType);
}

@Test
public void testGroupingKeyTypeWithIncompatibleSpecEvolution() {
TestTables.TestTable table =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,15 @@ public void testSparkTableAddDropPartitions() throws Exception {
assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)");

sql("ALTER TABLE %s ADD PARTITION FIELD truncate(data, 4)", tableName);
assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)");
assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)");
Copy link
Contributor Author

@aokolnychyi aokolnychyi Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to change Spark3Util.toTransforms to match TruncateFunction, which expects width first.


sql("ALTER TABLE %s ADD PARTITION FIELD years(ts)", tableName);
assertPartitioningEquals(sparkTable(), 3, "years(ts)");

sql("ALTER TABLE %s DROP PARTITION FIELD years(ts)", tableName);
assertPartitioningEquals(sparkTable(), 2, "truncate(data, 4)");
assertPartitioningEquals(sparkTable(), 2, "truncate(4, data)");

sql("ALTER TABLE %s DROP PARTITION FIELD truncate(data, 4)", tableName);
sql("ALTER TABLE %s DROP PARTITION FIELD truncate(4, data)", tableName);
assertPartitioningEquals(sparkTable(), 1, "bucket(16, id)");

sql("ALTER TABLE %s DROP PARTITION FIELD shard", tableName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -143,4 +148,32 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception
executorService.shutdown();
Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
}

@Test
public void testRuntimeFilteringWithReportedPartitioning() throws NoSuchTableException {
createAndInitPartitionedTable();

append(new Employee(1, "hr"), new Employee(3, "hr"));
append(new Employee(1, "hardware"), new Employee(2, "hardware"));

Map<String, String> sqlConf =
ImmutableMap.of(
SQLConf.V2_BUCKETING_ENABLED().key(),
"true",
SparkSQLProperties.PRESERVE_DATA_GROUPING,
"true");

withSQLConf(sqlConf, () -> sql("DELETE FROM %s WHERE id = 2", tableName));

Table table = validationCatalog.loadTable(tableIdent);
Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));

Snapshot currentSnapshot = table.currentSnapshot();
validateCopyOnWrite(currentSnapshot, "1", "1", "1");
Copy link
Contributor

@rdblue rdblue Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it would be nice to not pass the counts as strings...

Not that we should fix it in this PR. But maybe we can introduce a version of validateCopyOnWrite that accepts ints and calls the string version.

Copy link
Contributor Author

@aokolnychyi aokolnychyi Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we did historically to simplify the validation as we compare summary map (i.e. strings).
I'll check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this properly would touch lots of existing places. I'll follow up to fix this separately.


assertEquals(
"Should have expected rows",
ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")),
sql("SELECT * FROM %s ORDER BY id, dep", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -148,4 +152,45 @@ public synchronized void testMergeWithConcurrentTableRefresh() throws Exception
executorService.shutdown();
Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
}

@Test
public void testRuntimeFilteringWithReportedPartitioning() {
createAndInitTable("id INT, dep STRING");
sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);

append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }");
append(
tableName,
"{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }");

createOrReplaceView("source", Collections.singletonList(2), Encoders.INT());

Map<String, String> sqlConf =
ImmutableMap.of(
SQLConf.V2_BUCKETING_ENABLED().key(),
"true",
SparkSQLProperties.PRESERVE_DATA_GROUPING,
"true");

withSQLConf(
sqlConf,
() ->
sql(
"MERGE INTO %s t USING source s "
+ "ON t.id == s.value "
+ "WHEN MATCHED THEN "
+ " UPDATE SET id = -1",
tableName));

Table table = validationCatalog.loadTable(tableIdent);
Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));

Snapshot currentSnapshot = table.currentSnapshot();
validateCopyOnWrite(currentSnapshot, "1", "1", "1");

assertEquals(
"Should have expected rows",
ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")),
sql("SELECT * FROM %s ORDER BY id, dep", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.data.GenericRecord;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.spark.Spark3Util;
import org.apache.iceberg.spark.SparkSQLProperties;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assertions;
import org.junit.Assert;
import org.junit.Assume;
Expand Down Expand Up @@ -140,4 +144,35 @@ public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception
executorService.shutdown();
Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
}

@Test
public void testRuntimeFilteringWithReportedPartitioning() {
createAndInitTable("id INT, dep STRING");
sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);

append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }");
append(
tableName,
"{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }");

Map<String, String> sqlConf =
ImmutableMap.of(
SQLConf.V2_BUCKETING_ENABLED().key(),
"true",
SparkSQLProperties.PRESERVE_DATA_GROUPING,
"true");

withSQLConf(sqlConf, () -> sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", tableName));

Table table = validationCatalog.loadTable(tableIdent);
Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));

Snapshot currentSnapshot = table.currentSnapshot();
validateCopyOnWrite(currentSnapshot, "1", "1", "1");

assertEquals(
"Should have expected rows",
ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")),
sql("SELECT * FROM %s ORDER BY id, dep", tableName));
}
}
Loading