Skip to content

Commit

Permalink
Expose ML Config API (#2850)
Browse files Browse the repository at this point in the history
* Expose ML Config API

Signed-off-by: Ashish Agrawal <[email protected]>

* Add tests for rejected master key

Signed-off-by: Ashish Agrawal <[email protected]>

---------

Signed-off-by: Ashish Agrawal <[email protected]>
(cherry picked from commit 05eb53f)
  • Loading branch information
lezzago authored and github-actions[bot] committed Aug 28, 2024
1 parent 71ff4e2 commit 037efbb
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand Down Expand Up @@ -428,4 +429,20 @@ default ActionFuture<ToolMetadata> getTool(String toolName) {
*/
void getTool(String toolName, ActionListener<ToolMetadata> listener);

/**
* Get config
* @param configId ML config id
*/
default ActionFuture<MLConfig> getConfig(String configId) {
PlainActionFuture<MLConfig> actionFuture = PlainActionFuture.newFuture();
getConfig(configId, actionFuture);
return actionFuture;
}

/**
* Get config
* @param configId ML config id
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, ActionListener<MLConfig> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand All @@ -39,6 +40,9 @@
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
Expand Down Expand Up @@ -309,6 +313,13 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, getMlGetToolResponseActionListener(listener));
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();

client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}

private ActionListener<MLToolsListResponse> getMlListToolsResponseActionListener(ActionListener<List<ToolMetadata>> listener) {
ActionListener<MLToolsListResponse> internalListener = ActionListener.wrap(mlModelListResponse -> {
listener.onResponse(mlModelListResponse.getToolMetadataList());
Expand All @@ -331,6 +342,17 @@ private ActionListener<MLToolGetResponse> getMlGetToolResponseActionListener(Act
return actionListener;
}

private ActionListener<MLConfigGetResponse> getMlGetConfigResponseActionListener(ActionListener<MLConfig> listener) {
ActionListener<MLConfigGetResponse> internalListener = ActionListener.wrap(mlConfigGetResponse -> {
listener.onResponse(mlConfigGetResponse.getMlConfig());
}, listener::onFailure);
ActionListener<MLConfigGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLConfigGetResponse getResponse = MLConfigGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.ml.common.input.Constants.KMEANS;
import static org.opensearch.ml.common.input.Constants.TRAIN;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -28,8 +29,10 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.Configuration;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand All @@ -46,6 +49,7 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand Down Expand Up @@ -99,9 +103,13 @@ public class MachineLearningClientTest {
@Mock
MLRegisterAgentResponse registerAgentResponse;

@Mock
MLConfigGetResponse configGetResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
private MLConfig mlConfig;
private ToolMetadata toolMetadata;
private List<ToolMetadata> toolsList = new ArrayList<>();

Expand All @@ -124,6 +132,14 @@ public void setUp() {
.build();
toolsList.add(toolMetadata);

mlConfig = MLConfig
.builder()
.type("dummyType")
.configuration(Configuration.builder().agentId("agentId").build())
.createTime(Instant.now())
.lastUpdateTime(Instant.now())
.build();

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand Down Expand Up @@ -231,6 +247,11 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}
};
}

Expand Down Expand Up @@ -503,4 +524,9 @@ public void getTool() {
public void listTools() {
assertEquals(toolMetadata, machineLearningClient.listTools().actionGet().get(0));
}

@Test
public void getConfig() {
assertEquals(mlConfig, machineLearningClient.getConfig("configId").actionGet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.input.Constants.ACTION;
import static org.opensearch.ml.common.input.Constants.ALGORITHM;
import static org.opensearch.ml.common.input.Constants.KMEANS;
Expand All @@ -40,6 +41,7 @@
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -51,12 +53,15 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.Configuration;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLConfig;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -84,6 +89,9 @@
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.config.MLConfigGetAction;
import org.opensearch.ml.common.transport.config.MLConfigGetRequest;
import org.opensearch.ml.common.transport.config.MLConfigGetResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
Expand Down Expand Up @@ -206,6 +214,9 @@ public class MachineLearningNodeClientTest {
@Mock
ActionListener<ToolMetadata> getToolActionListener;

@Mock
ActionListener<MLConfig> getMlConfigListener;

@InjectMocks
MachineLearningNodeClient machineLearningNodeClient;

Expand Down Expand Up @@ -951,6 +962,43 @@ public void listTools() {
assertEquals("Use this tool to search general knowledge on wikipedia.", argumentCaptor.getValue().get(0).getDescription());
}

@Test
public void getConfig() {
MLConfig mlConfig = MLConfig.builder().type("type").configuration(Configuration.builder().agentId("agentId").build()).build();

doAnswer(invocation -> {
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
MLConfigGetResponse output = MLConfigGetResponse.builder().mlConfig(mlConfig).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLConfig> argumentCaptor = ArgumentCaptor.forClass(MLConfig.class);
machineLearningNodeClient.getConfig("agentId", getMlConfigListener);

verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
verify(getMlConfigListener).onResponse(argumentCaptor.capture());
assertEquals("agentId", argumentCaptor.getValue().getConfiguration().getAgentId());
assertEquals("type", argumentCaptor.getValue().getType());
}

@Test
public void getConfigRejectedMasterKey() {
doAnswer(invocation -> {
ActionListener<MLConfigGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN));
return null;
}).when(client).execute(eq(MLConfigGetAction.INSTANCE), any(), any());

ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
machineLearningNodeClient.getConfig(MASTER_KEY, getMlConfigListener);

verify(client).execute(eq(MLConfigGetAction.INSTANCE), isA(MLConfigGetRequest.class), any());
verify(getMlConfigListener).onFailure(argumentCaptor.capture());
assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status());
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import org.opensearch.ml.common.MLConfig;

import lombok.Builder;
import lombok.Getter;

@Getter
public class MLConfigGetResponse extends ActionResponse implements ToXContentObject {
MLConfig mlConfig;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.action.config;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

Expand Down Expand Up @@ -58,6 +59,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConf
String configId = mlConfigGetRequest.getConfigId();
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(configId);

if (configId.equals(MASTER_KEY)) {
actionListener.onFailure(new OpenSearchStatusException("You are not allowed to access this config doc", RestStatus.FORBIDDEN));
return;
}

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
log.debug("Completed Get Agent Request, id:{}", configId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

package org.opensearch.ml.action.config;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;

import java.io.IOException;
import java.time.Instant;
Expand All @@ -22,6 +24,7 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
Expand All @@ -30,6 +33,7 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -168,4 +172,27 @@ public GetResponse prepareMLConfig(String configID) throws IOException {
GetResponse getResponse = new GetResponse(getResult);
return getResponse;
}

@Test
public void testDoExecute_Rejected_MASTER_KEY() throws IOException {
String configID = MASTER_KEY;
GetResponse getResponse = prepareMLConfig(configID);
ActionListener<MLConfigGetResponse> actionListener = mock(ActionListener.class);
MLConfigGetRequest request = new MLConfigGetRequest(configID);
Task task = mock(Task.class);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);

getConfigTransportAction.doExecute(task, request, actionListener);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(RestStatus.FORBIDDEN, argumentCaptor.getValue().status());
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());

}
}

0 comments on commit 037efbb

Please sign in to comment.