Skip to content

Commit

Permalink
Spark 3.5: Parallelize reading files in snapshot and migrate procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang committed Apr 17, 2024
1 parent abf238a commit aa3aef1
Show file tree
Hide file tree
Showing 12 changed files with 270 additions and 19 deletions.
10 changes: 10 additions & 0 deletions api/src/main/java/org/apache/iceberg/actions/MigrateTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ default MigrateTable backupTableName(String tableName) {
throw new UnsupportedOperationException("Backup table name cannot be specified");
}

/**
* Sets the number of threads to use for file reading. The default is 1.
*
* @param numThreads the number of threads
* @return this for method chaining
*/
default MigrateTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
}

/** The action result that contains a summary of the execution. */
interface Result {
/** Returns the number of migrated data files. */
Expand Down
10 changes: 10 additions & 0 deletions api/src/main/java/org/apache/iceberg/actions/SnapshotTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ public interface SnapshotTable extends Action<SnapshotTable, SnapshotTable.Resul
*/
SnapshotTable tableProperty(String key, String value);

/**
* Sets the number of threads to use for file reading. The default is 1.
*
* @param numThreads the number of threads
* @return this for method chaining
*/
default SnapshotTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
}

/** The action result that contains a summary of the execution. */
interface Result {
/** Returns the number of imported data files. */
Expand Down
39 changes: 39 additions & 0 deletions data/src/main/java/org/apache/iceberg/data/MigrationService.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.data;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;

/** Have a separate class for getting ExecutorService to make it testable with static mock */
public class MigrationService {

private MigrationService() {}

public static ExecutorService get(int parallelism) {
return MoreExecutors.getExitingExecutorService(
(ThreadPoolExecutor)
Executors.newFixedThreadPool(
parallelism,
new ThreadFactoryBuilder().setNameFormat("table-migration-%d").build()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
Expand All @@ -44,8 +42,6 @@
import org.apache.iceberg.mapping.NameMapping;
import org.apache.iceberg.orc.OrcMetrics;
import org.apache.iceberg.parquet.ParquetUtil;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors;
import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.iceberg.util.Tasks;

public class TableMigrationUtil {
Expand Down Expand Up @@ -131,7 +127,7 @@ public static List<DataFile> listPartition(
Tasks.range(fileStatus.size()).stopOnFailure().throwFailureWhenFinished();

if (parallelism > 1) {
service = migrationService(parallelism);
service = MigrationService.get(parallelism);
task.executeWith(service);
}

Expand Down Expand Up @@ -214,12 +210,4 @@ private static DataFile buildDataFile(
.withPartitionValues(partitionValues)
.build();
}

private static ExecutorService migrationService(int parallelism) {
return MoreExecutors.getExitingExecutorService(
(ThreadPoolExecutor)
Executors.newFixedThreadPool(
parallelism,
new ThreadFactoryBuilder().setNameFormat("table-migration-%d").build()));
}
}
4 changes: 3 additions & 1 deletion docs/docs/spark-procedures.md
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ See [`migrate`](#migrate) to replace an existing table with an Iceberg table.
| `table` | ✔️ | string | Name of the new Iceberg table to create |
| `location` | | string | Table location for the new table (delegated to the catalog by default) |
| `properties` || map<string, string> | Properties to add to the newly created table |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

#### Output

Expand Down Expand Up @@ -588,6 +589,7 @@ By default, the original table is retained with the name `table_BACKUP_`.
| `properties` || map<string, string> | Properties for the new Iceberg table |
| `drop_backup` | | boolean | When true, the original table will not be retained as backup (defaults to false) |
| `backup_table_name` | | string | Name of the table that will be retained as backup (defaults to `table_BACKUP_`) |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

#### Output

Expand Down Expand Up @@ -629,7 +631,7 @@ will then treat these files as if they are part of the set of files owned by Ic
| `source_table` | ✔️ | string | Table where files should come from, paths are also possible in the form of \`file_format\`.\`path\` |
| `partition_filter` || map<string, string> | A map of partitions in the source table to import from |
| `check_duplicate_files` || boolean | Whether to prevent files existing in the table from being added (defaults to true) |
| `parallelism` | | int | number of threads to use for file reading (defaults to 1) |
| `parallelism` | | int | Number of threads to use for file reading (defaults to 1) |

Warning : Schema is not validated, adding files with different schema to the Iceberg table will cause issues.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,31 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.nio.file.Files;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.data.MigrationService;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestMigrateTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -232,4 +244,66 @@ public void testMigrateEmptyTable() throws Exception {
Object result = scalarSql("CALL %s.system.migrate('%s')", catalogName, tableName);
assertThat(result).isEqualTo(0L);
}

@TestTemplate
public void testMigrateWithSingleThread() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

MigrationService service = mock(MigrationService.class);
sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 1);
verifyNoInteractions(service);
}

@TestTemplate
public void testMigrateWithMultiThreads() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

try (MockedStatic<MigrationService> service = mockStatic(MigrationService.class)) {
int parallelism = 5;
ExecutorService executorService = mock(ExecutorService.class);
service.when(() -> MigrationService.get(eq(parallelism))).thenReturn(executorService);
Future future = mock(Future.class);
when(executorService.submit(any(Runnable.class))).thenReturn(future);
when(future.isDone()).thenReturn(true);

sql(
"CALL %s.system.migrate(table => '%s', parallelism => %d)",
catalogName, tableName, parallelism);
verify(executorService, times(2)).submit(any(Runnable.class));
}
}

@TestTemplate
public void testMigrateWithInvalidParallelism() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.migrate(table => '%s', parallelism => %d)",
catalogName, tableName, -1))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,31 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.nio.file.Files;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.data.MigrationService;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestSnapshotTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -223,4 +235,62 @@ public void testInvalidSnapshotsCases() throws IOException {
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Cannot handle an empty identifier for argument table");
}

@TestTemplate
public void testSnapshotWithSingleThread() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

ExecutorService service = mock(ExecutorService.class);
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, 1);
verifyNoInteractions(service);
}

@TestTemplate
public void testSnapshotWithMultiThreads() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

try (MockedStatic<MigrationService> service = mockStatic(MigrationService.class)) {
int parallelism = 5;
ExecutorService executorService = mock(ExecutorService.class);
service.when(() -> MigrationService.get(eq(parallelism))).thenReturn(executorService);
Future future = mock(Future.class);
when(executorService.submit(any(Runnable.class))).thenReturn(future);
when(future.isDone()).thenReturn(true);

sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, parallelism);
verify(executorService, times(2)).submit(any(Runnable.class));
}
}

@TestTemplate
public void testMigrateWithInvalidParallelism() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

Assertions.assertThatThrownBy(
() ->
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, -1))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Parallelism should be larger than 0");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,34 @@ public static void importSparkTable(
spark, sourceTableIdent, targetTable, stagingDir, partitionFilter, checkDuplicateFiles, 1);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
* <p>The import uses the Spark session to get table metadata. It assumes no operation is going on
* the original and target table and thus is not thread-safe.
*
* @param spark a Spark session
* @param sourceTableIdent an identifier of the source Spark table
* @param targetTable an Iceberg table where to import the data
* @param stagingDir a staging directory to store temporary manifest files
* @param parallelism number of threads to use for file reading
*/
public static void importSparkTable(
SparkSession spark,
TableIdentifier sourceTableIdent,
Table targetTable,
String stagingDir,
int parallelism) {
importSparkTable(
spark,
sourceTableIdent,
targetTable,
stagingDir,
Collections.emptyMap(),
false,
parallelism);
}

/**
* Import files from an existing Spark table to an Iceberg table.
*
Expand Down Expand Up @@ -550,7 +578,7 @@ private static void importUnpartitionedSparkTable(
}

AppendFiles append = targetTable.newAppend();
files.forEach(append::appendFile);
files.stream().filter(java.util.Objects::nonNull).forEach(append::appendFile);
append.commit();
} catch (NoSuchDatabaseException e) {
throw SparkExceptionUtil.toUncheckedException(
Expand Down
Loading

0 comments on commit aa3aef1

Please sign in to comment.