Skip to content

Commit

Permalink
Fix memory accounting in InMemoryHashAggregationBuilder
Browse files Browse the repository at this point in the history
InMemoryHashAggregationBuilder when used by SpillableHashAggregationBuilder
while processing input should be spillable and part of revocable memory.

The same instance is invoked while producing output by spillable accumulators
like: DedupBasedSpillableDistinctGroupedAccumulator
(GenericAccumulatorFactory.java#L401). While producing the output, the memory
should be accounted as part of user memory and is non-revocable.

In previous implementation, rehash in GroupByHash will cause all memory be
accounted to user memory, including preallocatedMemoryInBytes for rehash.
It maybe lead some queries to fail, although they could exe successfully,
especially in a low "query.max-memory-per-node" configuration environment.
  • Loading branch information
vermapratyush committed Oct 16, 2023
1 parent 2a77f69 commit 3ab3e09
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ private void initializeAggregationBuilderIfNeeded()
maxPartialMemory,
joinCompiler,
true,
useSystemMemory);
useSystemMemory ? ReserveType.SYSTEM : ReserveType.USER);
}
else {
verify(!useSystemMemory, "using system memory in spillable aggregations is not supported");
Expand Down Expand Up @@ -667,4 +667,11 @@ private static long calculateDefaultOutputHash(List<Type> groupByChannels, int g
}
return result;
}

public enum ReserveType
{
USER,
SYSTEM,
REVOCABLE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.common.type.Type;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.operator.GroupByHash;
import com.facebook.presto.operator.HashAggregationOperator.ReserveType;
import com.facebook.presto.operator.HashCollisionsCounter;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.TransformWork;
Expand All @@ -44,6 +45,7 @@
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.function.Consumer;

import static com.facebook.presto.SystemSessionProperties.isDictionaryAggregationEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
Expand All @@ -61,7 +63,8 @@ public class InMemoryHashAggregationBuilder
private final OptionalLong maxPartialMemory;
private final LocalMemoryContext systemMemoryContext;
private final LocalMemoryContext localUserMemoryContext;
private final boolean useSystemMemory;
private final ReserveType reserveType;
private final Consumer<Long> memoryConsumer;

private boolean full;

Expand All @@ -76,7 +79,7 @@ public InMemoryHashAggregationBuilder(
Optional<DataSize> maxPartialMemory,
JoinCompiler joinCompiler,
boolean yieldForMemoryReservation,
boolean useSystemMemory)
ReserveType reserveType)
{
this(accumulatorFactories,
step,
Expand All @@ -89,7 +92,36 @@ public InMemoryHashAggregationBuilder(
Optional.empty(),
joinCompiler,
yieldForMemoryReservation,
useSystemMemory);
reserveType,
Optional.empty());
}

public InMemoryHashAggregationBuilder(
List<AccumulatorFactory> accumulatorFactories,
Step step,
int expectedGroups,
List<Type> groupByTypes,
List<Integer> groupByChannels,
Optional<Integer> hashChannel,
OperatorContext operatorContext,
Optional<DataSize> maxPartialMemory,
JoinCompiler joinCompiler,
boolean yieldForMemoryReservation,
Optional<Consumer<Long>> memoryConsumer)
{
this(accumulatorFactories,
step,
expectedGroups,
groupByTypes,
groupByChannels,
hashChannel,
operatorContext,
maxPartialMemory,
Optional.empty(),
joinCompiler,
yieldForMemoryReservation,
ReserveType.REVOCABLE,
memoryConsumer);
}

public InMemoryHashAggregationBuilder(
Expand All @@ -104,16 +136,31 @@ public InMemoryHashAggregationBuilder(
Optional<Integer> overwriteIntermediateChannelOffset,
JoinCompiler joinCompiler,
boolean yieldForMemoryReservation,
boolean useSystemMemory)
ReserveType reserveType,
Optional<Consumer<Long>> memoryConsumer)
{
// reserveType is REVOCABLE implies current InMemoryHashAggregationBuilder is built from SpillableHashAggregationBuilder
// and it will accept a customized memoryConsumer for memory update
if (reserveType == ReserveType.REVOCABLE) {
checkArgument(memoryConsumer.isPresent(),
"memoryConsumer must be present when reserve type is REVOCABLE");
}

this.reserveType = reserveType;
if (memoryConsumer.isPresent()) {
this.memoryConsumer = memoryConsumer.get();
}
else {
this.memoryConsumer = this::updateMemory;
}

UpdateMemory updateMemory;
if (yieldForMemoryReservation) {
updateMemory = this::updateMemoryWithYieldInfo;
}
else {
// Report memory usage but do not yield for memory.
// This is specially used for spillable hash aggregation operator.
// TODO: revisit this when spillable hash aggregation operator is turned on
updateMemory = () -> {
updateMemoryWithYieldInfo();
return true;
Expand All @@ -132,7 +179,6 @@ public InMemoryHashAggregationBuilder(
this.maxPartialMemory = maxPartialMemory.map(dataSize -> OptionalLong.of(dataSize.toBytes())).orElseGet(OptionalLong::empty);
this.systemMemoryContext = operatorContext.newLocalSystemMemoryContext(InMemoryHashAggregationBuilder.class.getSimpleName());
this.localUserMemoryContext = operatorContext.localUserMemoryContext();
this.useSystemMemory = useSystemMemory;

// wrapper each function with an aggregator
ImmutableList.Builder<Aggregator> builder = ImmutableList.builder();
Expand All @@ -151,7 +197,7 @@ public InMemoryHashAggregationBuilder(
@Override
public void close()
{
updateMemory(0);
memoryConsumer.accept(0L);
}

@Override
Expand Down Expand Up @@ -326,24 +372,28 @@ private boolean updateMemoryWithYieldInfo()
{
long memorySize = getSizeInMemory();
if (partial && maxPartialMemory.isPresent()) {
updateMemory(memorySize);
memoryConsumer.accept(memorySize);
full = (memorySize > maxPartialMemory.getAsLong());
return true;
}
// Operator/driver will be blocked on memory after we call setBytes.
// If memory is not available, once we return, this operator will be blocked until memory is available.
updateMemory(memorySize);
memoryConsumer.accept(memorySize);
// If memory is not available, inform the caller that we cannot proceed for allocation.
return operatorContext.isWaitingForMemory().isDone();
}

private void updateMemory(long memorySize)
{
if (useSystemMemory) {
systemMemoryContext.setBytes(memorySize);
}
else {
localUserMemoryContext.setBytes(memorySize);
switch (reserveType) {
case USER:
localUserMemoryContext.setBytes(memorySize);
break;
case SYSTEM:
systemMemoryContext.setBytes(memorySize);
break;
default:
throw new AssertionError("InMemoryHashAggregationBuilder do not support reserve type: " + reserveType);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.common.Page;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.memory.context.LocalMemoryContext;
import com.facebook.presto.operator.HashAggregationOperator.ReserveType;
import com.facebook.presto.operator.OperatorContext;
import com.facebook.presto.operator.WorkProcessor;
import com.facebook.presto.operator.WorkProcessor.Transformation;
Expand Down Expand Up @@ -150,6 +151,7 @@ private void rebuildHashAggregationBuilder()
Optional.of(overwriteIntermediateChannelOffset),
joinCompiler,
false,
false);
ReserveType.USER,
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public class SpillableHashAggregationBuilder

private long hashCollisions;
private double expectedHashCollisions;
private boolean producingOutput;
private Boolean producingOutput = Boolean.FALSE;

public SpillableHashAggregationBuilder(
List<AccumulatorFactory> accumulatorFactories,
Expand Down Expand Up @@ -192,7 +192,7 @@ private boolean shouldMergeWithMemory(long memorySize)
public WorkProcessor<Page> buildResult()
{
checkState(hasPreviousSpillCompletedSuccessfully(), "Previous spill hasn't yet finished");
producingOutput = true;
producingOutput = Boolean.TRUE;

// Convert revocable memory to user memory as returned WorkProcessor holds on to memory so we no longer can revoke.
if (localRevocableMemoryContext.getBytes() > 0) {
Expand Down Expand Up @@ -231,8 +231,13 @@ public void close()
}
merger.ifPresent(closer::register);
spiller.ifPresent(closer::register);
closer.register(() -> localUserMemoryContext.setBytes(0));
closer.register(() -> localRevocableMemoryContext.setBytes(0));

closer.register(() -> {
localUserMemoryContext.setBytes(0);
});
closer.register(() -> {
localRevocableMemoryContext.setBytes(0);
});
}
catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -256,7 +261,6 @@ private ListenableFuture<?> spillToDisk()
// ... and immediately create new hashAggregationBuilder so effectively memory ownership
// over hashAggregationBuilder is transferred from this thread to a spilling thread
rebuildHashAggregationBuilder();

return spillInProgress;
}

Expand Down Expand Up @@ -335,7 +339,16 @@ private void rebuildHashAggregationBuilder()
Optional.of(DataSize.succinctBytes(0)),
joinCompiler,
false,
false);
Optional.of((memorySize) -> {
// The userMemory lambda is invoked in spillable accumulator like: DedupBasedSpillableDistinctGroupedAccumulator
// The memory is revocable only when the operator is not producing output.
if (producingOutput) {
localUserMemoryContext.setBytes(memorySize);
}
else {
localRevocableMemoryContext.setBytes(memorySize);
}
}));
emptyHashAggregationBuilderSize = hashAggregationBuilder.getSizeInMemory();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ public static TaskContext createTaskContext(Executor notificationExecutor, Sched
.build();
}

public static TaskContext createTaskContext(Executor notificationExecutor, ScheduledExecutorService yieldExecutor, Session session,
DataSize maxMemory, DataSize maxTotalMemory)
{
return builder(notificationExecutor, yieldExecutor, session)
.setQueryMaxMemory(maxMemory)
.setQueryMaxTotalMemory(maxTotalMemory)
.build();
}

public static TaskContext createTaskContext(Executor notificationExecutor, ScheduledExecutorService yieldExecutor, Session session, TaskStateMachine taskStateMachine)
{
return builder(notificationExecutor, yieldExecutor, session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,54 @@ public void testMergeWithMemorySpill()
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, resultBuilder.build());
}

@Test
public void testMemoryLimitInSpillWhenTriggerRehash()
{
RowPagesBuilder rowPagesBuilder = rowPagesBuilder(BIGINT);

int smallPagesSpillThresholdSize = 100000;

List<Page> input = rowPagesBuilder
.addSequencePage(smallPagesSpillThresholdSize, 0)
.addSequencePage(smallPagesSpillThresholdSize, smallPagesSpillThresholdSize)
.addSequencePage(smallPagesSpillThresholdSize, 2 * smallPagesSpillThresholdSize)
.addSequencePage(smallPagesSpillThresholdSize, 3 * smallPagesSpillThresholdSize)
.build();

HashAggregationOperatorFactory operatorFactory = new HashAggregationOperatorFactory(
0,
new PlanNodeId("test"),
ImmutableList.of(BIGINT),
ImmutableList.of(0),
ImmutableList.of(),
ImmutableList.of(),
Step.SINGLE,
false,
ImmutableList.of(generateAccumulatorFactory(LONG_SUM, ImmutableList.of(0), Optional.empty())),
rowPagesBuilder.getHashChannel(),
Optional.empty(),
1,
Optional.of(new DataSize(16, MEGABYTE)),
true,
new DataSize(smallPagesSpillThresholdSize, Unit.BYTE),
succinctBytes(Integer.MAX_VALUE),
spillerFactory,
joinCompiler,
false);

TaskContext taskContext = createTaskContext(executor, scheduledExecutor, TEST_SESSION,
new DataSize(10, MEGABYTE), new DataSize(20, MEGABYTE));
DriverContext driverContext = taskContext
.addPipelineContext(0, true, true, false)
.addDriverContext();

MaterializedResult.Builder resultBuilder = resultBuilder(driverContext.getSession(), BIGINT, BIGINT);
for (int i = 0; i < 4 * smallPagesSpillThresholdSize; ++i) {
resultBuilder.row((long) i, (long) i);
}
assertOperatorEqualsIgnoreOrder(operatorFactory, driverContext, input, resultBuilder.build());
}

@Test
public void testSpillerFailure()
{
Expand Down

0 comments on commit 3ab3e09

Please sign in to comment.