Skip to content

Commit

Permalink
[Google] [PaLM] Support text prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
qianmoQ committed Aug 21, 2023
1 parent a7ba02a commit b8c6fa7
Show file tree
Hide file tree
Showing 17 changed files with 329 additions and 4 deletions.
32 changes: 32 additions & 0 deletions docs/docs/reference/google_palm/completions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
title: Completions
---

!!! note

Support the google palm, product address: [https://developers.generativeai.google/products/palm](https://developers.generativeai.google/products/palm)

### Create completion

---

Creates a completion for the provided prompt and parameters.

```java
// Automatic resource release
try(OpenAiClient client=OpenAiClient.builder()
.provider(ProviderModel.GOOGLE_PALM)
.model(CompletionModel.TEXT_BISON_001)
.apiKey(System.getProperty("google.token"))
.build())
{
PromptEntity prompt = PromptEntity.builder()
.text("How to create a completion")
.build();
CompletionEntity configure = CompletionEntity.builder()
.prompt(prompt)
.build();
client.createPaLMCompletion(configure).getCandidates();
}
```

30 changes: 30 additions & 0 deletions docs/docs/reference/google_palm/completions.zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
title: Completions
---

!!! note

支持 google palm,产品地址: [https://developers.generativeai.google/products/palm](https://developers.generativeai.google/products/palm)

### Create completion

---

为提供的提示和参数创建补全。

```java
try(OpenAiClient client=OpenAiClient.builder()
.provider(ProviderModel.GOOGLE_PALM)
.model(CompletionModel.TEXT_BISON_001)
.apiKey(System.getProperty("google.token"))
.build())
{
PromptEntity prompt = PromptEntity.builder()
.text("How to create a completion")
.build();
CompletionEntity configure = CompletionEntity.builder()
.prompt(prompt)
.build();
client.createPaLMCompletion(configure).getCandidates();
}
```
2 changes: 2 additions & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,7 @@ nav:
- reference/azure/completions_chat.md
- Anthropic Claude:
- reference/anthropic/completions.md
- Google LaPM:
- reference/google_palm/completions.md
- released.md
- powered_by.md
4 changes: 4 additions & 0 deletions src/main/java/org/devlive/sdk/openai/DefaultApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ public interface DefaultApi
Single<CompleteResponse> fetchCompletions(@Url String url,
@Body CompletionEntity configure);

@POST
Single<CompleteResponse> fetchPaLMCompletions(@Url String url,
@Body org.devlive.sdk.openai.entity.google.CompletionEntity configure);

/**
* Creates a model response for the given chat conversation.
*/
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/devlive/sdk/openai/DefaultClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ public CompleteResponse createCompletion(CompletionEntity configure)
}
}

public CompleteResponse createPaLMCompletion(org.devlive.sdk.openai.entity.google.CompletionEntity configure)
{
return this.api.fetchPaLMCompletions(ProviderUtils.getUrl(provider, UrlModel.FETCH_COMPLETIONS), configure)
.blockingGet();
}

public ChatResponse createChatCompletion(ChatEntity configure)
{
String url = ProviderUtils.getUrl(provider, UrlModel.FETCH_CHAT_COMPLETIONS);
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/org/devlive/sdk/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import org.devlive.sdk.openai.interceptor.AzureInterceptor;
import org.devlive.sdk.openai.interceptor.ClaudeInterceptor;
import org.devlive.sdk.openai.interceptor.DefaultInterceptor;
import org.devlive.sdk.openai.interceptor.GooglePaLMInterceptor;
import org.devlive.sdk.openai.interceptor.OpenAiInterceptor;
import org.devlive.sdk.openai.model.CompletionModel;
import org.devlive.sdk.openai.model.ProviderModel;
import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
Expand Down Expand Up @@ -158,6 +160,12 @@ public OpenAiClientBuilder client(OkHttpClient client)
if (this.provider.equals(ProviderModel.CLAUDE)) {
interceptor = new ClaudeInterceptor();
}
// Google PaLM
if (this.provider.equals(ProviderModel.GOOGLE_PALM)) {
interceptor = new GooglePaLMInterceptor();
interceptor.setApiKey(this.apiKey);
interceptor.setModel(this.model);
}
interceptor.setApiKey(apiKey);
client = client.newBuilder()
.addInterceptor(interceptor)
Expand All @@ -166,6 +174,12 @@ public OpenAiClientBuilder client(OkHttpClient client)
return this;
}

public OpenAiClientBuilder model(CompletionModel model)
{
this.model = model.getName();
return this;
}

private String getDefaultHost()
{
if (ObjectUtils.isEmpty(this.provider)) {
Expand All @@ -174,6 +188,9 @@ private String getDefaultHost()
if (this.provider.equals(ProviderModel.CLAUDE)) {
return "https://api.anthropic.com";
}
if (this.provider.equals(ProviderModel.GOOGLE_PALM)) {
return "https://generativelanguage.googleapis.com";
}
return "https://api.openai.com";
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.devlive.sdk.openai.entity.google;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;

import java.util.List;

@Data
@Builder
@ToString
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class CompletionEntity
{
@JsonProperty(value = "prompt")
private PromptEntity prompt;

@JsonProperty(value = "temperature")
@Builder.Default
private Double temperature = 0.25;

@JsonProperty(value = "top_k")
@Builder.Default
private Integer topK = 40;

@JsonProperty(value = "top_p")
@Builder.Default
private Double topP = 1.0;

@JsonProperty(value = "candidate_count")
@Builder.Default
private Integer candidateCount = 1;

@JsonProperty(value = "max_output_tokens")
@Builder.Default
private Integer maxOutputTokens = 1024;

@JsonProperty(value = "stop_sequences")
private List<String> stop;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.devlive.sdk.openai.entity.google;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;

@Data
@Builder
@ToString
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class PromptEntity
{
@JsonProperty(value = "text")
private String text;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.devlive.sdk.openai.interceptor;

import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import okhttp3.HttpUrl;
import okhttp3.Request;
import org.apache.commons.lang3.StringUtils;
import org.devlive.sdk.openai.exception.ParamException;
import org.devlive.sdk.openai.utils.HttpUrlUtils;

import java.util.List;

@Slf4j
public class GooglePaLMInterceptor
extends DefaultInterceptor
{
public GooglePaLMInterceptor()
{
log.debug("Google PaLM Interceptor");
}

@Override
protected Request prepared(Request original)
{
if (StringUtils.isEmpty(this.getApiKey())) {
throw new ParamException("Invalid Google PaLM token, must be non-empty");
}
HttpUrl httpUrl = original.url();
List<String> pathSegments = Lists.newArrayList();
httpUrl = HttpUrlUtils.removePathSegment(httpUrl);
// https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key=YOUR_KEY
pathSegments.add(0, String.join(":", this.getModel(), "generateText"));
pathSegments.add(0, "models");
pathSegments.add(0, "v1beta2");
httpUrl = httpUrl.newBuilder()
.host(httpUrl.host())
.port(httpUrl.port())
.addPathSegments(String.join("/", pathSegments))
.addQueryParameter("key", this.getApiKey())
.build();
log.debug("Google PaLM interceptor request url {}", httpUrl);
return original.newBuilder()
.header("Content-Type", "application/json")
.url(httpUrl)
.method(original.method(), original.body())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ public enum CompletionModel
CLAUDE_INSTANT_1("claude-instant-1",
null,
null,
Integer.MAX_VALUE);
Integer.MAX_VALUE),
/* =============================== Google PaLM ================================ */
TEXT_BISON_001("text-bison-001", null, null, Integer.MAX_VALUE);

private final String name;
private final String description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,10 @@ public enum ProviderModel
* A next-generation AI assistant for your tasks, no matter the scale
* https://www.anthropic.com/product
*/
CLAUDE
CLAUDE,
/**
* Google PaLM
* https://makersuite.google.com
*/
GOOGLE_PALM
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.devlive.sdk.openai.response;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;

import java.util.List;

@Data
@ToString
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class CandidateResponse
{
@JsonProperty(value = "output")
private String output;

@JsonProperty(value = "safetyRatings")
private List<SafetyResponse> safetyRatings;
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,10 @@ public class CompleteResponse

@JsonProperty(value = "usage")
private UsageEntity usage;

/**
* Google PaLM
*/
@JsonProperty(value = "candidates")
private List<CandidateResponse> candidates;
}
22 changes: 22 additions & 0 deletions src/main/java/org/devlive/sdk/openai/response/SafetyResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.devlive.sdk.openai.response;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;

@Data
@ToString
@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class SafetyResponse
{
@JsonProperty(value = "category")
private String category;

@JsonProperty(value = "probability")
private String probability;
}
28 changes: 28 additions & 0 deletions src/main/java/org/devlive/sdk/openai/utils/HttpUrlUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.devlive.sdk.openai.utils;

import okhttp3.HttpUrl;

import java.util.List;

public class HttpUrlUtils
{
private HttpUrlUtils()
{}

/**
* Removes all path segments from the given HttpUrl.
*
* @param httpUrl the HttpUrl from which to remove path segments
* @return the modified HttpUrl with all path segments removed
*/
public static HttpUrl removePathSegment(HttpUrl httpUrl)
{
List<String> pathSegments = httpUrl.pathSegments();
for (int i = 0; i < pathSegments.size(); i++) {
httpUrl = httpUrl.newBuilder()
.removePathSegment(0)
.build();
}
return httpUrl;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public void before()
.apiHost("https://eus-chatgpt.openai.azure.com")
.apiKey(System.getProperty("azure.token"))
.provider(ProviderModel.AZURE)
.model("text-davinci-002")
.model(CompletionModel.TEXT_DAVINCI_002)
.version("2022-12-01")
.build();
}
Expand Down Expand Up @@ -54,7 +54,7 @@ public void testCreateChatCompletion()
.apiHost("https://eus-chatgpt.openai.azure.com")
.apiKey(System.getProperty("azure.token"))
.provider(ProviderModel.AZURE)
.model("gpt-35-turbo-0613")
.model(CompletionModel.GPT_4_32K_0613)
.version("2023-03-15-preview")
.build();

Expand Down
Loading

0 comments on commit b8c6fa7

Please sign in to comment.