Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improve][Transform] Add LLM model provider microsoft #7778

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/en/transform-v2/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ more.
## Options

| name | type | required | default value |
|------------------------| ------ | -------- |---------------|
|------------------------|--------|----------|---------------|
| model_provider | enum | yes | |
| output_data_type | enum | no | String |
| output_column_name | string | no | llm_output |
Expand All @@ -28,7 +28,9 @@ more.
### model_provider

The model provider to use. The available options are:
OPENAI, DOUBAO, KIMIAI, CUSTOM
OPENAI, DOUBAO, KIMIAI, MICROSOFT, CUSTOM

> tips: If you use Microsoft, please make sure api_path cannot be empty

### output_data_type

Expand Down Expand Up @@ -254,6 +256,7 @@ sink {
}
}
```

### Customize the LLM model

```hocon
Expand Down
4 changes: 3 additions & 1 deletion docs/zh/transform-v2/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
### model_provider

要使用的模型提供者。可用选项为:
OPENAI、DOUBAO、KIMIAI、CUSTOM
OPENAI、DOUBAO、KIMIAI、MICROSOFT, CUSTOM

> tips: 如果使用 Microsoft, 请确保 api_path 配置不能为空

### output_data_type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ public void testLLMWithOpenAI(TestContainer container)
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testLLMWithMicrosoft(TestContainer container)
throws IOException, InterruptedException {
Container.ExecResult execResult = container.executeJob("/llm_microsoft_transform.conf");
Assertions.assertEquals(0, execResult.getExitCode());
}

@TestTemplate
public void testLLMWithOpenAIBoolean(TestContainer container)
throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
######
###### This config file is a demonstration of streaming processing in seatunnel config
######

env {
job.mode = "BATCH"
}

source {
FakeSource {
row.num = 5
schema = {
fields {
id = "int"
name = "string"
}
}
rows = [
{fields = [1, "Jia Fan"], kind = INSERT}
{fields = [2, "Hailin Wang"], kind = INSERT}
{fields = [3, "Tomas"], kind = INSERT}
{fields = [4, "Eric"], kind = INSERT}
{fields = [5, "Guangdong Liu"], kind = INSERT}
]
result_table_name = "fake"
}
}

transform {
LLM {
source_table_name = "fake"
model_provider = MICROSOFT
model = gpt-35-turbo
api_key = sk-xxx
prompt = "Determine whether someone is Chinese or American by their name"
api_path = "http://mockserver:1080/openai/deployments/${model}/chat/completions?api-version=2024-02-01"
result_table_name = "llm_output"
}
}

sink {
Assert {
source_table_name = "llm_output"
rules =
{
field_rules = [
{
field_name = llm_output
field_type = string
field_value = [
{
rule_type = NOT_NULL
}
]
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,37 @@
"Content-Type": "application/json"
}
}
},
{
"httpRequest": {
"method": "POST",
"path": "/openai/deployments/gpt-35-turbo/chat/.*"
},
"httpResponse": {
"body": {
"id": "chatcmpl-6v7mkQj980V1yBec6ETrKPRqFjNw9",
"object": "chat.completion",
"created": 1679072642,
"model": "gpt-35-turbo",
"usage": {
"prompt_tokens": 58,
"completion_tokens": 68,
"total_tokens": 126
},
"choices": [
{
"message": {
"role": "assistant",
"content": "[\"Chinese\"]"
},
"finish_reason": "stop",
"index": 0
}
]
},
"headers": {
"Content-Type": "application/json"
}
}
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public enum ModelProvider {
"https://ark.cn-beijing.volces.com/api/v3/embeddings"),
QIANFAN("", "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"),
KIMIAI("https://api.moonshot.cn/v1/chat/completions", ""),
MICROSOFT("", ""),
corgy-w marked this conversation as resolved.
Show resolved Hide resolved
CUSTOM("", ""),
LOCAL("", "");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.seatunnel.transform.nlpmodel.llm.remote.Model;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;

import lombok.NonNull;
Expand Down Expand Up @@ -94,6 +95,17 @@ public void open() {
LLMTransformConfig.CustomRequestConfig
.CUSTOM_RESPONSE_PARSE));
break;
case MICROSOFT:
model =
new MicrosoftModel(
inputCatalogTable.getSeaTunnelRowType(),
outputDataType.getSqlType(),
config.get(LLMTransformConfig.INFERENCE_COLUMNS),
config.get(LLMTransformConfig.PROMPT),
config.get(LLMTransformConfig.MODEL),
config.get(LLMTransformConfig.API_KEY),
provider.usedLLMPath(config.get(LLMTransformConfig.API_PATH)));
break;
case OPENAI:
case DOUBAO:
model =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.seatunnel.api.table.factory.TableTransformFactory;
import org.apache.seatunnel.api.table.factory.TableTransformFactoryContext;
import org.apache.seatunnel.transform.nlpmodel.ModelProvider;
import org.apache.seatunnel.transform.nlpmodel.ModelTransformConfig;

import com.google.auto.service.AutoService;

Expand All @@ -50,14 +49,17 @@ public OptionRule optionRule() {
LLMTransformConfig.PROCESS_BATCH_SIZE)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
Lists.newArrayList(ModelProvider.OPENAI, ModelProvider.DOUBAO),
Lists.newArrayList(
ModelProvider.OPENAI,
ModelProvider.DOUBAO,
ModelProvider.MICROSOFT),
LLMTransformConfig.API_KEY)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
ModelProvider.QIANFAN,
LLMTransformConfig.API_KEY,
LLMTransformConfig.SECRET_KEY,
ModelTransformConfig.OAUTH_PATH)
LLMTransformConfig.OAUTH_PATH)
.conditional(
LLMTransformConfig.MODEL_PROVIDER,
ModelProvider.CUSTOM,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft;

import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;

import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;

import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;

import com.google.common.annotations.VisibleForTesting;

import java.io.IOException;
import java.util.List;

public class MicrosoftModel extends AbstractModel {

private final CloseableHttpClient client;
private final String apiKey;
private final String model;
private final String apiPath;

public MicrosoftModel(
SeaTunnelRowType rowType,
SqlType outputType,
List<String> projectionColumns,
String prompt,
String model,
String apiKey,
String apiPath) {
super(rowType, outputType, projectionColumns, prompt);
this.model = model;
this.apiKey = apiKey;
this.apiPath =
CustomConfigPlaceholder.replacePlaceholders(
apiPath, CustomConfigPlaceholder.REPLACE_PLACEHOLDER_MODEL, model, null);
this.client = HttpClients.createDefault();
}

@Override
protected List<String> chatWithModel(String prompt, String data) throws IOException {
HttpPost post = new HttpPost(apiPath);
post.setHeader("Authorization", "Bearer " + apiKey);
post.setHeader("Content-Type", "application/json");
ObjectNode objectNode = createJsonNodeFromData(prompt, data);
post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString(objectNode), "UTF-8"));
post.setConfig(
RequestConfig.custom().setConnectTimeout(20000).setSocketTimeout(20000).build());
CloseableHttpResponse response = client.execute(post);
String responseStr = EntityUtils.toString(response.getEntity());
if (response.getStatusLine().getStatusCode() != 200) {
throw new IOException("Failed to chat with model, response: " + responseStr);
}

JsonNode result = OBJECT_MAPPER.readTree(responseStr);
String resultData = result.get("choices").get(0).get("message").get("content").asText();
return OBJECT_MAPPER.readValue(
convertData(resultData), new TypeReference<List<String>>() {});
}

@VisibleForTesting
public ObjectNode createJsonNodeFromData(String prompt, String data) {
ObjectNode objectNode = OBJECT_MAPPER.createObjectNode();
ArrayNode messages = objectNode.putArray("messages");
messages.addObject().put("role", "system").put("content", prompt);
messages.addObject().put("role", "user").put("content", data);
return objectNode;
}

@Override
public void close() throws IOException {
if (client != null) {
client.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.seatunnel.format.json.RowToJsonConverters;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.custom.CustomModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.kimiai.KimiAIModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.microsoft.MicrosoftModel;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.openai.OpenAIModel;

import org.junit.jupiter.api.Assertions;
Expand All @@ -36,6 +37,7 @@
import com.google.common.collect.Lists;

import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -130,6 +132,38 @@ void testKimiAIRequestJson() throws IOException {
model.close();
}

@Test
void testMicrosoftRequestJson() throws Exception {
SeaTunnelRowType rowType =
new SeaTunnelRowType(
new String[] {"id", "name"},
new SeaTunnelDataType[] {BasicType.INT_TYPE, BasicType.STRING_TYPE});
MicrosoftModel model =
new MicrosoftModel(
rowType,
SqlType.STRING,
null,
"Determine whether someone is Chinese or American by their name",
"gpt-35-turbo",
"sk-xxx",
"https://api.moonshot.cn/openai/deployments/${model}/chat/completions?api-version=2024-02-01");
Field apiPathField = model.getClass().getDeclaredField("apiPath");
apiPathField.setAccessible(true);
String apiPath = (String) apiPathField.get(model);
Assertions.assertEquals(
"https://api.moonshot.cn/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-02-01",
apiPath);

ObjectNode node =
model.createJsonNodeFromData(
"Determine whether someone is Chinese or American by their name",
"{\"id\":1, \"name\":\"John\"}");
Assertions.assertEquals(
"{\"messages\":[{\"role\":\"system\",\"content\":\"Determine whether someone is Chinese or American by their name\"},{\"role\":\"user\",\"content\":\"{\\\"id\\\":1, \\\"name\\\":\\\"John\\\"}\"}]}",
OBJECT_MAPPER.writeValueAsString(node));
model.close();
}

@Test
void testCustomRequestJson() throws IOException {
SeaTunnelRowType rowType =
Expand Down
Loading