From 9131c1e07a321e4b54685c166b40ee0ecc762e68 Mon Sep 17 00:00:00 2001 From: Xianming Lei <31424839+leixm@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:30:20 +0800 Subject: [PATCH] [CELEBORN-1792] MemoryManager resume should use pinnedDirectMemory instead of usedDirectMemory ### What changes were proposed in this pull request? Congestion and MemoryManager should use pinnedDirectMemory instead of usedDirectMemory ### Why are the changes needed? In our production environment, after worker pausing, the usedDirectMemory keep high and does not decrease. The worker node is permanently blacklisted and cannot be used. This problem has been bothering us for a long time. When the thred cache is turned off, in fact, **after ctx.channel().config().setAutoRead(false), the netty framework will still hold some ByteBufs**. This part of ByteBuf result in a lot of PoolChunks cannot be released. In netty, if a chunk is 16M and 8k of this chunk has been allocated, then the pinnedMemory is 8k and the activeMemory is 16M. The remaining (16M-8k) memory can be allocated, but not yet allocated, netty allocates and releases memory in chunk units, so the 8k that has been allocated will result in 16M that cannot be returned to the operating system. Here are some scenes from our production/test environment: We config 10gb off-heap memory for worker, other configs as below: ``` celeborn.network.memory.allocator.allowCache false celeborn.worker.monitor.memory.check.interval 100ms celeborn.worker.monitor.memory.report.interval 10s celeborn.worker.directMemoryRatioToPauseReceive 0.75 celeborn.worker.directMemoryRatioToPauseReplicate 0.85 celeborn.worker.directMemoryRatioToResume 0.5 ``` When receiving high traffic, the worker's usedDirectMemory increases. After triggering trim and pause, usedDirectMemory still does not reach the resume threshold, and worker was excluded. ![image](https://github.com/user-attachments/assets/40f5609e-fbf9-4841-84ec-69a69256edf4) So we checked the heap snapshot of the abnormal worker, we can see that there are a large number of DirectByteBuffers in the heap memory. These DirectByteBuffers are all 4mb in size, which is exactly the size of chunksize. According to the path to gc root, DirectByteBuffer is held by PoolChunk, and these 4m only have 160k pinnedBytes. ![image](https://github.com/user-attachments/assets/3d755ef3-164c-4b5b-bec1-aaf039c0c0a5) ![image](https://github.com/user-attachments/assets/17907753-2f42-4617-a95e-1ee980553fb9) There are many ByteBufs that are not released ![image](https://github.com/user-attachments/assets/b87eb1a9-313f-4f42-baa8-227fd49c19b6) The stack shows that these ByteBufs are allocated by netty ![image](https://github.com/user-attachments/assets/f8783f99-507a-44a8-9de5-7215a5eed1db) We tried to reproduce this situation in the test environment. When the same problem occurred, we added a restful api of the worker to force the worker to resume. After the resume, the worker returned to normal, and PushDataHandler handled many delayed requests. ![image](https://github.com/user-attachments/assets/be37039b-97b8-4ae8-a64f-a2003bea613e) ![image](https://github.com/user-attachments/assets/24b1c8ad-131c-4bd6-adcb-bad655cfbdbf) So I think that when pinnedMemory is not high enough, we should not trigger pause and congestion, because at this time a large part of the memory can still be allocated. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. Closes #3018 from leixm/CELEBORN-1792. Authored-by: Xianming Lei <31424839+leixm@users.noreply.github.com> Signed-off-by: Shuang --- .../common/network/util/NettyUtils.java | 14 ++ .../apache/celeborn/common/CelebornConf.scala | 31 ++++ docs/configuration/worker.md | 3 + .../deploy/worker/memory/MemoryManager.java | 158 ++++++++++++------ .../deploy/memory/MemoryManagerSuite.scala | 67 +++++++- 5 files changed, 224 insertions(+), 49 deletions(-) diff --git a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java index 596b8078520..55a01891492 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java +++ b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java @@ -17,8 +17,10 @@ package org.apache.celeborn.common.network.util; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadFactory; @@ -47,6 +49,8 @@ public class NettyUtils { private static final ByteBufAllocator[] _sharedByteBufAllocator = new ByteBufAllocator[2]; private static final ConcurrentHashMap allocatorsIndex = JavaUtils.newConcurrentHashMap(); + private static final List pooledByteBufAllocators = new ArrayList<>(); + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ public static ThreadFactory createThreadFactory(String threadPoolPrefix) { return new DefaultThreadFactory(threadPoolPrefix, true); @@ -141,6 +145,9 @@ public static synchronized ByteBufAllocator getSharedByteBufAllocator( _sharedByteBufAllocator[index] = createByteBufAllocator( conf.networkMemoryAllocatorPooled(), true, allowCache, conf.networkAllocatorArenas()); + if (conf.networkMemoryAllocatorPooled()) { + pooledByteBufAllocators.add((PooledByteBufAllocator) _sharedByteBufAllocator[index]); + } if (source != null) { new NettyMemoryMetrics( _sharedByteBufAllocator[index], @@ -178,6 +185,9 @@ public static ByteBufAllocator getByteBufAllocator( conf.preferDirectBufs(), allowCache, arenas); + if (conf.getCelebornConf().networkMemoryAllocatorPooled()) { + pooledByteBufAllocators.add((PooledByteBufAllocator) allocator); + } if (source != null) { String poolName = "default-netty-pool"; Map labels = new HashMap<>(); @@ -196,4 +206,8 @@ public static ByteBufAllocator getByteBufAllocator( } return allocator; } + + public static List getAllPooledByteBufAllocators() { + return pooledByteBufAllocators; + } } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 791c6fc2fc4..f18cc272fb3 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1278,9 +1278,12 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def workerDirectMemoryRatioToPauseReplicate: Double = get(WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE) def workerDirectMemoryRatioToResume: Double = get(WORKER_DIRECT_MEMORY_RATIO_RESUME) + def workerPinnedMemoryRatioToResume: Double = get(WORKER_PINNED_MEMORY_RATIO_RESUME) def workerPartitionSorterDirectMemoryRatioThreshold: Double = get(WORKER_PARTITION_SORTER_DIRECT_MEMORY_RATIO_THRESHOLD) def workerDirectMemoryPressureCheckIntervalMs: Long = get(WORKER_DIRECT_MEMORY_CHECK_INTERVAL) + def workerPinnedMemoryCheckEnabled: Boolean = get(WORKER_PINNED_MEMORY_CHECK_ENABLED) + def workerPinnedMemoryCheckIntervalMs: Long = get(WORKER_PINNED_MEMORY_CHECK_INTERVAL) def workerDirectMemoryReportIntervalSecond: Long = get(WORKER_DIRECT_MEMORY_REPORT_INTERVAL) def workerDirectMemoryTrimChannelWaitInterval: Long = get(WORKER_DIRECT_MEMORY_TRIM_CHANNEL_WAIT_INTERVAL) @@ -3711,6 +3714,24 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("10ms") + val WORKER_PINNED_MEMORY_CHECK_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.worker.monitor.pinnedMemory.check.enabled") + .categories("worker") + .doc("If true, MemoryManager will check worker should resume by pinned memory used.") + .version("0.6.0") + .booleanConf + .createWithDefaultString("true") + + val WORKER_PINNED_MEMORY_CHECK_INTERVAL: ConfigEntry[Long] = + buildConf("celeborn.worker.monitor.pinnedMemory.check.interval") + .categories("worker") + .doc("Interval of worker direct pinned memory checking, " + + "only takes effect when celeborn.network.memory.allocator.pooled and " + + "celeborn.worker.monitor.pinnedMemory.check.enabled are enabled.") + .version("0.6.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("10s") + val WORKER_DIRECT_MEMORY_REPORT_INTERVAL: ConfigEntry[Long] = buildConf("celeborn.worker.monitor.memory.report.interval") .withAlternative("celeborn.worker.memory.reportInterval") @@ -3860,6 +3881,16 @@ object CelebornConf extends Logging { .doubleConf .createWithDefault(0.7) + val WORKER_PINNED_MEMORY_RATIO_RESUME: ConfigEntry[Double] = + buildConf("celeborn.worker.pinnedMemoryRatioToResume") + .categories("worker") + .doc("If pinned memory usage is less than this limit, worker will resume, " + + "only takes effect when celeborn.network.memory.allocator.pooled and " + + "celeborn.worker.monitor.pinnedMemory.check.enabled are enabled") + .version("0.6.0") + .doubleConf + .createWithDefault(0.3) + val WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE: ConfigEntry[Long] = buildConf("celeborn.worker.memoryFileStorage.maxFileSize") .categories("worker") diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index 14c7e791cd1..b4018b35e0d 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -144,9 +144,12 @@ license: | | celeborn.worker.monitor.memory.report.interval | 10s | false | Interval of worker direct memory tracker reporting to log. | 0.3.0 | celeborn.worker.memory.reportInterval | | celeborn.worker.monitor.memory.trimChannelWaitInterval | 1s | false | Wait time after worker trigger channel to trim cache. | 0.3.0 | | | celeborn.worker.monitor.memory.trimFlushWaitInterval | 1s | false | Wait time after worker trigger StorageManger to flush data. | 0.3.0 | | +| celeborn.worker.monitor.pinnedMemory.check.enabled | true | false | If true, MemoryManager will check worker should resume by pinned memory used. | 0.6.0 | | +| celeborn.worker.monitor.pinnedMemory.check.interval | 10s | false | Interval of worker direct pinned memory checking, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled. | 0.6.0 | | | celeborn.worker.partition.initial.readBuffersMax | 1024 | false | Max number of initial read buffers | 0.3.0 | | | celeborn.worker.partition.initial.readBuffersMin | 1 | false | Min number of initial read buffers | 0.3.0 | | | celeborn.worker.partitionSorter.directMemoryRatioThreshold | 0.1 | false | Max ratio of partition sorter's memory for sorting, when reserved memory is higher than max partition sorter memory, partition sorter will stop sorting. If this value is set to 0, partition files sorter will skip memory check and ServingState check. | 0.2.0 | | +| celeborn.worker.pinnedMemoryRatioToResume | 0.3 | false | If pinned memory usage is less than this limit, worker will resume, only takes effect when celeborn.network.memory.allocator.pooled and celeborn.worker.monitor.pinnedMemory.check.enabled are enabled | 0.6.0 | | | celeborn.worker.push.heartbeat.enabled | false | false | enable the heartbeat from worker to client when pushing data | 0.3.0 | | | celeborn.worker.push.io.threads | <undefined> | false | Netty IO thread number of worker to handle client push data. The default threads number is the number of flush thread. | 0.2.0 | | | celeborn.worker.push.port | 0 | false | Server port for Worker to receive push data request from ShuffleClient. | 0.2.0 | | diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java index 6db63598e61..31d2e47e827 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/memory/MemoryManager.java @@ -30,12 +30,14 @@ import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.util.internal.PlatformDependent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.metrics.source.AbstractSource; +import org.apache.celeborn.common.network.util.NettyUtils; import org.apache.celeborn.common.protocol.TransportModuleConstants; import org.apache.celeborn.common.util.ThreadUtils; import org.apache.celeborn.common.util.Utils; @@ -50,7 +52,8 @@ public class MemoryManager { @VisibleForTesting public long maxDirectMemory; private final long pausePushDataThreshold; private final long pauseReplicateThreshold; - private final double resumeRatio; + private final double directMemoryResumeRatio; + private final double pinnedMemoryResumeRatio; private final long maxSortMemory; private final int forceAppendPauseSpentTimeThreshold; private final List memoryPressureListeners = new ArrayList<>(); @@ -93,6 +96,9 @@ public class MemoryManager { private long memoryFileStorageThreshold; private final LongAdder memoryFileStorageCounter = new LongAdder(); private final StorageManager storageManager; + private boolean pinnedMemoryCheckEnabled; + private long pinnedMemoryCheckInterval; + private long pinnedMemoryLastCheckTime = 0; @VisibleForTesting public static MemoryManager initialize(CelebornConf conf) { @@ -120,11 +126,14 @@ public static MemoryManager instance() { private MemoryManager(CelebornConf conf, StorageManager storageManager, AbstractSource source) { double pausePushDataRatio = conf.workerDirectMemoryRatioToPauseReceive(); double pauseReplicateRatio = conf.workerDirectMemoryRatioToPauseReplicate(); - this.resumeRatio = conf.workerDirectMemoryRatioToResume(); + this.directMemoryResumeRatio = conf.workerDirectMemoryRatioToResume(); + this.pinnedMemoryResumeRatio = conf.workerPinnedMemoryRatioToResume(); double maxSortMemRatio = conf.workerPartitionSorterDirectMemoryRatioThreshold(); double readBufferRatio = conf.workerDirectMemoryRatioForReadBuffer(); double memoryFileStorageRatio = conf.workerDirectMemoryRatioForMemoryFilesStorage(); long checkInterval = conf.workerDirectMemoryPressureCheckIntervalMs(); + this.pinnedMemoryCheckEnabled = conf.workerPinnedMemoryCheckEnabled(); + this.pinnedMemoryCheckInterval = conf.workerPinnedMemoryCheckIntervalMs(); long reportInterval = conf.workerDirectMemoryReportIntervalSecond(); double readBufferTargetRatio = conf.readBufferTargetRatio(); long readBufferTargetUpdateInterval = conf.readBufferTargetUpdateInterval(); @@ -148,9 +157,10 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract pauseReplicateRatio, CelebornConf.WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE().key(), pausePushDataRatio)); - Preconditions.checkArgument(pausePushDataRatio > resumeRatio); + Preconditions.checkArgument(pausePushDataRatio > directMemoryResumeRatio); if (memoryFileStorageRatio > 0) { - Preconditions.checkArgument(resumeRatio > (readBufferRatio + memoryFileStorageRatio)); + Preconditions.checkArgument( + directMemoryResumeRatio > (readBufferRatio + memoryFileStorageRatio)); } maxSortMemory = ((long) (maxDirectMemory * maxSortMemRatio)); @@ -275,14 +285,16 @@ private MemoryManager(CelebornConf conf, StorageManager storageManager, Abstract + "pause replication memory: {}, " + "read buffer memory limit: {} target: {}, " + "memory shuffle storage limit: {}, " - + "resume memory ratio: {}", + + "resume by direct memory ratio: {}, " + + "resume by pinned memory ratio: {}", Utils.bytesToString(maxDirectMemory), Utils.bytesToString(pausePushDataThreshold), Utils.bytesToString(pauseReplicateThreshold), Utils.bytesToString(readBufferThreshold), Utils.bytesToString(readBufferTarget), Utils.bytesToString(memoryFileStorageThreshold), - resumeRatio); + directMemoryResumeRatio, + pinnedMemoryResumeRatio); } public boolean shouldEvict(boolean aggressiveMemoryFileEvictEnabled, double evictRatio) { @@ -305,7 +317,7 @@ public ServingState currentServingState() { return ServingState.PUSH_PAUSED; } // trigger resume - if (memoryUsage / (double) (maxDirectMemory) < resumeRatio) { + if (memoryUsage / (double) (maxDirectMemory) < directMemoryResumeRatio) { isPaused = false; return ServingState.NONE_PAUSED; } @@ -315,69 +327,70 @@ public ServingState currentServingState() { } @VisibleForTesting - protected void switchServingState() { + public void switchServingState() { ServingState lastState = servingState; servingState = currentServingState(); - if (lastState == servingState) { - if (servingState != ServingState.NONE_PAUSED) { + logger.info("Serving state transformed from {} to {}", lastState, servingState); + switch (servingState) { + case PUSH_PAUSED: + if (canResumeByPinnedMemory()) { + resumeByPinnedMemory(servingState); + } else { + pausePushDataCounter.increment(); + if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { + logger.info("Trigger action: RESUME REPLICATE"); + resumeReplicate(); + } else { + logger.info("Trigger action: PAUSE PUSH"); + pausePushDataStartTime = System.currentTimeMillis(); + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); + } + } logger.debug("Trigger action: TRIM"); trimCounter += 1; - // force to append pause spent time even we are in pause state + trimAllListeners(); if (trimCounter >= forceAppendPauseSpentTimeThreshold) { logger.debug( "Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter); appendPauseSpentTime(servingState); } - trimAllListeners(); - } - return; - } - logger.info("Serving state transformed from {} to {}", lastState, servingState); - switch (servingState) { - case PUSH_PAUSED: - pausePushDataCounter.increment(); - if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { - logger.info("Trigger action: RESUME REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); - } else if (lastState == ServingState.NONE_PAUSED) { - logger.info("Trigger action: PAUSE PUSH"); - pausePushDataStartTime = System.currentTimeMillis(); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); - } - trimAllListeners(); break; case PUSH_AND_REPLICATE_PAUSED: - pausePushDataAndReplicateCounter.increment(); - if (lastState == ServingState.NONE_PAUSED) { + if (canResumeByPinnedMemory()) { + resumeByPinnedMemory(servingState); + } else { + pausePushDataAndReplicateCounter.increment(); logger.info("Trigger action: PAUSE PUSH"); pausePushDataAndReplicateStartTime = System.currentTimeMillis(); memoryPressureListeners.forEach( memoryPressureListener -> memoryPressureListener.onPause(TransportModuleConstants.PUSH_MODULE)); + logger.info("Trigger action: PAUSE REPLICATE"); + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE)); } - logger.info("Trigger action: PAUSE REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onPause(TransportModuleConstants.REPLICATE_MODULE)); + logger.debug("Trigger action: TRIM"); + trimCounter += 1; trimAllListeners(); + if (trimCounter >= forceAppendPauseSpentTimeThreshold) { + logger.debug( + "Trigger action: TRIM for {} times, force to append pause spent time.", trimCounter); + appendPauseSpentTime(servingState); + } break; case NONE_PAUSED: // resume from paused mode, append pause spent time - appendPauseSpentTime(lastState); if (lastState == ServingState.PUSH_AND_REPLICATE_PAUSED) { - logger.info("Trigger action: RESUME REPLICATE"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); + resumeReplicate(); + resumePush(); + appendPauseSpentTime(lastState); + } else if (lastState == ServingState.PUSH_PAUSED) { + resumePush(); + appendPauseSpentTime(lastState); } - logger.info("Trigger action: RESUME PUSH"); - memoryPressureListeners.forEach( - memoryPressureListener -> - memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE)); } } @@ -436,6 +449,16 @@ public long getMemoryUsage() { return getNettyUsedDirectMemory() + sortMemoryCounter.get(); } + public long getPinnedMemory() { + return getNettyPinnedDirectMemory() + sortMemoryCounter.get(); + } + + public long getNettyPinnedDirectMemory() { + return NettyUtils.getAllPooledByteBufAllocators().stream() + .mapToLong(PooledByteBufAllocator::pinnedDirectMemory) + .sum(); + } + public AtomicLong getSortMemoryCounter() { return sortMemoryCounter; } @@ -557,6 +580,47 @@ public static void reset() { _INSTANCE = null; } + private void resumeByPinnedMemory(ServingState servingState) { + switch (servingState) { + case PUSH_AND_REPLICATE_PAUSED: + logger.info( + "Serving State is PUSH_AND_REPLICATE_PAUSED, but resume by lower pinned memory {}", + getNettyPinnedDirectMemory()); + resumeReplicate(); + resumePush(); + case PUSH_PAUSED: + logger.info( + "Serving State is PUSH_PAUSED, but resume by lower pinned memory {}", + getNettyPinnedDirectMemory()); + resumePush(); + } + } + + private boolean canResumeByPinnedMemory() { + if (pinnedMemoryCheckEnabled + && System.currentTimeMillis() - pinnedMemoryLastCheckTime >= pinnedMemoryCheckInterval + && getPinnedMemory() / (double) (maxDirectMemory) < pinnedMemoryResumeRatio) { + pinnedMemoryLastCheckTime = System.currentTimeMillis(); + return true; + } else { + return false; + } + } + + private void resumePush() { + logger.info("Trigger action: RESUME PUSH"); + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onResume(TransportModuleConstants.PUSH_MODULE)); + } + + private void resumeReplicate() { + logger.info("Trigger action: RESUME REPLICATE"); + memoryPressureListeners.forEach( + memoryPressureListener -> + memoryPressureListener.onResume(TransportModuleConstants.REPLICATE_MODULE)); + } + public interface MemoryPressureListener { void onPause(String moduleName); diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala index 78fc436703b..c0fe08e6173 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/memory/MemoryManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.celeborn.service.deploy.memory import scala.concurrent.duration.DurationInt +import org.mockito.{Mockito, MockitoSugar} import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.{interval, timeout} @@ -27,8 +28,8 @@ import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.CelebornConf.{WORKER_DIRECT_MEMORY_RATIO_PAUSE_RECEIVE, WORKER_DIRECT_MEMORY_RATIO_PAUSE_REPLICATE} import org.apache.celeborn.common.protocol.TransportModuleConstants import org.apache.celeborn.service.deploy.worker.memory.MemoryManager -import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.MemoryPressureListener -import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.ServingState +import org.apache.celeborn.service.deploy.worker.memory.MemoryManager.{MemoryPressureListener, ServingState} + class MemoryManagerSuite extends CelebornFunSuite { // reset the memory manager before each test @@ -153,6 +154,68 @@ class MemoryManagerSuite extends CelebornFunSuite { assert(memoryManager.getPausePushDataAndReplicateTime.longValue() > 0) } + test("[CELEBORN-1792] Test MemoryManager resume by pinned memory") { + val conf = new CelebornConf() + conf.set(CelebornConf.WORKER_DIRECT_MEMORY_CHECK_INTERVAL.key, "300s") + conf.set(CelebornConf.WORKER_PINNED_MEMORY_CHECK_INTERVAL.key, "0") + MemoryManager.reset() + val memoryManager = MockitoSugar.spy(MemoryManager.initialize(conf)) + val maxDirectorMemory = memoryManager.maxDirectMemory + val pushThreshold = + (conf.workerDirectMemoryRatioToPauseReceive * maxDirectorMemory).longValue() + val replicateThreshold = + (conf.workerDirectMemoryRatioToPauseReplicate * maxDirectorMemory).longValue() + + val pushListener = new MockMemoryPressureListener(TransportModuleConstants.PUSH_MODULE) + val replicateListener = + new MockMemoryPressureListener(TransportModuleConstants.REPLICATE_MODULE) + memoryManager.registerMemoryListener(pushListener) + memoryManager.registerMemoryListener(replicateListener) + + // NONE PAUSED -> PAUSE PUSH + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L) + Mockito.when(memoryManager.getMemoryUsage).thenReturn(pushThreshold + 1) + memoryManager.switchServingState() + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + assert(memoryManager.servingState == ServingState.PUSH_PAUSED) + + // KEEP PAUSE PUSH + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(pushThreshold + 1) + memoryManager.switchServingState() + assert(pushListener.isPause) + assert(!replicateListener.isPause) + assert(memoryManager.servingState == ServingState.PUSH_PAUSED) + + Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L) + memoryManager.switchServingState() + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + assert(memoryManager.servingState == ServingState.NONE_PAUSED) + + // NONE PAUSED -> PAUSE PUSH AND REPLICATE + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(0L) + Mockito.when(memoryManager.getMemoryUsage).thenReturn(replicateThreshold + 1) + memoryManager.switchServingState() + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + assert(memoryManager.servingState == ServingState.PUSH_AND_REPLICATE_PAUSED) + + // KEEP PAUSE PUSH AND REPLICATE + Mockito.when(memoryManager.getNettyPinnedDirectMemory).thenReturn(replicateThreshold + 1) + memoryManager.switchServingState() + assert(pushListener.isPause) + assert(replicateListener.isPause) + assert(memoryManager.servingState == ServingState.PUSH_AND_REPLICATE_PAUSED) + + Mockito.when(memoryManager.getMemoryUsage).thenReturn(0L) + memoryManager.switchServingState() + assert(!pushListener.isPause) + assert(!replicateListener.isPause) + assert(memoryManager.servingState == ServingState.NONE_PAUSED) + MemoryManager.reset() + } + class MockMemoryPressureListener( val belongModuleName: String, var isPause: Boolean = false) extends MemoryPressureListener {