diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ed59d52d5b..287fbb8127 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -324,16 +324,15 @@ public T createPayload(String action, Map parameters) { payload = fillNullParameters(parameters, payload); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - boolean isJson = isJson(payload); - if (!isJson) { - String manuallyFixedJson = connectorAction.get().getRequestBody(); + if (!isJson(payload)) { + String payloadAfterEscape = connectorAction.get().getRequestBody(); Map escapedParameters = escapeMapValues(parameters); StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}"); - manuallyFixedJson = escapedSubstitutor.replace(manuallyFixedJson); - if (!isJson(manuallyFixedJson)) { + payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape); + if (!isJson(payloadAfterEscape)) { throw new IllegalArgumentException("Invalid payload: " + payload); } else { - payload = manuallyFixedJson; + payload = payloadAfterEscape; } } return (T) payload; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index e247baa500..c8c161488d 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -197,6 +197,10 @@ public void createPayloadWithString() { Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload); } + /** + * + */ + @Test public void createPayloadWithList() { String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; @@ -211,6 +215,87 @@ public void createPayloadWithList() { connector.validatePayload(predictPayload); } + @Test + public void createPayloadWithNestedList() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + ArrayList NestedListOfDocuments = new ArrayList<>(); + NestedListOfDocuments.add("document2"); + listOfDocuments.add(toJson(NestedListOfDocuments)); + parameters.put("context", toJson(listOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + + @Test + public void createPayloadWithMap() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context}"); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + parameters.put("context", toJson(mapOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + + @Test + public void createPayloadWithNestedMapOfString() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context}"); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + Map nestedMapOfDocuments = new HashMap<>(); + nestedMapOfDocuments.put("city", "New York"); + mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); + parameters.put("context", toJson(mapOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + + @Test + public void createPayloadWithNestedMapOfObject() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context}"); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + Map nestedMapOfDocuments = new HashMap<>(); + nestedMapOfDocuments.put("city", "New York"); + mapOfDocuments.put("hometown", nestedMapOfDocuments); + parameters.put("context", toJson(mapOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + + @Test + public void createPayloadWithNestedListOfMapOfObject() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + ArrayList NestedListOfDocuments = new ArrayList<>(); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + Map nestedMapOfDocuments = new HashMap<>(); + nestedMapOfDocuments.put("city", "New York"); + mapOfDocuments.put("hometown", nestedMapOfDocuments); + listOfDocuments.add(toJson(NestedListOfDocuments)); + parameters.put("context", toJson(listOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + } + @Test public void createPayload() { HttpConnector connector = createHttpConnector();