diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java index 5a4fa640e96..712a6d7f908 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/java/org/apache/seatunnel/e2e/transform/TestLLMIT.java @@ -88,6 +88,14 @@ public void testLLMWithOpenAI(TestContainer container) Assertions.assertEquals(0, execResult.getExitCode()); } + @TestTemplate + public void testLLMWithOpenAIBoolean(TestContainer container) + throws IOException, InterruptedException { + Container.ExecResult execResult = + container.executeJob("/llm_openai_transform_boolean.conf"); + Assertions.assertEquals(0, execResult.getExitCode()); + } + @TestTemplate public void testLLMWithCustomModel(TestContainer container) throws IOException, InterruptedException { diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_boolean.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_boolean.conf new file mode 100644 index 00000000000..4aeec3c9658 --- /dev/null +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/llm_openai_transform_boolean.conf @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +###### +###### This config file is a demonstration of streaming processing in seatunnel config +###### + +env { + job.mode = "BATCH" +} + +source { + FakeSource { + row.num = 5 + schema = { + fields { + id = "int" + name = "string" + } + } + rows = [ + {fields = [1, "Jia Fan"], kind = INSERT} + {fields = [2, "Hailin Wang"], kind = INSERT} + {fields = [3, "Tomas"], kind = INSERT} + {fields = [4, "Eric"], kind = INSERT} + {fields = [5, "Guangdong Liu"], kind = INSERT} + ] + result_table_name = "fake" + } +} + +transform { + LLM { + source_table_name = "fake" + model_provider = OPENAI + model = gpt-4o-mini + api_key = sk-xxx + prompt = "Determine whether someone is Chinese or American by their name" + output_data_type = boolean + openai.api_path = "http://mockserver:1080/v2/chat/completions" + result_table_name = "llm_output" + } +} + +sink { + Assert { + source_table_name = "llm_output" + rules = + { + field_rules = [ + { + field_name = llm_output + field_type = boolean + field_value = [ + { + rule_type = NOT_NULL + } + ] + } + ] + } + } +} diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json index b4a2e53bea8..f7674d3a2a2 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/mockserver-config.json @@ -36,5 +36,41 @@ "Content-Type": "application/json" } } + }, + { + "httpRequest": { + "method": "POST", + "path": "/v2/chat/completions" + }, + "httpResponse": { + "body": { + "id": "chatcmpl-9s4hoBNGV0d9Mudkhvgzg64DAWPnx", + "object": "chat.completion", + "created": 1722674828, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "[True]" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 107, + "completion_tokens": 3, + "total_tokens": 110 + }, + "system_fingerprint": "fp_0f03d4f0ee", + "code": 0, + "msg": "ok" + }, + "headers": { + "Content-Type": "application/json" + } + } } ] diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java index e658e514597..4ee271c4085 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/AbstractModel.java @@ -66,4 +66,8 @@ public List inference(List rows) throws IOException { protected abstract List chatWithModel(String promptWithLimit, String rowsJson) throws IOException; + + protected String convertData(String data) { + return outputType == SqlType.BOOLEAN ? data.toLowerCase() : data; + } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java index 1424eed9e4c..8dc12ec0cd3 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/llm/remote/openai/OpenAIModel.java @@ -82,7 +82,8 @@ protected List chatWithModel(String prompt, String data) throws IOExcept JsonNode result = OBJECT_MAPPER.readTree(responseStr); String resultData = result.get("choices").get(0).get("message").get("content").asText(); - return OBJECT_MAPPER.readValue(resultData, new TypeReference>() {}); + return OBJECT_MAPPER.readValue( + convertData(resultData), new TypeReference>() {}); } @VisibleForTesting