diff --git a/docs/docs/reference/stream/console.md b/docs/docs/reference/stream/console.md
new file mode 100644
index 0000000..10f6110
--- /dev/null
+++ b/docs/docs/reference/stream/console.md
@@ -0,0 +1,35 @@
+---
+title: Console
+---
+
+!!! Note
+
+ Please build the client before calling, the build code is as follows:
+
+ ```java
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+ ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
+ .countDownLatch(countDownLatch)
+ .build();
+ OpenAiClient client = OpenAiClient.builder()
+ .apiKey(System.getProperty("openai.token"))
+ .listener(listener)
+ .build();
+ ```
+
+ `System.getProperty("openai.token")` is the key to access the API authorization.
+
+### Create completion
+
+---
+
+Creates a completion for the provided prompt and parameters.
+
+```java
+CompletionEntity configure = CompletionEntity.builder()
+ .model(CompleteModel.TEXT_DAVINCI_003.getName())
+ .prompt("How to create a completion")
+ .temperature(2D)
+ .build();
+client.createCompletion(configure);
+```
diff --git a/docs/docs/reference/stream/console.zh.md b/docs/docs/reference/stream/console.zh.md
new file mode 100644
index 0000000..1984473
--- /dev/null
+++ b/docs/docs/reference/stream/console.zh.md
@@ -0,0 +1,35 @@
+---
+title: Console
+---
+
+!!! Note
+
+ 调用前请先构建客户端,构建代码如下:
+
+ ```java
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+ ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
+ .countDownLatch(countDownLatch)
+ .build();
+ OpenAiClient client = OpenAiClient.builder()
+ .apiKey(System.getProperty("openai.token"))
+ .listener(listener)
+ .build();
+ ```
+
+ `System.getProperty("openai.token")` 是访问 API 授权的关键。
+
+### Create completion
+
+---
+
+为提供的提示和参数创建补全。
+
+```java
+CompletionEntity configure = CompletionEntity.builder()
+ .model(CompleteModel.TEXT_DAVINCI_003.getName())
+ .prompt("How to create a completion")
+ .temperature(2D)
+ .build();
+client.createCompletion(configure);
+```
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index b4084a2..20e250a 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -74,6 +74,8 @@ plugins:
nav:
- index.md
- Reference:
+ - Stream (Not provider):
+ - reference/stream/console.md
- Open Ai:
- reference/openai/users.md
- reference/openai/models.md
diff --git a/pom.xml b/pom.xml
index be5326b..13b196b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
org.devlive.sdk
openai-java-sdk
- 1.7.0
+ 1.8.0-SNAPSHOT
openai-java-sdk
@@ -103,6 +103,11 @@
okhttp
${okhttp.version}
+
+ com.squareup.okhttp3
+ okhttp-sse
+ ${okhttp.version}
+
com.google.guava
guava
diff --git a/src/main/java/org/devlive/sdk/openai/DefaultClient.java b/src/main/java/org/devlive/sdk/openai/DefaultClient.java
index 11d2f07..132dd02 100644
--- a/src/main/java/org/devlive/sdk/openai/DefaultClient.java
+++ b/src/main/java/org/devlive/sdk/openai/DefaultClient.java
@@ -1,8 +1,14 @@
package org.devlive.sdk.openai;
+import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.RequestBody;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import okhttp3.sse.EventSources;
import org.apache.commons.lang3.ObjectUtils;
import org.devlive.sdk.openai.entity.AudioEntity;
import org.devlive.sdk.openai.entity.ChatEntity;
@@ -14,6 +20,8 @@
import org.devlive.sdk.openai.entity.ModelEntity;
import org.devlive.sdk.openai.entity.ModerationEntity;
import org.devlive.sdk.openai.entity.UserKeyEntity;
+import org.devlive.sdk.openai.exception.RequestException;
+import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
import org.devlive.sdk.openai.model.ProviderModel;
import org.devlive.sdk.openai.model.UrlModel;
import org.devlive.sdk.openai.response.AudioResponse;
@@ -36,6 +44,8 @@ public abstract class DefaultClient
protected DefaultApi api;
protected ProviderModel provider;
protected OkHttpClient client;
+ protected String apiHost;
+ protected EventSourceListener listener;
public ModelResponse getModels()
{
@@ -51,8 +61,16 @@ public ModelEntity getModel(String model)
public CompleteResponse createCompletion(CompletionEntity configure)
{
- return this.api.fetchCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
- .blockingGet();
+ String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS);
+ if (ObjectUtils.isNotEmpty(this.listener)) {
+ configure.setStream(true);
+ this.createEventSource(url, configure);
+ return null;
+ }
+ else {
+ return this.api.fetchCompletions(url, configure)
+ .blockingGet();
+ }
}
public ChatResponse createChatCompletion(ChatEntity configure)
@@ -168,6 +186,29 @@ public Object retrieveFileContent(String id)
.blockingGet();
}
+ private ObjectMapper createObjectMapper()
+ {
+ ObjectMapper objectMapper = new ObjectMapper();
+ objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class);
+ return objectMapper;
+ }
+
+ private void createEventSource(String url, Object configure)
+ {
+ try {
+ EventSource.Factory factory = EventSources.createFactory(this.client);
+ ObjectMapper mapper = this.createObjectMapper();
+ Request request = new Request.Builder()
+ .url(String.join("/", this.apiHost, url))
+ .post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure)))
+ .build();
+ factory.newEventSource(request, this.listener);
+ }
+ catch (Exception e) {
+ throw new RequestException(String.format("Failed to create event source: %s", e.getMessage()));
+ }
+ }
+
public void close()
{
if (ObjectUtils.isNotEmpty(this.client)) {
diff --git a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java
index 60a8e99..08bf55e 100644
--- a/src/main/java/org/devlive/sdk/openai/OpenAiClient.java
+++ b/src/main/java/org/devlive/sdk/openai/OpenAiClient.java
@@ -5,6 +5,7 @@
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
+import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.devlive.sdk.openai.exception.ParamException;
@@ -35,6 +36,8 @@ public class OpenAiClient
// Azure provider requires
private String model; // The model name deployed in azure
private String version;
+ // Support see
+ private EventSourceListener listener;
private OpenAiClient(OpenAiClientBuilder builder)
{
@@ -69,9 +72,14 @@ private OpenAiClient(OpenAiClientBuilder builder)
if (ObjectUtils.isEmpty(builder.client)) {
builder.client(null);
}
+ if (ObjectUtils.isEmpty(builder.listener)) {
+ builder.listener(null);
+ }
super.provider = builder.provider;
super.client = builder.client;
+ super.listener = builder.listener;
+ super.apiHost = builder.apiHost;
// Build a remote API client
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
this.api = new Retrofit.Builder()
@@ -160,6 +168,9 @@ public OpenAiClientBuilder client(OkHttpClient client)
private String getDefaultHost()
{
+ if (ObjectUtils.isEmpty(this.provider)) {
+ this.provider = ProviderModel.OPENAI;
+ }
if (this.provider.equals(ProviderModel.CLAUDE)) {
return "https://api.anthropic.com";
}
diff --git a/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java b/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java
index 33337c6..9feb89e 100644
--- a/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java
+++ b/src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java
@@ -50,6 +50,13 @@ public class CompletionEntity
@JsonProperty(value = "stop")
private List stop;
+ /**
+ * Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
+ * 是否流回部分进度。如果设置,令牌将在可用时作为仅数据服务器发送事件发送,流由 data: [DONE] 消息终止。
+ */
+ @JsonProperty(value = "stream")
+ private boolean stream = false;
+
private CompletionEntity(CompletionEntityBuilder builder)
{
if (ObjectUtils.isEmpty(builder.model)) {
@@ -151,6 +158,11 @@ public CompletionEntityBuilder presencePenalty(Double presencePenalty)
return this;
}
+ private CompletionEntityBuilder stream()
+ {
+ return this;
+ }
+
public CompletionEntity build()
{
return new CompletionEntity(this);
diff --git a/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java
new file mode 100644
index 0000000..084202d
--- /dev/null
+++ b/src/main/java/org/devlive/sdk/openai/listener/ConsoleEventSourceListener.java
@@ -0,0 +1,83 @@
+package org.devlive.sdk.openai.listener;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import lombok.Builder;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.Response;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.apache.commons.lang3.ObjectUtils;
+import org.devlive.sdk.openai.response.CompleteResponse;
+import org.devlive.sdk.openai.utils.JsonUtils;
+
+import java.time.LocalDateTime;
+import java.util.concurrent.CountDownLatch;
+
+@Slf4j
+@Builder
+public class ConsoleEventSourceListener
+ extends EventSourceListener
+{
+ private CountDownLatch countDownLatch;
+ private JsonUtils jsonUtils;
+
+ @Override
+ public void onOpen(EventSource eventSource, Response response)
+ {
+ log.info("Console listener opened on time {}", LocalDateTime.now());
+ this.jsonUtils = JsonUtils.getInstance();
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource)
+ {
+ log.info("Console listener closed on time {}", LocalDateTime.now());
+ eventSource.cancel();
+ this.close();
+ }
+
+ @Override
+ public void onEvent(EventSource eventSource, String id, String type, String data)
+ {
+ // OpenAI ends with [DONE] by default
+ if (data.equals("[DONE]")) {
+ eventSource.cancel();
+ this.close();
+ }
+ else {
+ try {
+ CompleteResponse completeResponse = jsonUtils.getObject(data, CompleteResponse.class);
+ log.info("Console event received on time {} id {} type {} data {}", LocalDateTime.now(), id, type, completeResponse.getChoices().get(0).getContent());
+ }
+ catch (JsonProcessingException e) {
+ log.warn("Console event error on time {} id {} type {} data {}", LocalDateTime.now(), id, type, data, e);
+ }
+ }
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, Throwable throwable, Response response)
+ {
+ if (ObjectUtils.isNotEmpty(throwable)) {
+ if (throwable.getMessage().endsWith("CANCEL")) {
+ log.info("Console listener cancelled on time {}", LocalDateTime.now());
+ this.onClosed(eventSource);
+ }
+ else {
+ log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response);
+ }
+ }
+ else {
+ log.error("Console listener failure with empty throwable. Response: \n{}\n", response);
+ }
+ eventSource.cancel();
+ this.close();
+ }
+
+ private void close()
+ {
+ if (ObjectUtils.isNotEmpty(this.countDownLatch)) {
+ this.countDownLatch.countDown();
+ }
+ }
+}
diff --git a/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java b/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java
new file mode 100644
index 0000000..1730c01
--- /dev/null
+++ b/src/main/java/org/devlive/sdk/openai/mixin/IgnoreUnknownMixin.java
@@ -0,0 +1,8 @@
+package org.devlive.sdk.openai.mixin;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public abstract class IgnoreUnknownMixin
+{
+}
diff --git a/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java b/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java
index 716a41e..80d150f 100644
--- a/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java
+++ b/src/main/java/org/devlive/sdk/openai/utils/MultipartBodyUtils.java
@@ -9,6 +9,7 @@
public class MultipartBodyUtils
{
public static final MediaType TYPE = MediaType.parse("multipart/form-data");
+ public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8");
private MultipartBodyUtils()
{
diff --git a/src/test/java/org/devlive/sdk/openai/StreamClientTest.java b/src/test/java/org/devlive/sdk/openai/StreamClientTest.java
new file mode 100644
index 0000000..8cdd1cd
--- /dev/null
+++ b/src/test/java/org/devlive/sdk/openai/StreamClientTest.java
@@ -0,0 +1,45 @@
+package org.devlive.sdk.openai;
+
+import lombok.extern.slf4j.Slf4j;
+import org.devlive.sdk.openai.entity.CompletionEntity;
+import org.devlive.sdk.openai.listener.ConsoleEventSourceListener;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.concurrent.CountDownLatch;
+
+@Slf4j
+public class StreamClientTest
+{
+ private OpenAiClient client;
+ private CountDownLatch countDownLatch;
+
+ @Before
+ public void before()
+ {
+ countDownLatch = new CountDownLatch(1);
+ ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
+ .countDownLatch(countDownLatch)
+ .build();
+ client = OpenAiClient.builder()
+ .apiKey(System.getProperty("openai.token"))
+ .listener(listener)
+ .build();
+ }
+
+ @Test
+ public void testCreateCompletion()
+ {
+ CompletionEntity configure = CompletionEntity.builder()
+ .prompt("How to create a stream completion")
+ .temperature(2D)
+ .build();
+ client.createCompletion(configure);
+ try {
+ countDownLatch.await();
+ }
+ catch (InterruptedException e) {
+ log.error("Interrupted while waiting", e);
+ }
+ }
+}