Skip to content

Commit

Permalink
Added preliminary support for OpenAI Functions (#35748)
Browse files Browse the repository at this point in the history
* Added preliminary support for OpenAI Functions

* WIP: serialization still fails

* Tests are passing, requests go through to nonAzure OAI

* Moved gened classes with polymorphism to impl package

* Added assertion for function call

* Added test for usage of function not supplied in the request

* Renamed test for clarity

* Addressed most of the PR comments

* Added sync version of the tests

* Renamed runner method

* Added Azure sync/async tests

* Added docs for FunctionCall

* Removed unused files from samples package

* Preparing release notes

* Renamed static member

* Moved the exception handling one level up in the call stack

* code regened

* Moved custom models under their own package

* Updated test records

* Addressed most of the style checks

* removed unused import

* Renamed type

* Added spell checker exception for DALL-E

* Updated commit hash and re-ran code gen
  • Loading branch information
jpalvarezl authored Jul 13, 2023
1 parent caa742a commit 69d4813
Show file tree
Hide file tree
Showing 32 changed files with 1,211 additions and 83 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@
"creds",
"credscan",
"curr",
"DALL-E",
"databind",
"databricks",
"DAZURE",
Expand Down
4 changes: 4 additions & 0 deletions sdk/openai/azure-ai-openai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 1.0.0-beta.3 (Unreleased)

- Added methods and models to support DALL-E
- Added methods and models to support Functions
- Added models supporting ResponsibleAI annotations

### Features Added

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion sdk/openai/azure-ai-openai/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "java",
"TagPrefix": "java/openai/azure-ai-openai",
"Tag": "java/openai/azure-ai-openai_2a6e71fe2e"
"Tag": "java/openai/azure-ai-openai_9fc7970110"
}
1 change: 1 addition & 0 deletions sdk/openai/azure-ai-openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
--add-exports com.azure.core/com.azure.core.implementation.util=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai=ALL-UNNAMED
--add-opens com.azure.ai.openai/com.azure.ai.openai.implementation=com.fasterxml.jackson.databind
--add-opens com.azure.ai.openai/com.azure.ai.openai.functions=com.fasterxml.jackson.databind
</javaModulesSurefireArgLine>
<jacoco.skip>true</jacoco.skip>
</properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ public enum OpenAIServiceVersion implements ServiceVersion {
V2023_05_15("2023-05-15"),

/** Enum value 2023-06-01-preview. */
V2023_06_01_PREVIEW("2023-06-01-preview");
V2023_06_01_PREVIEW("2023-06-01-preview"),

/** Enum value 2023-07-01-preview. */
V2023_07_01_PREVIEW("2023-07-01-preview");

private final String version;

Expand All @@ -35,6 +38,6 @@ public String getVersion() {
* @return The latest {@link OpenAIServiceVersion}.
*/
public static OpenAIServiceVersion getLatest() {
return V2023_06_01_PREVIEW;
return V2023_07_01_PREVIEW;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

package com.azure.ai.openai.implementation;

import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.CompletionsOptions;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.core.annotation.BodyParam;
import com.azure.core.annotation.ExpectedResponses;
import com.azure.core.annotation.HeaderParam;
Expand All @@ -28,8 +25,14 @@
import com.azure.core.util.Context;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.serializer.SerializerAdapter;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;

/**
* Implementation for calling Non-Azure OpenAI Service
*/
Expand Down Expand Up @@ -66,6 +69,11 @@ public SerializerAdapter getSerializerAdapter() {
*/
public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1";

/**
* Mapper used to add the `modelId` into the request body for an nonAzure OpenAI request
*/
private static final ObjectMapper JSON_MAPPER = new ObjectMapper();

/**
* Initializes an instance of OpenAIClient client.
*
Expand Down Expand Up @@ -289,20 +297,20 @@ public Mono<Response<BinaryData>> getEmbeddingsWithResponseAsync(String modelId,
BinaryData embeddingsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData embeddingsOptionsUpdated = BinaryData.fromObject(
embeddingsOptions.toObject(EmbeddingsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getEmbeddings(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getEmbeddings(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -357,18 +365,18 @@ public Response<BinaryData> getEmbeddingsWithResponse(String modelId, BinaryData
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData embeddingsOptionsUpdated = BinaryData.fromObject(
embeddingsOptions.toObject(EmbeddingsOptions.class)
.setModel(modelId)
);

return service.getEmbeddingsSync(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
Context.NONE);
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId);
return service.getEmbeddingsSync(
OPEN_AI_ENDPOINT,
accept,
embeddingsOptionsUpdated,
requestOptions,
Context.NONE);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}

/**
Expand Down Expand Up @@ -458,20 +466,20 @@ public Mono<Response<BinaryData>> getCompletionsWithResponseAsync(String modelId
BinaryData completionsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData completionsOptionsUpdated = BinaryData.fromObject(
completionsOptions.toObject(CompletionsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getCompletions(
OPEN_AI_ENDPOINT,
accept,
completionsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getCompletions(
OPEN_AI_ENDPOINT,
accept,
completionsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -559,11 +567,14 @@ public Response<BinaryData> getCompletionsWithResponse(String modelId, BinaryDat
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData completionsOptionsUpdated = BinaryData.fromObject(
completionsOptions.toObject(CompletionsOptions.class)
.setModel(modelId)
);
// modelId is part of the request body in nonAzure OpenAI
BinaryData completionsOptionsUpdated = null;
try {
completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return service.getCompletionsSync(
OPEN_AI_ENDPOINT,
accept,
Expand Down Expand Up @@ -650,20 +661,20 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(String mod
BinaryData chatCompletionsOptions, RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject(
chatCompletionsOptions.toObject(ChatCompletionsOptions.class)
.setModel(modelId)
);

return FluxUtil.withContext(
context ->
service.getChatCompletions(
OPEN_AI_ENDPOINT,
accept,
chatCompletionsOptionsUpdated,
requestOptions,
context));
// modelId is part of the request body in nonAzure OpenAI
try {
BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId);
return FluxUtil.withContext(
context ->
service.getChatCompletions(
OPEN_AI_ENDPOINT,
accept,
chatCompletionsOptionsUpdated,
requestOptions,
context));
} catch (JsonProcessingException e) {
return Mono.error(e);
}
}

/**
Expand Down Expand Up @@ -743,11 +754,13 @@ public Response<BinaryData> getChatCompletionsWithResponse(String modelId, Binar
RequestOptions requestOptions) {
final String accept = "application/json";

// OpenAI has model ID in request body
BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject(
chatCompletionsOptions.toObject(ChatCompletionsOptions.class)
.setModel(modelId)
);
// modelId is part of the request body in nonAzure OpenAI
BinaryData chatCompletionsOptionsUpdated = null;
try {
chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}

return service.getChatCompletionsSync(
OPEN_AI_ENDPOINT,
Expand Down Expand Up @@ -870,4 +883,26 @@ public Response<BinaryData> generateImageWithResponse(
Context.NONE
);
}

/**
* This method injects the modelId in the request body for requests against nonAzure OpenAI. Unlike Azure OpenAI,
* the service expects this value in the body of the request, whereas Azure OpenAI passes it as part of the
* path of the request.
*
* @param inputJson JSON submitted by the client
* @param modelId The LLM model ID to be injected in the JSON
* @return
*/
private static BinaryData addModelIdJson(BinaryData inputJson, String modelId) throws JsonProcessingException {
JsonNode jsonNode = JSON_MAPPER.readTree(inputJson.toString());
if (jsonNode instanceof ObjectNode) {
ObjectNode objectNode = (ObjectNode) jsonNode;
objectNode.put("model", modelId);
inputJson = BinaryData.fromBytes(
objectNode.toString()
.getBytes(StandardCharsets.UTF_8));
}

return inputJson;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ public Response<BinaryData> getEmbeddingsWithResponse(
* int (Required)
* ]
* }
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* }
* ]
* usage (Required): {
Expand Down Expand Up @@ -646,7 +646,7 @@ public Mono<Response<BinaryData>> getCompletionsWithResponseAsync(
* int (Required)
* ]
* }
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* }
* ]
* usage (Required): {
Expand Down Expand Up @@ -693,10 +693,23 @@ public Response<BinaryData> getCompletionsWithResponse(
* {
* messages (Required): [
* (Required){
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* ]
* functions (Optional): [
* (Optional){
* name: String (Required)
* description: String (Optional)
* parameters: Object (Optional)
* }
* ]
* function_call: FunctionCallModelBase (Optional)
* max_tokens: Integer (Optional)
* temperature: Double (Optional)
* top_p: Double (Optional)
Expand Down Expand Up @@ -724,11 +737,16 @@ public Response<BinaryData> getCompletionsWithResponse(
* choices (Required): [
* (Required){
* message (Optional): {
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* index: int (Required)
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* delta (Optional): (recursive schema, see delta above)
* }
* ]
Expand Down Expand Up @@ -779,10 +797,23 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(
* {
* messages (Required): [
* (Required){
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* ]
* functions (Optional): [
* (Optional){
* name: String (Required)
* description: String (Optional)
* parameters: Object (Optional)
* }
* ]
* function_call: FunctionCallModelBase (Optional)
* max_tokens: Integer (Optional)
* temperature: Double (Optional)
* top_p: Double (Optional)
Expand Down Expand Up @@ -810,11 +841,16 @@ public Mono<Response<BinaryData>> getChatCompletionsWithResponseAsync(
* choices (Required): [
* (Required){
* message (Optional): {
* role: String(system/assistant/user) (Required)
* role: String(system/assistant/user/function) (Required)
* content: String (Optional)
* name: String (Optional)
* function_call (Optional): {
* name: String (Required)
* arguments: String (Required)
* }
* }
* index: int (Required)
* finish_reason: String(stop/length/content_filter) (Required)
* finish_reason: String(stop/length/content_filter/function_call) (Required)
* delta (Optional): (recursive schema, see delta above)
* }
* ]
Expand Down
Loading

0 comments on commit 69d4813

Please sign in to comment.