Skip to content

Commit

Permalink
Use HBO to set number of tasks for scaled writers
Browse files Browse the repository at this point in the history
  • Loading branch information
feilong-liu committed Oct 11, 2023
1 parent 1e249e1 commit 18ac248
Show file tree
Hide file tree
Showing 20 changed files with 177 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ public final class SystemSessionProperties
public static final String PULL_EXPRESSION_FROM_LAMBDA_ENABLED = "pull_expression_from_lambda_enabled";
public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression";
public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates";
public static final String ENABLE_HISTORY_BASED_SCALED_WRITER = "enable_history_based_scaled_writer";

// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "native_simplified_expression_evaluation_enabled";
Expand Down Expand Up @@ -1751,6 +1752,11 @@ public SystemSessionProperties(
INFER_INEQUALITY_PREDICATES,
"Infer nonequality predicates for joins",
featuresConfig.getInferInequalityPredicates(),
false),
booleanProperty(
ENABLE_HISTORY_BASED_SCALED_WRITER,
"Enable setting the initial number of tasks for scaled writers with HBO",
featuresConfig.isUseHBOForScaledWriters(),
false));
}

Expand Down Expand Up @@ -2920,4 +2926,9 @@ public static boolean shouldInferInequalityPredicates(Session session)
{
return session.getSystemProperty(INFER_INEQUALITY_PREDICATES, Boolean.class);
}

public static boolean useHBOForScaledWriters(Session session)
{
return session.getSystemProperty(ENABLE_HISTORY_BASED_SCALED_WRITER, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class TableWriterNodeStatsEstimate
private final double taskCountIfScaledWriter;

@JsonCreator
public TableWriterNodeStatsEstimate(@JsonProperty("taskNumberIfScaledWriter") double taskCountIfScaledWriter)
public TableWriterNodeStatsEstimate(@JsonProperty("taskCountIfScaledWriter") double taskCountIfScaledWriter)
{
this.taskCountIfScaledWriter = taskCountIfScaledWriter;
}
Expand All @@ -48,7 +48,7 @@ public double getTaskCountIfScaledWriter()
public String toString()
{
return toStringHelper(this)
.add("taskNumberIfScaledWriter", taskCountIfScaledWriter)
.add("taskCountIfScaledWriter", taskCountIfScaledWriter)
.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class ScaledWriterScheduler

private final boolean optimizedScaleWriterProducerBuffer;
private final long writerMinSizeBytes;
private final Optional<Integer> initialTaskCount;

private final Set<InternalNode> scheduledNodes = new HashSet<>();

Expand All @@ -61,7 +62,8 @@ public ScaledWriterScheduler(
NodeSelector nodeSelector,
ScheduledExecutorService executor,
DataSize writerMinSize,
boolean optimizedScaleWriterProducerBuffer)
boolean optimizedScaleWriterProducerBuffer,
Optional<Integer> initialTaskCount)
{
this.stage = requireNonNull(stage, "stage is null");
this.sourceTasksProvider = requireNonNull(sourceTasksProvider, "sourceTasksProvider is null");
Expand All @@ -70,6 +72,7 @@ public ScaledWriterScheduler(
this.executor = requireNonNull(executor, "executor is null");
this.writerMinSizeBytes = requireNonNull(writerMinSize, "minWriterSize is null").toBytes();
this.optimizedScaleWriterProducerBuffer = optimizedScaleWriterProducerBuffer;
this.initialTaskCount = requireNonNull(initialTaskCount, "initialTaskCount is null");
}

public void finish()
Expand All @@ -93,7 +96,7 @@ public ScheduleResult schedule()
private int getNewTaskCount()
{
if (scheduledNodes.isEmpty()) {
return 1;
return initialTaskCount.orElse(1);
}

double fullTasks = sourceTasksProvider.get().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.facebook.presto.sql.planner.NodePartitionMap;
import com.facebook.presto.sql.planner.NodePartitioningManager;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PlanFragmenterUtils;
import com.facebook.presto.sql.planner.SplitSourceFactory;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
Expand Down Expand Up @@ -309,14 +310,17 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) {
.map(RemoteTask::getTaskStatus)
.collect(toList());

Optional<Integer> taskNumberIfScaledWriter = PlanFragmenterUtils.getTableWriterTasks(plan.getFragment().getRoot());

ScaledWriterScheduler scheduler = new ScaledWriterScheduler(
stageExecution,
sourceTasksProvider,
writerTasksProvider,
nodeScheduler.createNodeSelector(session, null, nodePredicate),
scheduledExecutor,
getWriterMinSize(session),
isOptimizedScaleWriterProducerBuffer(session));
isOptimizedScaleWriterProducerBuffer(session),
taskNumberIfScaledWriter);
whenAllStages(childStageExecutions, StageExecutionState::isDone)
.addListener(scheduler::finish, directExecutor());
return scheduler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ public class FeaturesConfig
private boolean rewriteConstantArrayContainsToIn;

private boolean preProcessMetadataCalls;
private boolean useHBOForScaledWriters;

public enum PartitioningPrecisionStrategy
{
Expand Down Expand Up @@ -2802,4 +2803,17 @@ public FeaturesConfig setRewriteConstantArrayContainsToInEnabled(boolean rewrite
this.rewriteConstantArrayContainsToIn = rewriteConstantArrayContainsToIn;
return this;
}

public boolean isUseHBOForScaledWriters()
{
return this.useHBOForScaledWriters;
}

@Config("optimizer.use-hbo-for-scaled-writers")
@ConfigDescription("Enable HBO for setting initial number of tasks for scaled writers")
public FeaturesConfig setUseHBOForScaledWriters(boolean useHBOForScaledWriters)
{
this.useHBOForScaledWriters = useHBOForScaledWriters;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ private TableFinishNode createTemporaryTableWrite(
outputNotNullColumnVariables,
Optional.of(partitioningScheme),
Optional.empty(),
enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty())),
enableStatsCollectionForTemporaryTable ? Optional.of(localAggregations.getPartialAggregation()) : Optional.empty(),
Optional.empty())),
variableAllocator.newVariable("intermediaterows", BIGINT),
variableAllocator.newVariable("intermediatefragments", VARBINARY),
variableAllocator.newVariable("intermediatetablecommitcontext", VARBINARY),
Expand All @@ -599,7 +600,8 @@ private TableFinishNode createTemporaryTableWrite(
outputNotNullColumnVariables,
Optional.of(partitioningScheme),
Optional.empty(),
enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty());
enableStatsCollectionForTemporaryTable ? Optional.of(aggregations.getPartialAggregation()) : Optional.empty(),
Optional.empty());
}

return new TableFinishNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ public Optional<PlanNode> visitTableWriter(TableWriterNode node, Context context
ImmutableSet.of(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());
context.addPlan(node, new CanonicalPlan(result, strategy));
return Optional.of(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ private RelationPlan createTableWriterPlan(
preferredShufflePartitioningScheme,
// partial aggregation is run within the TableWriteOperator to calculate the statistics for
// the data consumed by the TableWriteOperator
Optional.of(aggregations.getPartialAggregation())),
Optional.of(aggregations.getPartialAggregation()),
Optional.empty()),
Optional.of(target),
variableAllocator.newVariable("rows", BIGINT),
// final aggregation is run within the TableFinishOperator to summarize collected statistics
Expand All @@ -448,6 +449,7 @@ private RelationPlan createTableWriterPlan(
notNullColumnVariables,
tablePartitioningScheme,
preferredShufflePartitioningScheme,
Optional.empty(),
Optional.empty()),
Optional.of(target),
variableAllocator.newVariable("rows", BIGINT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.getExchangeMaterializationStrategy;
Expand Down Expand Up @@ -244,6 +245,16 @@ public static Set<PlanNodeId> getTableWriterNodeIds(PlanNode plan)
.collect(toImmutableSet());
}

public static Optional<Integer> getTableWriterTasks(PlanNode plan)
{
return stream(forTree(PlanNode::getSources).depthFirstPreOrder(plan))
.filter(node -> node instanceof TableWriterNode)
.map(x -> ((TableWriterNode) x).getTaskCountIfScaledWriter())
.filter(Optional::isPresent)
.map(Optional::get)
.max(Integer::compareTo);
}

private static final class PartitioningHandleReassigner
extends SimplePlanRewriter<Void>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
import com.facebook.presto.sql.planner.iterative.rule.RewriteFilterWithExternalFunctionToProject;
import com.facebook.presto.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation;
import com.facebook.presto.sql.planner.iterative.rule.RuntimeReorderJoinSides;
import com.facebook.presto.sql.planner.iterative.rule.ScaledWriterRule;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyCardinalityMap;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant;
import com.facebook.presto.sql.planner.iterative.rule.SimplifyRowExpressions;
Expand Down Expand Up @@ -706,6 +707,14 @@ public PlanOptimizers(

builder.add(new RemoveRedundantDistinctAggregation());

builder.add(
new IterativeOptimizer(
metadata,
ruleStats,
statsCalculator,
costCalculator,
ImmutableSet.of(new ScaledWriterRule())));

if (!forceSingleNode) {
builder.add(new ReplicateSemiJoinInDelete()); // Must run before AddExchanges
builder.add(new IterativeOptimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ public Result apply(TableWriterNode node, Captures captures, Context context)
return Result.ofPlanNode(new TableWriterNode(
node.getSourceLocation(),
node.getId(),
node.getStatsEquivalentPlanNode(),
node.getSource(),
node.getTarget(),
node.getRowCountVariable(),
Expand All @@ -575,7 +576,8 @@ public Result apply(TableWriterNode node, Captures captures, Context context)
node.getNotNullColumnVariables(),
node.getTablePartitioningScheme(),
node.getPreferredShufflePartitioningScheme(),
rewrittenStatisticsAggregation));
rewrittenStatisticsAggregation,
node.getTaskCountIfScaledWriter()));
}
return Result.empty();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.TableWriterNode;

import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.useHBOForScaledWriters;
import static com.facebook.presto.sql.planner.plan.Patterns.tableWriterNode;
import static com.google.common.base.Preconditions.checkState;

public class ScaledWriterRule
implements Rule<TableWriterNode>
{
@Override
public Pattern<TableWriterNode> getPattern()
{
return tableWriterNode().matching(x -> !x.getTaskCountIfScaledWriter().isPresent());
}

@Override
public boolean isEnabled(Session session)
{
return useHBOForScaledWriters(session);
}

@Override
public Result apply(TableWriterNode node, Captures captures, Context context)
{
double taskNumber = context.getStatsProvider().getStats(node).getTableWriterNodeStatsEstimate().getTaskCountIfScaledWriter();
if (Double.isNaN(taskNumber)) {
return Result.empty();
}
// We start from half of the original number
int initialTaskNumber = (int) Math.ceil(taskNumber / 2);
checkState(initialTaskNumber > 0, "taskCountIfScaledWriter should be at least 1");
return Result.ofPlanNode(new TableWriterNode(
node.getSourceLocation(),
node.getId(),
node.getStatsEquivalentPlanNode(),
node.getSource(),
node.getTarget(),
node.getRowCountVariable(),
node.getFragmentVariable(),
node.getTableCommitContextVariable(),
node.getColumns(),
node.getColumnNames(),
node.getNotNullColumnVariables(),
node.getTablePartitioningScheme(),
node.getPreferredShufflePartitioningScheme(),
node.getStatisticsAggregation(),
Optional.of(initialTaskNumber)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo
originalTableWriterNode.getNotNullColumnVariables(),
originalTableWriterNode.getTablePartitioningScheme(),
originalTableWriterNode.getPreferredShufflePartitioningScheme(),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation),
originalTableWriterNode.getTaskCountIfScaledWriter()),
fixedParallelism(),
fixedParallelism());
}
Expand All @@ -599,7 +600,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo
originalTableWriterNode.getNotNullColumnVariables(),
originalTableWriterNode.getTablePartitioningScheme(),
originalTableWriterNode.getPreferredShufflePartitioningScheme(),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation),
originalTableWriterNode.getTaskCountIfScaledWriter()),
exchange.getProperties());
}
}
Expand Down Expand Up @@ -627,7 +629,8 @@ public PlanWithProperties visitTableWriter(TableWriterNode originalTableWriterNo
originalTableWriterNode.getNotNullColumnVariables(),
originalTableWriterNode.getTablePartitioningScheme(),
originalTableWriterNode.getPreferredShufflePartitioningScheme(),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation)),
statisticAggregations.map(StatisticAggregations.Parts::getPartialAggregation),
originalTableWriterNode.getTaskCountIfScaledWriter()),
exchange.getProperties());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Set<Variab
node.getNotNullColumnVariables(),
node.getTablePartitioningScheme(),
node.getPreferredShufflePartitioningScheme(),
node.getStatisticsAggregation());
node.getStatisticsAggregation(),
node.getTaskCountIfScaledWriter());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId new
node.getNotNullColumnVariables(),
node.getTablePartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)),
node.getPreferredShufflePartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)),
node.getStatisticsAggregation().map(this::map));
node.getStatisticsAggregation().map(this::map),
node.getTaskCountIfScaledWriter());
}

public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source)
Expand Down
Loading

0 comments on commit 18ac248

Please sign in to comment.