Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stream] Support completions #28

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/docs/reference/stream/console.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
title: Console
---

!!! Note

Please build the client before calling, the build code is as follows:

```java
CountDownLatch countDownLatch = new CountDownLatch(1);
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
.countDownLatch(countDownLatch)
.build();
OpenAiClient client = OpenAiClient.builder()
.apiKey(System.getProperty("openai.token"))
.listener(listener)
.build();
```

`System.getProperty("openai.token")` is the key to access the API authorization.

### Create completion

---

Creates a completion for the provided prompt and parameters.

```java
CompletionEntity configure = CompletionEntity.builder()
.model(CompleteModel.TEXT_DAVINCI_003.getName())
.prompt("How to create a completion")
.temperature(2D)
.build();
client.createCompletion(configure);
```
35 changes: 35 additions & 0 deletions docs/docs/reference/stream/console.zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
---
title: Console
---

!!! Note

调用前请先构建客户端,构建代码如下:

```java
CountDownLatch countDownLatch = new CountDownLatch(1);
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
.countDownLatch(countDownLatch)
.build();
OpenAiClient client = OpenAiClient.builder()
.apiKey(System.getProperty("openai.token"))
.listener(listener)
.build();
```

`System.getProperty("openai.token")` 是访问 API 授权的关键。

### Create completion

---

为提供的提示和参数创建补全。

```java
CompletionEntity configure = CompletionEntity.builder()
.model(CompleteModel.TEXT_DAVINCI_003.getName())
.prompt("How to create a completion")
.temperature(2D)
.build();
client.createCompletion(configure);
```
2 changes: 2 additions & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ plugins:
nav:
- index.md
- Reference:
- Stream (Not provider):
- reference/stream/console.md
- Open Ai:
- reference/openai/users.md
- reference/openai/models.md
Expand Down
7 changes: 6 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.devlive.sdk</groupId>
<artifactId>openai-java-sdk</artifactId>
<version>1.7.0</version>
<version>1.8.0-SNAPSHOT</version>

<name>openai-java-sdk</name>
<description>
Expand Down Expand Up @@ -103,6 +103,11 @@
<artifactId>okhttp</artifactId>
<version>${okhttp.version}</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>${okhttp.version}</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down
45 changes: 43 additions & 2 deletions src/main/java/org/devlive/sdk/openai/DefaultClient.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package org.devlive.sdk.openai;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.lang3.ObjectUtils;
import org.devlive.sdk.openai.entity.AudioEntity;
import org.devlive.sdk.openai.entity.ChatEntity;
Expand All @@ -14,6 +20,8 @@
import org.devlive.sdk.openai.entity.ModelEntity;
import org.devlive.sdk.openai.entity.ModerationEntity;
import org.devlive.sdk.openai.entity.UserKeyEntity;
import org.devlive.sdk.openai.exception.RequestException;
import org.devlive.sdk.openai.mixin.IgnoreUnknownMixin;
import org.devlive.sdk.openai.model.ProviderModel;
import org.devlive.sdk.openai.model.UrlModel;
import org.devlive.sdk.openai.response.AudioResponse;
Expand All @@ -36,6 +44,8 @@ public abstract class DefaultClient
protected DefaultApi api;
protected ProviderModel provider;
protected OkHttpClient client;
protected String apiHost;
protected EventSourceListener listener;

public ModelResponse getModels()
{
Expand All @@ -51,8 +61,16 @@ public ModelEntity getModel(String model)

public CompleteResponse createCompletion(CompletionEntity configure)
{
return this.api.fetchCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
.blockingGet();
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS);
if (ObjectUtils.isNotEmpty(this.listener)) {
configure.setStream(true);
this.createEventSource(url, configure);
return null;
}
else {
return this.api.fetchCompletions(url, configure)
.blockingGet();
}
}

public ChatResponse createChatCompletion(ChatEntity configure)
Expand Down Expand Up @@ -168,6 +186,29 @@ public Object retrieveFileContent(String id)
.blockingGet();
}

private ObjectMapper createObjectMapper()
{
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.addMixIn(Object.class, IgnoreUnknownMixin.class);
return objectMapper;
}

private void createEventSource(String url, Object configure)
{
try {
EventSource.Factory factory = EventSources.createFactory(this.client);
ObjectMapper mapper = this.createObjectMapper();
Request request = new Request.Builder()
.url(String.join("/", this.apiHost, url))
.post(RequestBody.create(MultipartBodyUtils.JSON, mapper.writeValueAsString(configure)))
.build();
factory.newEventSource(request, this.listener);
}
catch (Exception e) {
throw new RequestException(String.format("Failed to create event source: %s", e.getMessage()));
}
}

public void close()
{
if (ObjectUtils.isNotEmpty(this.client)) {
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/devlive/sdk/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.devlive.sdk.openai.exception.ParamException;
Expand Down Expand Up @@ -35,6 +36,8 @@ public class OpenAiClient
// Azure provider requires
private String model; // The model name deployed in azure
private String version;
// Support see
private EventSourceListener listener;

private OpenAiClient(OpenAiClientBuilder builder)
{
Expand Down Expand Up @@ -69,9 +72,14 @@ private OpenAiClient(OpenAiClientBuilder builder)
if (ObjectUtils.isEmpty(builder.client)) {
builder.client(null);
}
if (ObjectUtils.isEmpty(builder.listener)) {
builder.listener(null);
}

super.provider = builder.provider;
super.client = builder.client;
super.listener = builder.listener;
super.apiHost = builder.apiHost;
// Build a remote API client
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
this.api = new Retrofit.Builder()
Expand Down Expand Up @@ -160,6 +168,9 @@ public OpenAiClientBuilder client(OkHttpClient client)

private String getDefaultHost()
{
if (ObjectUtils.isEmpty(this.provider)) {
this.provider = ProviderModel.OPENAI;
}
if (this.provider.equals(ProviderModel.CLAUDE)) {
return "https://api.anthropic.com";
}
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/devlive/sdk/openai/entity/CompletionEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ public class CompletionEntity
@JsonProperty(value = "stop")
private List<String> stop;

/**
* Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent events as they become available, with the stream terminated by a data: [DONE] message.
* 是否流回部分进度。如果设置,令牌将在可用时作为仅数据服务器发送事件发送,流由 data: [DONE] 消息终止。
*/
@JsonProperty(value = "stream")
private boolean stream = false;

private CompletionEntity(CompletionEntityBuilder builder)
{
if (ObjectUtils.isEmpty(builder.model)) {
Expand Down Expand Up @@ -151,6 +158,11 @@ public CompletionEntityBuilder presencePenalty(Double presencePenalty)
return this;
}

private CompletionEntityBuilder stream()
{
return this;
}

public CompletionEntity build()
{
return new CompletionEntity(this);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.devlive.sdk.openai.listener;

import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.apache.commons.lang3.ObjectUtils;
import org.devlive.sdk.openai.response.CompleteResponse;
import org.devlive.sdk.openai.utils.JsonUtils;

import java.time.LocalDateTime;
import java.util.concurrent.CountDownLatch;

@Slf4j
@Builder
public class ConsoleEventSourceListener
extends EventSourceListener
{
private CountDownLatch countDownLatch;
private JsonUtils<CompleteResponse> jsonUtils;

@Override
public void onOpen(EventSource eventSource, Response response)
{
log.info("Console listener opened on time {}", LocalDateTime.now());
this.jsonUtils = JsonUtils.getInstance();
}

@Override
public void onClosed(EventSource eventSource)
{
log.info("Console listener closed on time {}", LocalDateTime.now());
eventSource.cancel();
this.close();
}

@Override
public void onEvent(EventSource eventSource, String id, String type, String data)
{
// OpenAI ends with [DONE] by default
if (data.equals("[DONE]")) {
eventSource.cancel();
this.close();
}
else {
try {
CompleteResponse completeResponse = jsonUtils.getObject(data, CompleteResponse.class);
log.info("Console event received on time {} id {} type {} data {}", LocalDateTime.now(), id, type, completeResponse.getChoices().get(0).getContent());
}
catch (JsonProcessingException e) {
log.warn("Console event error on time {} id {} type {} data {}", LocalDateTime.now(), id, type, data, e);
}
}
}

@Override
public void onFailure(EventSource eventSource, Throwable throwable, Response response)
{
if (ObjectUtils.isNotEmpty(throwable)) {
if (throwable.getMessage().endsWith("CANCEL")) {
log.info("Console listener cancelled on time {}", LocalDateTime.now());
this.onClosed(eventSource);
}
else {
log.error("Console listener throwable \n{}\n response: \n{}\n", throwable, response);
}
}
else {
log.error("Console listener failure with empty throwable. Response: \n{}\n", response);
}
eventSource.cancel();
this.close();
}

private void close()
{
if (ObjectUtils.isNotEmpty(this.countDownLatch)) {
this.countDownLatch.countDown();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.devlive.sdk.openai.mixin;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(ignoreUnknown = true)
public abstract class IgnoreUnknownMixin
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
public class MultipartBodyUtils
{
public static final MediaType TYPE = MediaType.parse("multipart/form-data");
public static final MediaType JSON = MediaType.parse("application/json; charset=utf-8");

private MultipartBodyUtils()
{
Expand Down
45 changes: 45 additions & 0 deletions src/test/java/org/devlive/sdk/openai/StreamClientTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.devlive.sdk.openai;

import lombok.extern.slf4j.Slf4j;
import org.devlive.sdk.openai.entity.CompletionEntity;
import org.devlive.sdk.openai.listener.ConsoleEventSourceListener;
import org.junit.Before;
import org.junit.Test;

import java.util.concurrent.CountDownLatch;

@Slf4j
public class StreamClientTest
{
private OpenAiClient client;
private CountDownLatch countDownLatch;

@Before
public void before()
{
countDownLatch = new CountDownLatch(1);
ConsoleEventSourceListener listener = ConsoleEventSourceListener.builder()
.countDownLatch(countDownLatch)
.build();
client = OpenAiClient.builder()
.apiKey(System.getProperty("openai.token"))
.listener(listener)
.build();
}

@Test
public void testCreateCompletion()
{
CompletionEntity configure = CompletionEntity.builder()
.prompt("How to create a stream completion")
.temperature(2D)
.build();
client.createCompletion(configure);
try {
countDownLatch.await();
}
catch (InterruptedException e) {
log.error("Interrupted while waiting", e);
}
}
}