diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index e1fd6445a2..ac73b397e7 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -230,12 +230,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener listener) { MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); - client - .execute( - MLRegisterModelAction.INSTANCE, - registerRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener)); } @Override @@ -266,6 +261,16 @@ private ActionListener getMlPredictionTaskResponseActionListener return actionListener; } + private ActionListener getMLRegisterModelResponseActionListener( + ActionListener listener + ) { + ActionListener actionListener = wrapActionListener(listener, res -> { + MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res); + return registerModelResponse; + }); + return actionListener; + } + private ActionListener wrapActionListener( final ActionListener listener, final Function recreate diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index c7baa9b3a6..18c64c6c5f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -7,13 +7,19 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.MLTaskResponse; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { @@ -61,4 +67,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelResponse) { + return (MLRegisterModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e); + } + } }