diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 57279f771d..e64349a187 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -35,8 +35,9 @@ public class AgentTool implements Tool { @Setter @Getter private String alias; + private static String DEFAULT_DESCRIPTION = "Use this tool to run any agent."; @Getter @Setter - private String description = "Use this tool to run any agent."; + private String description = DEFAULT_DESCRIPTION; public AgentTool(Client client, String agentId) { this.client = client; @@ -92,5 +93,10 @@ public void init(Client client) { public AgentTool create(Map map) { return new AgentTool(client, (String)map.get("agent_id")); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java index 4809fa325e..0f395c4081 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/CatIndexTool.java @@ -46,8 +46,9 @@ public class CatIndexTool implements Tool { @Setter @Getter private String alias; + private static String DEFAULT_DESCRIPTION = "User this tool to get index information."; @Getter @Setter - private String description = "User this tool to get index information."; + private String description = DEFAULT_DESCRIPTION; private Client client; private String modelId; @Setter @@ -232,5 +233,10 @@ public void init(Client client, ClusterService clusterService) { public CatIndexTool create(Map map) { return new CatIndexTool(client, clusterService, (String)map.get("model_id")); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 22ea2cdeaf..8ae1650ff4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -36,8 +36,9 @@ public class MLModelTool implements Tool { @Setter @Getter private String alias; + private static String DEFAULT_DESCRIPTION = "User this tool to run any model."; @Getter @Setter - private String description = "User this tool to run any model."; + private String description = DEFAULT_DESCRIPTION; private Client client; private String modelId; @Setter @@ -116,5 +117,10 @@ public void init(Client client) { public MLModelTool create(Map map) { return new MLModelTool(client, (String)map.get("model_id")); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java index 3389505958..65c79d3e0a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MathTool.java @@ -29,8 +29,9 @@ public class MathTool implements Tool { @Setter private ScriptService scriptService; + private static String DEFAULT_DESCRIPTION = "Use this tool to calculate any math problem."; @Getter @Setter - private String description = "Use this tool to calculate any math problem."; + private String description = DEFAULT_DESCRIPTION; public MathTool(ScriptService scriptService) { this.scriptService = scriptService; @@ -95,5 +96,10 @@ public void init(ScriptService scriptService) { public MathTool create(Map map) { return new MathTool(scriptService); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java index 2e2eabd62e..f90e3906c5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/PainlessScriptTool.java @@ -30,8 +30,9 @@ public class PainlessScriptTool implements Tool { @Setter @Getter private String alias; + private static String DEFAULT_DESCRIPTION = "User this tool to get index information."; @Getter @Setter - private String description = "User this tool to get index information."; + private String description = DEFAULT_DESCRIPTION; private Client client; private String modelId; @Setter @@ -104,5 +105,10 @@ public void init(Client client, ScriptService scriptService) { public PainlessScriptTool create(Map map) { return new PainlessScriptTool(client, scriptService); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java index ecfd80f80d..fe41eaabed 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VectorDBTool.java @@ -44,8 +44,9 @@ public class VectorDBTool implements Tool { public static final String NAME = "VectorDBTool"; @Setter @Getter private String alias; + private static String DEFAULT_DESCRIPTION = "Useful for when need to search my data in OpenSearch index."; @Getter @Setter - private String description = "Useful for when need to search my data in OpenSearch index."; + private String description = DEFAULT_DESCRIPTION; private Client client; private NamedXContentRegistry xContentRegistry; @@ -229,6 +230,11 @@ public VectorDBTool create(Map params) { .docSize(docSize) .build(); } + + @Override + public String getDefaultDescription() { + return DEFAULT_DESCRIPTION; + } } } \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 7b7a765c56..463cd8fe69 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -606,8 +606,8 @@ public List getRestHandlers( RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); - RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(externalTools); - RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(externalTools); + RestMLListToolsAction restMLListToolsAction = new RestMLListToolsAction(toolFactories); + RestMLGetToolAction restMLGetToolAction = new RestMLGetToolAction(toolFactories); return ImmutableList .of( restMLStatsAction, @@ -817,7 +817,7 @@ public void loadExtensions(ExtensionLoader loader) { } } - List toolFactories = extension.getToolFactories(); + List> toolFactories = extension.getToolFactories(); for (Tool.Factory toolFactory : toolFactories) { ToolAnnotation toolAnnotation = toolFactory.getClass().getDeclaringClass().getAnnotation(ToolAnnotation.class); if (toolAnnotation == null) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java index 01f1ba510b..e0ab0bc820 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetToolAction.java @@ -27,10 +27,10 @@ public class RestMLGetToolAction extends BaseRestHandler { private static final String ML_GET_TOOL_ACTION = "ml_get_tool_action"; - private Map externalTools; + private Map toolFactories; - public RestMLGetToolAction(Map externalTools) { - this.externalTools = externalTools; + public RestMLGetToolAction(Map toolFactories) { + this.toolFactories = toolFactories; } @Override @@ -41,7 +41,7 @@ public String getName() { @Override public List routes() { return ImmutableList - .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tool/{%s}", ML_BASE_URI, PARAMETER_TOOL_NAME))); + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/tools/{%s}", ML_BASE_URI, PARAMETER_TOOL_NAME))); } /** @@ -60,10 +60,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { List toolList = new ArrayList<>(); - externalTools - .forEach( - (key, value) -> toolList.add(ToolMetadata.builder().name(value.getName()).description(value.getDescription()).build()) - ); + toolFactories + .forEach((key, value) -> toolList.add(ToolMetadata.builder().name(key).description(value.getDefaultDescription()).build())); String toolName = getParameterId(request, PARAMETER_TOOL_NAME); MLToolGetRequest mlToolGetRequest = MLToolGetRequest.builder().toolName(toolName).externalTools(toolList).build(); return channel -> client.execute(MLGetToolAction.INSTANCE, mlToolGetRequest, new RestToXContentListener<>(channel)); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java index 7c240d79d8..db67296978 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListToolsAction.java @@ -26,10 +26,10 @@ public class RestMLListToolsAction extends BaseRestHandler { private static final String ML_GET_MODEL_ACTION = "ml_get_tools_action"; - private Map externalTools; + private Map toolFactories; - public RestMLListToolsAction(Map externalTools) { - this.externalTools = externalTools; + public RestMLListToolsAction(Map toolFactories) { + this.toolFactories = toolFactories; } @Override @@ -58,10 +58,8 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { List toolList = new ArrayList<>(); - externalTools - .forEach( - (key, value) -> toolList.add(ToolMetadata.builder().name(value.getName()).description(value.getDescription()).build()) - ); + toolFactories + .forEach((key, value) -> toolList.add(ToolMetadata.builder().name(key).description(value.getDefaultDescription()).build())); MLToolsListRequest mlToolsGetRequest = MLToolsListRequest.builder().externalTools(toolList).build(); return channel -> client.execute(MLListToolsAction.INSTANCE, mlToolsGetRequest, new RestToXContentListener<>(channel)); } diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java index 354749f395..05fb90b804 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java +++ b/spi/src/main/java/org/opensearch/ml/common/spi/tools/Tool.java @@ -92,5 +92,6 @@ default boolean end(String input, Map toolParameters) { */ interface Factory { T create(Map params); + String getDefaultDescription(); } }