Skip to content

Commit

Permalink
feat: support customize model name #50 (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
qianmoQ authored Dec 20, 2024
2 parents db0061b + d946d39 commit 8c2876f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/devlive/sdk/common/DefaultClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ public AssistantsEntity createAssistants(AssistantsEntity configure)
}

public AssistantsFileEntity createAssistantsFile(String fileId,
String assistantId)
String assistantId)
{
String url = String.format(ProviderUtils.getUrl(provider, UrlModel.FETCH_ASSISTANTS_FILES), assistantId);
Map<String, String> configure = Maps.newHashMap();
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/devlive/sdk/openai/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ public OpenAiClientBuilder model(CompletionModel model)
return this;
}

public OpenAiClientBuilder model(String model)
{
this.model = model;
return this;
}

private String getDefaultHost()
{
if (ObjectUtils.isEmpty(this.provider)) {
Expand Down
30 changes: 24 additions & 6 deletions src/main/java/org/devlive/sdk/openai/entity/ChatEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import org.devlive.sdk.openai.utils.EnumsUtils;

import java.util.List;
import java.util.Objects;

import static org.devlive.sdk.openai.model.CompletionModel.GPT_35_TURBO;

@Data
@Builder
Expand Down Expand Up @@ -47,7 +50,7 @@ public class ChatEntity
private ChatEntity(ChatEntityBuilder builder)
{
if (ObjectUtils.isEmpty(builder.model)) {
builder.model(CompletionModel.GPT_35_TURBO);
builder.model(GPT_35_TURBO);
}
this.model = builder.model;
this.messages = builder.messages;
Expand All @@ -73,7 +76,7 @@ public static class ChatEntityBuilder
public ChatEntityBuilder model(CompletionModel model)
{
if (ObjectUtils.isEmpty(model)) {
model = CompletionModel.GPT_35_TURBO;
model = GPT_35_TURBO;
}
switch (model) {
case GPT_35_TURBO:
Expand All @@ -96,6 +99,12 @@ public ChatEntityBuilder model(CompletionModel model)
return this;
}

public ChatEntityBuilder model(String model)
{
this.model = model;
return this;
}

public ChatEntityBuilder temperature(Double temperature)
{
if (temperature < 0 || temperature > 2) {
Expand All @@ -108,11 +117,20 @@ public ChatEntityBuilder temperature(Double temperature)
public ChatEntityBuilder maxTokens(Integer maxTokens)
{
CompletionModel completionModel = EnumsUtils.getCompleteModel(this.model);
if (ObjectUtils.isNotEmpty(this.model) && maxTokens > completionModel.getMaxTokens()) {
throw new ParamException(String.format("Invalid maxTokens: %s, Cannot be larger than the model default configuration %s", maxTokens, completionModel.getMaxTokens()));
if (Objects.isNull(completionModel)) {
this.maxTokens = maxTokens;
return this;
}
else {
if (ObjectUtils.isNotEmpty(this.model)
&& maxTokens > completionModel.getMaxTokens()) {
throw new ParamException(String.format(
"Invalid maxTokens: %s, Cannot be larger than the model default configuration %s",
maxTokens, completionModel.getMaxTokens()));
}
this.maxTokens = maxTokens;
return this;
}
this.maxTokens = maxTokens;
return this;
}

private ChatEntityBuilder stream()
Expand Down
25 changes: 25 additions & 0 deletions src/test/java/org/devlive/sdk/openai/OpenAiClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.devlive.sdk.common.exception.RequestException;
import org.devlive.sdk.openai.model.CompletionModel;
import org.devlive.sdk.openai.model.EditModel;
import org.devlive.sdk.openai.response.ChatResponse;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -94,6 +95,30 @@ public void testCreateCompletion()
Assert.assertTrue(client.createCompletion(configure).getChoices().size() > 0);
}

@Test
public void testCustomizedModel()
{
client = OpenAiClient.builder()
.apiHost(System.getProperty("proxy.host"))
.apiKey(System.getProperty("openai.token"))
.model("text-davinci-003")
.build();

List<MessageEntity> messages = Lists.newArrayList();
messages.add(MessageEntity.builder()
.content("Hello, please show me a jok!")
.build());

ChatEntity configure = ChatEntity.builder()
.messages(messages)
.model("text-davinci-003")
.build();
ChatResponse chatCompletion = client.createChatCompletion(configure);
String content = chatCompletion.getChoices().get(0).getMessage().getContent();
// System.out.println(content);
Assert.assertNotNull(content);
}

@Test
public void testCreateChatCompletion()
{
Expand Down

0 comments on commit 8c2876f

Please sign in to comment.