From 8723620b6155fd871c90d75d3334a5a9736f89cf Mon Sep 17 00:00:00 2001 From: qianmoQ Date: Tue, 15 Aug 2023 00:23:42 +0800 Subject: [PATCH 1/2] [Stream] Support completions --- docs/docs/reference/stream/console.md | 35 ++++++++ docs/docs/reference/stream/console.zh.md | 35 ++++++++ docs/mkdocs.yml | 2 + pom.xml | 7 +- .../org/devlive/sdk/openai/DefaultClient.java | 45 ++++++++++- .../org/devlive/sdk/openai/OpenAiClient.java | 11 +++ .../sdk/openai/entity/CompletionEntity.java | 12 +++ .../listener/ConsoleEventSourceListener.java | 80 +++++++++++++++++++ .../sdk/openai/mixin/IgnoreUnknownMixin.java | 8 ++ .../sdk/openai/utils/MultipartBodyUtils.java | 1 + .../devlive/sdk/openai/StreamClientTest.java | 45 +++++++++++ 11 files changed, 278 insertions(+), 3 deletions(-) create mode 100644 docs/docs/reference/stream/console.md create mode 100644 docs/docs/reference/stream/console.zh.md create mode 100644 src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java create mode 100644 src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java create mode 100644 src/test/java/org/devlive/sdk/openai/StreamClientTest.java diff --git a/docs/docs/reference/stream/console.md b/docs/docs/reference/stream/console.md new file mode 100644 index 0000000..10f6110 --- /dev/null +++ b/docs/docs/reference/stream/console.md @@ -0,0 +1,35 @@ +--- +title: Console +--- + +!!! Note + + Please build the client before calling, the build code is as follows: + + ```java + CountDownLatch countDownLatch = new CountDownLatch(1); + ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder() + .countDownLatch(countDownLatch) + .build(); + OpenAiClient client = OpenAiClient.builder() + .apiKey(System.getProperty("openai.token")) + .listener(listener) + .build(); + ``` + + `System.getProperty("openai.token")` is the key to access the API authorization. + +### Create completion + +--- + +Creates a completion for the provided prompt and parameters. + +```java +CompletionEntity configure = CompletionEntity.builder() + .model(CompleteModel.TEXT_DAVINCI_003.getName()) + .prompt("How to create a completion") + .temperature(2D) + .build(); +client.createCompletion(configure); +``` diff --git a/docs/docs/reference/stream/console.zh.md b/docs/docs/reference/stream/console.zh.md new file mode 100644 index 0000000..1984473 --- /dev/null +++ b/docs/docs/reference/stream/console.zh.md @@ -0,0 +1,35 @@ +--- +title: Console +--- + +!!! Note + + 调用前请先构建客户端,构建代码如下: + + ```java + CountDownLatch countDownLatch = new CountDownLatch(1); + ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder() + .countDownLatch(countDownLatch) + .build(); + OpenAiClient client = OpenAiClient.builder() + .apiKey(System.getProperty("openai.token")) + .listener(listener) + .build(); + ``` + + `System.getProperty("openai.token")` 是访问 API 授权的关键。 + +### Create completion + +--- + +为提供的提示和参数创建补全。 + +```java +CompletionEntity configure = CompletionEntity.builder() + .model(CompleteModel.TEXT_DAVINCI_003.getName()) + .prompt("How to create a completion") + .temperature(2D) + .build(); +client.createCompletion(configure); +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index b4084a2..20e250a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -74,6 +74,8 @@ plugins: nav: - index.md - Reference: + - Stream (Not provider): + - reference/stream/console.md - Open Ai: - reference/openai/users.md - reference/openai/models.md diff --git a/pom.xml b/pom.xml index be5326b..13b196b 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.devlive.sdk openai-java-sdk - 1.7.0 + 1.8.0-SNAPSHOT openai-java-sdk @@ -103,6 +103,11 @@ okhttp ${okhttp.version} + + com.squareup.okhttp3 + okhttp-sse + ${okhttp.version} + com.google.guava guava diff --git a/src/main/java/org/devlive/sdk/openai/DefaultClient.java b/src/main/java/org/devlive/sdk/openai/DefaultClient.java index 11d2f07..132dd02 100644 --- a/src/main/java/org/devlive/sdk/openai/DefaultClient.java +++ b/src/main/java/org/devlive/sdk/openai/DefaultClient.java @@ -1,8 +1,14 @@ package org.devlive.sdk.openai; +import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; import okhttp3.MultipartBody; import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; import org.apache.commons.lang3.ObjectUtils; import org.devlive.sdk.openai.entity.AudioEntity; import org.devlive.sdk.openai.entity.ChatEntity; @@ -14,6 +20,8 @@ import org.devlive.sdk.openai.entity.ModelEntity; import org.devlive.sdk.openai.entity.ModerationEntity; import org.devlive.sdk.openai.entity.UserKeyEntity; +import org.devlive.sdk.openai.exception.RequestException; +import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin; import org.devlive.sdk.openai.model.ProviderModel; import org.devlive.sdk.openai.model.UrlModel; import org.devlive.sdk.openai.response.AudioResponse; @@ -36,6 +44,8 @@ public abstract class DefaultClient protected DefaultApi api; protected ProviderModel provider; protected OkHttpClient client; + protected String apiHost; + protected EventSourceListener listener; public ModelResponse getModels() { @@ -51,8 +61,16 @@ public ModelEntity getModel(String model) public CompleteResponse createCompletion(CompletionEntity configure) { - return this.api.fetchCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure) - .blockingGet(); + String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS); + if (ObjectUtils.isNotEmpty(this.listener)) { + configure.setStream(true); + this.createEventSource(url, configure); + return null; + } + else { + return this.api.fetchCompletions(url, configure) + .blockingGet(); + } } public ChatResponse createChatCompletion(ChatEntity configure) @@ -168,6 +186,29 @@ public Object retrieveFileContent(String id) .blockingGet(); } + private ObjectMapper createObjectMapper() + { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class); + return objectMapper; + } + + private void createEventSource(String url, Object configure) + { + try { + EventSource.Factory factory = EventSources.createFactory(this.client); + ObjectMapper mapper = this.createObjectMapper(); + Request request = new Request.Builder() + .url(String.join("/", this.apiHost, url)) + .post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure))) + .build(); + factory.newEventSource(request, this.listener); + } + catch (Exception e) { + throw new RequestException(String.format("Failed to create event source: %s", e.getMessage())); + } + } + public void close() { if (ObjectUtils.isNotEmpty(this.client)) { diff --git a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java index 60a8e99..08bf55e 100644 --- a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java +++ b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java @@ -5,6 +5,7 @@ import lombok.Builder; import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; +import okhttp3.sse.EventSourceListener; import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; import org.devlive.sdk.openai.exception.ParamException; @@ -35,6 +36,8 @@ public class OpenAiClient // Azure provider requires private String model; // The model name deployed in azure private String version; + // Support see + private EventSourceListener listener; private OpenAiClient(OpenAiClientBuilder builder) { @@ -69,9 +72,14 @@ private OpenAiClient(OpenAiClientBuilder builder) if (ObjectUtils.isEmpty(builder.client)) { builder.client(null); } + if (ObjectUtils.isEmpty(builder.listener)) { + builder.listener(null); + } super.provider = builder.provider; super.client = builder.client; + super.listener = builder.listener; + super.apiHost = builder.apiHost; // Build a remote API client objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); this.api = new Retrofit.Builder() @@ -160,6 +168,9 @@ public OpenAiClientBuilder client(OkHttpClient client) private String getDefaultHost() { + if (ObjectUtils.isEmpty(this.provider)) { + this.provider = ProviderModel.OPENAI; + } if (this.provider.equals(ProviderModel.CLAUDE)) { return "https://api.anthropic.com"; } diff --git a/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java b/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java index 33337c6..9feb89e 100644 --- a/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java +++ b/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java @@ -50,6 +50,13 @@ public class CompletionEntity @JsonProperty(value = "stop") private List stop; + /** + * Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message. + * 是否流回部分进度。如果设置,令牌将在可用时作为仅数据服务器发送事件发送,流由 data: [DONE] 消息终止。 + */ + @JsonProperty(value = "stream") + private boolean stream = false; + private CompletionEntity(CompletionEntityBuilder builder) { if (ObjectUtils.isEmpty(builder.model)) { @@ -151,6 +158,11 @@ public CompletionEntityBuilder presencePenalty(Double presencePenalty) return this; } + private CompletionEntityBuilder stream() + { + return this; + } + public CompletionEntity build() { return new CompletionEntity(this); diff --git a/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java new file mode 100644 index 0000000..50ff4f5 --- /dev/null +++ b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java @@ -0,0 +1,80 @@ +package org.devlive.sdk.openai.listener; + +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.Builder; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.apache.commons.lang3.ObjectUtils; +import org.devlive.sdk.openai.response.CompleteResponse; +import org.devlive.sdk.openai.utils.JsonUtils; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.time.LocalDateTime; +import java.util.concurrent.CountDownLatch; + +@Slf4j +@Builder +public class ConsoleEventSourceListener + extends EventSourceListener +{ + private CountDownLatch countDownLatch; + private JsonUtils jsonUtils; + + @Override + public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) + { + log.info("Console listener opened on time {}", LocalDateTime.now()); + this.jsonUtils = JsonUtils.getInstance(); + } + + @Override + public void onClosed(@NotNull EventSource eventSource) + { + log.info("Console listener closed on time {}", LocalDateTime.now()); + eventSource.cancel(); + this.close(); + } + + @Override + public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) + { + // OpenAI ends with [DONE] by default + if (data.equals("[DONE]")) { + eventSource.cancel(); + this.close(); + } + else { + try { + CompleteResponse completeResponse = jsonUtils.getObject(data, CompleteResponse.class); + log.info("Console event received on time {} id {} type {} data {}", LocalDateTime.now(), id, type, completeResponse.getChoices().get(0).getContent()); + } + catch (JsonProcessingException e) { + log.warn("Console event error on time {} id {} type {} data {}", LocalDateTime.now(), id, type, data, e); + } + } + } + + @Override + public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable throwable, @Nullable Response response) + { + if (throwable.getMessage().endsWith("CANCEL")) { + log.info("Console listener cancelled on time {}", LocalDateTime.now()); + this.onClosed(eventSource); + } + else { + log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response); + } + eventSource.cancel(); + this.close(); + } + + private void close() + { + if (ObjectUtils.isNotEmpty(this.countDownLatch)) { + this.countDownLatch.countDown(); + } + } +} diff --git a/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java b/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java new file mode 100644 index 0000000..1730c01 --- /dev/null +++ b/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java @@ -0,0 +1,8 @@ +package org.devlive.sdk.openai.mixin; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; + +@JsonIgnoreProperties(ignoreUnknown = true) +public abstract class IgnoreUnknownMixin +{ +} diff --git a/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java b/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java index 716a41e..80d150f 100644 --- a/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java +++ b/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java @@ -9,6 +9,7 @@ public class MultipartBodyUtils { public static final MediaType TYPE = MediaType.parse("multipart/form-data"); + public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8"); private MultipartBodyUtils() { diff --git a/src/test/java/org/devlive/sdk/openai/StreamClientTest.java b/src/test/java/org/devlive/sdk/openai/StreamClientTest.java new file mode 100644 index 0000000..8cdd1cd --- /dev/null +++ b/src/test/java/org/devlive/sdk/openai/StreamClientTest.java @@ -0,0 +1,45 @@ +package org.devlive.sdk.openai; + +import lombok.extern.slf4j.Slf4j; +import org.devlive.sdk.openai.entity.CompletionEntity; +import org.devlive.sdk.openai.listener.ConsoleEventSourceListener; +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; + +@Slf4j +public class StreamClientTest +{ + private OpenAiClient client; + private CountDownLatch countDownLatch; + + @Before + public void before() + { + countDownLatch = new CountDownLatch(1); + ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder() + .countDownLatch(countDownLatch) + .build(); + client = OpenAiClient.builder() + .apiKey(System.getProperty("openai.token")) + .listener(listener) + .build(); + } + + @Test + public void testCreateCompletion() + { + CompletionEntity configure = CompletionEntity.builder() + .prompt("How to create a stream completion") + .temperature(2D) + .build(); + client.createCompletion(configure); + try { + countDownLatch.await(); + } + catch (InterruptedException e) { + log.error("Interrupted while waiting", e); + } + } +} From ceefd294ffbc6429faac688ec3303cddf70ffa58 Mon Sep 17 00:00:00 2001 From: qianmoQ Date: Tue, 15 Aug 2023 00:35:59 +0800 Subject: [PATCH 2/2] [Stream] Fixed code --- .../listener/ConsoleEventSourceListener.java | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java index 50ff4f5..084202d 100644 --- a/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java +++ b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java @@ -9,8 +9,6 @@ import org.apache.commons.lang3.ObjectUtils; import org.devlive.sdk.openai.response.CompleteResponse; import org.devlive.sdk.openai.utils.JsonUtils; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; import java.time.LocalDateTime; import java.util.concurrent.CountDownLatch; @@ -24,14 +22,14 @@ public class ConsoleEventSourceListener private JsonUtils jsonUtils; @Override - public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) + public void onOpen(EventSource eventSource, Response response) { log.info("Console listener opened on time {}", LocalDateTime.now()); this.jsonUtils = JsonUtils.getInstance(); } @Override - public void onClosed(@NotNull EventSource eventSource) + public void onClosed(EventSource eventSource) { log.info("Console listener closed on time {}", LocalDateTime.now()); eventSource.cancel(); @@ -39,7 +37,7 @@ public void onClosed(@NotNull EventSource eventSource) } @Override - public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) + public void onEvent(EventSource eventSource, String id, String type, String data) { // OpenAI ends with [DONE] by default if (data.equals("[DONE]")) { @@ -58,14 +56,19 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null } @Override - public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable throwable, @Nullable Response response) + public void onFailure(EventSource eventSource, Throwable throwable, Response response) { - if (throwable.getMessage().endsWith("CANCEL")) { - log.info("Console listener cancelled on time {}", LocalDateTime.now()); - this.onClosed(eventSource); + if (ObjectUtils.isNotEmpty(throwable)) { + if (throwable.getMessage().endsWith("CANCEL")) { + log.info("Console listener cancelled on time {}", LocalDateTime.now()); + this.onClosed(eventSource); + } + else { + log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response); + } } else { - log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response); + log.error("Console listener failure with empty throwable. Response: \n{}\n", response); } eventSource.cancel(); this.close();