diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 5b4cffc80b63..9b7b7a3586a7 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -243,6 +243,7 @@ "creds", "credscan", "curr", + "DALL-E", "databind", "databricks", "DAZURE", diff --git a/sdk/openai/azure-ai-openai/CHANGELOG.md b/sdk/openai/azure-ai-openai/CHANGELOG.md index 7765dbb0d608..29165ff37fc7 100644 --- a/sdk/openai/azure-ai-openai/CHANGELOG.md +++ b/sdk/openai/azure-ai-openai/CHANGELOG.md @@ -2,6 +2,10 @@ ## 1.0.0-beta.3 (Unreleased) +- Added methods and models to support DALL-E +- Added methods and models to support Functions +- Added models supporting ResponsibleAI annotations + ### Features Added ### Breaking Changes diff --git a/sdk/openai/azure-ai-openai/assets.json b/sdk/openai/azure-ai-openai/assets.json index 1fa8b11ec119..ea1e75e6a58d 100644 --- a/sdk/openai/azure-ai-openai/assets.json +++ b/sdk/openai/azure-ai-openai/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "java", "TagPrefix": "java/openai/azure-ai-openai", - "Tag": "java/openai/azure-ai-openai_2a6e71fe2e" + "Tag": "java/openai/azure-ai-openai_9fc7970110" } diff --git a/sdk/openai/azure-ai-openai/pom.xml b/sdk/openai/azure-ai-openai/pom.xml index 67bc4ea1abb3..2cd7e4b75087 100644 --- a/sdk/openai/azure-ai-openai/pom.xml +++ b/sdk/openai/azure-ai-openai/pom.xml @@ -50,6 +50,7 @@ --add-exports com.azure.core/com.azure.core.implementation.util=ALL-UNNAMED --add-opens com.azure.ai.openai/com.azure.ai.openai=ALL-UNNAMED --add-opens com.azure.ai.openai/com.azure.ai.openai.implementation=com.fasterxml.jackson.databind + --add-opens com.azure.ai.openai/com.azure.ai.openai.functions=com.fasterxml.jackson.databind true diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIServiceVersion.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIServiceVersion.java index 8d15584caf4d..6a0186a342c1 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIServiceVersion.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIServiceVersion.java @@ -15,7 +15,10 @@ public enum OpenAIServiceVersion implements ServiceVersion { V2023_05_15("2023-05-15"), /** Enum value 2023-06-01-preview. */ - V2023_06_01_PREVIEW("2023-06-01-preview"); + V2023_06_01_PREVIEW("2023-06-01-preview"), + + /** Enum value 2023-07-01-preview. */ + V2023_07_01_PREVIEW("2023-07-01-preview"); private final String version; @@ -35,6 +38,6 @@ public String getVersion() { * @return The latest {@link OpenAIServiceVersion}. */ public static OpenAIServiceVersion getLatest() { - return V2023_06_01_PREVIEW; + return V2023_07_01_PREVIEW; } } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java index dcbcdb760166..5ecd55ec21b3 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/NonAzureOpenAIClientImpl.java @@ -3,9 +3,6 @@ package com.azure.ai.openai.implementation; -import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.CompletionsOptions; -import com.azure.ai.openai.models.EmbeddingsOptions; import com.azure.core.annotation.BodyParam; import com.azure.core.annotation.ExpectedResponses; import com.azure.core.annotation.HeaderParam; @@ -28,8 +25,14 @@ import com.azure.core.util.Context; import com.azure.core.util.FluxUtil; import com.azure.core.util.serializer.SerializerAdapter; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import reactor.core.publisher.Mono; +import java.nio.charset.StandardCharsets; + /** * Implementation for calling Non-Azure OpenAI Service */ @@ -66,6 +69,11 @@ public SerializerAdapter getSerializerAdapter() { */ public static final String OPEN_AI_ENDPOINT = "https://api.openai.com/v1"; + /** + * Mapper used to add the `modelId` into the request body for an nonAzure OpenAI request + */ + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + /** * Initializes an instance of OpenAIClient client. * @@ -289,20 +297,20 @@ public Mono> getEmbeddingsWithResponseAsync(String modelId, BinaryData embeddingsOptions, RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData embeddingsOptionsUpdated = BinaryData.fromObject( - embeddingsOptions.toObject(EmbeddingsOptions.class) - .setModel(modelId) - ); - - return FluxUtil.withContext( - context -> - service.getEmbeddings( - OPEN_AI_ENDPOINT, - accept, - embeddingsOptionsUpdated, - requestOptions, - context)); + // modelId is part of the request body in nonAzure OpenAI + try { + BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); + return FluxUtil.withContext( + context -> + service.getEmbeddings( + OPEN_AI_ENDPOINT, + accept, + embeddingsOptionsUpdated, + requestOptions, + context)); + } catch (JsonProcessingException e) { + return Mono.error(e); + } } /** @@ -357,18 +365,18 @@ public Response getEmbeddingsWithResponse(String modelId, BinaryData RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData embeddingsOptionsUpdated = BinaryData.fromObject( - embeddingsOptions.toObject(EmbeddingsOptions.class) - .setModel(modelId) - ); - - return service.getEmbeddingsSync( - OPEN_AI_ENDPOINT, - accept, - embeddingsOptionsUpdated, - requestOptions, - Context.NONE); + // modelId is part of the request body in nonAzure OpenAI + try { + BinaryData embeddingsOptionsUpdated = addModelIdJson(embeddingsOptions, modelId); + return service.getEmbeddingsSync( + OPEN_AI_ENDPOINT, + accept, + embeddingsOptionsUpdated, + requestOptions, + Context.NONE); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } } /** @@ -458,20 +466,20 @@ public Mono> getCompletionsWithResponseAsync(String modelId BinaryData completionsOptions, RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData completionsOptionsUpdated = BinaryData.fromObject( - completionsOptions.toObject(CompletionsOptions.class) - .setModel(modelId) - ); - - return FluxUtil.withContext( - context -> - service.getCompletions( - OPEN_AI_ENDPOINT, - accept, - completionsOptionsUpdated, - requestOptions, - context)); + // modelId is part of the request body in nonAzure OpenAI + try { + BinaryData completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); + return FluxUtil.withContext( + context -> + service.getCompletions( + OPEN_AI_ENDPOINT, + accept, + completionsOptionsUpdated, + requestOptions, + context)); + } catch (JsonProcessingException e) { + return Mono.error(e); + } } /** @@ -559,11 +567,14 @@ public Response getCompletionsWithResponse(String modelId, BinaryDat RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData completionsOptionsUpdated = BinaryData.fromObject( - completionsOptions.toObject(CompletionsOptions.class) - .setModel(modelId) - ); + // modelId is part of the request body in nonAzure OpenAI + BinaryData completionsOptionsUpdated = null; + try { + completionsOptionsUpdated = addModelIdJson(completionsOptions, modelId); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + return service.getCompletionsSync( OPEN_AI_ENDPOINT, accept, @@ -650,20 +661,20 @@ public Mono> getChatCompletionsWithResponseAsync(String mod BinaryData chatCompletionsOptions, RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject( - chatCompletionsOptions.toObject(ChatCompletionsOptions.class) - .setModel(modelId) - ); - - return FluxUtil.withContext( - context -> - service.getChatCompletions( - OPEN_AI_ENDPOINT, - accept, - chatCompletionsOptionsUpdated, - requestOptions, - context)); + // modelId is part of the request body in nonAzure OpenAI + try { + BinaryData chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); + return FluxUtil.withContext( + context -> + service.getChatCompletions( + OPEN_AI_ENDPOINT, + accept, + chatCompletionsOptionsUpdated, + requestOptions, + context)); + } catch (JsonProcessingException e) { + return Mono.error(e); + } } /** @@ -743,11 +754,13 @@ public Response getChatCompletionsWithResponse(String modelId, Binar RequestOptions requestOptions) { final String accept = "application/json"; - // OpenAI has model ID in request body - BinaryData chatCompletionsOptionsUpdated = BinaryData.fromObject( - chatCompletionsOptions.toObject(ChatCompletionsOptions.class) - .setModel(modelId) - ); + // modelId is part of the request body in nonAzure OpenAI + BinaryData chatCompletionsOptionsUpdated = null; + try { + chatCompletionsOptionsUpdated = addModelIdJson(chatCompletionsOptions, modelId); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } return service.getChatCompletionsSync( OPEN_AI_ENDPOINT, @@ -870,4 +883,26 @@ public Response generateImageWithResponse( Context.NONE ); } + + /** + * This method injects the modelId in the request body for requests against nonAzure OpenAI. Unlike Azure OpenAI, + * the service expects this value in the body of the request, whereas Azure OpenAI passes it as part of the + * path of the request. + * + * @param inputJson JSON submitted by the client + * @param modelId The LLM model ID to be injected in the JSON + * @return + */ + private static BinaryData addModelIdJson(BinaryData inputJson, String modelId) throws JsonProcessingException { + JsonNode jsonNode = JSON_MAPPER.readTree(inputJson.toString()); + if (jsonNode instanceof ObjectNode) { + ObjectNode objectNode = (ObjectNode) jsonNode; + objectNode.put("model", modelId); + inputJson = BinaryData.fromBytes( + objectNode.toString() + .getBytes(StandardCharsets.UTF_8)); + } + + return inputJson; + } } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/OpenAIClientImpl.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/OpenAIClientImpl.java index 3603ed471e2a..6ada94f34ce7 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/OpenAIClientImpl.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/OpenAIClientImpl.java @@ -548,7 +548,7 @@ public Response getEmbeddingsWithResponse( * int (Required) * ] * } - * finish_reason: String(stop/length/content_filter) (Required) + * finish_reason: String(stop/length/content_filter/function_call) (Required) * } * ] * usage (Required): { @@ -646,7 +646,7 @@ public Mono> getCompletionsWithResponseAsync( * int (Required) * ] * } - * finish_reason: String(stop/length/content_filter) (Required) + * finish_reason: String(stop/length/content_filter/function_call) (Required) * } * ] * usage (Required): { @@ -693,10 +693,23 @@ public Response getCompletionsWithResponse( * { * messages (Required): [ * (Required){ - * role: String(system/assistant/user) (Required) + * role: String(system/assistant/user/function) (Required) * content: String (Optional) + * name: String (Optional) + * function_call (Optional): { + * name: String (Required) + * arguments: String (Required) + * } + * } + * ] + * functions (Optional): [ + * (Optional){ + * name: String (Required) + * description: String (Optional) + * parameters: Object (Optional) * } * ] + * function_call: FunctionCallModelBase (Optional) * max_tokens: Integer (Optional) * temperature: Double (Optional) * top_p: Double (Optional) @@ -724,11 +737,16 @@ public Response getCompletionsWithResponse( * choices (Required): [ * (Required){ * message (Optional): { - * role: String(system/assistant/user) (Required) + * role: String(system/assistant/user/function) (Required) * content: String (Optional) + * name: String (Optional) + * function_call (Optional): { + * name: String (Required) + * arguments: String (Required) + * } * } * index: int (Required) - * finish_reason: String(stop/length/content_filter) (Required) + * finish_reason: String(stop/length/content_filter/function_call) (Required) * delta (Optional): (recursive schema, see delta above) * } * ] @@ -779,10 +797,23 @@ public Mono> getChatCompletionsWithResponseAsync( * { * messages (Required): [ * (Required){ - * role: String(system/assistant/user) (Required) + * role: String(system/assistant/user/function) (Required) * content: String (Optional) + * name: String (Optional) + * function_call (Optional): { + * name: String (Required) + * arguments: String (Required) + * } + * } + * ] + * functions (Optional): [ + * (Optional){ + * name: String (Required) + * description: String (Optional) + * parameters: Object (Optional) * } * ] + * function_call: FunctionCallModelBase (Optional) * max_tokens: Integer (Optional) * temperature: Double (Optional) * top_p: Double (Optional) @@ -810,11 +841,16 @@ public Mono> getChatCompletionsWithResponseAsync( * choices (Required): [ * (Required){ * message (Optional): { - * role: String(system/assistant/user) (Required) + * role: String(system/assistant/user/function) (Required) * content: String (Optional) + * name: String (Optional) + * function_call (Optional): { + * name: String (Required) + * arguments: String (Required) + * } * } * index: int (Required) - * finish_reason: String(stop/length/content_filter) (Required) + * finish_reason: String(stop/length/content_filter/function_call) (Required) * delta (Optional): (recursive schema, see delta above) * } * ] diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallModelBase.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallModelBase.java new file mode 100644 index 000000000000..f8bbf6f7ebf1 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallModelBase.java @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. + +package com.azure.ai.openai.implementation.models; + +/** The FunctionCallModelBase model. */ +public abstract class FunctionCallModelBase { + /** Creates an instance of FunctionCallModelBase class. */ + protected FunctionCallModelBase() {} +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPreset.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPreset.java new file mode 100644 index 000000000000..46bc254473ae --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPreset.java @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package com.azure.ai.openai.implementation.models; + +import com.azure.core.annotation.Generated; +import com.azure.core.util.ExpandableStringEnum; +import com.fasterxml.jackson.annotation.JsonCreator; +import java.util.Collection; + +/** + * The collection of predefined behaviors for handling request-provided function information in a chat completions + * operation. + */ +public final class FunctionCallPreset extends ExpandableStringEnum { + + /** + * Specifies that the model may either use any of the functions provided in this chat completions request or instead + * return a standard chat completions response as if no functions were provided. + */ + @Generated public static final FunctionCallPreset AUTO = fromString("auto"); + + /** + * Specifies that the model should not respond with a function call and should instead provide a standard chat + * completions response. Response content may still be influenced by the provided function information. + */ + @Generated public static final FunctionCallPreset NONE = fromString("none"); + + /** + * Creates a new instance of FunctionCallPreset value. + * + * @deprecated Use the {@link #fromString(String)} factory method. + */ + @Generated + @Deprecated + public FunctionCallPreset() {} + + /** + * Creates or finds a FunctionCallPreset from its string representation. + * + * @param name a name to look for. + * @return the corresponding FunctionCallPreset. + */ + @Generated + @JsonCreator + public static FunctionCallPreset fromString(String name) { + return fromString(name, FunctionCallPreset.class); + } + + /** + * Gets known FunctionCallPreset values. + * + * @return known FunctionCallPreset values. + */ + @Generated + public static Collection values() { + return values(FunctionCallPreset.class); + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPresetFunctionCallModel.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPresetFunctionCallModel.java new file mode 100644 index 000000000000..2a4bad739d8d --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionCallPresetFunctionCallModel.java @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. + +package com.azure.ai.openai.implementation.models; + +import com.azure.core.annotation.Immutable; +import com.fasterxml.jackson.annotation.JsonValue; + +/** The FunctionCallPresetFunctionCallModel model. */ +@Immutable +public final class FunctionCallPresetFunctionCallModel extends FunctionCallModelBase { + private final FunctionCallPreset value; + + /** + * Creates an instance of FunctionCallPresetFunctionCallModel class. + * + * @param value the value. + */ + public FunctionCallPresetFunctionCallModel(FunctionCallPreset value) { + this.value = value; + } + + /** + * Gets the value. + * + * @return the value. + */ + @JsonValue + public FunctionCallPreset getValue() { + return this.value; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionDefinition.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionDefinition.java new file mode 100644 index 000000000000..3d4ffcc5d08b --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionDefinition.java @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package com.azure.ai.openai.implementation.models; + +import com.azure.core.annotation.Fluent; +import com.azure.core.annotation.Generated; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The definition of a caller-specified function that chat completions may invoke in response to matching user input. + */ +@Fluent +public final class FunctionDefinition { + + /* + * The name of the function to be called. + */ + @Generated + @JsonProperty(value = "name") + private String name; + + /* + * A description of what the function does. The model will use this description when selecting the function and + * interpreting its parameters. + */ + @Generated + @JsonProperty(value = "description") + private String description; + + /* + * The parameters the functions accepts, described as a JSON Schema object. + */ + @Generated + @JsonProperty(value = "parameters") + private Object parameters; + + /** + * Creates an instance of FunctionDefinition class. + * + * @param name the name value to set. + */ + @Generated + @JsonCreator + public FunctionDefinition(@JsonProperty(value = "name") String name) { + this.name = name; + } + + /** + * Get the name property: The name of the function to be called. + * + * @return the name value. + */ + @Generated + public String getName() { + return this.name; + } + + /** + * Get the description property: A description of what the function does. The model will use this description when + * selecting the function and interpreting its parameters. + * + * @return the description value. + */ + @Generated + public String getDescription() { + return this.description; + } + + /** + * Set the description property: A description of what the function does. The model will use this description when + * selecting the function and interpreting its parameters. + * + * @param description the description value to set. + * @return the FunctionDefinition object itself. + */ + @Generated + public FunctionDefinition setDescription(String description) { + this.description = description; + return this; + } + + /** + * Get the parameters property: The parameters the functions accepts, described as a JSON Schema object. + * + * @return the parameters value. + */ + @Generated + public Object getParameters() { + return this.parameters; + } + + /** + * Set the parameters property: The parameters the functions accepts, described as a JSON Schema object. + * + * @param parameters the parameters value to set. + * @return the FunctionDefinition object itself. + */ + @Generated + public FunctionDefinition setParameters(Object parameters) { + this.parameters = parameters; + return this; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionNameFunctionCallModel.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionNameFunctionCallModel.java new file mode 100644 index 000000000000..14f61cf22412 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/FunctionNameFunctionCallModel.java @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. + +package com.azure.ai.openai.implementation.models; + +import com.azure.ai.openai.models.FunctionName; +import com.azure.core.annotation.Immutable; +import com.fasterxml.jackson.annotation.JsonValue; + +/** The FunctionNameFunctionCallModel model. */ +@Immutable +public final class FunctionNameFunctionCallModel extends FunctionCallModelBase { + private final FunctionName value; + + /** + * Creates an instance of FunctionNameFunctionCallModel class. + * + * @param value the value. + */ + public FunctionNameFunctionCallModel(FunctionName value) { + this.value = value; + } + + /** + * Gets the value. + * + * @return the value. + */ + @JsonValue + public FunctionName getValue() { + return this.value; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/package-info.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/package-info.java new file mode 100644 index 000000000000..26eb2ce2a673 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/implementation/models/package-info.java @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. + +/** Package containing the data models for OpenAI. Azure OpenAI APIs for completions and search. */ +package com.azure.ai.openai.implementation.models; diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatCompletionsOptions.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatCompletionsOptions.java index 7df28d9666c7..02f714199d42 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatCompletionsOptions.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatCompletionsOptions.java @@ -3,9 +3,15 @@ // Code generated by Microsoft (R) AutoRest Code Generator. package com.azure.ai.openai.models; +import com.azure.ai.openai.implementation.models.FunctionCallModelBase; +import com.azure.ai.openai.implementation.models.FunctionCallPreset; +import com.azure.ai.openai.implementation.models.FunctionCallPresetFunctionCallModel; +import com.azure.ai.openai.implementation.models.FunctionDefinition; +import com.azure.ai.openai.implementation.models.FunctionNameFunctionCallModel; import com.azure.core.annotation.Fluent; import com.azure.core.annotation.Generated; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; import java.util.Map; @@ -429,4 +435,108 @@ public ChatCompletionsOptions setModel(String model) { this.model = model; return this; } + + /* + * A list of functions the model may generate JSON inputs for. + */ + @Generated + @JsonProperty(value = "functions") + private List functions; + + /* + * Controls how the model responds to function calls. "none" means the model does not call a function, + * and responds to the end-user. "auto" means the model can pick between an end-user or calling a function. + * Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + * "none" is the default when no functions are present. "auto" is the default if functions are present. + */ + @Generated + @JsonProperty(value = "function_call") + private FunctionCallModelBase functionCall; + + /* + * Field not used for serialization. This is a convenience helper field for the serialization of "function_call". + */ + @JsonIgnore private FunctionCallConfig functionCallConfig; + + /** + * Get the functions property: A list of functions the model may generate JSON inputs for. + * + * @return the functions value. + */ + @Generated + public List getFunctions() { + return this.functions; + } + + /** + * Set the functions property: A list of functions the model may generate JSON inputs for. + * + * @param functions the functions value to set. + * @return the ChatCompletionsOptions object itself. + */ + @Generated + public ChatCompletionsOptions setFunctions(List functions) { + this.functions = functions; + return this; + } + + /** + * Get the functionCall property: Controls how the model responds to function calls. "none" means the model does not + * call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a + * function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + * "none" is the default when no functions are present. "auto" is the default if functions are present. + * + * @return the functionCall value. + */ + FunctionCallModelBase getFunctionCallInternal() { + return this.functionCall; + } + + /** + * Set the functionCall property: Controls how the model responds to function calls. "none" means the model does not + * call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a + * function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + * "none" is the default when no functions are present. "auto" is the default if functions are present. + * + * @param functionCall the functionCall value to set. + * @return the ChatCompletionsOptions object itself. + */ + ChatCompletionsOptions setFunctionCallInternal(FunctionCallModelBase functionCall) { + this.functionCall = functionCall; + return this; + } + + /** + * Get the functionCall property: Controls how the model responds to function calls. "none" means the model does not + * call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a + * function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + * "none" is the default when no functions are present. "auto" is the default if functions are present. + * + * @return the functionCall value. + */ + public FunctionCallConfig getFunctionCall() { + return this.functionCallConfig; + } + + /** + * Set the functionCall property: Controls how the model responds to function calls. "none" means the model does not + * call a function, and responds to the end-user. "auto" means the model can pick between an end-user or calling a + * function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + * "none" is the default when no functions are present. "auto" is the default if functions are present. + * + * @param functionCallConfig the functionCall value to set. + * @return the ChatCompletionsOptions object itself. + */ + public ChatCompletionsOptions setFunctionCall(FunctionCallConfig functionCallConfig) { + this.functionCallConfig = functionCallConfig; + if (FunctionCallPreset.values().stream() + .anyMatch(preset -> preset.toString().equals(functionCallConfig.getName()))) { + this.functionCall = + new FunctionCallPresetFunctionCallModel( + FunctionCallPreset.fromString(this.functionCallConfig.getName())); + } else { + this.functionCall = new FunctionNameFunctionCallModel(new FunctionName(this.functionCallConfig.getName())); + } + return this; + } } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatMessage.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatMessage.java index 2ad4cfd18b97..2594743f63f7 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatMessage.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatMessage.java @@ -68,4 +68,72 @@ public ChatMessage setContent(String content) { this.content = content; return this; } + + /* + * The name of the author of this message. `name` is required if role is `function`, and it should be the name of + * the + * function whose response is in the `content`. May contain a-z, A-Z, 0-9, and underscores, with a maximum length + * of + * 64 characters. + */ + @Generated + @JsonProperty(value = "name") + private String name; + + /* + * The name and arguments of a function that should be called, as generated by the model. + */ + @Generated + @JsonProperty(value = "function_call") + private FunctionCall functionCall; + + /** + * Get the name property: The name of the author of this message. `name` is required if role is `function`, and it + * should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and + * underscores, with a maximum length of 64 characters. + * + * @return the name value. + */ + @Generated + public String getName() { + return this.name; + } + + /** + * Set the name property: The name of the author of this message. `name` is required if role is `function`, and it + * should be the name of the function whose response is in the `content`. May contain a-z, A-Z, 0-9, and + * underscores, with a maximum length of 64 characters. + * + * @param name the name value to set. + * @return the ChatMessage object itself. + */ + @Generated + public ChatMessage setName(String name) { + this.name = name; + return this; + } + + /** + * Get the functionCall property: The name and arguments of a function that should be called, as generated by the + * model. + * + * @return the functionCall value. + */ + @Generated + public FunctionCall getFunctionCall() { + return this.functionCall; + } + + /** + * Set the functionCall property: The name and arguments of a function that should be called, as generated by the + * model. + * + * @param functionCall the functionCall value to set. + * @return the ChatMessage object itself. + */ + @Generated + public ChatMessage setFunctionCall(FunctionCall functionCall) { + this.functionCall = functionCall; + return this; + } } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatRole.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatRole.java index dcf8fe55f839..621ea748768e 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatRole.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/ChatRole.java @@ -50,4 +50,7 @@ public static ChatRole fromString(String name) { public static Collection values() { return values(ChatRole.class); } + + /** The role that provides function results for char completions. */ + @Generated public static final ChatRole FUNCTION = fromString("function"); } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/CompletionsFinishReason.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/CompletionsFinishReason.java index c267098eca5b..b16c05f00c3f 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/CompletionsFinishReason.java +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/CompletionsFinishReason.java @@ -52,4 +52,7 @@ public static CompletionsFinishReason fromString(String name) { public static Collection values() { return values(CompletionsFinishReason.class); } + + /** Completion ended normally, with the model requesting a function to be called. */ + @Generated public static final CompletionsFinishReason FUNCTION_CALL = fromString("function_call"); } diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCall.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCall.java new file mode 100644 index 000000000000..f03fa51fdb66 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCall.java @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package com.azure.ai.openai.models; + +import com.azure.core.annotation.Generated; +import com.azure.core.annotation.Immutable; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** The name and arguments of a function that should be called, as generated by the model. */ +@Immutable +public final class FunctionCall { + + /* + * The name of the function to call. + */ + @Generated + @JsonProperty(value = "name") + private String name; + + /* + * The arguments to call the function with, as generated by the model in JSON format. + * Note that the model does not always generate valid JSON, and may hallucinate parameters + * not defined by your function schema. Validate the arguments in your code before calling + * your function. + */ + @Generated + @JsonProperty(value = "arguments") + private String arguments; + + /** + * Creates an instance of FunctionCall class. + * + * @param name the name value to set. + * @param arguments the arguments value to set. + */ + @Generated + @JsonCreator + public FunctionCall( + @JsonProperty(value = "name") String name, @JsonProperty(value = "arguments") String arguments) { + this.name = name; + this.arguments = arguments; + } + + /** + * Get the name property: The name of the function to call. + * + * @return the name value. + */ + @Generated + public String getName() { + return this.name; + } + + /** + * Get the arguments property: The arguments to call the function with, as generated by the model in JSON format. + * Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your + * function schema. Validate the arguments in your code before calling your function. + * + * @return the arguments value. + */ + @Generated + public String getArguments() { + return this.arguments; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCallConfig.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCallConfig.java new file mode 100644 index 000000000000..ac1c713931b8 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionCallConfig.java @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. + +package com.azure.ai.openai.models; + +/** The name and arguments of a function that should be called, as generated by the model. */ +public class FunctionCallConfig { + + /** + * The name of the function to call. + */ + private final String name; + + /** + * AUTO will indicate the service to call any functions that are necessary for text completion generation. + */ + public static final FunctionCallConfig AUTO = new FunctionCallConfig("auto"); + + /** + * NONE will indicate the service to not call nay of the functions that may have been provided with the request for + * text completion generation. + */ + public static final FunctionCallConfig NONE = new FunctionCallConfig("none"); + + /** + * Creates an instance of FunctionCall class. + * + * @param name the name value to set. + */ + public FunctionCallConfig(String name) { + this.name = name; + } + + /** + * Get the name property: The name of the function to call. + * + * @return the name value. + */ + public String getName() { + return this.name; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionName.java b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionName.java new file mode 100644 index 000000000000..82b2abb8eff4 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/models/FunctionName.java @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// Code generated by Microsoft (R) AutoRest Code Generator. +package com.azure.ai.openai.models; + +import com.azure.core.annotation.Generated; +import com.azure.core.annotation.Immutable; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A structure that specifies the exact name of a specific, request-provided function to use when processing a chat + * completions operation. + */ +@Immutable +public final class FunctionName { + + /* + * The name of the function to call. + */ + @Generated + @JsonProperty(value = "name") + private String name; + + /** + * Creates an instance of FunctionName class. + * + * @param name the name value to set. + */ + @Generated + @JsonCreator + public FunctionName(@JsonProperty(value = "name") String name) { + this.name = name; + } + + /** + * Get the name property: The name of the function to call. + * + * @return the name value. + */ + @Generated + public String getName() { + return this.name; + } +} diff --git a/sdk/openai/azure-ai-openai/src/main/java/module-info.java b/sdk/openai/azure-ai-openai/src/main/java/module-info.java index 590e8b4f6dfd..07571db1c578 100644 --- a/sdk/openai/azure-ai-openai/src/main/java/module-info.java +++ b/sdk/openai/azure-ai-openai/src/main/java/module-info.java @@ -8,8 +8,15 @@ exports com.azure.ai.openai; exports com.azure.ai.openai.models; + exports com.azure.ai.openai.implementation.models; opens com.azure.ai.openai.models to com.azure.core, com.fasterxml.jackson.databind; + opens com.azure.ai.openai.implementation.models to + com.azure.core, + com.fasterxml.jackson.databind; + opens com.azure.ai.openai.implementation to + com.azure.core, + com.fasterxml.jackson.databind; } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAIAsyncClientTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAIAsyncClientTest.java index 16e3a92fcedb..9b0987d45af1 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAIAsyncClientTest.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAIAsyncClientTest.java @@ -3,12 +3,16 @@ package com.azure.ai.openai; +import com.azure.ai.openai.functions.MyFunctionCallArguments; +import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRole; import com.azure.ai.openai.models.Completions; import com.azure.ai.openai.models.CompletionsOptions; import com.azure.ai.openai.models.CompletionsUsage; import com.azure.ai.openai.models.Embeddings; +import com.azure.ai.openai.models.FunctionCallConfig; import com.azure.core.exception.ClientAuthenticationException; import com.azure.core.exception.HttpResponseException; import com.azure.core.http.HttpClient; @@ -240,4 +244,56 @@ public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion servic .assertNext(OpenAIClientTestBase::assertImageResponse) .verifyComplete()); } + + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionAutoPreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAIAsyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.AUTO); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .assertNext(chatCompletions -> { + assertEquals(1, chatCompletions.getChoices().size()); + ChatChoice chatChoice = chatCompletions.getChoices().get(0); + MyFunctionCallArguments arguments = assertFunctionCall( + chatChoice, + "MyFunction", + MyFunctionCallArguments.class); + assertEquals(arguments.getLocation(), "San Francisco, CA"); + assertEquals(arguments.getUnit(), "CELSIUS"); + }) + .verifyComplete(); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNonePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAIAsyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.NONE); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .assertNext(chatCompletions -> { + assertChatCompletions(1, "stop", ChatRole.ASSISTANT, chatCompletions); + }) + .verifyComplete(); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNotSuppliedByNamePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAIAsyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(new FunctionCallConfig("NotMyFunction")); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .verifyErrorSatisfies(throwable -> { + assertInstanceOf(HttpResponseException.class, throwable); + HttpResponseException httpResponseException = (HttpResponseException) throwable; + assertEquals(400, httpResponseException.getResponse().getStatusCode()); + assertTrue(httpResponseException.getMessage().contains("Invalid value for 'function_call'")); + }); + }); + } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAISyncClientTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAISyncClientTest.java index 51d10cabd27a..da35b3dcd071 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAISyncClientTest.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/NonAzureOpenAISyncClientTest.java @@ -3,12 +3,16 @@ package com.azure.ai.openai; +import com.azure.ai.openai.functions.MyFunctionCallArguments; +import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRole; import com.azure.ai.openai.models.Completions; import com.azure.ai.openai.models.CompletionsOptions; import com.azure.ai.openai.models.CompletionsUsage; import com.azure.ai.openai.models.Embeddings; +import com.azure.ai.openai.models.FunctionCallConfig; import com.azure.core.exception.ClientAuthenticationException; import com.azure.core.exception.HttpResponseException; import com.azure.core.http.HttpClient; @@ -21,6 +25,7 @@ import static com.azure.ai.openai.TestUtils.DISPLAY_NAME_WITH_ARGUMENTS; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -194,4 +199,50 @@ public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion servic client = getNonAzureOpenAISyncClient(httpClient); getImageGenerationRunner(options -> assertImageResponse(client.generateImage(options))); } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionAutoPreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAISyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.AUTO); + ChatCompletions chatCompletions = client.getChatCompletions(modelId, chatCompletionsOptions); + + assertEquals(1, chatCompletions.getChoices().size()); + ChatChoice chatChoice = chatCompletions.getChoices().get(0); + MyFunctionCallArguments arguments = assertFunctionCall( + chatChoice, + "MyFunction", + MyFunctionCallArguments.class); + assertEquals(arguments.getLocation(), "San Francisco, CA"); + assertEquals(arguments.getUnit(), "CELSIUS"); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNonePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAISyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.NONE); + ChatCompletions chatCompletions = client.getChatCompletions(modelId, chatCompletionsOptions); + + assertChatCompletions(1, "stop", ChatRole.ASSISTANT, chatCompletions); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNotSuppliedByNamePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getNonAzureOpenAISyncClient(httpClient); + getChatFunctionForNonAzureRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(new FunctionCallConfig("NotMyFunction")); + HttpResponseException exception = assertThrows(HttpResponseException.class, + () -> client.getChatCompletions(modelId, chatCompletionsOptions)); + assertEquals(400, exception.getResponse().getStatusCode()); + + assertInstanceOf(HttpResponseException.class, exception); + assertTrue(exception.getMessage().contains("Invalid value for 'function_call'")); + }); + } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIAsyncClientTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIAsyncClientTest.java index aa69e69e08c8..f417ca1a8d33 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIAsyncClientTest.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIAsyncClientTest.java @@ -3,12 +3,17 @@ package com.azure.ai.openai; +import com.azure.ai.openai.functions.MyFunctionCallArguments; +import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRole; import com.azure.ai.openai.models.Completions; import com.azure.ai.openai.models.CompletionsOptions; import com.azure.ai.openai.models.CompletionsUsage; import com.azure.ai.openai.models.Embeddings; +import com.azure.ai.openai.models.FunctionCallConfig; +import com.azure.core.exception.HttpResponseException; import com.azure.core.exception.ResourceNotFoundException; import com.azure.core.http.HttpClient; import com.azure.core.http.rest.RequestOptions; @@ -51,8 +56,7 @@ public void testGetCompletions(HttpClient httpClient, OpenAIServiceVersion servi @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") public void testGetCompletionsStream(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { - // TODO: use the parameterized serviceVersion once we have a service release that responds with valid data - client = getOpenAIAsyncClient(httpClient, OpenAIServiceVersion.V2023_05_15); + client = getOpenAIAsyncClient(httpClient, serviceVersion); getCompletionsRunner((deploymentId, prompt) -> { StepVerifier.create(client.getCompletionsStream(deploymentId, new CompletionsOptions(prompt))) .recordWith(ArrayList::new) @@ -161,8 +165,7 @@ public void testGetChatCompletions(HttpClient httpClient, OpenAIServiceVersion s @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") public void testGetChatCompletionsStream(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { - // TODO: use the parameterized serviceVersion once we have a service release that responds with valid data - client = getOpenAIAsyncClient(httpClient, OpenAIServiceVersion.V2023_05_15); + client = getOpenAIAsyncClient(httpClient, serviceVersion); getChatCompletionsRunner((deploymentId, chatMessages) -> { StepVerifier.create(client.getChatCompletionsStream(deploymentId, new ChatCompletionsOptions(chatMessages))) .recordWith(ArrayList::new) @@ -227,4 +230,55 @@ public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion servic .assertNext(OpenAIClientTestBase::assertImageResponse) .verifyComplete()); } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionAutoPreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIAsyncClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.AUTO); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .assertNext(chatCompletions -> { + assertEquals(1, chatCompletions.getChoices().size()); + ChatChoice chatChoice = chatCompletions.getChoices().get(0); + MyFunctionCallArguments arguments = assertFunctionCall( + chatChoice, + "MyFunction", + MyFunctionCallArguments.class); + assertEquals(arguments.getLocation(), "San Francisco, CA"); + assertEquals(arguments.getUnit(), "CELSIUS"); + }) + .verifyComplete(); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNonePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIAsyncClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.NONE); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .assertNext(chatCompletions -> { + assertChatCompletions(1, "stop", ChatRole.ASSISTANT, chatCompletions); + }) + .verifyComplete(); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNotSuppliedByNamePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIAsyncClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(new FunctionCallConfig("NotMyFunction")); + StepVerifier.create(client.getChatCompletions(modelId, chatCompletionsOptions)) + .verifyErrorSatisfies(throwable -> { + assertInstanceOf(HttpResponseException.class, throwable); + HttpResponseException httpResponseException = (HttpResponseException) throwable; + assertEquals(400, httpResponseException.getResponse().getStatusCode()); + assertTrue(httpResponseException.getMessage().contains("Invalid value for 'function_call'")); + }); + }); + } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientTestBase.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientTestBase.java index 2075f5938603..489c418e3953 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientTestBase.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAIClientTestBase.java @@ -4,8 +4,11 @@ package com.azure.ai.openai; +import com.azure.ai.openai.functions.Parameters; +import com.azure.ai.openai.models.FunctionCall; import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; +import com.azure.ai.openai.models.ChatCompletionsOptions; import com.azure.ai.openai.models.ChatMessage; import com.azure.ai.openai.models.ChatRole; import com.azure.ai.openai.models.Choice; @@ -13,6 +16,7 @@ import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.ai.openai.implementation.models.FunctionDefinition; import com.azure.ai.openai.models.ImageGenerationOptions; import com.azure.ai.openai.models.ImageResponse; import com.azure.ai.openai.models.NonAzureOpenAIKeyCredential; @@ -134,6 +138,14 @@ void getImageGenerationRunner(Consumer testRunner) { ); } + void getChatFunctionForNonAzureRunner(BiConsumer testRunner) { + testRunner.accept("gpt-3.5-turbo-0613", getChatMessagesWithFunction()); + } + + void getChatFunctionForRunner(BiConsumer testRunner) { + testRunner.accept("gpt-4", getChatMessagesWithFunction()); + } + private List getChatMessages() { List chatMessages = new ArrayList<>(); chatMessages.add(new ChatMessage(ChatRole.SYSTEM).setContent("You are a helpful assistant. You will talk like a pirate.")); @@ -143,6 +155,20 @@ private List getChatMessages() { return chatMessages; } + private ChatCompletionsOptions getChatMessagesWithFunction() { + FunctionDefinition functionDefinition = new FunctionDefinition("MyFunction"); + Parameters parameters = new Parameters(); + functionDefinition.setParameters(parameters); + List functions = Arrays.asList(functionDefinition); + + List chatMessages = new ArrayList<>(); + chatMessages.add(new ChatMessage(ChatRole.USER).setContent("What's the weather like in San Francisco in Celsius?")); + + ChatCompletionsOptions chatCompletionOptions = new ChatCompletionsOptions(chatMessages); + chatCompletionOptions.setFunctions(functions); + return chatCompletionOptions; + } + static void assertCompletions(int choicesPerPrompt, Completions actual) { assertCompletions(choicesPerPrompt, "stop", actual); } @@ -246,4 +272,13 @@ static void assertImageResponse(ImageResponse actual) { assertNotNull(actual.getData()); assertFalse(actual.getData().isEmpty()); } + + static T assertFunctionCall(ChatChoice actual, String functionName, Class myPropertiesClazz) { + assertEquals(0, actual.getIndex()); + assertEquals("function_call", actual.getFinishReason().toString()); + FunctionCall functionCall = actual.getMessage().getFunctionCall(); + assertEquals(functionName, functionCall.getName()); + BinaryData argumentJson = BinaryData.fromString(functionCall.getArguments()); + return argumentJson.toObject(myPropertiesClazz); + } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAISyncClientTest.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAISyncClientTest.java index 9579283be5ef..98f300126b4a 100644 --- a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAISyncClientTest.java +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/OpenAISyncClientTest.java @@ -3,12 +3,17 @@ package com.azure.ai.openai; +import com.azure.ai.openai.functions.MyFunctionCallArguments; +import com.azure.ai.openai.models.ChatChoice; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRole; import com.azure.ai.openai.models.Completions; import com.azure.ai.openai.models.CompletionsOptions; import com.azure.ai.openai.models.CompletionsUsage; import com.azure.ai.openai.models.Embeddings; +import com.azure.ai.openai.models.FunctionCallConfig; +import com.azure.core.exception.HttpResponseException; import com.azure.core.exception.ResourceNotFoundException; import com.azure.core.http.HttpClient; import com.azure.core.http.rest.RequestOptions; @@ -20,6 +25,7 @@ import static com.azure.ai.openai.TestUtils.DISPLAY_NAME_WITH_ARGUMENTS; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -182,4 +188,50 @@ public void testGenerateImage(HttpClient httpClient, OpenAIServiceVersion servic client = getOpenAIClient(httpClient, serviceVersion); getImageGenerationRunner(options -> assertImageResponse(client.generateImage(options))); } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionAutoPreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.AUTO); + ChatCompletions chatCompletions = client.getChatCompletions(modelId, chatCompletionsOptions); + + assertEquals(1, chatCompletions.getChoices().size()); + ChatChoice chatChoice = chatCompletions.getChoices().get(0); + MyFunctionCallArguments arguments = assertFunctionCall( + chatChoice, + "MyFunction", + MyFunctionCallArguments.class); + assertEquals(arguments.getLocation(), "San Francisco, CA"); + assertEquals(arguments.getUnit(), "CELSIUS"); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNonePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(FunctionCallConfig.NONE); + ChatCompletions chatCompletions = client.getChatCompletions(modelId, chatCompletionsOptions); + + assertChatCompletions(1, "stop", ChatRole.ASSISTANT, chatCompletions); + }); + } + + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) + @MethodSource("com.azure.ai.openai.TestUtils#getTestParameters") + public void testChatFunctionNotSuppliedByNamePreset(HttpClient httpClient, OpenAIServiceVersion serviceVersion) { + client = getOpenAIClient(httpClient, serviceVersion); + getChatFunctionForRunner((modelId, chatCompletionsOptions) -> { + chatCompletionsOptions.setFunctionCall(new FunctionCallConfig("NotMyFunction")); + HttpResponseException exception = assertThrows(HttpResponseException.class, + () -> client.getChatCompletions(modelId, chatCompletionsOptions)); + assertEquals(400, exception.getResponse().getStatusCode()); + + assertInstanceOf(HttpResponseException.class, exception); + assertTrue(exception.getMessage().contains("Invalid value for 'function_call'")); + }); + } } diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Location.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Location.java new file mode 100644 index 000000000000..d36d19354572 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Location.java @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.openai.functions; + +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +public class Location { + + /** + * type defines the JSON type of the value that the service will request if a @FunctionCall is requested + */ + @JsonProperty(value = "type") + private String type = "string"; + + /** + * Examples provided in the description appear to be used verbatim. Such as "San Francisco, CA" + */ + @JsonProperty(value = "description") + private String description = "The city and state, e.g. San Francisco, CA"; + + @JsonGetter + public String getType() { + return type; + } + + @JsonSetter + public void setType(String type) { + this.type = type; + } + + @JsonGetter + public String getDescription() { + return description; + } + + @JsonSetter + public void setDescription(String description) { + this.description = description; + } +} diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/MyFunctionCallArguments.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/MyFunctionCallArguments.java new file mode 100644 index 000000000000..4cd3b0d19d12 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/MyFunctionCallArguments.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.openai.functions; + +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +public class MyFunctionCallArguments { + @JsonProperty(value = "unit") + private String unit; + + @JsonProperty(value = "location") + private String location; + + @JsonGetter + public String getUnit() { + return unit; + } + + @JsonSetter + public void setUnit(String unit) { + this.unit = unit; + } + + @JsonGetter + public String getLocation() { + return location; + } + + @JsonSetter + public void setLocation(String location) { + this.location = location; + } +} diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Parameters.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Parameters.java new file mode 100644 index 000000000000..b248c138094d --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Parameters.java @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.openai.functions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +public class Parameters { + + @JsonProperty(value = "type") + private String type = "object"; + + @JsonProperty(value = "properties") + private Properties properties = new Properties(); + + @JsonCreator + public Parameters( + @JsonProperty(value = "type") + String type, + @JsonProperty(value = "properties") + Properties properties + ) { + this.type = type; + this.properties = properties; + } + + @JsonCreator + public Parameters() {} + + @JsonGetter + public String getType() { + return type; + } + + @JsonSetter + public void setType(String type) { + this.type = type; + } + + @JsonGetter + public Properties getProperties() { + return properties; + } + + @JsonSetter + public void setProperties(Properties properties) { + this.properties = properties; + } +} diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Properties.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Properties.java new file mode 100644 index 000000000000..4303e531cdbd --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Properties.java @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.openai.functions; + +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +public class Properties { + @JsonProperty(value = "unit") + private Unit unit = new Unit(); + + @JsonProperty + private Location location = new Location(); + + @JsonGetter + public Unit getUnit() { + return unit; + } + + @JsonSetter + public void setUnit(Unit unit) { + this.unit = unit; + } + + @JsonGetter + public Location getLocation() { + return location; + } + + @JsonSetter + public void setLocation(Location location) { + this.location = location; + } +} diff --git a/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Unit.java b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Unit.java new file mode 100644 index 000000000000..6c3e92f94627 --- /dev/null +++ b/sdk/openai/azure-ai-openai/src/test/java/com/azure/ai/openai/functions/Unit.java @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.openai.functions; + +import com.fasterxml.jackson.annotation.JsonGetter; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSetter; + +import java.util.Arrays; +import java.util.List; + +public class Unit { + @JsonProperty(value = "type") + private String type = "string"; + + @JsonProperty(value = "enum") + private List enumValues = Arrays.asList("CELSIUS", "FAHRENHEIT"); + + @JsonGetter + public String getType() { + return type; + } + + @JsonSetter + public void setType(String type) { + this.type = type; + } + + @JsonGetter + public List getEnumValues() { + return enumValues; + } + + @JsonSetter + public void setEnumValues(List enumValues) { + this.enumValues = enumValues; + } +} diff --git a/sdk/openai/azure-ai-openai/tsp-location.yaml b/sdk/openai/azure-ai-openai/tsp-location.yaml index e4077d3198f6..ffcad5837832 100644 --- a/sdk/openai/azure-ai-openai/tsp-location.yaml +++ b/sdk/openai/azure-ai-openai/tsp-location.yaml @@ -1,5 +1,5 @@ directory: specification/cognitiveservices/OpenAI.Inference additionalDirectories: - specification/cognitiveservices/OpenAI.Authoring -commit: 018905ddfbba9e08961964784a5de7093815b42e +commit: e994b93c82c5d23eb377f35434354438e748cb87 repo: Azure/azure-rest-api-specs