From 09ac78cd0d0e837fea70ad20cb9975a83acf4f37 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 21 Mar 2023 18:57:46 -0700 Subject: [PATCH] add allow custom deployment plan setting; add deploy to all nodes field in model index Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 5 +- .../org/opensearch/ml/common/MLModel.java | 16 +++- .../action/load/TransportLoadModelAction.java | 26 ++++-- .../opensearch/ml/cluster/MLSyncUpCron.java | 50 +++++++++-- .../ml/plugin/MachineLearningPlugin.java | 3 +- .../ml/settings/MLCommonsSettings.java | 2 + .../load/TransportLoadModelActionTests.java | 82 +++++++++++++++---- 7 files changed, 154 insertions(+), 30 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 9b4a35e8ae..81367b9219 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -30,7 +30,7 @@ public class CommonValue { public static final String ML_MODEL_INDEX = ".plugins-ml-model"; public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER @@ -94,6 +94,9 @@ public class CommonValue { + MLModel.PLANNING_WORKER_NODES_FIELD + "\": {\"type\": \"keyword\"},\n" + " \"" + + MLModel.DEPLOY_TO_ALL_NODES_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + MLModel.MODEL_CONFIG_FIELD + "\" : {\"properties\":{\"" + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index b1a8451be9..5f04a11a50 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -57,6 +57,7 @@ public class MLModel implements ToXContentObject { public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count"; public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count"; public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes"; + public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes"; private String name; private FunctionName algorithm; @@ -85,6 +86,7 @@ public class MLModel implements ToXContentObject { private Integer currentWorkerNodeCount; // model is deployed to how many nodes private String[] planningWorkerNodes; // plan to deploy model to these nodes + private boolean deployToAllNodes; @Builder(toBuilder = true) public MLModel(String name, FunctionName algorithm, @@ -106,7 +108,8 @@ public MLModel(String name, Integer totalChunks, Integer planningWorkerNodeCount, Integer currentWorkerNodeCount, - String[] planningWorkerNodes) { + String[] planningWorkerNodes, + boolean deployToAllNodes) { this.name = name; this.algorithm = algorithm; this.version = version; @@ -129,6 +132,7 @@ public MLModel(String name, this.planningWorkerNodeCount = planningWorkerNodeCount; this.currentWorkerNodeCount = currentWorkerNodeCount; this.planningWorkerNodes = planningWorkerNodes; + this.deployToAllNodes = deployToAllNodes; } public MLModel(StreamInput input) throws IOException{ @@ -165,6 +169,7 @@ public MLModel(StreamInput input) throws IOException{ planningWorkerNodeCount = input.readOptionalInt(); currentWorkerNodeCount = input.readOptionalInt(); planningWorkerNodes = input.readOptionalStringArray(); + deployToAllNodes = input.readBoolean(); } } @@ -211,6 +216,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalInt(planningWorkerNodeCount); out.writeOptionalInt(currentWorkerNodeCount); out.writeOptionalStringArray(planningWorkerNodes); + out.writeBoolean(deployToAllNodes); } @Override @@ -282,6 +288,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (planningWorkerNodes != null && planningWorkerNodes.length > 0) { builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes); } + if (deployToAllNodes) { + builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes); + } builder.endObject(); return builder; } @@ -312,6 +321,7 @@ public static MLModel parse(XContentParser parser) throws IOException { Integer planningWorkerNodeCount = null; Integer currentWorkerNodeCount = null; List planningWorkerNodes = new ArrayList<>(); + boolean deployToAllNodes = false; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -379,6 +389,9 @@ public static MLModel parse(XContentParser parser) throws IOException { planningWorkerNodes.add(parser.text()); } break; + case DEPLOY_TO_ALL_NODES_FIELD: + deployToAllNodes = parser.booleanValue(); + break; case CREATED_TIME_FIELD: createdTime = Instant.ofEpochMilli(parser.longValue()); break; @@ -422,6 +435,7 @@ public static MLModel parse(XContentParser parser) throws IOException { .planningWorkerNodeCount(planningWorkerNodeCount) .currentWorkerNodeCount(currentWorkerNodeCount) .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) + .deployToAllNodes(deployToAllNodes) .build(); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java index e1fe9fa218..861194faf8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; import java.time.Instant; @@ -31,6 +32,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; @@ -75,6 +77,8 @@ public class TransportLoadModelAction extends HandledTransportAction allowCustomDeploymentPlan = it); } @Override @@ -109,6 +118,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener eligibleNodes = new ArrayList<>(); List nodeIds = new ArrayList<>(); - if (targetNodeIds != null && targetNodeIds.length > 0) { + if (!deployToAllNodes) { for (String nodeId : targetNodeIds) { if (allEligibleNodeIds.contains(nodeId)) { eligibleNodes.add(nodeMapping.get(nodeId)); @@ -189,7 +203,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener eligibleNodes, - FunctionName algorithm + boolean deployToAllNodes ) { LoadModelInput loadModelInput = new LoadModelInput( modelId, @@ -264,7 +278,9 @@ void updateModelLoadStatusAndTriggerOnNodesAction( MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, eligibleNodes.size(), MLModel.PLANNING_WORKER_NODES_FIELD, - workerNodes + workerNodes, + MLModel.DEPLOY_TO_ALL_NODES_FIELD, + deployToAllNodes ), ActionListener .wrap( diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 55ac6d7b7d..36912d0fc4 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.time.Instant; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -15,6 +16,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.Semaphore; +import java.util.stream.Collectors; import lombok.extern.log4j.Log4j2; @@ -174,6 +176,8 @@ void refreshModelState(Map> modelWorkerNodes, Map> modelWorkerNodes, Map { SearchHit[] hits = res.getHits().getHits(); Map newModelStates = new HashMap<>(); + Map> newPlanningWorkerNodes = new HashMap<>(); for (SearchHit hit : hits) { String modelId = hit.getId(); Map sourceAsMap = hit.getSourceAsMap(); @@ -196,6 +201,24 @@ void refreshModelState(Map> modelWorkerNodes, Map planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD) + ? (List) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD) + : new ArrayList<>(); + if (deployToAllNodes) { + DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(); + planningWorkerNodeCount = eligibleNodes.length; + List eligibleNodeIds = Arrays + .asList(eligibleNodes) + .stream() + .map(n -> n.getId()) + .collect(Collectors.toList()); + if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) { + newPlanningWorkerNodes.put(modelId, eligibleNodeIds); + } + } MLModelState mlModelState = getNewModelState( loadingModels, modelWorkerNodes, @@ -209,7 +232,7 @@ void refreshModelState(Map> modelWorkerNodes, Map { updateModelStateSemaphore.release(); log.error("Failed to search models", e); @@ -270,16 +293,29 @@ private MLModelState getNewModelState( return null; } - private void bulkUpdateModelState(Map> modelWorkerNodes, Map newModelStates) { - if (newModelStates.size() > 0) { + private void bulkUpdateModelState( + Map> modelWorkerNodes, + Map newModelStates, + Map> newPlanningWorkNodes + ) { + Set updatedModelIds = new HashSet<>(); + updatedModelIds.addAll(newModelStates.keySet()); + updatedModelIds.addAll(newPlanningWorkNodes.keySet()); + + if (updatedModelIds.size() > 0) { BulkRequest bulkUpdateRequest = new BulkRequest(); - for (String modelId : newModelStates.keySet()) { + for (String modelId : updatedModelIds) { UpdateRequest updateRequest = new UpdateRequest(); Instant now = Instant.now(); ImmutableMap.Builder builder = ImmutableMap.builder(); - builder - .put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name()) - .put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli()); + if (newModelStates.containsKey(modelId)) { + builder.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name()); + } + if (newPlanningWorkNodes.containsKey(modelId)) { + builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkNodes.get(modelId)); + builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkNodes.get(modelId).size()); + } + builder.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli()); Set workerNodes = modelWorkerNodes.get(modelId); int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size(); builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index fb97ffc36e..6dec71f722 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -515,7 +515,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD, - MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES + MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES, + MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index f9b69872e0..0a15c840fd 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -59,4 +59,6 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_EXCLUDE_NODE_NAMES = Setting .simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting + .boolSetting("plugins.ml_commons.allow_custom_deployment_plan", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java index c1ebeff960..2a24aaecad 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/load/TransportLoadModelActionTests.java @@ -7,17 +7,20 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import java.lang.reflect.Field; import java.nio.file.Path; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; -import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -31,6 +34,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -87,13 +91,18 @@ public class TransportLoadModelActionTests extends OpenSearchTestCase { @Mock private MLLoadModelRequest mlLoadModelRequest; - @InjectMocks private TransportLoadModelAction transportLoadModelAction; @Mock private ExecutorService executorService; @Mock MLTask mlTask; + @Mock + MLTaskDispatcher mlTaskDispatcher; + @Mock + NamedXContentRegistry namedXContentRegistry; + private Settings settings; + private ClusterSettings clusterSettings; private final String modelId = "mock_model_id"; private final MLModel mlModel = mock(MLModel.class); private final String localNodeId = "mockNodeId"; @@ -102,9 +111,16 @@ public class TransportLoadModelActionTests extends OpenSearchTestCase { private final List eligibleNodes = mock(List.class); + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Before public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings.builder().put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), true).build(); + clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10))); modelHelper = new ModelHelper(mlEngine); when(mlLoadModelRequest.getModelId()).thenReturn("mockModelId"); @@ -125,6 +141,21 @@ public void setup() { MLStat mlStat = mock(MLStat.class); when(mlStats.getStat(eq(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT))).thenReturn(mlStat); + transportLoadModelAction = new TransportLoadModelAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + clusterService, + threadPool, + client, + namedXContentRegistry, + nodeFilter, + mlTaskDispatcher, + mlModelManager, + mlStats, + settings + ); } public void testDoExecute_success() { @@ -149,6 +180,34 @@ public void testDoExecute_success() { verify(loadModelResponseListener).onResponse(any(LoadModelResponse.class)); } + public void testDoExecute_DoNotAllowCustomDeploymentPlan() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Don't allow custom deployment plan"); + Settings settings = Settings.builder().put(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.getKey(), false).build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN)) + ); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + TransportLoadModelAction transportLoadModelAction = new TransportLoadModelAction( + transportService, + actionFilters, + modelHelper, + mlTaskManager, + clusterService, + threadPool, + client, + namedXContentRegistry, + nodeFilter, + mlTaskDispatcher, + mlModelManager, + mlStats, + settings + ); + + transportLoadModelAction.doExecute(mock(Task.class), mlLoadModelRequest, mock(ActionListener.class)); + } + public void testDoExecute_whenLoadModelRequestNodeIdsEmpty_thenMLResourceNotFoundException() { DiscoveryNodeHelper nodeHelper = mock(DiscoveryNodeHelper.class); when(nodeHelper.getEligibleNodes()).thenReturn(new DiscoveryNode[] {}); @@ -161,11 +220,12 @@ public void testDoExecute_whenLoadModelRequestNodeIdsEmpty_thenMLResourceNotFoun clusterService, threadPool, client, - mock(NamedXContentRegistry.class), + namedXContentRegistry, nodeHelper, - mock(MLTaskDispatcher.class), + mlTaskDispatcher, mlModelManager, - mlStats + mlStats, + settings ) ); MLLoadModelRequest mlLoadModelRequest1 = mock(MLLoadModelRequest.class); @@ -243,7 +303,7 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No localNodeId, mlTask, Arrays.asList(discoveryNode), - FunctionName.ANOMALY_LOCALIZATION + true ); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); @@ -257,15 +317,7 @@ public void testUpdateModelLoadStatusAndTriggerOnNodesAction_success() throws No public void testUpdateModelLoadStatusAndTriggerOnNodesAction_whenMLTaskManagerThrowException_ListenerOnFailureExecuted() { doCallRealMethod().when(mlModelManager).updateModel(anyString(), any(ImmutableMap.class), isA(ActionListener.class)); transportLoadModelAction - .updateModelLoadStatusAndTriggerOnNodesAction( - modelId, - "mock_task_id", - mlModel, - localNodeId, - mlTask, - eligibleNodes, - FunctionName.TEXT_EMBEDDING - ); + .updateModelLoadStatusAndTriggerOnNodesAction(modelId, "mock_task_id", mlModel, localNodeId, mlTask, eligibleNodes, false); verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); }