diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index 8f3e537e68..b9413c5b4d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common; import lombok.extern.log4j.Log4j2; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.annotation.Connector; import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.annotation.ExecuteOutput; @@ -15,9 +16,11 @@ import org.opensearch.ml.common.annotation.MLInput; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLOutputType; import org.reflections.Reflections; +import java.io.IOException; import java.lang.reflect.Constructor; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -203,12 +206,27 @@ public static , S, I extends Object> S initMLInstance(T type, @SuppressWarnings("unchecked") public static , S, I extends Object> S initExecuteInputInstance(T type, I in, Class constructorParamClass) { - return init(executeInputClassMap, type, in, constructorParamClass); + try { + return init(executeInputClassMap, type, in, constructorParamClass); + } catch (Exception e) { + return init(mlInputClassMap, type, in, constructorParamClass); + } } @SuppressWarnings("unchecked") public static , S, I extends Object> S initExecuteOutputInstance(T type, I in, Class constructorParamClass) { - return init(executeOutputClassMap, type, in, constructorParamClass); + try { + return init(executeOutputClassMap, type, in, constructorParamClass); + } catch (Exception e) { + if (in instanceof StreamInput) { + try { + return (S) MLOutput.fromStream((StreamInput) in); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + throw e; + } } @SuppressWarnings("unchecked") diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java index 2d4ade93fb..b90266c7f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.exception.ExecuteException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.output.Output; @@ -16,5 +17,5 @@ public interface Executable { * @param input input data * @return execution result */ - Output execute(Input input) throws ExecuteException; + void execute(Input input, ActionListener listener) throws ExecuteException; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index aeaed6bd21..85f06eb89d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -9,6 +9,7 @@ import java.util.Locale; import java.util.Map; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; @@ -152,20 +153,20 @@ public MLOutput trainAndPredict(Input input) { return trainAndPredictable.trainAndPredict(mlInput); } - public Output execute(Input input) throws Exception { + public void execute(Input input, ActionListener listener) throws Exception { validateInput(input); if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) { MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); if (executable == null) { throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } - return executable.execute(input); + executable.execute(input, listener); } else { Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); if (executable == null) { throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } - return executable.execute(input); + executable.execute(input, listener); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index fab052da19..0ae0c44dc7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -17,9 +17,9 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.io.FileUtils; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.ExecuteException; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.model.MLModelFormat; @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable { protected Device[] devices; protected AtomicInteger nextDevice = new AtomicInteger(0); - public abstract Output execute(Input input) throws ExecuteException; + public abstract void execute(Input input, ActionListener listener); protected Predictor getPredictor() { int currentDevice = nextDevice.getAndIncrement(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java index c874c903e4..b11fc9a39c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java @@ -16,13 +16,10 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.SearchRequest; @@ -527,23 +524,10 @@ protected List> getAllIntervals() { } @Override - public Output execute(Input input) { - CountDownLatch latch = new CountDownLatch(1); - AtomicReference outRef = new AtomicReference<>(); - AtomicReference exRef = new AtomicReference<>(); + public void execute(Input input, ActionListener listener) { getLocalizationResults( (AnomalyLocalizationInput) input, - new LatchedActionListener(ActionListener.wrap(o -> outRef.set(o), e -> exRef.set(e)), latch) + ActionListener.wrap(o -> listener.onResponse(o), e -> listener.onFailure(e)) ); - try { - latch.await(); - } catch (InterruptedException e) { - throw new IllegalStateException(e); - } - if (exRef.get() != null) { - throw new RuntimeException(exRef.get()); - } else { - return outRef.get(); - } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index c44704f688..15bccadac0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -107,8 +107,13 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust * contains 3 properties event_window, event_pattern and suspected_metrics * @throws ExecuteException */ + // @Override + // public MetricsCorrelationOutput execute(Input input) throws ExecuteException { + // + // } + @Override - public MetricsCorrelationOutput execute(Input input) throws ExecuteException { + public void execute(Input input, ActionListener listener) { if (!(input instanceof MetricsCorrelationInput)) { throw new ExecuteException("wrong input"); } @@ -148,7 +153,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX).id(FunctionName.METRICS_CORRELATION.name()); - ActionListener listener = ActionListener.wrap(r -> { + ActionListener actionListener = ActionListener.wrap(r -> { if (r.isExists()) { modelId = r.getId(); Map sourceAsMap = r.getSourceAsMap(); @@ -176,7 +181,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ); } }, e -> { log.error("Failed to get model", e); }); - client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); + client.get(getModelRequest, ActionListener.runBefore(actionListener, context::restore)); } } } else { @@ -227,7 +232,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { } tensorOutputs.add(parseModelTensorOutput(djlOutput, null)); - return new MetricsCorrelationOutput(tensorOutputs); + listener.onResponse(new MetricsCorrelationOutput(tensorOutputs)); } @VisibleForTesting diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java index 2802cebf86..ee4353d7e3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java @@ -10,6 +10,7 @@ import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; @@ -36,7 +37,7 @@ public LocalSampleCalculator(Client client, Settings settings) { } @Override - public Output execute(Input input) { + public void execute(Input input, ActionListener listener) { if (input == null || !(input instanceof LocalSampleCalculatorInput)) { throw new IllegalArgumentException("wrong input"); } @@ -46,13 +47,16 @@ public Output execute(Input input) { switch (operation) { case "sum": double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum(); - return new LocalSampleCalculatorOutput(sum); + listener.onResponse(new LocalSampleCalculatorOutput(sum)); + break; case "max": double max = inputData.stream().max(Comparator.naturalOrder()).get(); - return new LocalSampleCalculatorOutput(max); + listener.onResponse(new LocalSampleCalculatorOutput(max)); + break; case "min": double min = inputData.stream().min(Comparator.naturalOrder()).get(); - return new LocalSampleCalculatorOutput(min); + listener.onResponse(new LocalSampleCalculatorOutput(min)); + break; default: throw new IllegalArgumentException("can't support this operation"); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java index 27c32055a1..89d7a6c77a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java @@ -5,8 +5,7 @@ package org.opensearch.ml.engine; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.*; import static org.mockito.Mockito.mock; import java.util.ArrayList; @@ -17,9 +16,11 @@ import org.junit.Test; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; @@ -43,19 +44,27 @@ public void initInstance_LocalSampleCalculator() { // set properties MLEngineClassLoader.deregister(FunctionName.LOCAL_SAMPLE_CALCULATOR); - LocalSampleCalculator instance = MLEngineClassLoader + final LocalSampleCalculator instance = MLEngineClassLoader .initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class, properties); - LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) instance.execute(input); - assertEquals(d1 + d2, output.getResult(), 1e-6); - assertEquals(client, instance.getClient()); - assertEquals(settings, instance.getSettings()); + + ActionListener actionListener = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + assertEquals(d1 + d2, output.getResult(), 1e-6); + assertEquals(client, instance.getClient()); + assertEquals(settings, instance.getSettings()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + + instance.execute(input, actionListener); // don't set properties - instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class); - output = (LocalSampleCalculatorOutput) instance.execute(input); - assertEquals(d1 + d2, output.getResult(), 1e-6); - assertNull(instance.getClient()); - assertNull(instance.getSettings()); + final LocalSampleCalculator instance2 = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class); + ActionListener actionListener2 = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + assertEquals(d1 + d2, output.getResult(), 1e-6); + assertNull(instance2.getClient()); + assertNull(instance2.getSettings()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + instance2.execute(input, actionListener2); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 11f0c207e6..8a3d7c5453 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -6,6 +6,8 @@ package org.opensearch.ml.engine; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame; import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame; @@ -23,6 +25,7 @@ import org.junit.rules.ExpectedException; import org.mockito.MockedStatic; import org.mockito.Mockito; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; @@ -40,6 +43,7 @@ import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; import org.opensearch.ml.engine.algorithms.regression.LinearRegression; import org.opensearch.ml.engine.encryptor.Encryptor; @@ -265,8 +269,12 @@ public void trainAndPredictWithInvalidInput() { @Test public void executeLocalSampleCalculator() throws Exception { Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0)); - LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) mlEngine.execute(input); - assertEquals(3.0, output.getResult(), 1e-5); + ActionListener actionListener = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + assertEquals(3.0, output.getResult(), 1e-5); + }, e -> { fail("Test failed: " + e.getMessage()); }); + mlEngine.execute(input, actionListener); + } @Test @@ -289,7 +297,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params return null; } }; - mlEngine.execute(input); + ActionListener actionListener = mock(ActionListener.class); + mlEngine.execute(input, actionListener); } private MLModel trainKMeansModel() { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java index 722e2f21aa..daf1698dfa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java @@ -6,6 +6,8 @@ package org.opensearch.ml.engine.algorithms.anomalylocalization; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; @@ -13,6 +15,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -53,7 +56,9 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilder; import org.opensearch.ml.common.input.execute.anomalylocalization.AnomalyLocalizationInput; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.execute.anomalylocalization.AnomalyLocalizationOutput; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.Aggregations; @@ -61,8 +66,6 @@ import org.opensearch.search.aggregations.bucket.filter.Filters; import org.opensearch.search.aggregations.metrics.NumericMetricsAggregation.SingleValue; -import com.google.common.collect.ImmutableMap; - public class AnomalyLocalizerImplTests { @Mock @@ -438,13 +441,17 @@ public void testExecuteSucceed() { when(indexNameExpressionResolver.concreteIndexNames(any(ClusterState.class), any(IndicesOptions.class), anyString())) .thenReturn(IndicesOptions); - AnomalyLocalizationOutput actualOutput = (AnomalyLocalizationOutput) anomalyLocalizer.execute(input); - - assertEquals(expectedOutput, actualOutput); + ActionListener actionListener = ActionListener.wrap(o -> { + AnomalyLocalizationOutput actualOutput = (AnomalyLocalizationOutput) o; + assertEquals(expectedOutput, actualOutput); + }, e -> { + fail("Test failed: " + e.getMessage()); + }); + anomalyLocalizer.execute(input, actionListener); } @SuppressWarnings("unchecked") - @Test(expected = RuntimeException.class) + @Test public void testExecuteFail() { doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -452,13 +459,19 @@ public void testExecuteFail() { listener.onFailure(new RuntimeException()); return null; }).when(client).multiSearch(any(), any()); - anomalyLocalizer.execute(input); + ActionListener actionListener = mock(ActionListener.class); + anomalyLocalizer.execute(input, actionListener); + ArgumentCaptor exceptionArgumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(exceptionArgumentCaptor.capture()); + assertTrue(exceptionArgumentCaptor.getValue() instanceof RuntimeException); } - @Test(expected = RuntimeException.class) + @Test public void testExecuteInterrupted() { - Thread.currentThread().interrupt(); - anomalyLocalizer.execute(input); + ActionListener actionListener = ActionListener.wrap(o -> { Thread.currentThread().interrupt(); }, e -> { + assertTrue(e.getMessage().contains("Failed to find index")); + }); + anomalyLocalizer.execute(input, actionListener); } private ClusterState setupTestClusterState() { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 223cb22289..0a05127b87 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -10,6 +10,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.anyLong; @@ -85,6 +86,7 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; @@ -105,6 +107,7 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -115,8 +118,6 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import com.google.common.collect.ImmutableMap; - public class MetricsCorrelationTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -328,10 +329,13 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio return null; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); - MetricsCorrelationOutput output = metricsCorrelation.execute(input); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(input, actionListener); } @Ignore @@ -360,10 +364,13 @@ public void testExecuteWithModelInIndexAndEmptyOutput() throws ExecuteException, return null; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); - MetricsCorrelationOutput output = metricsCorrelation.execute(input); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(input, actionListener); } @Test @@ -387,12 +394,15 @@ public void testExecuteWithModelInIndexAndOneEvent() throws ExecuteException, UR when(client.execute(any(MLTaskGetAction.class), any(MLTaskGetRequest.class))).thenReturn(mockedFutureResponse); when(mockedFutureResponse.actionGet(anyLong())).thenReturn(taskResponse); - MetricsCorrelationOutput output = metricsCorrelation.execute(extendedInput); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(extendedInput, actionListener); } @Ignore @@ -428,12 +438,15 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR return mlRegisterModelResponse; }).when(client).execute(any(MLRegisterModelAction.class), any(MLRegisterModelRequest.class), isA(ActionListener.class)); - MetricsCorrelationOutput output = metricsCorrelation.execute(extendedInput); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(extendedInput, actionListener); } @Ignore @@ -475,12 +488,15 @@ public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws Execu return mlDeployModelResponse; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); - MetricsCorrelationOutput output = metricsCorrelation.execute(extendedInput); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(extendedInput, actionListener); } @Ignore @@ -517,12 +533,15 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, return mlRegisterModelResponse; }).when(client).execute(any(MLRegisterModelAction.class), any(MLRegisterModelRequest.class), isA(ActionListener.class)); - MetricsCorrelationOutput output = metricsCorrelation.execute(extendedInput); - List mlModelOutputs = output.getModelOutput(); - assert mlModelOutputs.size() == 1; - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); - assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + ActionListener actionListener = ActionListener.wrap(o -> { + MetricsCorrelationOutput output = (MetricsCorrelationOutput) o; + List mlModelOutputs = output.getModelOutput(); + assert mlModelOutputs.size() == 1; + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_window()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getEvent_pattern()); + assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); + }, e -> { fail("Test failed: " + e.getMessage()); }); + metricsCorrelation.execute(extendedInput, actionListener); } // working @@ -650,7 +669,7 @@ public void testDeployModelFail() { @Test public void testWrongInput() throws ExecuteException { exceptionRule.expect(ExecuteException.class); - metricsCorrelation.execute(mock(LocalSampleCalculatorInput.class)); + metricsCorrelation.execute(mock(LocalSampleCalculatorInput.class), mock(ActionListener.class)); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculatorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculatorTest.java index f9eb01db12..2f85b550c8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculatorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculatorTest.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine.algorithms.sample; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + import java.util.Arrays; import org.junit.Assert; @@ -15,7 +18,9 @@ import org.mockito.Mock; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; public class LocalSampleCalculatorTest { @@ -36,16 +41,25 @@ public void setUp() { @Test public void execute() { - LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) calculator.execute(input); - Assert.assertEquals(6.0, output.getResult().doubleValue(), 1e-5); + ActionListener actionListener1 = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + Assert.assertEquals(6.0, output.getResult().doubleValue(), 1e-5); + }, e -> { fail("Test failed: " + e.getMessage()); }); + calculator.execute(input, actionListener1); + ActionListener actionListener2 = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + Assert.assertEquals(3.0, output.getResult().doubleValue(), 1e-5); + }, e -> { fail("Test failed: " + e.getMessage()); }); input = new LocalSampleCalculatorInput("max", Arrays.asList(1.0, 2.0, 3.0)); - output = (LocalSampleCalculatorOutput) calculator.execute(input); - Assert.assertEquals(3.0, output.getResult().doubleValue(), 1e-5); + calculator.execute(input, actionListener2); + ActionListener actionListener3 = ActionListener.wrap(o -> { + LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; + Assert.assertEquals(1.0, output.getResult().doubleValue(), 1e-5); + }, e -> { fail("Test failed: " + e.getMessage()); }); input = new LocalSampleCalculatorInput("min", Arrays.asList(1.0, 2.0, 3.0)); - output = (LocalSampleCalculatorOutput) calculator.execute(input); - Assert.assertEquals(1.0, output.getResult().doubleValue(), 1e-5); + calculator.execute(input, actionListener3); } @Test @@ -53,13 +67,14 @@ public void executeWithWrongOperation() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("can't support this operation"); input = new LocalSampleCalculatorInput("wrong_operation", Arrays.asList(1.0, 2.0, 3.0)); - calculator.execute(input); + ActionListener actionListener = ActionListener.wrap(o -> {}, e -> { fail("Test failed: " + e.getMessage()); }); + calculator.execute(input, actionListener); } @Test public void executeWithNullInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("wrong input"); - calculator.execute(null); + calculator.execute(null, mock(ActionListener.class)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 3e82e7a20e..fb526e6e55 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -16,7 +16,6 @@ import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; -import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; @@ -104,9 +103,10 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener { + MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); + listener.onResponse(response); + }, e -> { listener.onFailure(e); })); } catch (Exception e) { mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT)