Skip to content

Commit

Permalink
[HUDI-7277] Fix hoodie.bulkinsert.shuffle.parallelism not activated…
Browse files Browse the repository at this point in the history
… with no-partitioned table (#10532)

Signed-off-by: wulingqi <[email protected]>
  • Loading branch information
KnightChess authored Jan 20, 2024
1 parent 2823d78 commit edc45df
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ object HoodieDatasetBulkInsertHelper

val updatedSchema = StructType(metaFields ++ schema.fields)

val targetParallelism =
deduceShuffleParallelism(df, config.getBulkInsertShuffleParallelism)

val updatedDF = if (populateMetaFields) {
val keyGeneratorClassName = config.getStringOrThrow(HoodieWriteConfig.KEYGENERATOR_CLASS_NAME,
"Key-generator class name is required")
Expand Down Expand Up @@ -110,7 +113,7 @@ object HoodieDatasetBulkInsertHelper
}

val dedupedRdd = if (config.shouldCombineBeforeInsert) {
dedupeRows(prependedRdd, updatedSchema, config.getPreCombineField, SparkHoodieIndexFactory.isGlobalIndex(config))
dedupeRows(prependedRdd, updatedSchema, config.getPreCombineField, SparkHoodieIndexFactory.isGlobalIndex(config), targetParallelism)
} else {
prependedRdd
}
Expand All @@ -127,9 +130,6 @@ object HoodieDatasetBulkInsertHelper
HoodieUnsafeUtils.createDataFrameFrom(df.sparkSession, prependedQuery)
}

val targetParallelism =
deduceShuffleParallelism(updatedDF, config.getBulkInsertShuffleParallelism)

partitioner.repartitionRecords(updatedDF, targetParallelism)
}

Expand Down Expand Up @@ -193,7 +193,7 @@ object HoodieDatasetBulkInsertHelper
table.getContext.parallelize(writeStatuses.toList.asJava)
}

private def dedupeRows(rdd: RDD[InternalRow], schema: StructType, preCombineFieldRef: String, isGlobalIndex: Boolean): RDD[InternalRow] = {
private def dedupeRows(rdd: RDD[InternalRow], schema: StructType, preCombineFieldRef: String, isGlobalIndex: Boolean, targetParallelism: Int): RDD[InternalRow] = {
val recordKeyMetaFieldOrd = schema.fieldIndex(HoodieRecord.RECORD_KEY_METADATA_FIELD)
val partitionPathMetaFieldOrd = schema.fieldIndex(HoodieRecord.PARTITION_PATH_METADATA_FIELD)
// NOTE: Pre-combine field could be a nested field
Expand All @@ -212,16 +212,15 @@ object HoodieDatasetBulkInsertHelper
// since Spark might be providing us with a mutable copy (updated during the iteration)
(rowKey, row.copy())
}
.reduceByKey {
(oneRow, otherRow) =>
val onePreCombineVal = getNestedInternalRowValue(oneRow, preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
val otherPreCombineVal = getNestedInternalRowValue(otherRow, preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
if (onePreCombineVal.compareTo(otherPreCombineVal.asInstanceOf[AnyRef]) >= 0) {
oneRow
} else {
otherRow
}
}
.reduceByKey ((oneRow, otherRow) => {
val onePreCombineVal = getNestedInternalRowValue(oneRow, preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
val otherPreCombineVal = getNestedInternalRowValue(otherRow, preCombineFieldPath).asInstanceOf[Comparable[AnyRef]]
if (onePreCombineVal.compareTo(otherPreCombineVal.asInstanceOf[AnyRef]) >= 0) {
oneRow
} else {
otherRow
}
}, targetParallelism)
.values
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
import org.apache.avro.Schema;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.ReduceFunction;
import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.SparkListenerStageSubmitted;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.HoodieUnsafeUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.types.StructType;
Expand All @@ -59,6 +62,7 @@
import scala.Tuple2;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

Expand Down Expand Up @@ -348,4 +352,53 @@ public void testNoPropsSet() {
private ExpressionEncoder getEncoder(StructType schema) {
return SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema);
}

@Test
public void testBulkInsertParallelismParam() {
HoodieWriteConfig config = getConfigBuilder(schemaStr).withProps(getPropsAllSet("_row_key"))
.combineInput(true, true)
.withPreCombineField("ts").build();
int checkParallelism = 7;
config.setValue("hoodie.bulkinsert.shuffle.parallelism", String.valueOf(checkParallelism));
StageCheckBulkParallelismListener stageCheckBulkParallelismListener =
new StageCheckBulkParallelismListener("org.apache.hudi.HoodieDatasetBulkInsertHelper$.dedupeRows");
sqlContext.sparkContext().addSparkListener(stageCheckBulkParallelismListener);
List<Row> inserts = DataSourceTestUtils.generateRandomRows(10);
Dataset<Row> dataset = sqlContext.createDataFrame(inserts, structType).repartition(3);
assertNotEquals(checkParallelism, HoodieUnsafeUtils.getNumPartitions(dataset));
assertNotEquals(checkParallelism, sqlContext.sparkContext().defaultParallelism());
Dataset<Row> result = HoodieDatasetBulkInsertHelper.prepareForBulkInsert(dataset, config,
new NonSortPartitionerWithRows(), "000001111");
// trigger job
result.count();
assertEquals(checkParallelism, stageCheckBulkParallelismListener.getParallelism());
sqlContext.sparkContext().removeSparkListener(stageCheckBulkParallelismListener);
}

class StageCheckBulkParallelismListener extends SparkListener {

private boolean checkFlag = false;
private String checkMessage;
private int parallelism;

StageCheckBulkParallelismListener(String checkMessage) {
this.checkMessage = checkMessage;
}

@Override
public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
if (checkFlag) {
// dedup next stage is reduce task
this.parallelism = stageSubmitted.stageInfo().numTasks();
checkFlag = false;
}
if (stageSubmitted.stageInfo().details().contains(checkMessage)) {
checkFlag = true;
}
}

public int getParallelism() {
return parallelism;
}
}
}

0 comments on commit edc45df

Please sign in to comment.