Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
manuzhang committed Jun 25, 2024
1 parent 5c6da4a commit 27cf7b7
Show file tree
Hide file tree
Showing 13 changed files with 340 additions and 154 deletions.
10 changes: 6 additions & 4 deletions api/src/main/java/org/apache/iceberg/actions/MigrateTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.iceberg.actions;

import java.util.Map;
import java.util.concurrent.ExecutorService;

/** An action that migrates an existing table to Iceberg. */
public interface MigrateTable extends Action<MigrateTable, MigrateTable.Result> {
Expand Down Expand Up @@ -61,13 +62,14 @@ default MigrateTable backupTableName(String tableName) {
}

/**
* Sets the number of threads to use for file reading. The default is 1.
* Sets the executor service to use for parallel file reading. The default is not using executor
* service.
*
* @param numThreads the number of threads
* @param service executor service
* @return this for method chaining
*/
default MigrateTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
default MigrateTable executeMigrateWith(ExecutorService service) {
throw new UnsupportedOperationException("Setting executor service is not supported");
}

/** The action result that contains a summary of the execution. */
Expand Down
10 changes: 6 additions & 4 deletions api/src/main/java/org/apache/iceberg/actions/SnapshotTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.iceberg.actions;

import java.util.Map;
import java.util.concurrent.ExecutorService;

/** An action that creates an independent snapshot of an existing table. */
public interface SnapshotTable extends Action<SnapshotTable, SnapshotTable.Result> {
Expand Down Expand Up @@ -58,13 +59,14 @@ public interface SnapshotTable extends Action<SnapshotTable, SnapshotTable.Resul
SnapshotTable tableProperty(String key, String value);

/**
* Sets the number of threads to use for file reading. The default is 1.
* Sets the executor service to use for parallel file reading. The default is not using executor
* service.
*
* @param numThreads the number of threads
* @param service executor service
* @return this for method chaining
*/
default SnapshotTable parallelism(int numThreads) {
throw new UnsupportedOperationException("Setting parallelism is not supported");
default SnapshotTable executeSnapshotWith(ExecutorService service) {
throw new UnsupportedOperationException("Setting executor service is not supported");
}

/** The action result that contains a summary of the execution. */
Expand Down
39 changes: 0 additions & 39 deletions data/src/main/java/org/apache/iceberg/data/MigrationService.java

This file was deleted.

49 changes: 46 additions & 3 deletions data/src/main/java/org/apache/iceberg/data/TableMigrationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.iceberg.orc.OrcMetrics;
import org.apache.iceberg.parquet.ParquetUtil;
import org.apache.iceberg.util.Tasks;
import org.apache.iceberg.util.ThreadPools;

public class TableMigrationUtil {
private static final PathFilter HIDDEN_PATH_FILTER =
Expand Down Expand Up @@ -108,7 +109,46 @@ public static List<DataFile> listPartition(
MetricsConfig metricsSpec,
NameMapping mapping,
int parallelism) {
ExecutorService service = null;
return listPartition(
partition,
partitionUri,
format,
spec,
conf,
metricsSpec,
mapping,
migrationService(parallelism));
}

/**
* Returns the data files in a partition by listing the partition location. Metrics are read from
* the files and the file reading is done in parallel by a specified number of threads.
*
* <p>For Parquet and ORC partitions, this will read metrics from the file footer. For Avro
* partitions, metrics other than row count are set to null.
*
* <p>Note: certain metrics, like NaN counts, that are only supported by Iceberg file writers but
* not file footers, will not be populated.
*
* @param partition map of column names to column values for the partition
* @param partitionUri partition location URI
* @param format partition format, avro, parquet or orc
* @param spec a partition spec
* @param conf a Hadoop conf
* @param metricsSpec a metrics conf
* @param mapping a name mapping
* @param service executor service to use for file reading
* @return a List of DataFile
*/
public static List<DataFile> listPartition(
Map<String, String> partition,
String partitionUri,
String format,
PartitionSpec spec,
Configuration conf,
MetricsConfig metricsSpec,
NameMapping mapping,
ExecutorService service) {
try {
List<String> partitionValues =
spec.fields().stream()
Expand All @@ -126,8 +166,7 @@ public static List<DataFile> listPartition(
Tasks.Builder<Integer> task =
Tasks.range(fileStatus.size()).stopOnFailure().throwFailureWhenFinished();

if (parallelism > 1) {
service = MigrationService.get(parallelism);
if (service != null) {
task.executeWith(service);
}

Expand Down Expand Up @@ -210,4 +249,8 @@ private static DataFile buildDataFile(
.withPartitionValues(partitionValues)
.build();
}

public static ExecutorService migrationService(int parallelism) {
return parallelism == 1 ? null : ThreadPools.newWorkerPool("table-migration", parallelism);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,19 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.data.MigrationService;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestMigrateTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -245,7 +234,7 @@ public void testMigrateEmptyTable() throws Exception {
}

@TestTemplate
public void testMigrateWithSingleThread() throws IOException {
public void testMigrateWithParallelism() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
Expand All @@ -255,35 +244,14 @@ public void testMigrateWithSingleThread() throws IOException {
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

MigrationService service = mock(MigrationService.class);
sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 1);
verifyNoInteractions(service);
}

@TestTemplate
public void testMigrateWithMultiThreads() throws IOException {
assumeThat(catalogName).isEqualToIgnoringCase("spark_catalog");
List<Object[]> result =
sql("CALL %s.system.migrate(table => '%s', parallelism => %d)", catalogName, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);

String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
tableName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

try (MockedStatic<MigrationService> service = mockStatic(MigrationService.class)) {
int parallelism = 5;
ExecutorService executorService = mock(ExecutorService.class);
service.when(() -> MigrationService.get(eq(parallelism))).thenReturn(executorService);
Future future = mock(Future.class);
when(executorService.submit(any(Runnable.class))).thenReturn(future);
when(future.isDone()).thenReturn(true);

sql(
"CALL %s.system.migrate(table => '%s', parallelism => %d)",
catalogName, tableName, parallelism);
verify(executorService, times(2)).submit(any(Runnable.class));
}
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(2L, "b")),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@TestTemplate
Expand All @@ -297,7 +265,7 @@ public void testMigrateWithInvalidParallelism() throws IOException {
sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", tableName);

Assertions.assertThatThrownBy(
assertThatThrownBy(
() ->
sql(
"CALL %s.system.migrate(table => '%s', parallelism => %d)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,19 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assumptions.assumeThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.data.MigrationService;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.AnalysisException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;

@ExtendWith(ParameterizedTestExtension.class)
public class TestSnapshotTableProcedure extends ExtensionsTestBase {
Expand Down Expand Up @@ -237,55 +226,35 @@ public void testInvalidSnapshotsCases() throws IOException {
}

@TestTemplate
public void testSnapshotWithSingleThread() throws IOException {
public void testSnapshotWithParallelism() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

ExecutorService service = mock(ExecutorService.class);
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, 1);
verifyNoInteractions(service);
}

@TestTemplate
public void testSnapshotWithMultiThreads() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

try (MockedStatic<MigrationService> service = mockStatic(MigrationService.class)) {
int parallelism = 5;
ExecutorService executorService = mock(ExecutorService.class);
service.when(() -> MigrationService.get(eq(parallelism))).thenReturn(executorService);
Future future = mock(Future.class);
when(executorService.submit(any(Runnable.class))).thenReturn(future);
when(future.isDone()).thenReturn(true);

sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, parallelism);
verify(executorService, times(2)).submit(any(Runnable.class));
}
List<Object[]> result =
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
catalogName, sourceName, tableName, 2);
assertEquals("Procedure output must match", ImmutableList.of(row(2L)), result);
assertEquals(
"Should have expected rows",
ImmutableList.of(row(1L, "a"), row(2L, "b")),
sql("SELECT * FROM %s ORDER BY id", tableName));
}

@TestTemplate
public void testMigrateWithInvalidParallelism() throws IOException {
public void testSnapshotWithInvalidParallelism() throws IOException {
String location = Files.createTempDirectory(temp, "junit").toFile().toString();
sql(
"CREATE TABLE %s (id bigint NOT NULL, data string) USING parquet LOCATION '%s'",
sourceName, location);
sql("INSERT INTO TABLE %s VALUES (1, 'a')", sourceName);
sql("INSERT INTO TABLE %s VALUES (2, 'b')", sourceName);

Assertions.assertThatThrownBy(
assertThatThrownBy(
() ->
sql(
"CALL %s.system.snapshot(source_table => '%s', table => '%s', parallelism => %d)",
Expand Down
Loading

0 comments on commit 27cf7b7

Please sign in to comment.