Skip to content

Commit

Permalink
add more ut
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Dec 17, 2023
1 parent f339dcc commit 0335a8f
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 29 deletions.
33 changes: 19 additions & 14 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ public MLAgent(String name,
Instant createdTime,
Instant lastUpdateTime,
String appType) {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
this.name = name;
this.type = type;
this.description = description;
Expand All @@ -78,6 +75,24 @@ public MLAgent(String name,
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.appType = appType;
validate();
}

private void validate() {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
Set<String> toolNames = new HashSet<>();
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Duplicate tool defined: " + toolName);
} else {
toolNames.add(toolName);
}
}
}
}

public MLAgent(StreamInput input) throws IOException{
Expand All @@ -103,17 +118,7 @@ public MLAgent(StreamInput input) throws IOException{
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
}
validate();
}

public void writeTo(StreamOutput out) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionReque
return new MLRegisterAgentRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterModelRequest", e);
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterAgentRequest", e);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionRe
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterAgentResponse", e);
throw new UncheckedIOException("Failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionR
return new MLUndeployModelsResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLUndeployModelsResponse", e);
throw new UncheckedIOException("Failed to parse ActionResponse into MLUndeployModelsResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public void constructor_NullName() {
MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void constructor_DuplicateTool() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Duplicate tool defined: test_tool_name");
MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false);
MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void writeTo() throws IOException {
MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@
import org.opensearch.ml.common.agent.MLToolSpec;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;

public class MLRegisterAgentRequestTest {

Expand Down Expand Up @@ -88,4 +85,29 @@ public void writeTo(StreamOutput out) throws IOException {
assertEquals(registerAgentRequest.getMlAgent(), parsedRequest.getMlAgent());
}

@Test
public void fromActionRequest_Success_MLRegisterAgentRequest() {
MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent);
MLRegisterAgentRequest parsedRequest = MLRegisterAgentRequest.fromActionRequest(registerAgentRequest);
assertSame(registerAgentRequest, parsedRequest);
}

@Test
public void fromActionRequest_Exception() {
exceptionRule.expect(UncheckedIOException.class);
exceptionRule.expectMessage("Failed to parse ActionRequest into MLRegisterAgentRequest");
MLRegisterAgentRequest registerAgentRequest = new MLRegisterAgentRequest(mlAgent);
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLRegisterAgentRequest.fromActionRequest(actionRequest);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.*;

public class MLRegisterAgentResponseTest {
String agentId;
Expand Down Expand Up @@ -72,4 +71,27 @@ public void writeTo(StreamOutput out) throws IOException {
assertNotSame(registerAgentResponse, parsedResponse);
assertEquals(registerAgentResponse.getAgentId(), parsedResponse.getAgentId());
}

@Test
public void fromActionResponse_Success_MLRegisterAgentResponse() {
MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId);
MLRegisterAgentResponse parsedResponse = MLRegisterAgentResponse.fromActionResponse(registerAgentResponse);
assertSame(registerAgentResponse, parsedResponse);
}

@Test
public void fromActionResponse_Exception() {
exceptionRule.expect(UncheckedIOException.class);
exceptionRule.expectMessage("Failed to parse ActionResponse into MLRegisterAgentResponse");
MLRegisterAgentResponse registerAgentResponse = new MLRegisterAgentResponse(agentId);
ActionResponse actionResponse = new ActionResponse() {

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLRegisterAgentResponse.fromActionResponse(actionResponse);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package org.opensearch.ml.common.transport.undeploy;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
import org.opensearch.cluster.ClusterName;
Expand All @@ -21,21 +23,22 @@
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.*;
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;

public class MLUndeployModelsResponseTest {

MLUndeployModelNodesResponse undeployModelNodesResponse;
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() {
Expand Down Expand Up @@ -81,7 +84,7 @@ public void toXContent() throws IOException {
}

@Test
public void fromActionResponse_Sucess() {
public void fromActionResponse_Success() {
MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse);
ActionResponse actionResponse = new ActionResponse() {
@Override
Expand All @@ -94,4 +97,25 @@ public void writeTo(StreamOutput out) throws IOException {
assertEquals(1, parsedResponse.getResponse().getNodes().size());
assertEquals("test_node_id", parsedResponse.getResponse().getNodes().get(0).getNode().getId());
}

@Test
public void fromActionResponse_Success_MLUndeployModelsResponse() {
MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse);
MLUndeployModelsResponse parsedResponse = MLUndeployModelsResponse.fromActionResponse(undeployModelsResponse);
assertSame(undeployModelsResponse, parsedResponse);
}

@Test
public void fromActionResponse_Exception() {
exceptionRule.expect(UncheckedIOException.class);
exceptionRule.expectMessage("Failed to parse ActionResponse into MLUndeployModelsResponse");
MLUndeployModelsResponse undeployModelsResponse = new MLUndeployModelsResponse(undeployModelNodesResponse);
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLUndeployModelsResponse.fromActionResponse(actionResponse);
}
}

0 comments on commit 0335a8f

Please sign in to comment.