Skip to content

Commit

Permalink
add more ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 19, 2023
1 parent fee3c7b commit 6caee78
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import lombok.extern.log4j.Log4j2;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
Expand All @@ -15,9 +16,11 @@
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
Expand Down Expand Up @@ -203,12 +206,27 @@ public static <T extends Enum<T>, S, I extends Object> S initMLInstance(T type,

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteInputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeInputClassMap, type, in, constructorParamClass);
try {
return init(executeInputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
return init(mlInputClassMap, type, in, constructorParamClass);
}
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteOutputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeOutputClassMap, type, in, constructorParamClass);
try {
return init(executeOutputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
if (in instanceof StreamInput) {
try {
return (S) MLOutput.fromStream((StreamInput) in);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
throw e;
}
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.Output;
Expand All @@ -16,5 +17,5 @@ public interface Executable {
* @param input input data
* @return execution result
*/
Output execute(Input input) throws ExecuteException;
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand Down Expand Up @@ -152,20 +153,20 @@ public MLOutput trainAndPredict(Input input) {
return trainAndPredictable.trainAndPredict(mlInput);
}

public Output execute(Input input) throws Exception {
public void execute(Input input, ActionListener<Output> listener) throws Exception {
validateInput(input);
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
} else {
Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.io.FileUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

public abstract Output execute(Input input) throws ExecuteException;
public abstract void execute(Input input, ActionListener<Output> listener);

protected Predictor<float[][], ai.djl.modality.Output> getPredictor() {
int currentDevice = nextDevice.getAndIncrement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
import org.opensearch.action.search.SearchRequest;
Expand Down Expand Up @@ -527,23 +524,10 @@ protected List<Map.Entry<Long, Long>> getAllIntervals() {
}

@Override
public Output execute(Input input) {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AnomalyLocalizationOutput> outRef = new AtomicReference<>();
AtomicReference<Exception> exRef = new AtomicReference<>();
public void execute(Input input, ActionListener<Output> listener) {
getLocalizationResults(
(AnomalyLocalizationInput) input,
new LatchedActionListener(ActionListener.<AnomalyLocalizationOutput>wrap(o -> outRef.set(o), e -> exRef.set(e)), latch)
ActionListener.wrap(o -> listener.onResponse(o), e -> listener.onFailure(e))
);
try {
latch.await();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (exRef.get() != null) {
throw new RuntimeException(exRef.get());
} else {
return outRef.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
* contains 3 properties event_window, event_pattern and suspected_metrics
* @throws ExecuteException
*/
// @Override
// public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
//
// }

@Override
public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
if (!(input instanceof MetricsCorrelationInput)) {
throw new ExecuteException("wrong input");
}
Expand Down Expand Up @@ -148,7 +153,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX).id(FunctionName.METRICS_CORRELATION.name());
ActionListener<GetResponse> listener = ActionListener.wrap(r -> {
ActionListener<GetResponse> actionListener = ActionListener.wrap(r -> {
if (r.isExists()) {
modelId = r.getId();
Map<String, Object> sourceAsMap = r.getSourceAsMap();
Expand Down Expand Up @@ -176,7 +181,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
);
}
}, e -> { log.error("Failed to get model", e); });
client.get(getModelRequest, ActionListener.runBefore(listener, context::restore));
client.get(getModelRequest, ActionListener.runBefore(actionListener, context::restore));
}
}
} else {
Expand Down Expand Up @@ -227,7 +232,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
}

tensorOutputs.add(parseModelTensorOutput(djlOutput, null));
return new MetricsCorrelationOutput(tensorOutputs);
listener.onResponse(new MetricsCorrelationOutput(tensorOutputs));
}

@VisibleForTesting
Expand Down

0 comments on commit 6caee78

Please sign in to comment.