Skip to content

Commit

Permalink
add cancel batch prediction job API for offline inference
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Aug 31, 2024
1 parent 0354a82 commit b6be1f5
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ public class CommonValue {
+ USER_FIELD_MAPPING
+ " }\n"
+ "}"
+ MLTask.TRANSFORM_JOB_FIELD
+ MLTask.REMOTE_JOB_FIELD
+ "\" : {\"type\": \"flat_object\"}\n"
+ " }\n"
+ "}";
Expand Down
26 changes: 13 additions & 13 deletions common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class MLTask implements ToXContentObject, Writeable {
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
public static final String ERROR_FIELD = "error";
public static final String IS_ASYNC_TASK_FIELD = "is_async";
public static final String TRANSFORM_JOB_FIELD = "transform_job";
public static final String REMOTE_JOB_FIELD = "remote_job";
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB = CommonValue.VERSION_2_16_0;

@Setter
Expand All @@ -75,7 +75,7 @@ public class MLTask implements ToXContentObject, Writeable {
private User user; // TODO: support document level access control later
private boolean async;
@Setter
private Map<String, Object> transformJob;
private Map<String, Object> remoteJob;

@Builder(toBuilder = true)
public MLTask(
Expand All @@ -93,7 +93,7 @@ public MLTask(
String error,
User user,
boolean async,
Map<String, Object> transformJob
Map<String, Object> remoteJob
) {
this.taskId = taskId;
this.modelId = modelId;
Expand All @@ -109,7 +109,7 @@ public MLTask(
this.error = error;
this.user = user;
this.async = async;
this.transformJob = transformJob;
this.remoteJob = remoteJob;
}

public MLTask(StreamInput input) throws IOException {
Expand Down Expand Up @@ -139,7 +139,7 @@ public MLTask(StreamInput input) throws IOException {
if (streamInputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
if (input.readBoolean()) {
String mapStr = input.readString();
this.transformJob = gson.fromJson(mapStr, Map.class);
this.remoteJob = gson.fromJson(mapStr, Map.class);
}
}
}
Expand Down Expand Up @@ -171,11 +171,11 @@ public void writeTo(StreamOutput out) throws IOException {
}
out.writeBoolean(async);
if (streamOutputVersion.onOrAfter(MLTask.MINIMAL_SUPPORTED_VERSION_FOR_BATCH_TRANSFORM_JOB)) {
if (transformJob != null) {
if (remoteJob != null) {
out.writeBoolean(true);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
out.writeString(gson.toJson(transformJob));
out.writeString(gson.toJson(remoteJob));
return null;
});
} catch (PrivilegedActionException e) {
Expand Down Expand Up @@ -230,8 +230,8 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
builder.field(USER, user);
}
builder.field(IS_ASYNC_TASK_FIELD, async);
if (transformJob != null) {
builder.field(TRANSFORM_JOB_FIELD, transformJob);
if (remoteJob != null) {
builder.field(REMOTE_JOB_FIELD, remoteJob);
}
return builder.endObject();
}
Expand All @@ -256,7 +256,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
String error = null;
User user = null;
boolean async = false;
Map<String, Object> transformJob = null;
Map<String, Object> remoteJob = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -314,8 +314,8 @@ public static MLTask parse(XContentParser parser) throws IOException {
case IS_ASYNC_TASK_FIELD:
async = parser.booleanValue();
break;
case TRANSFORM_JOB_FIELD:
transformJob = parser.map();
case REMOTE_JOB_FIELD:
remoteJob = parser.map();
break;
default:
parser.skipChildren();
Expand All @@ -338,7 +338,7 @@ public static MLTask parse(XContentParser parser) throws IOException {
.error(error)
.user(user)
.async(async)
.transformJob(transformJob)
.remoteJob(remoteJob)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ public enum ActionType {
PREDICT,
EXECUTE,
BATCH_PREDICT,
CANCEL_BATCH,
BATCH_STATUS;
CANCEL_BATCH_PREDICT,
BATCH_PREDICT_STATUS;

public static ActionType from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public ModelTensors(List<ModelTensor> mlModelTensors) {
this.mlModelTensors = mlModelTensors;
}

@Builder
public ModelTensors(Integer statusCode) {
this.statusCode = statusCode;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import org.opensearch.action.ActionType;

public class MLCancelBatchJobAction extends ActionType<MLCancelBatchJobResponse> {
public static final MLCancelBatchJobAction INSTANCE = new MLCancelBatchJobAction();
public static final String NAME = "cluster:admin/opensearch/ml/tasks/cancel_batch_job";

private MLCancelBatchJobAction() {
super(NAME, MLCancelBatchJobResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import lombok.Builder;
import lombok.Getter;

public class MLCancelBatchJobRequest extends ActionRequest {
@Getter
String taskId;

@Builder
public MLCancelBatchJobRequest(String taskId) {
this.taskId = taskId;
}

public MLCancelBatchJobRequest(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.taskId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.taskId == null) {
exception = addValidationError("ML task id can't be null", exception);
}

return exception;
}

public static MLCancelBatchJobRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLCancelBatchJobRequest) {
return (MLCancelBatchJobRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCancelBatchJobRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLCancelBatchJobRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.task;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import lombok.Builder;
import lombok.Getter;

@Getter
public class MLCancelBatchJobResponse extends ActionResponse implements ToXContentObject {

RestStatus status;

@Builder
public MLCancelBatchJobResponse(RestStatus status) {
this.status = status;
}

public MLCancelBatchJobResponse(StreamInput in) throws IOException {
super(in);
status = in.readEnum(RestStatus.class);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(status);
}

public static MLCancelBatchJobResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLCancelBatchJobResponse) {
return (MLCancelBatchJobResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLCancelBatchJobResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLTaskGetResponse", e);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return xContentBuilder.startObject().field("status", status).endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

import java.nio.ByteBuffer;
Expand Down Expand Up @@ -169,13 +170,14 @@ public void onComplete() {
}

private void response() {
String body = responseBody.toString();

if (exceptionHolder.get() != null) {
actionListener.onFailure(exceptionHolder.get());
return;
}

String body = responseBody.toString();
if (Strings.isBlank(body)) {
if (Strings.isBlank(body) && !action.equals(CANCEL_BATCH_PREDICT.toString())) {
log.error("Remote model response body is empty!");
actionListener.onFailure(new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST));
return;
Expand All @@ -187,6 +189,13 @@ private void response() {
return;
}

if (action.equals(CANCEL_BATCH_PREDICT.toString())) {
ModelTensors tensors = ModelTensors.builder().statusCode(statusCode).build();
tensors.setStatusCode(statusCode);
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
return;
}

try {
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
tensors.setStatusCode(statusCode);
Expand Down
Loading

0 comments on commit b6be1f5

Please sign in to comment.