Skip to content

Commit

Permalink
fix register client API
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Oct 27, 2023
1 parent 0920ba7 commit e005d0a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
@Override
public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
client
.execute(
MLRegisterModelAction.INSTANCE,
registerRequest,
ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
);
client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener));
}

@Override
Expand Down Expand Up @@ -266,6 +261,14 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
return actionListener;
}

private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener(ActionListener<MLRegisterModelResponse> listener) {
ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res);
return registerModelResponse;
});
return actionListener;
}

private <T extends ActionResponse> ActionListener<T> wrapActionListener(
final ActionListener<T> listener,
final Function<ActionResponse, T> recreate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.transport.MLTaskResponse;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -61,4 +67,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.endObject();
return builder;
}

public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterModelResponse) {
return (MLRegisterModelResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterModelResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e);
}
}
}

0 comments on commit e005d0a

Please sign in to comment.