Skip to content

Commit

Permalink
add register action request/response
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 17, 2023
1 parent 4d8d32e commit e38c989
Show file tree
Hide file tree
Showing 13 changed files with 415 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class LLMSpec implements ToXContentObject {
public static final String MODEL_ID_FIELD = "model_id";
Expand Down
25 changes: 14 additions & 11 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -26,7 +27,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
Expand Down Expand Up @@ -99,15 +100,17 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
appType = input.readString();
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
}
Expand Down Expand Up @@ -144,9 +147,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeInstant(createdTime);
out.writeInstant(lastUpdateTime);
out.writeString(appType);
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(appType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -18,7 +19,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@EqualsAndHashCode
@Getter
public class MLMemorySpec implements ToXContentObject {
public static final String MEMORY_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLToolSpec implements ToXContentObject {
public static final String TOOL_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.transport.agent;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -20,6 +21,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLAgentGetResponse extends ActionResponse implements ToXContentObject {
MLAgent mlAgent;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import org.opensearch.action.ActionType;

public class MLRegisterAgentAction extends ActionType<MLRegisterAgentResponse> {
public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction();
public static final String NAME = "cluster:admin/opensearch/ml/agents/register";

private MLRegisterAgentAction() {
super(NAME, MLRegisterAgentResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.agent.MLAgent;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLRegisterAgentRequest extends ActionRequest {

MLAgent mlAgent;

@Builder
public MLRegisterAgentRequest(MLAgent mlAgent) {
this.mlAgent = mlAgent;
}

public MLRegisterAgentRequest(StreamInput in) throws IOException {
super(in);
this.mlAgent = new MLAgent(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlAgent == null) {
exception = addValidationError("ML agent can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlAgent.writeTo(out);
}

public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLRegisterAgentRequest) {
return (MLRegisterAgentRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
public static final String AGENT_ID_FIELD = "agent_id";

private String agentId;

public MLRegisterAgentResponse(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public MLRegisterAgentResponse(String agentId) {
this.agentId= agentId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(agentId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(AGENT_ID_FIELD, agentId);
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
return builder;
}

public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLUndeployModelsResponse) {
return (MLUndeployModelsResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUndeployModelsResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e);
}
}
}
Loading

0 comments on commit e38c989

Please sign in to comment.