Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add allow custom deployment plan setting; add deploy to all nodes field in model index #818

Merged
merged 1 commit into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public class CommonValue {

public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 3;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 4;
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down Expand Up @@ -94,6 +94,9 @@ public class CommonValue {
+ MLModel.PLANNING_WORKER_NODES_FIELD
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.DEPLOY_TO_ALL_NODES_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
16 changes: 15 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public class MLModel implements ToXContentObject {
public static final String PLANNING_WORKER_NODE_COUNT_FIELD = "planning_worker_node_count";
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

private String name;
private FunctionName algorithm;
Expand Down Expand Up @@ -85,6 +86,7 @@ public class MLModel implements ToXContentObject {
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;
@Builder(toBuilder = true)
public MLModel(String name,
FunctionName algorithm,
Expand All @@ -106,7 +108,8 @@ public MLModel(String name,
Integer totalChunks,
Integer planningWorkerNodeCount,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes) {
String[] planningWorkerNodes,
boolean deployToAllNodes) {
this.name = name;
this.algorithm = algorithm;
this.version = version;
Expand All @@ -129,6 +132,7 @@ public MLModel(String name,
this.planningWorkerNodeCount = planningWorkerNodeCount;
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
}

public MLModel(StreamInput input) throws IOException{
Expand Down Expand Up @@ -165,6 +169,7 @@ public MLModel(StreamInput input) throws IOException{
planningWorkerNodeCount = input.readOptionalInt();
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
}
}

Expand Down Expand Up @@ -211,6 +216,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(planningWorkerNodeCount);
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
}

@Override
Expand Down Expand Up @@ -282,6 +288,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (planningWorkerNodes != null && planningWorkerNodes.length > 0) {
builder.field(PLANNING_WORKER_NODES_FIELD, planningWorkerNodes);
}
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -312,6 +321,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
Integer planningWorkerNodeCount = null;
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -379,6 +389,9 @@ public static MLModel parse(XContentParser parser) throws IOException {
planningWorkerNodes.add(parser.text());
}
break;
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case CREATED_TIME_FIELD:
createdTime = Instant.ofEpochMilli(parser.longValue());
break;
Expand Down Expand Up @@ -422,6 +435,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
.planningWorkerNodeCount(planningWorkerNodeCount)
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.LOAD_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;

import java.time.Instant;
Expand All @@ -31,6 +32,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
Expand Down Expand Up @@ -75,6 +77,8 @@ public class TransportLoadModelAction extends HandledTransportAction<ActionReque
MLModelManager mlModelManager;
MLStats mlStats;

private volatile boolean allowCustomDeploymentPlan;

@Inject
public TransportLoadModelAction(
TransportService transportService,
Expand All @@ -88,7 +92,8 @@ public TransportLoadModelAction(
DiscoveryNodeHelper nodeFilter,
MLTaskDispatcher mlTaskDispatcher,
MLModelManager mlModelManager,
MLStats mlStats
MLStats mlStats,
Settings settings
) {
super(MLLoadModelAction.NAME, transportService, actionFilters, MLLoadModelRequest::new);
this.transportService = transportService;
Expand All @@ -102,13 +107,22 @@ public TransportLoadModelAction(
this.mlTaskDispatcher = mlTaskDispatcher;
this.mlModelManager = mlModelManager;
this.mlStats = mlStats;
allowCustomDeploymentPlan = ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, it -> allowCustomDeploymentPlan = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<LoadModelResponse> listener) {
MLLoadModelRequest deployModelRequest = MLLoadModelRequest.fromActionRequest(request);
String modelId = deployModelRequest.getModelId();
String[] targetNodeIds = deployModelRequest.getModelNodeIds();
boolean deployToAllNodes = targetNodeIds == null || targetNodeIds.length == 0;
if (!allowCustomDeploymentPlan && !deployToAllNodes) {
throw new IllegalArgumentException("Don't allow custom deployment plan");
}

// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes();
Expand All @@ -121,7 +135,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo

List<DiscoveryNode> eligibleNodes = new ArrayList<>();
List<String> nodeIds = new ArrayList<>();
if (targetNodeIds != null && targetNodeIds.length > 0) {
if (!deployToAllNodes) {
for (String nodeId : targetNodeIds) {
if (allEligibleNodeIds.contains(nodeId)) {
eligibleNodes.add(nodeMapping.get(nodeId));
Expand Down Expand Up @@ -189,7 +203,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
localNodeId,
mlTask,
eligibleNodes,
algorithm
deployToAllNodes
)
);
} catch (Exception ex) {
Expand Down Expand Up @@ -226,7 +240,7 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
String localNodeId,
MLTask mlTask,
List<DiscoveryNode> eligibleNodes,
FunctionName algorithm
boolean deployToAllNodes
) {
LoadModelInput loadModelInput = new LoadModelInput(
modelId,
Expand Down Expand Up @@ -264,7 +278,9 @@ void updateModelLoadStatusAndTriggerOnNodesAction(
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes
workerNodes,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
deployToAllNodes
),
ActionListener
.wrap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

Expand Down Expand Up @@ -174,6 +176,8 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
.fetchSource(
new String[] {
MLModel.MODEL_STATE_FIELD,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
MLModel.PLANNING_WORKER_NODES_FIELD,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
MLModel.LAST_UPDATED_TIME_FIELD,
MLModel.CURRENT_WORKER_NODE_COUNT_FIELD },
Expand All @@ -183,6 +187,7 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
client.search(searchRequest, ActionListener.wrap(res -> {
SearchHit[] hits = res.getHits().getHits();
Map<String, MLModelState> newModelStates = new HashMap<>();
Map<String, List<String>> newPlanningWorkerNodes = new HashMap<>();
for (SearchHit hit : hits) {
String modelId = hit.getId();
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
Expand All @@ -196,6 +201,24 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
int currentWorkerNodeCountInIndex = sourceAsMap.containsKey(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
? (int) sourceAsMap.get(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD)
: 0;
boolean deployToAllNodes = sourceAsMap.containsKey(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
? (boolean) sourceAsMap.get(MLModel.DEPLOY_TO_ALL_NODES_FIELD)
: false;
List<String> planningWorkNodes = sourceAsMap.containsKey(MLModel.PLANNING_WORKER_NODES_FIELD)
? (List<String>) sourceAsMap.get(MLModel.PLANNING_WORKER_NODES_FIELD)
: new ArrayList<>();
if (deployToAllNodes) {
DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes();
planningWorkerNodeCount = eligibleNodes.length;
List<String> eligibleNodeIds = Arrays
.asList(eligibleNodes)
.stream()
.map(n -> n.getId())
.collect(Collectors.toList());
if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) {
newPlanningWorkerNodes.put(modelId, eligibleNodeIds);
}
}
MLModelState mlModelState = getNewModelState(
loadingModels,
modelWorkerNodes,
Expand All @@ -209,7 +232,7 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
newModelStates.put(modelId, mlModelState);
}
}
bulkUpdateModelState(modelWorkerNodes, newModelStates);
bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes);
}, e -> {
updateModelStateSemaphore.release();
log.error("Failed to search models", e);
Expand Down Expand Up @@ -270,16 +293,29 @@ private MLModelState getNewModelState(
return null;
}

private void bulkUpdateModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, MLModelState> newModelStates) {
if (newModelStates.size() > 0) {
private void bulkUpdateModelState(
Map<String, Set<String>> modelWorkerNodes,
Map<String, MLModelState> newModelStates,
Map<String, List<String>> newPlanningWorkNodes
) {
Set<String> updatedModelIds = new HashSet<>();
updatedModelIds.addAll(newModelStates.keySet());
updatedModelIds.addAll(newPlanningWorkNodes.keySet());

if (updatedModelIds.size() > 0) {
BulkRequest bulkUpdateRequest = new BulkRequest();
for (String modelId : newModelStates.keySet()) {
for (String modelId : updatedModelIds) {
UpdateRequest updateRequest = new UpdateRequest();
Instant now = Instant.now();
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder
.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name())
.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli());
if (newModelStates.containsKey(modelId)) {
builder.put(MLModel.MODEL_STATE_FIELD, newModelStates.get(modelId).name());
}
if (newPlanningWorkNodes.containsKey(modelId)) {
builder.put(MLModel.PLANNING_WORKER_NODES_FIELD, newPlanningWorkNodes.get(modelId));
builder.put(MLModel.PLANNING_WORKER_NODE_COUNT_FIELD, newPlanningWorkNodes.get(modelId).size());
}
builder.put(MLModel.LAST_UPDATED_TIME_FIELD, now.toEpochMilli());
Set<String> workerNodes = modelWorkerNodes.get(modelId);
int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size();
builder.put(MLModel.CURRENT_WORKER_NODE_COUNT_FIELD, currentWorkNodeCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE,
MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX,
MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD,
MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES
MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES,
MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,6 @@ private MLCommonsSettings() {}

public static final Setting<String> ML_COMMONS_EXCLUDE_NODE_NAMES = Setting
.simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Boolean> ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN = Setting
.boolSetting("plugins.ml_commons.allow_custom_deployment_plan", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Loading