Skip to content

Commit

Permalink
Spark 3.5: Parallelize reading files in snapshot and migrate procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang committed Apr 9, 2024
1 parent abf238a commit 6159067
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 5 deletions.
10 changes: 10 additions & 0 deletions api/src/main/java/org/apache/iceberg/actions/MigrateTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ default MigrateTable backupTableName(String tableName) {
throw new UnsupportedOperationException("Backup table name cannot be specified");
}

/**
* Sets the number of threads to use for file reading. The default is 1.
*
* @param numThreads the number of threads
* @return this for method chaining
*/
default MigrateTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
}

/** The action result that contains a summary of the execution. */
interface Result {
/** Returns the number of migrated data files. */
Expand Down
10 changes: 10 additions & 0 deletions api/src/main/java/org/apache/iceberg/actions/SnapshotTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ public interface SnapshotTable extends Action<SnapshotTable, SnapshotTable.Resul
*/
SnapshotTable tableProperty(String key, String value);

/**
* Sets the number of threads to use for file reading. The default is 1.
*
* @param numThreads the number of threads
* @return this for method chaining
*/
default SnapshotTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
}

/** The action result that contains a summary of the execution. */
interface Result {
/** Returns the number of imported data files. */
Expand Down
4 changes: 3 additions & 1 deletion docs/docs/spark-procedures.md
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ See [`migrate`](#migrate) to replace an existing table with an Iceberg table.
| `table` | ✔️ | string | Name of the new Iceberg table to create |
| `location` | | string | Table location for the new table (delegated to the catalog by default) |
| `properties` || map<string, string> | Properties to add to the newly created table |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

#### Output

Expand Down Expand Up @@ -588,6 +589,7 @@ By default, the original table is retained with the name `table_BACKUP_`.
| `properties` || map<string, string> | Properties for the new Iceberg table |
| `drop_backup` | | boolean | When true, the original table will not be retained as backup (defaults to false) |
| `backup_table_name` | | string | Name of the table that will be retained as backup (defaults to `table_BACKUP_`) |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

#### Output

Expand Down Expand Up @@ -629,7 +631,7 @@ will then treat these files as if they are part of the set of files owned by Ic
| `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` |
| `partition_filter` || map<string, string> | A map of partitions in the source table to import from |
| `check_duplicate_files` || boolean | Whether to prevent files existing in the table from being added (defaults to true) |
| `parallelism` | | int | number of threads to use for file reading (defaults to 1) |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

Warning : Schema is not validated, adding files with different schema to the Iceberg table will cause issues.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.UUID;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.iceberg.ImmutableGenericPartitionStatisticsFile;
import org.apache.iceberg.PartitionStatisticsFile;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.PositionOutputStream;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.jetbrains.annotations.NotNull;

public class ProcedureUtil {

Expand All @@ -51,4 +56,29 @@ static String statsFileLocation(String tableLocation) {
String statsFileName = "stats-file-" + UUID.randomUUID();
return tableLocation.replaceFirst("file:", "") + "/metadata/" + statsFileName;
}

static class TestExecutorService extends ThreadPoolExecutor {

private int executedTasks = 0;

TestExecutorService() {
super(
1,
1,
0L,
TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(),
new ThreadFactoryBuilder().build());
}

@Override
public void execute(@NotNull Runnable task) {
super.execute(task);
executedTasks++;
}

public int getExecutedTasks() {
return executedTasks;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,24 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mockStatic;

import java.io.IOException;
import java.nio.file.Files;
import java.util.Map;
import java.util.concurrent.ThreadPoolExecutor;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.spark.sql.AnalysisException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestMigrateTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -232,4 +237,42 @@ public void testMigrateEmptyTable() throws Exception {
Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName);
assertThat(result).isEqualTo(0L);
}

@TestTemplate
public void testMigrateWithParallelism() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

testWithParallelism(-1);
testWithParallelism(0);
testWithParallelism(1);
testWithParallelism(5);
}

private void testWithParallelism(int parallelism) throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

try (MockedStatic<MoreExecutors> executors = mockStatic(MoreExecutors.class)) {
ProcedureUtil.TestExecutorService testService = new ProcedureUtil.TestExecutorService();
executors
.when(() -> MoreExecutors.getExitingExecutorService(any(ThreadPoolExecutor.class)))
.thenReturn(testService);

Object result =
scalarSql(
"CALL %s.system.migrate(table => '%s', parallelism => %d)",
catalogName, tableName, parallelism);
assertThat(result).as("Should have added two files").isEqualTo(2L);

assertThat(testService.getExecutedTasks()).isEqualTo(parallelism > 1 ? 2 : 0);
testService.shutdown();
}

sql("DROP TABLE IF EXISTS %s PURGE", tableName);
sql("DROP TABLE IF EXISTS %s_BACKUP_", tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;

import java.io.IOException;
import java.nio.file.Files;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadPoolExecutor;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.spark.sql.AnalysisException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestSnapshotTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -223,4 +231,53 @@ public void testInvalidSnapshotsCases() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot handle an empty identifier for argument table");
}

@TestTemplate
public void testSnapshotWithParallelism() throws IOException {
testWithParallelism(-1);
testWithParallelism(0);
testWithParallelism(1);
testWithParallelism(5);
}

private void testWithParallelism(int parallelism) throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

try (MockedStatic<MoreExecutors> executors = mockStatic(MoreExecutors.class)) {
ProcedureUtil.TestExecutorService testService = new ProcedureUtil.TestExecutorService();
executors
.when(() -> MoreExecutors.getExitingExecutorService(any(ThreadPoolExecutor.class)))
.thenReturn(testService);
executors
.when(
() ->
MoreExecutors.getExitingScheduledExecutorService(
any(ScheduledThreadPoolExecutor.class)))
.thenReturn(mock(ScheduledExecutorService.class));
Object result =
scalarSql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, parallelism);

assertThat(result).as("Should have added two files").isEqualTo(2L);
assertThat(testService.getExecutedTasks()).isEqualTo(parallelism > 1 ? 2 : 0);
}

Table createdTable = validationCatalog.loadTable(tableIdent);
String tableLocation = createdTable.location();
assertThat(tableLocation)
.as("Table should not have the original location")
.isNotEqualTo(location);

assertThat(sql("SELECT * FROM %s ORDER BY id", tableName))
.containsExactly(row(1L, "a"), row(2L, "b"));

sql("DROP TABLE IF EXISTS %s", tableName);
sql("DROP TABLE IF EXISTS %s PURGE", sourceName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,34 @@ public static void importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
Collections.emptyMap(),
false,
parallelism);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public class MigrateTableSparkAction extends BaseTableCreationSparkAction<Migrat

private Identifier backupIdent;
private boolean dropBackup = false;
private int parallelism = 1;

MigrateTableSparkAction(
SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
Expand Down Expand Up @@ -108,6 +109,12 @@ public MigrateTableSparkAction backupTableName(String tableName) {
return this;
}

@Override
public MigrateTableSparkAction parallelism(int numThreads) {
this.parallelism = numThreads;
return this;
}

@Override
public MigrateTable.Result execute() {
String desc = String.format("Migrating table %s", destTableIdent().toString());
Expand Down Expand Up @@ -137,7 +144,8 @@ private MigrateTable.Result doExecute() {
TableIdentifier v1BackupIdent = new TableIdentifier(backupIdent.name(), backupNamespace);
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1BackupIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(
spark(), v1BackupIdent, icebergTable, stagingLocation, parallelism);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class SnapshotTableSparkAction extends BaseTableCreationSparkAction<Snaps
private StagingTableCatalog destCatalog;
private Identifier destTableIdent;
private String destTableLocation = null;
private int parallelism = 1;

SnapshotTableSparkAction(
SparkSession spark, CatalogPlugin sourceCatalog, Identifier sourceTableIdent) {
Expand Down Expand Up @@ -98,6 +99,12 @@ public SnapshotTableSparkAction tableProperty(String property, String value) {
return this;
}

@Override
public SnapshotTable parallelism(int numThreads) {
this.parallelism = numThreads;
return this;
}

@Override
public SnapshotTable.Result execute() {
String desc = String.format("Snapshotting table %s as %s", sourceTableIdent(), destTableIdent);
Expand Down Expand Up @@ -126,7 +133,8 @@ private SnapshotTable.Result doExecute() {
TableIdentifier v1TableIdent = v1SourceTable().identifier();
String stagingLocation = getMetadataLocation(icebergTable);
LOG.info("Generating Iceberg metadata for {} in {}", destTableIdent(), stagingLocation);
SparkTableUtil.importSparkTable(spark(), v1TableIdent, icebergTable, stagingLocation);
SparkTableUtil.importSparkTable(
spark(), v1TableIdent, icebergTable, stagingLocation, parallelism);

LOG.info("Committing staged changes to {}", destTableIdent());
stagedTable.commitStagedChanges();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class MigrateTableProcedure extends BaseProcedure {
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.optional("properties", STRING_MAP),
ProcedureParameter.optional("drop_backup", DataTypes.BooleanType),
ProcedureParameter.optional("backup_table_name", DataTypes.StringType)
ProcedureParameter.optional("backup_table_name", DataTypes.StringType),
ProcedureParameter.optional("parallelism", DataTypes.IntegerType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -105,6 +106,10 @@ public InternalRow[] call(InternalRow args) {
migrateTableSparkAction = migrateTableSparkAction.backupTableName(backupTableName);
}

if (!args.isNullAt(4)) {
migrateTableSparkAction = migrateTableSparkAction.parallelism(args.getInt(4));
}

MigrateTable.Result result = migrateTableSparkAction.execute();
return new InternalRow[] {newInternalRow(result.migratedDataFilesCount())};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class SnapshotTableProcedure extends BaseProcedure {
ProcedureParameter.required("source_table", DataTypes.StringType),
ProcedureParameter.required("table", DataTypes.StringType),
ProcedureParameter.optional("location", DataTypes.StringType),
ProcedureParameter.optional("properties", STRING_MAP)
ProcedureParameter.optional("properties", STRING_MAP),
ProcedureParameter.optional("parallelism", DataTypes.IntegerType)
};

private static final StructType OUTPUT_TYPE =
Expand Down Expand Up @@ -102,6 +103,10 @@ public InternalRow[] call(InternalRow args) {
action.tableLocation(snapshotLocation);
}

if (!args.isNullAt(4)) {
action = action.parallelism(args.getInt(4));
}

SnapshotTable.Result result = action.tableProperties(properties).execute();
return new InternalRow[] {newInternalRow(result.importedDataFilesCount())};
}
Expand Down

0 comments on commit 6159067

Please sign in to comment.