Skip to content

Commit

Permalink
fix list/get tool api
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Nov 1, 2023
1 parent 7b9a1f2 commit 3f4c569
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ public class AgentTool implements Tool {
@Setter @Getter
private String alias;

private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent.";
@Getter @Setter
private String description = "Use this tool to run any agent.";
private String description = DEFAULT_DESCRIPTION;

public AgentTool(Client client, String agentId) {
this.client = client;
Expand Down Expand Up @@ -92,5 +93,10 @@ public void init(Client client) {
public AgentTool create(Map<String, Object> map) {
return new AgentTool(client, (String)map.get("agent_id"));
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ public class CatIndexTool implements Tool {

@Setter @Getter
private String alias;
private static String DEFAULT_DESCRIPTION = "User this tool to get index information.";
@Getter @Setter
private String description = "User this tool to get index information.";
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
Expand Down Expand Up @@ -232,5 +233,10 @@ public void init(Client client, ClusterService clusterService) {
public CatIndexTool create(Map<String, Object> map) {
return new CatIndexTool(client, clusterService, (String)map.get("model_id"));
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ public class MLModelTool implements Tool {

@Setter @Getter
private String alias;
private static String DEFAULT_DESCRIPTION = "User this tool to run any model.";
@Getter @Setter
private String description = "User this tool to run any model.";
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
Expand Down Expand Up @@ -116,5 +117,10 @@ public void init(Client client) {
public MLModelTool create(Map<String, Object> map) {
return new MLModelTool(client, (String)map.get("model_id"));
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ public class MathTool implements Tool {
@Setter
private ScriptService scriptService;

private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem.";
@Getter @Setter
private String description = "Use this tool to calculate any math problem.";
private String description = DEFAULT_DESCRIPTION;

public MathTool(ScriptService scriptService) {
this.scriptService = scriptService;
Expand Down Expand Up @@ -95,5 +96,10 @@ public void init(ScriptService scriptService) {
public MathTool create(Map<String, Object> map) {
return new MathTool(scriptService);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ public class PainlessScriptTool implements Tool {

@Setter @Getter
private String alias;
private static String DEFAULT_DESCRIPTION = "User this tool to get index information.";
@Getter @Setter
private String description = "User this tool to get index information.";
private String description = DEFAULT_DESCRIPTION;
private Client client;
private String modelId;
@Setter
Expand Down Expand Up @@ -104,5 +105,10 @@ public void init(Client client, ScriptService scriptService) {
public PainlessScriptTool create(Map<String, Object> map) {
return new PainlessScriptTool(client, scriptService);
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ public class VectorDBTool implements Tool {
public static final String NAME = "VectorDBTool";
@Setter @Getter
private String alias;
private static String DEFAULT_DESCRIPTION = "Useful for when need to search my data in OpenSearch index.";
@Getter @Setter
private String description = "Useful for when need to search my data in OpenSearch index.";
private String description = DEFAULT_DESCRIPTION;

private Client client;
private NamedXContentRegistry xContentRegistry;
Expand Down Expand Up @@ -229,6 +230,11 @@ public VectorDBTool create(Map<String, Object> params) {
.docSize(docSize)
.build();
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,8 @@ public List<RestHandler> getRestHandlers(
RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction();
RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction();
RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting);
RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(externalTools);
RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(externalTools);
RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories);
RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories);
return ImmutableList
.of(
restMLStatsAction,
Expand Down Expand Up @@ -817,7 +817,7 @@ public void loadExtensions(ExtensionLoader loader) {
}
}

List<Tool.Factory> toolFactories = extension.getToolFactories();
List<Tool.Factory<? extends Tool>> toolFactories = extension.getToolFactories();
for (Tool.Factory toolFactory : toolFactories) {
ToolAnnotation toolAnnotation = toolFactory.getClass().getDeclaringClass().getAnnotation(ToolAnnotation.class);
if (toolAnnotation == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ public class RestMLGetToolAction extends BaseRestHandler {

private static final String ML_GET_TOOL_ACTION = "ml_get_tool_action";

private Map<String, Tool> externalTools;
private Map<String, Tool.Factory> toolFactories;

public RestMLGetToolAction(Map<String, Tool> externalTools) {
this.externalTools = externalTools;
public RestMLGetToolAction(Map<String, Tool.Factory> toolFactories) {
this.toolFactories = toolFactories;
}

@Override
Expand All @@ -41,7 +41,7 @@ public String getName() {
@Override
public List<Route> routes() {
return ImmutableList
.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tool/{%s}", ML_BASE_URI, PARAMETER_TOOL_NAME)));
.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tools/{%s}", ML_BASE_URI, PARAMETER_TOOL_NAME)));
}

/**
Expand All @@ -60,10 +60,8 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
List<ToolMetadata> toolList = new ArrayList<>();
externalTools
.forEach(
(key, value) -> toolList.add(ToolMetadata.builder().name(value.getName()).description(value.getDescription()).build())
);
toolFactories
.forEach((key, value) -> toolList.add(ToolMetadata.builder().name(key).description(value.getDefaultDescription()).build()));
String toolName = getParameterId(request, PARAMETER_TOOL_NAME);
MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).externalTools(toolList).build();
return channel -> client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, new RestToXContentListener<>(channel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
public class RestMLListToolsAction extends BaseRestHandler {
private static final String ML_GET_MODEL_ACTION = "ml_get_tools_action";

private Map<String, Tool> externalTools;
private Map<String, Tool.Factory> toolFactories;

public RestMLListToolsAction(Map<String, Tool> externalTools) {
this.externalTools = externalTools;
public RestMLListToolsAction(Map<String, Tool.Factory> toolFactories) {
this.toolFactories = toolFactories;
}

@Override
Expand Down Expand Up @@ -58,10 +58,8 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
List<ToolMetadata> toolList = new ArrayList<>();
externalTools
.forEach(
(key, value) -> toolList.add(ToolMetadata.builder().name(value.getName()).description(value.getDescription()).build())
);
toolFactories
.forEach((key, value) -> toolList.add(ToolMetadata.builder().name(key).description(value.getDefaultDescription()).build()));
MLToolsListRequest mlToolsGetRequest = MLToolsListRequest.builder().externalTools(toolList).build();
return channel -> client.execute(MLListToolsAction.INSTANCE, mlToolsGetRequest, new RestToXContentListener<>(channel));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,6 @@ default boolean end(String input, Map<String, String> toolParameters) {
*/
interface Factory<T extends Tool> {
T create(Map<String, Object> params);
String getDefaultDescription();
}
}

0 comments on commit 3f4c569

Please sign in to comment.