diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 34be01b3a2..1e9bf1f1f6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -19,6 +19,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -171,6 +172,8 @@ public void onResponse(MLModel mlModel) { ); } else if (e instanceof MLResourceNotFoundException) { wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND)); + } else if (e instanceof CircuitBreakingException) { + wrappedListener.onFailure(e); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 5e045ae539..08762f8d10 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -18,6 +18,7 @@ public class MemoryCircuitBreaker extends ThresholdCircuitBreaker { // TODO: make this value configurable as cluster setting private static final String ML_MEMORY_CB = "Memory Circuit Breaker"; public static final short DEFAULT_JVM_HEAP_USAGE_THRESHOLD = 85; + public static final short JVM_HEAP_MAX_THRESHOLD = 100; // when threshold is 100, this CB check is ignored private final JvmService jvmService; private volatile Integer jvmHeapMemThreshold = 85; @@ -50,6 +51,6 @@ public Short getThreshold() { @Override public boolean isOpen() { - return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold(); + return getThreshold() < JVM_HEAP_MAX_THRESHOLD && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index e83493f4e5..b6a098b6b6 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -59,6 +59,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentLinkedDeque; @@ -827,7 +828,9 @@ private ThreadedActionListener threadedActionListener(String threadPoolNa * @param runningTaskLimit limit */ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) { - checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) { + checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + } mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index a19ffc4af3..76cc0275c1 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -137,7 +137,7 @@ public void dispatchTask( if (clusterService.localNode().getId().equals(node.getId())) { log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); - executeTask(request, listener); + checkCBAndExecute(functionName, request, listener); } else { log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index b2c71d6ed8..54195ab156 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) { public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener listener) { if (!request.isDispatchTask()) { log.debug("Run ML request {} locally", request.getRequestID()); - checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); - executeTask(request, listener); + checkCBAndExecute(functionName, request, listener); return; } dispatchTask(functionName, request, transportService, listener); @@ -129,4 +128,11 @@ public void dispatchTask( protected abstract TransportResponseHandler getResponseHandler(ActionListener listener); protected abstract void executeTask(Request request, ActionListener listener); + + protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener listener) { + if (functionName != FunctionName.REMOTE) { + checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); + } + executeTask(request, listener); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index f5c0f5ba52..f8f0dcb547 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -16,12 +16,13 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; @@ -60,7 +61,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB(); if (openCircuitBreaker != null) { mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment(); - throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!"); + throw new CircuitBreakingException( + openCircuitBreaker.getName() + " is open, please check your resources!", + CircuitBreaker.Durability.TRANSIENT + ); } } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index a1832dcd62..3f01b1bfea 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -34,6 +34,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; @@ -233,4 +235,26 @@ public void testPrediction_MLResourceNotFoundException() { assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage()); } + public void testPrediction_MLLimitExceededException() { + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); + when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT)); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[3]).onResponse(null); + return null; + }).when(mlPredictTaskRunner).run(any(), any(), any(), any()); + + transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage()); + } + } diff --git a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java index cdd1f6fc22..8c7f6f41d4 100644 --- a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java @@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() { settingsService.applySettings(newSettingsBuilder.build()); Assert.assertFalse(breaker.isOpen()); } + + @Test + public void testIsOpen_DisableMemoryCB() { + ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD); + when(clusterService.getClusterSettings()).thenReturn(settingsService); + + CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService); + + when(mem.getHeapUsedPercent()).thenReturn((short) 90); + Assert.assertTrue(breaker.isOpen()); + + when(mem.getHeapUsedPercent()).thenReturn((short) 100); + Settings.Builder newSettingsBuilder = Settings.builder(); + newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100); + settingsService.applySettings(newSettingsBuilder.build()); + Assert.assertFalse(breaker.isOpen()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 569aaee3c9..466f8231c6 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -80,9 +80,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.breaker.CircuitBreaker; +import org.opensearch.core.common.breaker.CircuitBreakingException; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.breaker.MLCircuitBreakerService; -import org.opensearch.ml.breaker.MemoryCircuitBreaker; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -113,7 +114,6 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.task.MLTaskManager; -import org.opensearch.monitor.jvm.JvmService; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -322,7 +322,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() { when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker); when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker"); when(thresholdCircuitBreaker.getThreshold()).thenReturn(87); - expectedEx.expect(MLException.class); + expectedEx.expect(CircuitBreakingException.class); expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!"); modelManager.registerMLModel(registerModelInput, mlTask); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); @@ -451,21 +451,32 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException { verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } - public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() { + public void testRegisterMLRemoteModel_SkipMemoryCBOpen() { ActionListener listener = mock(ActionListener.class); - MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class)); - String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!"; - when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage)); + doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any()); + when(mlCircuitBreakerService.checkOpenCB()) + .thenThrow( + new CircuitBreakingException( + "Memory Circuit Breaker is open, please check your resources!", + CircuitBreaker.Durability.TRANSIENT + ) + ); + when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService); + when(modelHelper.isModelAllowed(any(), any())).thenReturn(true); MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true); MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build(); + mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true); + doAnswer(invocation -> { + ActionListener indexResponseActionListener = (ActionListener) invocation.getArguments()[1]; + indexResponseActionListener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + when(indexResponse.getId()).thenReturn("mockIndexId"); modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); - verify(listener, times(1)).onFailure(argCaptor.capture()); - Exception e = argCaptor.getValue(); - assertTrue(e instanceof MLLimitExceededException); - assertEquals(memCBIsOpenMessage, e.getMessage()); + assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } public void testIndexRemoteModel() throws PrivilegedActionException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java index d1d332050e..dcaa2610b7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java @@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException { exception.getMessage(), allOf( containsString("Memory Circuit Breaker is open, please check your resources!"), - containsString("m_l_limit_exceeded_exception") + containsString("circuit_breaking_exception") ) ); diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java index 9e2abccebb..c47578925b 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java @@ -34,7 +34,6 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.transport.MLTaskRequest; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStat; @@ -132,15 +131,15 @@ public void testHandleAsyncMLTaskComplete_SyncTask() { verify(mlTaskManager, never()).updateMLTask(eq(syncMlTask.getTaskId()), any(), anyLong(), anyBoolean()); } - public void testRun_CircuitBreakerOpen() { + public void testRemoteInferenceRun_CircuitBreakerNotOpen() { when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker); when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker"); when(thresholdCircuitBreaker.getThreshold()).thenReturn(87); TransportService transportService = mock(TransportService.class); ActionListener listener = mock(ActionListener.class); MLTaskRequest request = new MLTaskRequest(false); - expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener)); + mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener); Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue(); - assertEquals(1L, value.longValue()); + assertEquals(0L, value.longValue()); } }