diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index 8f3e537e68..b9413c5b4d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common; import lombok.extern.log4j.Log4j2; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.annotation.Connector; import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.annotation.ExecuteOutput; @@ -15,9 +16,11 @@ import org.opensearch.ml.common.annotation.MLInput; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLOutputType; import org.reflections.Reflections; +import java.io.IOException; import java.lang.reflect.Constructor; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -203,12 +206,27 @@ public static , S, I extends Object> S initMLInstance(T type, @SuppressWarnings("unchecked") public static , S, I extends Object> S initExecuteInputInstance(T type, I in, Class constructorParamClass) { - return init(executeInputClassMap, type, in, constructorParamClass); + try { + return init(executeInputClassMap, type, in, constructorParamClass); + } catch (Exception e) { + return init(mlInputClassMap, type, in, constructorParamClass); + } } @SuppressWarnings("unchecked") public static , S, I extends Object> S initExecuteOutputInstance(T type, I in, Class constructorParamClass) { - return init(executeOutputClassMap, type, in, constructorParamClass); + try { + return init(executeOutputClassMap, type, in, constructorParamClass); + } catch (Exception e) { + if (in instanceof StreamInput) { + try { + return (S) MLOutput.fromStream((StreamInput) in); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + throw e; + } } @SuppressWarnings("unchecked") diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java index 2d4ade93fb..b90266c7f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.exception.ExecuteException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.output.Output; @@ -16,5 +17,5 @@ public interface Executable { * @param input input data * @return execution result */ - Output execute(Input input) throws ExecuteException; + void execute(Input input, ActionListener listener) throws ExecuteException; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index aeaed6bd21..85f06eb89d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -9,6 +9,7 @@ import java.util.Locale; import java.util.Map; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; @@ -152,20 +153,20 @@ public MLOutput trainAndPredict(Input input) { return trainAndPredictable.trainAndPredict(mlInput); } - public Output execute(Input input) throws Exception { + public void execute(Input input, ActionListener listener) throws Exception { validateInput(input); if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) { MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); if (executable == null) { throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } - return executable.execute(input); + executable.execute(input, listener); } else { Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); if (executable == null) { throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } - return executable.execute(input); + executable.execute(input, listener); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index fab052da19..0ae0c44dc7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -17,9 +17,9 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.io.FileUtils; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.exception.ExecuteException; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.model.MLModelFormat; @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable { protected Device[] devices; protected AtomicInteger nextDevice = new AtomicInteger(0); - public abstract Output execute(Input input) throws ExecuteException; + public abstract void execute(Input input, ActionListener listener); protected Predictor getPredictor() { int currentDevice = nextDevice.getAndIncrement(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java index c874c903e4..b11fc9a39c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java @@ -16,13 +16,10 @@ import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.IntStream; -import org.opensearch.action.LatchedActionListener; import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.MultiSearchResponse; import org.opensearch.action.search.SearchRequest; @@ -527,23 +524,10 @@ protected List> getAllIntervals() { } @Override - public Output execute(Input input) { - CountDownLatch latch = new CountDownLatch(1); - AtomicReference outRef = new AtomicReference<>(); - AtomicReference exRef = new AtomicReference<>(); + public void execute(Input input, ActionListener listener) { getLocalizationResults( (AnomalyLocalizationInput) input, - new LatchedActionListener(ActionListener.wrap(o -> outRef.set(o), e -> exRef.set(e)), latch) + ActionListener.wrap(o -> listener.onResponse(o), e -> listener.onFailure(e)) ); - try { - latch.await(); - } catch (InterruptedException e) { - throw new IllegalStateException(e); - } - if (exRef.get() != null) { - throw new RuntimeException(exRef.get()); - } else { - return outRef.get(); - } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index c44704f688..15bccadac0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -107,8 +107,13 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust * contains 3 properties event_window, event_pattern and suspected_metrics * @throws ExecuteException */ + // @Override + // public MetricsCorrelationOutput execute(Input input) throws ExecuteException { + // + // } + @Override - public MetricsCorrelationOutput execute(Input input) throws ExecuteException { + public void execute(Input input, ActionListener listener) { if (!(input instanceof MetricsCorrelationInput)) { throw new ExecuteException("wrong input"); } @@ -148,7 +153,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX).id(FunctionName.METRICS_CORRELATION.name()); - ActionListener listener = ActionListener.wrap(r -> { + ActionListener actionListener = ActionListener.wrap(r -> { if (r.isExists()) { modelId = r.getId(); Map sourceAsMap = r.getSourceAsMap(); @@ -176,7 +181,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ); } }, e -> { log.error("Failed to get model", e); }); - client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); + client.get(getModelRequest, ActionListener.runBefore(actionListener, context::restore)); } } } else { @@ -227,7 +232,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { } tensorOutputs.add(parseModelTensorOutput(djlOutput, null)); - return new MetricsCorrelationOutput(tensorOutputs); + listener.onResponse(new MetricsCorrelationOutput(tensorOutputs)); } @VisibleForTesting