Skip to content

Commit

Permalink
MSQ: Rework memory management. (apache#17057)
Browse files Browse the repository at this point in the history
* MSQ: Rework memory management.

This patch reworks memory management to better support multi-threaded
workers running in shared JVMs. There are two main changes.

First, processing buffers and threads are moved from a per-JVM model to
a per-worker model. This enables queries to hold processing buffers
without blocking other concurrently-running queries. Changes:

- Introduce ProcessingBuffersSet and ProcessingBuffers to hold the
  per-worker and per-work-order processing buffers (respectively). On Peons,
  this is the JVM-wide processing pool. On Indexers, this is a per-worker
  pool of on-heap buffers. (This change fixes a bug on Indexers where
  excessive processing buffers could be used if MSQ tasks ran concurrently
  with realtime tasks.)

- Add "bufferPool" argument to GroupingEngine#process so a per-worker pool
  can be passed in.

- Add "druid.msq.task.memory.maxThreads" property, which controls the
  maximum number of processing threads to use per task. This allows usage of
  multiple processing buffers per task if admins desire.

- IndexerWorkerContext acquires processingBuffers when creating the FrameContext
  for a work order, and releases them when closing the FrameContext.

- Add "usesProcessingBuffers()" to FrameProcessorFactory so workers know
  how many sets of processing buffers are needed to run a given query.

Second, adjustments to how WorkerMemoryParameters slices up bundles, to
favor more memory for sorting and segment generation. Changes:

- Instead of using same-sized bundles for processing and for sorting,
  workers now use minimally-sized processing bundles (just enough to read
  inputs plus a little overhead). The rest is devoted to broadcast data
  buffering, sorting, and segment-building.

- Segment-building is now limited to 1 concurrent segment per work order.
  This allows each segment-building action to use more memory. Note that
  segment-building is internally multi-threaded to a degree. (Build and
  persist can run concurrently.)

- Simplify frame size calculations by removing the distinction between
  "standard" and "large" frames. The new default frame size is the same
  as the old "standard" frames, 1 MB. The original goal of of the large
  frames was to reduce the number of temporary files during sorting, but
  I think we can achieve the same thing by simply merging a larger number
  of standard frames at once.

- Remove the small worker adjustment that was added in apache#14117 to account
  for an extra frame involved in writing to durable storage. Instead,
  account for the extra frame whenever we are actually using durable storage.

- Cap super-sorter parallelism using the number of output partitions, rather
  than using a hard coded cap at 4. Note that in practice, so far, this cap
  has not been relevant for tasks because they have only been using a single
  processing thread anyway.

* Remove unused import.

* Fix errorprone annotation.

* Fixes for javadocs and inspections.

* Additional test coverage.

* Fix test.
  • Loading branch information
gianm authored and kfaraz committed Oct 1, 2024
1 parent 9b192bd commit 508e462
Show file tree
Hide file tree
Showing 67 changed files with 1,957 additions and 831 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,6 @@ public String getFormatString()
final GroupingEngine groupingEngine = new GroupingEngine(
druidProcessingConfig,
configSupplier,
bufferPool,
groupByResourcesReservationPool,
TestHelper.makeJsonMapper(),
new ObjectMapper(new SmileFactory()),
Expand All @@ -387,7 +386,8 @@ public String getFormatString()

factory = new GroupByQueryRunnerFactory(
groupingEngine,
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool)
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool),
bufferPool
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,13 @@ private static GroupByQueryRunnerFactory makeGroupByQueryRunnerFactory(
final GroupingEngine groupingEngine = new GroupingEngine(
processingConfig,
configSupplier,
bufferPool,
groupByResourcesReservationPool,
mapper,
mapper,
QueryRunnerTestHelper.NOOP_QUERYWATCHER
);
final GroupByQueryQueryToolChest toolChest = new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool);
return new GroupByQueryRunnerFactory(groupingEngine, toolChest);
return new GroupByQueryRunnerFactory(groupingEngine, toolChest, bufferPool);
}

@TearDown(Level.Trial)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,6 @@ public String getFormatString()
final GroupingEngine groupingEngine = new GroupingEngine(
druidProcessingConfig,
configSupplier,
bufferPool,
groupByResourcesReservationPool,
TestHelper.makeJsonMapper(),
new ObjectMapper(new SmileFactory()),
Expand All @@ -504,7 +503,8 @@ public String getFormatString()

factory = new GroupByQueryRunnerFactory(
groupingEngine,
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool)
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool),
bufferPool
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ public int getNumThreads()
return 1;
}
},
() -> config,
new StupidPool<>("map-virtual-column-groupby-test", () -> ByteBuffer.allocate(1024)),
GroupByQueryConfig::new,
groupByResourcesReservationPool,
TestHelper.makeJsonMapper(),
new DefaultObjectMapper(),
Expand All @@ -109,7 +108,8 @@ public int getNumThreads()

final GroupByQueryRunnerFactory factory = new GroupByQueryRunnerFactory(
groupingEngine,
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool)
new GroupByQueryQueryToolChest(groupingEngine, groupByResourcesReservationPool),
new StupidPool<>("map-virtual-column-groupby-test", () -> ByteBuffer.allocate(1024))
);

runner = QueryRunnerTestHelper.makeQueryRunner(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.apache.druid.msq.exec;

import com.google.common.base.Preconditions;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
Expand All @@ -29,10 +28,10 @@
* Class for determining how much JVM heap to allocate to various purposes for {@link Controller}.
*
* First, look at how much of total JVM heap that is dedicated for MSQ; see
* {@link MemoryIntrospector#usableMemoryInJvm()}.
* {@link MemoryIntrospector#memoryPerTask()}.
*
* Then, we split up that total amount of memory into equally-sized portions per {@link Controller}; see
* {@link MemoryIntrospector#numQueriesInJvm()}. The number of controllers is based entirely on server configuration,
* {@link MemoryIntrospector#numTasksInJvm()}. The number of controllers is based entirely on server configuration,
* which makes the calculation robust to different queries running simultaneously in the same JVM.
*
* Then, we split that up into a chunk used for input channels, and a chunk used for partition statistics.
Expand Down Expand Up @@ -70,29 +69,28 @@ public static ControllerMemoryParameters createProductionInstance(
final int maxWorkerCount
)
{
final long usableMemoryInJvm = memoryIntrospector.usableMemoryInJvm();
final int numControllersInJvm = memoryIntrospector.numQueriesInJvm();
Preconditions.checkArgument(usableMemoryInJvm > 0, "Usable memory[%s] must be > 0", usableMemoryInJvm);
Preconditions.checkArgument(numControllersInJvm > 0, "Number of controllers[%s] must be > 0", numControllersInJvm);
Preconditions.checkArgument(maxWorkerCount > 0, "Number of workers[%s] must be > 0", maxWorkerCount);

final long memoryPerController = usableMemoryInJvm / numControllersInJvm;
final long memoryForInputChannels = WorkerMemoryParameters.memoryNeededForInputChannels(maxWorkerCount);
final long totalMemory = memoryIntrospector.memoryPerTask();
final long memoryForInputChannels =
WorkerMemoryParameters.computeProcessorMemoryForInputChannels(
maxWorkerCount,
WorkerMemoryParameters.DEFAULT_FRAME_SIZE
);
final int partitionStatisticsMaxRetainedBytes = (int) Math.min(
memoryPerController - memoryForInputChannels,
totalMemory - memoryForInputChannels,
PARTITION_STATS_MAX_MEMORY
);

if (partitionStatisticsMaxRetainedBytes < PARTITION_STATS_MIN_MEMORY) {
final long requiredMemory = memoryForInputChannels + PARTITION_STATS_MIN_MEMORY;
final long requiredTaskMemory = memoryForInputChannels + PARTITION_STATS_MIN_MEMORY;
throw new MSQException(
new NotEnoughMemoryFault(
memoryIntrospector.computeJvmMemoryRequiredForUsableMemory(requiredMemory),
memoryIntrospector.computeJvmMemoryRequiredForTaskMemory(requiredTaskMemory),
memoryIntrospector.totalMemoryInJvm(),
usableMemoryInJvm,
numControllersInJvm,
memoryIntrospector.numProcessorsInJvm(),
0
memoryIntrospector.memoryPerTask(),
memoryIntrospector.numTasksInJvm(),
memoryIntrospector.numProcessingThreads(),
maxWorkerCount,
1
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class Limits
/**
* Maximum number of columns that can appear in a frame signature.
* <p>
* Somewhat less than {@link WorkerMemoryParameters#STANDARD_FRAME_SIZE} divided by typical minimum column size:
* Somewhat less than {@link WorkerMemoryParameters#DEFAULT_FRAME_SIZE} divided by typical minimum column size:
* {@link org.apache.druid.frame.allocation.AppendableMemory#DEFAULT_INITIAL_ALLOCATION_SIZE}.
*/
public static final int MAX_FRAME_COLUMNS = 2000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@

package org.apache.druid.msq.exec;

import org.apache.druid.msq.kernel.WorkOrder;

/**
* Introspector used to generate {@link ControllerMemoryParameters}.
* Introspector used to generate {@link WorkerMemoryParameters} and {@link ControllerMemoryParameters}.
*/
public interface MemoryIntrospector
{
Expand All @@ -32,34 +30,23 @@ public interface MemoryIntrospector
long totalMemoryInJvm();

/**
* Amount of memory usable for the multi-stage query engine in the entire JVM.
*
* This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl}
* estimates size of all lookups as part of computing this value.
* Amount of memory alloted to each {@link Worker} or {@link Controller}.
*/
long usableMemoryInJvm();
long memoryPerTask();

/**
* Amount of total JVM memory required for a particular amount of usable memory to be available.
*
* This may be an expensive operation. For example, the production implementation {@link MemoryIntrospectorImpl}
* estimates size of all lookups as part of computing this value.
* Computes the amount of total JVM memory that would be required for a particular memory allotment per task, i.e.,
* a particular return value from {@link #memoryPerTask()}.
*/
long computeJvmMemoryRequiredForUsableMemory(long usableMemory);
long computeJvmMemoryRequiredForTaskMemory(long memoryPerTask);

/**
* Maximum number of queries that run simultaneously in this JVM.
*
* On workers, this is the maximum number of {@link Worker} that run simultaneously in this JVM. See
* {@link WorkerMemoryParameters} for how memory is divided among and within {@link WorkOrder} run by a worker.
*
* On controllers, this is the maximum number of {@link Controller} that run simultaneously. See
* {@link ControllerMemoryParameters} for how memory is used by controllers.
* Maximum number of tasks ({@link Worker} or {@link Controller}) that run simultaneously in this JVM.
*/
int numQueriesInJvm();
int numTasksInJvm();

/**
* Maximum number of processing threads that can be used at once in this JVM.
* Maximum number of processing threads that can be used at once by each {@link Worker} or {@link Controller}.
*/
int numProcessorsInJvm();
int numProcessingThreads();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
package org.apache.druid.msq.exec;

import com.google.common.collect.ImmutableList;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.DruidProcessingConfig;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;

import javax.annotation.Nullable;
import java.util.List;

/**
Expand All @@ -34,37 +36,47 @@
public class MemoryIntrospectorImpl implements MemoryIntrospector
{
private static final Logger log = new Logger(MemoryIntrospectorImpl.class);
private static final long LOOKUP_FOOTPRINT_INIT = Long.MIN_VALUE;

private final LookupExtractorFactoryContainerProvider lookupProvider;
private final long totalMemoryInJvm;
private final int numQueriesInJvm;
private final int numProcessorsInJvm;
private final double usableMemoryFraction;
private final int numTasksInJvm;
private final int numProcessingThreads;

/**
* Lookup footprint per task, set the first time {@link #memoryPerTask()} is called.
*/
private volatile long lookupFootprint = LOOKUP_FOOTPRINT_INIT;

@Nullable
private final LookupExtractorFactoryContainerProvider lookupProvider;

/**
* Create an introspector.
*
* @param lookupProvider provider of lookups; we use this to subtract lookup size from total JVM memory when
* computing usable memory
* @param totalMemoryInJvm maximum JVM heap memory
* @param usableMemoryFraction fraction of JVM memory, after subtracting lookup overhead, that we consider usable
* for multi-stage queries
* @param numQueriesInJvm maximum number of {@link Controller} or {@link Worker} that may run concurrently
* @param numProcessorsInJvm size of processing thread pool, typically {@link DruidProcessingConfig#getNumThreads()}
* for {@link Controller} or {@link Worker}
* @param numTasksInJvm maximum number of {@link Controller} or {@link Worker} that may run concurrently
* @param numProcessingThreads size of processing thread pool, typically {@link DruidProcessingConfig#getNumThreads()}
* @param lookupProvider provider of lookups; we use this to subtract lookup size from total JVM memory when
* computing usable memory. Ignored if null. This is used once the first time
* {@link #memoryPerTask()} is called, then the footprint is cached. As such, it provides
* a point-in-time view only.
*/
public MemoryIntrospectorImpl(
final LookupExtractorFactoryContainerProvider lookupProvider,
final long totalMemoryInJvm,
final double usableMemoryFraction,
final int numQueriesInJvm,
final int numProcessorsInJvm
final int numTasksInJvm,
final int numProcessingThreads,
@Nullable final LookupExtractorFactoryContainerProvider lookupProvider
)
{
this.lookupProvider = lookupProvider;
this.totalMemoryInJvm = totalMemoryInJvm;
this.numQueriesInJvm = numQueriesInJvm;
this.numProcessorsInJvm = numProcessorsInJvm;
this.usableMemoryFraction = usableMemoryFraction;
this.numTasksInJvm = numTasksInJvm;
this.numProcessingThreads = numProcessingThreads;
this.lookupProvider = lookupProvider;
}

@Override
Expand All @@ -74,45 +86,66 @@ public long totalMemoryInJvm()
}

@Override
public long usableMemoryInJvm()
public long memoryPerTask()
{
final long totalMemory = totalMemoryInJvm();
final long totalLookupFootprint = computeTotalLookupFootprint(true);
return Math.max(
0,
(long) ((totalMemory - totalLookupFootprint) * usableMemoryFraction)
(long) ((totalMemoryInJvm - getTotalLookupFootprint()) * usableMemoryFraction) / numTasksInJvm
);
}

@Override
public long computeJvmMemoryRequiredForUsableMemory(long usableMemory)
public long computeJvmMemoryRequiredForTaskMemory(long memoryPerTask)
{
final long totalLookupFootprint = computeTotalLookupFootprint(false);
return (long) Math.ceil(usableMemory / usableMemoryFraction + totalLookupFootprint);
if (memoryPerTask <= 0) {
throw new IAE("Invalid memoryPerTask[%d], expected a positive number", memoryPerTask);
}

return (long) Math.ceil(memoryPerTask * numTasksInJvm / usableMemoryFraction) + getTotalLookupFootprint();
}

@Override
public int numQueriesInJvm()
public int numTasksInJvm()
{
return numQueriesInJvm;
return numTasksInJvm;
}

@Override
public int numProcessorsInJvm()
public int numProcessingThreads()
{
return numProcessorsInJvm;
return numProcessingThreads;
}

/**
* Get a possibly-cached value of {@link #computeTotalLookupFootprint()}. The underlying computation method is
* called just once, meaning this is not a good way to track the size of lookups over time. This is done to keep
* memory calculations as consistent as possible.
*/
private long getTotalLookupFootprint()
{
if (lookupFootprint == LOOKUP_FOOTPRINT_INIT) {
synchronized (this) {
if (lookupFootprint == LOOKUP_FOOTPRINT_INIT) {
lookupFootprint = computeTotalLookupFootprint();
}
}
}

return lookupFootprint;
}

/**
* Compute and return total estimated lookup footprint.
*
* Correctness of this approach depends on lookups being loaded *before* calling this method. Luckily, this is the
* typical mode of operation, since by default druid.lookup.enableLookupSyncOnStartup = true.
*
* @param logFootprint whether footprint should be logged
*/
private long computeTotalLookupFootprint(final boolean logFootprint)
private long computeTotalLookupFootprint()
{
if (lookupProvider == null) {
return 0;
}

final List<String> lookupNames = ImmutableList.copyOf(lookupProvider.getAllLookupNames());

long lookupFootprint = 0;
Expand All @@ -131,10 +164,7 @@ private long computeTotalLookupFootprint(final boolean logFootprint)
}
}

if (logFootprint) {
log.info("Lookup footprint: lookup count[%d], total bytes[%,d].", lookupNames.size(), lookupFootprint);
}

log.info("Lookup footprint: lookup count[%d], total bytes[%,d].", lookupNames.size(), lookupFootprint);
return lookupFootprint;
}
}
Loading

0 comments on commit 508e462

Please sign in to comment.