Skip to content

Commit

Permalink
[CELEBORN-1792] MemoryManager resume should use pinnedDirectMemory in…
Browse files Browse the repository at this point in the history
…stead of usedDirectMemory

### What changes were proposed in this pull request?
Congestion and MemoryManager should use pinnedDirectMemory instead of usedDirectMemory

### Why are the changes needed?
In our production environment, after worker pausing, the usedDirectMemory keep high and does not decrease. The worker node is permanently blacklisted and cannot be used.

This problem has been bothering us for a long time. When the thred cache is turned off, in fact, **after ctx.channel().config().setAutoRead(false), the netty framework will still hold some ByteBufs**. This part of ByteBuf result in a lot of PoolChunks cannot be released.

In netty, if a chunk is 16M and 8k of this chunk has been allocated, then the pinnedMemory is 8k and the activeMemory is 16M. The remaining (16M-8k) memory can be allocated, but not yet allocated, netty allocates and releases memory in chunk units, so the 8k that has been allocated will result in 16M that cannot be returned to the operating system.

Here are some scenes from our production/test environment:

We config 10gb off-heap memory for worker, other configs as below:
```
celeborn.network.memory.allocator.allowCache                         false
celeborn.worker.monitor.memory.check.interval                         100ms
celeborn.worker.monitor.memory.report.interval                        10s
celeborn.worker.directMemoryRatioToPauseReceive                       0.75
celeborn.worker.directMemoryRatioToPauseReplicate                     0.85
celeborn.worker.directMemoryRatioToResume                             0.5
```

When receiving high traffic, the worker's usedDirectMemory increases. After triggering trim and pause, usedDirectMemory still does not reach the resume threshold, and worker was excluded.

![image](https://github.com/user-attachments/assets/40f5609e-fbf9-4841-84ec-69a69256edf4)

So we checked the heap snapshot of the abnormal worker, we can see that there are a large number of DirectByteBuffers in the heap memory. These DirectByteBuffers are all 4mb in size, which is exactly the size of chunksize. According to the path to gc root, DirectByteBuffer is held by PoolChunk, and these 4m only have 160k pinnedBytes.

![image](https://github.com/user-attachments/assets/3d755ef3-164c-4b5b-bec1-aaf039c0c0a5)

![image](https://github.com/user-attachments/assets/17907753-2f42-4617-a95e-1ee980553fb9)

There are many ByteBufs that are not released

![image](https://github.com/user-attachments/assets/b87eb1a9-313f-4f42-baa8-227fd49c19b6)

The stack shows that these ByteBufs are allocated by netty
![image](https://github.com/user-attachments/assets/f8783f99-507a-44a8-9de5-7215a5eed1db)

We tried to reproduce this situation in the test environment. When the same problem occurred, we added a restful api of the worker to force the worker to resume. After the resume, the worker returned to normal, and PushDataHandler handled many delayed requests.

![image](https://github.com/user-attachments/assets/be37039b-97b8-4ae8-a64f-a2003bea613e)

![image](https://github.com/user-attachments/assets/24b1c8ad-131c-4bd6-adcb-bad655cfbdbf)

So I think that when pinnedMemory is not high enough, we should not trigger pause and congestion, because at this time a large part of the memory can still be allocated.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UTs.

Closes #3018 from leixm/CELEBORN-1792.

Authored-by: Xianming Lei <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
leixm authored and RexXiong committed Jan 22, 2025
1 parent 30e46ee commit 9131c1e
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.celeborn.common.network.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
Expand Down Expand Up @@ -47,6 +49,8 @@ public class NettyUtils {
private static final ByteBufAllocator[] _sharedByteBufAllocator = new ByteBufAllocator[2];
private static final ConcurrentHashMap<String, Integer> allocatorsIndex =
JavaUtils.newConcurrentHashMap();
private static final List<PooledByteBufAllocator> pooledByteBufAllocators = new ArrayList<>();

/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new DefaultThreadFactory(threadPoolPrefix, true);
Expand Down Expand Up @@ -141,6 +145,9 @@ public static synchronized ByteBufAllocator getSharedByteBufAllocator(
_sharedByteBufAllocator[index] =
createByteBufAllocator(
conf.networkMemoryAllocatorPooled(), true, allowCache, conf.networkAllocatorArenas());
if (conf.networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) _sharedByteBufAllocator[index]);
}
if (source != null) {
new NettyMemoryMetrics(
_sharedByteBufAllocator[index],
Expand Down Expand Up @@ -178,6 +185,9 @@ public static ByteBufAllocator getByteBufAllocator(
conf.preferDirectBufs(),
allowCache,
arenas);
if (conf.getCelebornConf().networkMemoryAllocatorPooled()) {
pooledByteBufAllocators.add((PooledByteBufAllocator) allocator);
}
if (source != null) {
String poolName = "default-netty-pool";
Map<String, String> labels = new HashMap<>();
Expand All @@ -196,4 +206,8 @@ public static ByteBufAllocator getByteBufAllocator(
}
return allocator;
}

public static List<PooledByteBufAllocator> getAllPooledByteBufAllocators() {
return pooledByteBufAllocators;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1278,9 +1278,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerDirectMemoryRatioToPauseReplicate: Double =
get(WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE)
def workerDirectMemoryRatioToResume: Double = get(WORKER_DIRECT_MEMORY_RATIO_RESUME)
def workerPinnedMemoryRatioToResume: Double = get(WORKER_PINNED_MEMORY_RATIO_RESUME)
def workerPartitionSorterDirectMemoryRatioThreshold: Double =
get(WORKER_PARTITION_SORTER_DIRECT_MEMORY_RATIO_THRESHOLD)
def workerDirectMemoryPressureCheckIntervalMs: Long = get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL)
def workerPinnedMemoryCheckEnabled: Boolean = get(WORKER_PINNED_MEMORY_CHECK_ENABLED)
def workerPinnedMemoryCheckIntervalMs: Long = get(WORKER_PINNED_MEMORY_CHECK_INTERVAL)
def workerDirectMemoryReportIntervalSecond: Long = get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL)
def workerDirectMemoryTrimChannelWaitInterval: Long =
get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL)
Expand Down Expand Up @@ -3711,6 +3714,24 @@ object CelebornConf extends Logging {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10ms")

val WORKER_PINNED_MEMORY_CHECK_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.enabled")
.categories("worker")
.doc("If true, MemoryManager will check worker should resume by pinned memory used.")
.version("0.6.0")
.booleanConf
.createWithDefaultString("true")

val WORKER_PINNED_MEMORY_CHECK_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.pinnedMemory.check.interval")
.categories("worker")
.doc("Interval of worker direct pinned memory checking, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled.")
.version("0.6.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10s")

val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] =
buildConf("celeborn.worker.monitor.memory.report.interval")
.withAlternative("celeborn.worker.memory.reportInterval")
Expand Down Expand Up @@ -3860,6 +3881,16 @@ object CelebornConf extends Logging {
.doubleConf
.createWithDefault(0.7)

val WORKER_PINNED_MEMORY_RATIO_RESUME: ConfigEntry[Double] =
buildConf("celeborn.worker.pinnedMemoryRatioToResume")
.categories("worker")
.doc("If pinned memory usage is less than this limit, worker will resume, " +
"only takes effect when celeborn.network.memory.allocator.pooled and " +
"celeborn.worker.monitor.pinnedMemory.check.enabled are enabled")
.version("0.6.0")
.doubleConf
.createWithDefault(0.3)

val WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE: ConfigEntry[Long] =
buildConf("celeborn.worker.memoryFileStorage.maxFileSize")
.categories("worker")
Expand Down
3 changes: 3 additions & 0 deletions docs/configuration/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ license: |
| celeborn.worker.monitor.memory.report.interval | 10s | false | Interval of worker direct memory tracker reporting to log. | 0.3.0 | celeborn.worker.memory.reportInterval |
| celeborn.worker.monitor.memory.trimChannelWaitInterval | 1s | false | Wait time after worker trigger channel to trim cache. | 0.3.0 | |
| celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait time after worker trigger StorageManger to flush data. | 0.3.0 | |
| celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, MemoryManager will check worker should resume by pinned memory used. | 0.6.0 | |
| celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval of worker direct pinned memory checking, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 | |
| celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number of initial read buffers | 0.3.0 | |
| celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of initial read buffers | 0.3.0 | |
| celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | Max ratio of partition sorter's memory for sorting, when reserved memory is higher than max partition sorter memory, partition sorter will stop sorting. If this value is set to 0, partition files sorter will skip memory check and ServingState check. | 0.2.0 | |
| celeborn.worker.pinnedMemoryRatioToResume | 0.3 | false | If pinned memory usage is less than this limit, worker will resume, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled | 0.6.0 | |
| celeborn.worker.push.heartbeat.enabled | false | false | enable the heartbeat from worker to client when pushing data | 0.3.0 | |
| celeborn.worker.push.io.threads | &lt;undefined&gt; | false | Netty IO thread number of worker to handle client push data. The default threads number is the number of flush thread. | 0.2.0 | |
| celeborn.worker.push.port | 0 | false | Server port for Worker to receive push data request from ShuffleClient. | 0.2.0 | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.util.internal.PlatformDependent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.metrics.source.AbstractSource;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
Expand All @@ -50,7 +52,8 @@ public class MemoryManager {
@VisibleForTesting public long maxDirectMemory;
private final long pausePushDataThreshold;
private final long pauseReplicateThreshold;
private final double resumeRatio;
private final double directMemoryResumeRatio;
private final double pinnedMemoryResumeRatio;
private final long maxSortMemory;
private final int forceAppendPauseSpentTimeThreshold;
private final List<MemoryPressureListener> memoryPressureListeners = new ArrayList<>();
Expand Down Expand Up @@ -93,6 +96,9 @@ public class MemoryManager {
private long memoryFileStorageThreshold;
private final LongAdder memoryFileStorageCounter = new LongAdder();
private final StorageManager storageManager;
private boolean pinnedMemoryCheckEnabled;
private long pinnedMemoryCheckInterval;
private long pinnedMemoryLastCheckTime = 0;

@VisibleForTesting
public static MemoryManager initialize(CelebornConf conf) {
Expand Down Expand Up @@ -120,11 +126,14 @@ public static MemoryManager instance() {
private MemoryManager(CelebornConf conf, StorageManager storageManager, AbstractSource source) {
double pausePushDataRatio = conf.workerDirectMemoryRatioToPauseReceive();
double pauseReplicateRatio = conf.workerDirectMemoryRatioToPauseReplicate();
this.resumeRatio = conf.workerDirectMemoryRatioToResume();
this.directMemoryResumeRatio = conf.workerDirectMemoryRatioToResume();
this.pinnedMemoryResumeRatio = conf.workerPinnedMemoryRatioToResume();
double maxSortMemRatio = conf.workerPartitionSorterDirectMemoryRatioThreshold();
double readBufferRatio = conf.workerDirectMemoryRatioForReadBuffer();
double memoryFileStorageRatio = conf.workerDirectMemoryRatioForMemoryFilesStorage();
long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs();
this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled();
this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs();
long reportInterval = conf.workerDirectMemoryReportIntervalSecond();
double readBufferTargetRatio = conf.readBufferTargetRatio();
long readBufferTargetUpdateInterval = conf.readBufferTargetUpdateInterval();
Expand All @@ -148,9 +157,10 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
pauseReplicateRatio,
CelebornConf.WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE().key(),
pausePushDataRatio));
Preconditions.checkArgument(pausePushDataRatio > resumeRatio);
Preconditions.checkArgument(pausePushDataRatio > directMemoryResumeRatio);
if (memoryFileStorageRatio > 0) {
Preconditions.checkArgument(resumeRatio > (readBufferRatio + memoryFileStorageRatio));
Preconditions.checkArgument(
directMemoryResumeRatio > (readBufferRatio + memoryFileStorageRatio));
}

maxSortMemory = ((long) (maxDirectMemory * maxSortMemRatio));
Expand Down Expand Up @@ -275,14 +285,16 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract
+ "pause replication memory: {}, "
+ "read buffer memory limit: {} target: {}, "
+ "memory shuffle storage limit: {}, "
+ "resume memory ratio: {}",
+ "resume by direct memory ratio: {}, "
+ "resume by pinned memory ratio: {}",
Utils.bytesToString(maxDirectMemory),
Utils.bytesToString(pausePushDataThreshold),
Utils.bytesToString(pauseReplicateThreshold),
Utils.bytesToString(readBufferThreshold),
Utils.bytesToString(readBufferTarget),
Utils.bytesToString(memoryFileStorageThreshold),
resumeRatio);
directMemoryResumeRatio,
pinnedMemoryResumeRatio);
}

public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double evictRatio) {
Expand All @@ -305,7 +317,7 @@ public ServingState currentServingState() {
return ServingState.PUSH_PAUSED;
}
// trigger resume
if (memoryUsage / (double) (maxDirectMemory) < resumeRatio) {
if (memoryUsage / (double) (maxDirectMemory) < directMemoryResumeRatio) {
isPaused = false;
return ServingState.NONE_PAUSED;
}
Expand All @@ -315,69 +327,70 @@ public ServingState currentServingState() {
}

@VisibleForTesting
protected void switchServingState() {
public void switchServingState() {
ServingState lastState = servingState;
servingState = currentServingState();
if (lastState == servingState) {
if (servingState != ServingState.NONE_PAUSED) {
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
resumeReplicate();
} else {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
}
logger.debug("Trigger action: TRIM");
trimCounter += 1;
// force to append pause spent time even we are in pause state
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
trimAllListeners();
}
return;
}
logger.info("Serving state transformed from {} to {}", lastState, servingState);
switch (servingState) {
case PUSH_PAUSED:
pausePushDataCounter.increment();
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
} else if (lastState == ServingState.NONE_PAUSED) {
logger.info("Trigger action: PAUSE PUSH");
pausePushDataStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
}
trimAllListeners();
break;
case PUSH_AND_REPLICATE_PAUSED:
pausePushDataAndReplicateCounter.increment();
if (lastState == ServingState.NONE_PAUSED) {
if (canResumeByPinnedMemory()) {
resumeByPinnedMemory(servingState);
} else {
pausePushDataAndReplicateCounter.increment();
logger.info("Trigger action: PAUSE PUSH");
pausePushDataAndReplicateStartTime = System.currentTimeMillis();
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE));
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
}
logger.info("Trigger action: PAUSE REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE));
logger.debug("Trigger action: TRIM");
trimCounter += 1;
trimAllListeners();
if (trimCounter >= forceAppendPauseSpentTimeThreshold) {
logger.debug(
"Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter);
appendPauseSpentTime(servingState);
}
break;
case NONE_PAUSED:
// resume from paused mode, append pause spent time
appendPauseSpentTime(lastState);
if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
resumeReplicate();
resumePush();
appendPauseSpentTime(lastState);
} else if (lastState == ServingState.PUSH_PAUSED) {
resumePush();
appendPauseSpentTime(lastState);
}
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}
}

Expand Down Expand Up @@ -436,6 +449,16 @@ public long getMemoryUsage() {
return getNettyUsedDirectMemory() + sortMemoryCounter.get();
}

public long getPinnedMemory() {
return getNettyPinnedDirectMemory() + sortMemoryCounter.get();
}

public long getNettyPinnedDirectMemory() {
return NettyUtils.getAllPooledByteBufAllocators().stream()
.mapToLong(PooledByteBufAllocator::pinnedDirectMemory)
.sum();
}

public AtomicLong getSortMemoryCounter() {
return sortMemoryCounter;
}
Expand Down Expand Up @@ -557,6 +580,47 @@ public static void reset() {
_INSTANCE = null;
}

private void resumeByPinnedMemory(ServingState servingState) {
switch (servingState) {
case PUSH_AND_REPLICATE_PAUSED:
logger.info(
"Serving State is PUSH_AND_REPLICATE_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumeReplicate();
resumePush();
case PUSH_PAUSED:
logger.info(
"Serving State is PUSH_PAUSED, but resume by lower pinned memory {}",
getNettyPinnedDirectMemory());
resumePush();
}
}

private boolean canResumeByPinnedMemory() {
if (pinnedMemoryCheckEnabled
&& System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval
&& getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) {
pinnedMemoryLastCheckTime = System.currentTimeMillis();
return true;
} else {
return false;
}
}

private void resumePush() {
logger.info("Trigger action: RESUME PUSH");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE));
}

private void resumeReplicate() {
logger.info("Trigger action: RESUME REPLICATE");
memoryPressureListeners.forEach(
memoryPressureListener ->
memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE));
}

public interface MemoryPressureListener {
void onPause(String moduleName);

Expand Down
Loading

0 comments on commit 9131c1e

Please sign in to comment.