Skip to content

Commit

Permalink
fine tune execute api
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 19, 2023
1 parent fee3c7b commit 1b42c0f
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import lombok.extern.log4j.Log4j2;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.ml.common.annotation.Connector;
import org.opensearch.ml.common.annotation.ExecuteInput;
import org.opensearch.ml.common.annotation.ExecuteOutput;
Expand All @@ -15,9 +16,11 @@
import org.opensearch.ml.common.annotation.MLInput;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.reflections.Reflections;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
Expand Down Expand Up @@ -203,12 +206,27 @@ public static <T extends Enum<T>, S, I extends Object> S initMLInstance(T type,

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteInputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeInputClassMap, type, in, constructorParamClass);
try {
return init(executeInputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
return init(mlInputClassMap, type, in, constructorParamClass);
}
}

@SuppressWarnings("unchecked")
public static <T extends Enum<T>, S, I extends Object> S initExecuteOutputInstance(T type, I in, Class<?> constructorParamClass) {
return init(executeOutputClassMap, type, in, constructorParamClass);
try {
return init(executeOutputClassMap, type, in, constructorParamClass);
} catch (Exception e) {
if (in instanceof StreamInput) {
try {
return (S) MLOutput.fromStream((StreamInput) in);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
throw e;
}
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.Output;
Expand All @@ -16,5 +17,5 @@ public interface Executable {
* @param input input data
* @return execution result
*/
Output execute(Input input) throws ExecuteException;
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Locale;
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand Down Expand Up @@ -152,20 +153,20 @@ public MLOutput trainAndPredict(Input input) {
return trainAndPredictable.trainAndPredict(mlInput);
}

public Output execute(Input input) throws Exception {
public void execute(Input input, ActionListener<Output> listener) throws Exception {
validateInput(input);
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
} else {
Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
return executable.execute(input);
executable.execute(input, listener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.io.FileUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -52,7 +52,7 @@ public abstract class DLModelExecute implements MLExecutable {
protected Device[] devices;
protected AtomicInteger nextDevice = new AtomicInteger(0);

public abstract Output execute(Input input) throws ExecuteException;
public abstract void execute(Input input, ActionListener<Output> listener);

protected Predictor<float[][], ai.djl.modality.Output> getPredictor() {
int currentDevice = nextDevice.getAndIncrement();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.MultiSearchRequest;
import org.opensearch.action.search.MultiSearchResponse;
import org.opensearch.action.search.SearchRequest;
Expand Down Expand Up @@ -527,23 +524,10 @@ protected List<Map.Entry<Long, Long>> getAllIntervals() {
}

@Override
public Output execute(Input input) {
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<AnomalyLocalizationOutput> outRef = new AtomicReference<>();
AtomicReference<Exception> exRef = new AtomicReference<>();
public void execute(Input input, ActionListener<Output> listener) {
getLocalizationResults(
(AnomalyLocalizationInput) input,
new LatchedActionListener(ActionListener.<AnomalyLocalizationOutput>wrap(o -> outRef.set(o), e -> exRef.set(e)), latch)
ActionListener.wrap(o -> listener.onResponse(o), e -> listener.onFailure(e))
);
try {
latch.await();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (exRef.get() != null) {
throw new RuntimeException(exRef.get());
} else {
return outRef.get();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@ public MetricsCorrelation(Client client, Settings settings, ClusterService clust
* contains 3 properties event_window, event_pattern and suspected_metrics
* @throws ExecuteException
*/
// @Override
// public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
//
// }

@Override
public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
public void execute(Input input, ActionListener<org.opensearch.ml.common.output.Output> listener) {
if (!(input instanceof MetricsCorrelationInput)) {
throw new ExecuteException("wrong input");
}
Expand Down Expand Up @@ -148,7 +153,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getModelRequest = new GetRequest(ML_MODEL_INDEX).id(FunctionName.METRICS_CORRELATION.name());
ActionListener<GetResponse> listener = ActionListener.wrap(r -> {
ActionListener<GetResponse> actionListener = ActionListener.wrap(r -> {
if (r.isExists()) {
modelId = r.getId();
Map<String, Object> sourceAsMap = r.getSourceAsMap();
Expand Down Expand Up @@ -176,7 +181,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
);
}
}, e -> { log.error("Failed to get model", e); });
client.get(getModelRequest, ActionListener.runBefore(listener, context::restore));
client.get(getModelRequest, ActionListener.runBefore(actionListener, context::restore));
}
}
} else {
Expand Down Expand Up @@ -227,7 +232,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
}

tensorOutputs.add(parseModelTensorOutput(djlOutput, null));
return new MetricsCorrelationOutput(tensorOutputs);
listener.onResponse(new MetricsCorrelationOutput(tensorOutputs));
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
Expand All @@ -36,7 +37,7 @@ public LocalSampleCalculator(Client client, Settings settings) {
}

@Override
public Output execute(Input input) {
public void execute(Input input, ActionListener<Output> listener) {
if (input == null || !(input instanceof LocalSampleCalculatorInput)) {
throw new IllegalArgumentException("wrong input");
}
Expand All @@ -46,13 +47,16 @@ public Output execute(Input input) {
switch (operation) {
case "sum":
double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum();
return new LocalSampleCalculatorOutput(sum);
listener.onResponse(new LocalSampleCalculatorOutput(sum));
break;
case "max":
double max = inputData.stream().max(Comparator.naturalOrder()).get();
return new LocalSampleCalculatorOutput(max);
listener.onResponse(new LocalSampleCalculatorOutput(max));
break;
case "min":
double min = inputData.stream().min(Comparator.naturalOrder()).get();
return new LocalSampleCalculatorOutput(min);
listener.onResponse(new LocalSampleCalculatorOutput(min));
break;
default:
throw new IllegalArgumentException("can't support this operation");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

package org.opensearch.ml.engine;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;

import java.util.ArrayList;
Expand All @@ -17,9 +16,11 @@
import org.junit.Test;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator;

Expand All @@ -43,19 +44,27 @@ public void initInstance_LocalSampleCalculator() {

// set properties
MLEngineClassLoader.deregister(FunctionName.LOCAL_SAMPLE_CALCULATOR);
LocalSampleCalculator instance = MLEngineClassLoader
final LocalSampleCalculator instance = MLEngineClassLoader
.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class, properties);
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) instance.execute(input);
assertEquals(d1 + d2, output.getResult(), 1e-6);
assertEquals(client, instance.getClient());
assertEquals(settings, instance.getSettings());

ActionListener<Output> actionListener = ActionListener.wrap(o -> {
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o;
assertEquals(d1 + d2, output.getResult(), 1e-6);
assertEquals(client, instance.getClient());
assertEquals(settings, instance.getSettings());
}, e -> { fail("Test failed: " + e.getMessage()); });

instance.execute(input, actionListener);

// don't set properties
instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class);
output = (LocalSampleCalculatorOutput) instance.execute(input);
assertEquals(d1 + d2, output.getResult(), 1e-6);
assertNull(instance.getClient());
assertNull(instance.getSettings());
final LocalSampleCalculator instance2 = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class);
ActionListener<Output> actionListener2 = ActionListener.wrap(o -> {
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o;
assertEquals(d1 + d2, output.getResult(), 1e-6);
assertNull(instance2.getClient());
assertNull(instance2.getSettings());
}, e -> { fail("Test failed: " + e.getMessage()); });
instance2.execute(input, actionListener2);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.ml.engine;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame;
import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionTrainDataFrame;
import static org.opensearch.ml.engine.helper.MLTestHelper.constructTestDataFrame;
Expand All @@ -23,6 +25,7 @@
import org.junit.rules.ExpectedException;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
Expand All @@ -40,6 +43,7 @@
import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput;
import org.opensearch.ml.engine.algorithms.regression.LinearRegression;
import org.opensearch.ml.engine.encryptor.Encryptor;
Expand Down Expand Up @@ -265,8 +269,12 @@ public void trainAndPredictWithInvalidInput() {
@Test
public void executeLocalSampleCalculator() throws Exception {
Input input = new LocalSampleCalculatorInput("sum", Arrays.asList(1.0, 2.0));
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) mlEngine.execute(input);
assertEquals(3.0, output.getResult(), 1e-5);
ActionListener<Output> actionListener = ActionListener.wrap(o -> {
LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o;
assertEquals(3.0, output.getResult(), 1e-5);
}, e -> { fail("Test failed: " + e.getMessage()); });
mlEngine.execute(input, actionListener);

}

@Test
Expand All @@ -289,7 +297,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
return null;
}
};
mlEngine.execute(input);
ActionListener<Output> actionListener = mock(ActionListener.class);
mlEngine.execute(input, actionListener);
}

private MLModel trainKMeansModel() {
Expand Down
Loading

0 comments on commit 1b42c0f

Please sign in to comment.