From 0335a8fd393b4426a270f63da663df49a47dbc1a Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sat, 16 Dec 2023 22:38:51 -0800 Subject: [PATCH] add more ut Signed-off-by: Yaliang Wu --- .../opensearch/ml/common/agent/MLAgent.java | 33 +++++++++++-------- .../agent/MLRegisterAgentRequest.java | 2 +- .../agent/MLRegisterAgentResponse.java | 2 +- .../undeploy/MLUndeployModelsResponse.java | 2 +- .../ml/common/agent/MLAgentTest.java | 8 +++++ .../agent/MLRegisterAgentRequestTest.java | 32 +++++++++++++++--- .../agent/MLRegisterAgentResponseTest.java | 28 ++++++++++++++-- .../MLUndeployModelsResponseTest.java | 32 +++++++++++++++--- 8 files changed, 110 insertions(+), 29 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 2d2dad77bf..9033b92afc 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -65,9 +65,6 @@ public MLAgent(String name, Instant createdTime, Instant lastUpdateTime, String appType) { - if (name == null) { - throw new IllegalArgumentException("agent name is null"); - } this.name = name; this.type = type; this.description = description; @@ -78,6 +75,24 @@ public MLAgent(String name, this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; this.appType = appType; + validate(); + } + + private void validate() { + if (name == null) { + throw new IllegalArgumentException("agent name is null"); + } + Set toolNames = new HashSet<>(); + if (tools != null) { + for (MLToolSpec toolSpec : tools) { + String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); + if (toolNames.contains(toolName)) { + throw new IllegalArgumentException("Duplicate tool defined: " + toolName); + } else { + toolNames.add(toolName); + } + } + } } public MLAgent(StreamInput input) throws IOException{ @@ -103,17 +118,7 @@ public MLAgent(StreamInput input) throws IOException{ createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); appType = input.readOptionalString(); - if (!"flow".equals(type)) { - Set toolNames = new HashSet<>(); - if (tools != null) { - for (MLToolSpec toolSpec : tools) { - String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); - if (toolNames.contains(toolName)) { - throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName); - } - } - } - } + validate(); } public void writeTo(StreamOutput out) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java index 00489ccf08..4add7827d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -70,7 +70,7 @@ public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionReque return new MLRegisterAgentRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelRequest", e); + throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterAgentRequest", e); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java index 3005739416..7f8b633cbe 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponse.java @@ -59,7 +59,7 @@ public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionRe return new MLRegisterAgentResponse(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e); + throw new UncheckedIOException("Failed to parse ActionResponse into MLRegisterAgentResponse", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java index d86b889ce6..71fc7ef38b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java @@ -67,7 +67,7 @@ public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionR return new MLUndeployModelsResponse(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e); + throw new UncheckedIOException("Failed to parse ActionResponse into MLUndeployModelsResponse", e); } } } diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index bfaec959c4..e00a49aeb6 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -35,6 +35,14 @@ public void constructor_NullName() { MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test"); } + @Test + public void constructor_DuplicateTool() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Duplicate tool defined: test_tool_name"); + MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false); + MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test"); + } + @Test public void writeTo() throws IOException { MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test"); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java index 2c189690dc..ee446db82f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -17,13 +17,10 @@ import org.opensearch.ml.common.agent.MLToolSpec; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Arrays; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; public class MLRegisterAgentRequestTest { @@ -88,4 +85,29 @@ public void writeTo(StreamOutput out) throws IOException { assertEquals(registerAgentRequest.getMlAgent(), parsedRequest.getMlAgent()); } + @Test + public void fromActionRequest_Success_MLRegisterAgentRequest() { + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + MLRegisterAgentRequest parsedRequest = MLRegisterAgentRequest.fromActionRequest(registerAgentRequest); + assertSame(registerAgentRequest, parsedRequest); + } + + @Test + public void fromActionRequest_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionRequest into MLRegisterAgentRequest"); + MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLRegisterAgentRequest.fromActionRequest(actionRequest); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java index 6c300e786c..9997eb0ad6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentResponseTest.java @@ -18,10 +18,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.io.UncheckedIOException; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.*; public class MLRegisterAgentResponseTest { String agentId; @@ -72,4 +71,27 @@ public void writeTo(StreamOutput out) throws IOException { assertNotSame(registerAgentResponse, parsedResponse); assertEquals(registerAgentResponse.getAgentId(), parsedResponse.getAgentId()); } + + @Test + public void fromActionResponse_Success_MLRegisterAgentResponse() { + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + MLRegisterAgentResponse parsedResponse = MLRegisterAgentResponse.fromActionResponse(registerAgentResponse); + assertSame(registerAgentResponse, parsedResponse); + } + + @Test + public void fromActionResponse_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionResponse into MLRegisterAgentResponse"); + MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId); + ActionResponse actionResponse = new ActionResponse() { + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLRegisterAgentResponse.fromActionResponse(actionResponse); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java index 747ed7bff3..69f12099e9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponseTest.java @@ -6,7 +6,9 @@ package org.opensearch.ml.common.transport.undeploy; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.cluster.ClusterName; @@ -21,6 +23,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.io.UncheckedIOException; import java.net.InetAddress; import java.util.Arrays; import java.util.Collections; @@ -28,14 +31,14 @@ import java.util.List; import java.util.Map; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.*; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; public class MLUndeployModelsResponseTest { MLUndeployModelNodesResponse undeployModelNodesResponse; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() { @@ -81,7 +84,7 @@ public void toXContent() throws IOException { } @Test - public void fromActionResponse_Sucess() { + public void fromActionResponse_Success() { MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); ActionResponse actionResponse = new ActionResponse() { @Override @@ -94,4 +97,25 @@ public void writeTo(StreamOutput out) throws IOException { assertEquals(1, parsedResponse.getResponse().getNodes().size()); assertEquals("test_node_id", parsedResponse.getResponse().getNodes().get(0).getNode().getId()); } + + @Test + public void fromActionResponse_Success_MLUndeployModelsResponse() { + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + MLUndeployModelsResponse parsedResponse = MLUndeployModelsResponse.fromActionResponse(undeployModelsResponse); + assertSame(undeployModelsResponse, parsedResponse); + } + + @Test + public void fromActionResponse_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionResponse into MLUndeployModelsResponse"); + MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUndeployModelsResponse.fromActionResponse(actionResponse); + } }