diff --git a/README.md b/README.md index 406a95b1d8c..946315ef89a 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ This repository contains a suite of libraries that provide facilities commonly needed when creating production-ready applications. Initially developed to support high-scale and high-availability services within Microsoft, such as Microsoft Teams, these libraries deliver functionality that can help make applications more efficient, more robust, and more manageable. The major functional areas this repo addresses are: +- AI: Abstractions and middlewares for working with generative AI models and services. - Compliance: Mechanisms to help manage application data according to privacy regulations and policies, which includes a data annotation framework, audit report generation, and telemetry redaction. - Diagnostics: Provides a set of APIs that can be used to gather and report diagnostic information about the health of a service. - Contextual Options: Extends the .NET Options model to enable experimentations in production. diff --git a/build.sh b/build.sh index 15cbc641df9..375b1fa5743 100755 --- a/build.sh +++ b/build.sh @@ -1,5 +1,14 @@ #!/usr/bin/env bash +function is_cygwin_or_mingw() +{ + case $(uname -s) in + CYGWIN*) return 0;; + MINGW*) return 0;; + *) return 1;; + esac +} + # Stop script if unbound variable found (use ${var:-} if intentional) set -u @@ -16,4 +25,12 @@ then fi DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -"$DIR/eng/build.sh" "$@" + +if is_cygwin_or_mingw; then + # if bash shell running on Windows (not WSL), + # pass control to powershell build script. + DIR=$(cygpath -d "$DIR") + powershell -c "$DIR\\build.cmd" $@ +else + "$DIR/eng/build.sh" $@ +fi diff --git a/eng/MSBuild/LegacySupport.props b/eng/MSBuild/LegacySupport.props index 8ebacbd60f7..2cfe7b73964 100644 --- a/eng/MSBuild/LegacySupport.props +++ b/eng/MSBuild/LegacySupport.props @@ -1,4 +1,12 @@ + + + + + + + + diff --git a/eng/MSBuild/ProjectStaging.targets b/eng/MSBuild/ProjectStaging.targets index 32f3920db29..e3a89d03542 100644 --- a/eng/MSBuild/ProjectStaging.targets +++ b/eng/MSBuild/ProjectStaging.targets @@ -1,8 +1,14 @@ + + + true + $(NoWarn);LA0003 + + - + diff --git a/eng/MSBuild/Shared.props b/eng/MSBuild/Shared.props index a68b0e4298f..7c5ac8424e0 100644 --- a/eng/MSBuild/Shared.props +++ b/eng/MSBuild/Shared.props @@ -1,4 +1,8 @@ + + + + diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index 7915f0da377..d645fcf0549 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,172 +1,188 @@ - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 99135af51fa200682ecfc585011eaba907dea4ba + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 diff --git a/eng/Versions.props b/eng/Versions.props index 0c038ad9157..3732d8e1434 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -27,48 +27,52 @@ --> - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 - 9.0.0-rtm.24507.7 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.json b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs new file mode 100644 index 00000000000..f92fcfa3bc9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -0,0 +1,659 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; +using OpenAI; +using OpenAI.Chat; + +#pragma warning disable S1135 // Track uses of "TODO" tags +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace Microsoft.Extensions.AI; + +/// An for an OpenAI or . +public sealed partial class OpenAIChatClient : IChatClient +{ + /// Default OpenAI endpoint. + private static readonly Uri _defaultOpenAIEndpoint = new("https://api.openai.com/v1"); + + /// The underlying . + private readonly OpenAIClient? _openAIClient; + + /// The underlying . + private readonly ChatClient _chatClient; + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + /// The model to use. + public OpenAIChatClient(OpenAIClient openAIClient, string modelId) + { + _ = Throw.IfNull(openAIClient); + _ = Throw.IfNullOrWhitespace(modelId); + + _openAIClient = openAIClient; + _chatClient = openAIClient.GetChatClient(modelId); + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(openAIClient) as Uri ?? _defaultOpenAIEndpoint; + + Metadata = new(providerName, providerUrl, modelId); + } + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + public OpenAIChatClient(ChatClient chatClient) + { + _ = Throw.IfNull(chatClient); + + _chatClient = chatClient; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = chatClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(chatClient) as Uri ?? _defaultOpenAIEndpoint; + string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(chatClient) as string; + + Metadata = new(providerName, providerUrl, model); + } + + /// Gets or sets to use for any serialization activities related to tool call arguments and results. + public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; } + + /// + public ChatClientMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class => + typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : + typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient : + this as TService; + + /// + public async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + // Make the call to OpenAI. + OpenAI.Chat.ChatCompletion response = (await _chatClient.CompleteChatAsync( + ToOpenAIChatMessages(chatMessages), + ToOpenAIOptions(options), + cancellationToken).ConfigureAwait(false)).Value; + + // Create the return message. + ChatMessage returnMessage = new() + { + RawRepresentation = response, + Role = ToChatRole(response.Role), + }; + + // Populate its content from those in the OpenAI response content. + foreach (ChatMessageContentPart contentPart in response.Content) + { + if (ToAIContent(contentPart, response.Model) is AIContent aiContent) + { + returnMessage.Contents.Add(aiContent); + } + } + + // Also manufacture function calling content items from any tool calls in the response. + if (options?.Tools is { Count: > 0 }) + { + foreach (ChatToolCall toolCall in response.ToolCalls) + { + if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) + { + Dictionary? arguments = FunctionCallHelpers.ParseFunctionCallArguments(toolCall.FunctionArguments, out Exception? parsingException); + + returnMessage.Contents.Add(new FunctionCallContent(toolCall.Id, toolCall.FunctionName, arguments) + { + ModelId = response.Model, + Exception = parsingException, + RawRepresentation = toolCall + }); + } + } + } + + // Wrap the content in a ChatCompletion to return. + var completion = new ChatCompletion([returnMessage]) + { + RawRepresentation = response, + CompletionId = response.Id, + CreatedAt = response.CreatedAt, + ModelId = response.Model, + FinishReason = ToFinishReason(response.FinishReason), + }; + + if (response.Usage is ChatTokenUsage tokenUsage) + { + completion.Usage = new() + { + InputTokenCount = tokenUsage.InputTokenCount, + OutputTokenCount = tokenUsage.OutputTokenCount, + TotalTokenCount = tokenUsage.TotalTokenCount, + }; + + if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails details) + { + completion.Usage.AdditionalProperties = new() { [nameof(details.ReasoningTokenCount)] = details.ReasoningTokenCount }; + } + } + + if (response.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(response.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (response.Refusal is string refusal) + { + (completion.AdditionalProperties ??= [])[nameof(response.Refusal)] = refusal; + } + + if (response.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(response.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (response.SystemFingerprint is string systemFingerprint) + { + (completion.AdditionalProperties ??= [])[nameof(response.SystemFingerprint)] = systemFingerprint; + } + + return completion; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + Dictionary? functionCallInfos = null; + ChatRole? streamedRole = null; + ChatFinishReason? finishReason = null; + StringBuilder? refusal = null; + string? completionId = null; + DateTimeOffset? createdAt = null; + string? modelId = null; + string? fingerprint = null; + + // Process each update as it arrives + await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in _chatClient.CompleteChatStreamingAsync( + ToOpenAIChatMessages(chatMessages), ToOpenAIOptions(options), cancellationToken).ConfigureAwait(false)) + { + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? ToChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? ToFinishReason(reason) : null; + completionId ??= chatCompletionUpdate.CompletionId; + createdAt ??= chatCompletionUpdate.CreatedAt; + modelId ??= chatCompletionUpdate.Model; + fingerprint ??= chatCompletionUpdate.SystemFingerprint; + + // Create the response content object. + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = chatCompletionUpdate.CompletionId, + CreatedAt = chatCompletionUpdate.CreatedAt, + FinishReason = finishReason, + RawRepresentation = chatCompletionUpdate, + Role = streamedRole, + }; + + // Populate it with any additional metadata from the OpenAI object. + if (chatCompletionUpdate.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (chatCompletionUpdate.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.SystemFingerprint)] = fingerprint; + } + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is { Count: > 0 }) + { + foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate) + { + if (ToAIContent(contentPart, modelId) is AIContent aiContent) + { + completionUpdate.Contents.Add(aiContent); + } + } + } + + // Transfer over refusal updates. + if (chatCompletionUpdate.RefusalUpdate is not null) + { + _ = (refusal ??= new()).Append(chatCompletionUpdate.RefusalUpdate); + } + + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdates is { Count: > 0 } toolCallUpdates) + { + foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) + { + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) + { + functionCallInfos[toolCallUpdate.Index] = existing = new(); + } + + existing.CallId ??= toolCallUpdate.ToolCallId; + existing.Name ??= toolCallUpdate.FunctionName; + if (toolCallUpdate.FunctionArgumentsUpdate is not null) + { + _ = (existing.Arguments ??= new()).Append(toolCallUpdate.FunctionArgumentsUpdate); + } + } + } + + // Transfer over usage updates. + if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage) + { + UsageDetails usageDetails = new() + { + InputTokenCount = tokenUsage.InputTokenCount, + OutputTokenCount = tokenUsage.OutputTokenCount, + TotalTokenCount = tokenUsage.TotalTokenCount, + }; + + if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails details) + { + (usageDetails.AdditionalProperties = [])[nameof(tokenUsage.OutputTokenDetails)] = new Dictionary + { + [nameof(details.ReasoningTokenCount)] = details.ReasoningTokenCount, + }; + } + + // TODO: Add support for prompt token details (e.g. cached tokens) once it's exposed in OpenAI library. + + completionUpdate.Contents.Add(new UsageContent(usageDetails) + { + ModelId = modelId + }); + } + + // Now yield the item. + yield return completionUpdate; + } + + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) + { + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = completionId, + CreatedAt = createdAt, + FinishReason = finishReason, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + var arguments = FunctionCallHelpers.ParseFunctionCallArguments( + fci.Arguments?.ToString() ?? string.Empty, + out Exception? parsingException); + + completionUpdate.Contents.Add(new FunctionCallContent(fci.CallId!, fci.Name!, arguments) + { + ModelId = modelId, + Exception = parsingException + }); + } + } + + // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, + // add it to this function calling item. + if (refusal is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); + } + + // Propagate additional relevant metadata. + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = fingerprint; + } + + yield return completionUpdate; + } + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IChatClient interface. + } + + /// POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates. + private sealed class FunctionCallInfo + { + public string? CallId; + public string? Name; + public StringBuilder? Arguments; + } + + /// Converts an OpenAI role to an Extensions role. + private static ChatRole ToChatRole(ChatMessageRole role) => + role switch + { + ChatMessageRole.System => ChatRole.System, + ChatMessageRole.User => ChatRole.User, + ChatMessageRole.Assistant => ChatRole.Assistant, + ChatMessageRole.Tool => ChatRole.Tool, + _ => new ChatRole(role.ToString()), + }; + + /// Converts an OpenAI finish reason to an Extensions finish reason. + private static ChatFinishReason? ToFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => + finishReason?.ToString() is not string s ? null : + finishReason switch + { + OpenAI.Chat.ChatFinishReason.Stop => ChatFinishReason.Stop, + OpenAI.Chat.ChatFinishReason.Length => ChatFinishReason.Length, + OpenAI.Chat.ChatFinishReason.ContentFilter => ChatFinishReason.ContentFilter, + OpenAI.Chat.ChatFinishReason.ToolCalls or OpenAI.Chat.ChatFinishReason.FunctionCall => ChatFinishReason.ToolCalls, + _ => new ChatFinishReason(s), + }; + + /// Converts an extensions options instance to an OpenAI options instance. + private ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) + { + ChatCompletionOptions result = new(); + + if (options is not null) + { + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxOutputTokenCount = options.MaxOutputTokens; + result.TopP = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + foreach (string stopSequence in stopSequences) + { + result.StopSequences.Add(stopSequence); + } + } + + if (options.AdditionalProperties is { Count: > 0 } additionalProperties) + { + if (additionalProperties.TryGetConvertedValue(nameof(result.EndUserId), out string? endUserId)) + { + result.EndUserId = endUserId; + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.IncludeLogProbabilities), out bool includeLogProbabilities)) + { + result.IncludeLogProbabilities = includeLogProbabilities; + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.LogitBiases), out IDictionary? logitBiases)) + { + foreach (KeyValuePair kvp in logitBiases!) + { + result.LogitBiases[kvp.Key] = kvp.Value; + } + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls)) + { + result.AllowParallelToolCalls = allowParallelToolCalls; + } + +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (additionalProperties.TryGetConvertedValue(nameof(result.Seed), out long seed)) + { + result.Seed = seed; + } +#pragma warning restore OPENAI001 + + if (additionalProperties.TryGetConvertedValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) + { + result.TopLogProbabilityCount = topLogProbabilityCountInt; + } + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (AITool tool in tools) + { + if (tool is AIFunction af) + { + result.Tools.Add(ToOpenAIChatTool(af)); + } + } + + switch (options.ToolMode) + { + case AutoChatToolMode: + result.ToolChoice = ChatToolChoice.CreateAutoChoice(); + break; + + case RequiredChatToolMode required: + result.ToolChoice = required.RequiredFunctionName is null ? + ChatToolChoice.CreateRequiredChoice() : + ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName); + break; + } + } + + if (options.ResponseFormat is ChatResponseFormatText) + { + result.ResponseFormat = OpenAI.Chat.ChatResponseFormat.CreateTextFormat(); + } + else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + result.ResponseFormat = jsonFormat.Schema is string jsonSchema ? + OpenAI.Chat.ChatResponseFormat.CreateJsonSchemaFormat(jsonFormat.SchemaName ?? "json_schema", BinaryData.FromString(jsonSchema), jsonFormat.SchemaDescription) : + OpenAI.Chat.ChatResponseFormat.CreateJsonObjectFormat(); + } + } + + return result; + } + + /// Converts an Extensions function to an OpenAI chat tool. + private ChatTool ToOpenAIChatTool(AIFunction aiFunction) + { + _ = aiFunction.Metadata.AdditionalProperties.TryGetConvertedValue("Strict", out bool strict); + + BinaryData resultParameters = OpenAIChatToolJson.ZeroFunctionParametersSchema; + + var parameters = aiFunction.Metadata.Parameters; + if (parameters is { Count: > 0 }) + { + OpenAIChatToolJson tool = new(); + + foreach (AIFunctionParameterMetadata parameter in parameters) + { + tool.Properties.Add( + parameter.Name, + FunctionCallHelpers.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions)); + + if (parameter.IsRequired) + { + tool.Required.Add(parameter.Name); + } + } + + resultParameters = BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.OpenAIChatToolJson)); + } + + return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict); + } + + /// Used to create the JSON payload for an OpenAI chat tool description. + private sealed class OpenAIChatToolJson + { + /// Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function. + public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray()); + + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + [JsonPropertyName("required")] + public List Required { get; set; } = []; + + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + } + + /// Creates an from a . + /// The content part to convert into a content. + /// The model ID. + /// The constructed , or null if the content part could not be converted. + private static AIContent? ToAIContent(ChatMessageContentPart contentPart, string? modelId) + { + AIContent? aiContent = null; + + AdditionalPropertiesDictionary? additionalProperties = null; + + if (contentPart.Kind == ChatMessageContentPartKind.Text) + { + aiContent = new TextContent(contentPart.Text); + } + else if (contentPart.Kind == ChatMessageContentPartKind.Image) + { + ImageContent? imageContent; + aiContent = imageContent = + contentPart.ImageUri is not null ? new ImageContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : + contentPart.ImageBytes is not null ? new ImageContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : + null; + + if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) + { + (additionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; + } + } + + if (aiContent is not null) + { + if (contentPart.Refusal is string refusal) + { + (additionalProperties ??= [])[nameof(contentPart.Refusal)] = refusal; + } + + aiContent.ModelId = modelId; + aiContent.AdditionalProperties = additionalProperties; + aiContent.RawRepresentation = contentPart; + } + + return aiContent; + } + + /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. + private IEnumerable ToOpenAIChatMessages(IEnumerable inputs) + { + // Maps all of the M.E.AI types to the corresponding OpenAI types. + // Unrecognized content is ignored. + + foreach (ChatMessage input in inputs) + { + if (input.Role == ChatRole.System) + { + yield return new SystemChatMessage(input.Text) { ParticipantName = input.AuthorName }; + } + else if (input.Role == ChatRole.Tool) + { + foreach (AIContent item in input.Contents) + { + if (item is FunctionResultContent resultContent) + { + string? result = resultContent.Result as string; + if (result is null && resultContent.Result is not null) + { + try + { + result = FunctionCallHelpers.FormatFunctionResultAsJson(resultContent.Result, ToolCallJsonSerializerOptions); + } + catch (NotSupportedException) + { + // If the type can't be serialized, skip it. + } + } + + yield return new ToolChatMessage(resultContent.CallId, result ?? string.Empty); + } + } + } + else if (input.Role == ChatRole.User) + { + yield return new UserChatMessage(input.Contents.Select(static (AIContent item) => item switch + { + TextContent textContent => ChatMessageContentPart.CreateTextPart(textContent.Text), + ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType) : + imageContent.Uri is string uri ? ChatMessageContentPart.CreateImagePart(new Uri(uri)) : + null, + _ => null, + }).Where(c => c is not null)) + { ParticipantName = input.AuthorName }; + } + else if (input.Role == ChatRole.Assistant) + { + Dictionary? toolCalls = null; + + foreach (var content in input.Contents) + { + if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true) + { + (toolCalls ??= []).Add( + callRequest.CallId, + ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions))); + } + } + + AssistantChatMessage message = toolCalls is not null ? + new(toolCalls.Values) { ParticipantName = input.AuthorName } : + new(input.Text) { ParticipantName = input.AuthorName }; + + if (input.AdditionalProperties?.TryGetConvertedValue(nameof(message.Refusal), out string? refusal) is true) + { + message.Refusal = refusal; + } + + yield return message; + } + } + } + + /// Source-generated JSON type information. + [JsonSerializable(typeof(OpenAIChatToolJson))] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs new file mode 100644 index 00000000000..a33fd34e1ea --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for working with s. +public static class OpenAIClientExtensions +{ + /// Gets an for use with this . + /// The client. + /// The model. + /// An that may be used to converse via the . + public static IChatClient AsChatClient(this OpenAIClient openAIClient, string modelId) => + new OpenAIChatClient(openAIClient, modelId); + + /// Gets an for use with this . + /// The client. + /// An that may be used to converse via the . + public static IChatClient AsChatClient(this ChatClient chatClient) => + new OpenAIChatClient(chatClient); + + /// Gets an for use with this . + /// The client. + /// The model to use. + /// The number of dimensions to generate in each embedding. + /// An that may be used to generate embeddings via the . + public static IEmbeddingGenerator> AsEmbeddingGenerator(this OpenAIClient openAIClient, string modelId, int? dimensions = null) => + new OpenAIEmbeddingGenerator(openAIClient, modelId, dimensions); + + /// Gets an for use with this . + /// The client. + /// The number of dimensions to generate in each embedding. + /// An that may be used to generate embeddings via the . + public static IEmbeddingGenerator> AsEmbeddingGenerator(this EmbeddingClient embeddingClient, int? dimensions = null) => + new OpenAIEmbeddingGenerator(embeddingClient, dimensions); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs new file mode 100644 index 00000000000..e91394befdd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; +using OpenAI; +using OpenAI.Embeddings; + +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace Microsoft.Extensions.AI; + +/// An for an OpenAI . +public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator> +{ + /// Default OpenAI endpoint. + private const string DefaultOpenAIEndpoint = "https://api.openai.com/v1"; + + /// The underlying . + private readonly OpenAIClient? _openAIClient; + + /// The underlying . + private readonly EmbeddingClient _embeddingClient; + + /// The number of dimensions produced by the generator. + private readonly int? _dimensions; + + /// Initializes a new instance of the class. + /// The underlying client. + /// The model to use. + /// The number of dimensions to generate in each embedding. + public OpenAIEmbeddingGenerator( + OpenAIClient openAIClient, string modelId, int? dimensions = null) + { + _ = Throw.IfNull(openAIClient); + _ = Throw.IfNullOrWhitespace(modelId); + if (dimensions is < 1) + { + Throw.ArgumentOutOfRangeException(nameof(dimensions), "Value must be greater than 0."); + } + + _openAIClient = openAIClient; + _embeddingClient = openAIClient.GetEmbeddingClient(modelId); + _dimensions = dimensions; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + string providerUrl = (typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(openAIClient) as Uri)?.ToString() ?? + DefaultOpenAIEndpoint; + + Metadata = CreateMetadata(dimensions, providerName, providerUrl, modelId); + } + + /// Initializes a new instance of the class. + /// The underlying client. + /// The number of dimensions to generate in each embedding. + public OpenAIEmbeddingGenerator(EmbeddingClient embeddingClient, int? dimensions = null) + { + _ = Throw.IfNull(embeddingClient); + if (dimensions < 1) + { + Throw.ArgumentOutOfRangeException(nameof(dimensions), "Value must be greater than 0."); + } + + _embeddingClient = embeddingClient; + _dimensions = dimensions; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = embeddingClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + string providerUrl = (typeof(EmbeddingClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(embeddingClient) as Uri)?.ToString() ?? + DefaultOpenAIEndpoint; + + FieldInfo? modelField = typeof(EmbeddingClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + string? model = modelField?.GetValue(embeddingClient) as string; + + Metadata = CreateMetadata(dimensions, providerName, providerUrl, model); + } + + /// Creates the for this instance. + private static EmbeddingGeneratorMetadata CreateMetadata(int? dimensions, string providerName, string providerUrl, string? model) => + new(providerName, Uri.TryCreate(providerUrl, UriKind.Absolute, out Uri? providerUri) ? providerUri : null, model, dimensions); + + /// + public EmbeddingGeneratorMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class + => + typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : + typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient : + this as TService; + + /// + public async Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + OpenAI.Embeddings.EmbeddingGenerationOptions? openAIOptions = ToOpenAIOptions(options); + + var embeddings = (await _embeddingClient.GenerateEmbeddingsAsync(values, openAIOptions, cancellationToken).ConfigureAwait(false)).Value; + + return new(embeddings.Select(e => + new Embedding(e.ToFloats()) + { + CreatedAt = DateTimeOffset.UtcNow, + ModelId = embeddings.Model, + })) + { + Usage = new() + { + InputTokenCount = embeddings.Usage.InputTokenCount, + TotalTokenCount = embeddings.Usage.TotalTokenCount + }, + }; + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IEmbeddingGenerator interface. + } + + /// Converts an extensions options instance to an OpenAI options instance. + private OpenAI.Embeddings.EmbeddingGenerationOptions? ToOpenAIOptions(EmbeddingGenerationOptions? options) + { + OpenAI.Embeddings.EmbeddingGenerationOptions openAIOptions = new() + { + Dimensions = _dimensions, + }; + + if (options?.AdditionalProperties is { Count: > 0 } additionalProperties) + { + // Allow per-instance dimensions to be overridden by a per-call property + if (additionalProperties.TryGetConvertedValue(nameof(openAIOptions.Dimensions), out int? dimensions)) + { + openAIOptions.Dimensions = dimensions; + } + + if (additionalProperties.TryGetConvertedValue(nameof(openAIOptions.EndUserId), out string? endUserId)) + { + openAIOptions.EndUserId = endUserId; + } + } + + return openAIOptions; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md b/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md new file mode 100644 index 00000000000..f7af212f4d7 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md @@ -0,0 +1,313 @@ +# Microsoft.Extensions.AI.OpenAI + +Provides an implementation of the `IChatClient` interface for the `OpenAI` package and OpenAI-compatible endpoints. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI.OpenAI +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +### Chat + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Chat + Conversation History + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync( +[ + new ChatMessage(ChatRole.System, "You are a helpful AI assistant"), + new ChatMessage(ChatRole.User, "What is AI?"), +])); +``` + +### Chat streaming + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +await foreach (var update in client.CompleteStreamingAsync("What is AI?")) +{ + Console.Write(update); +} +``` + +### Tool calling + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseFunctionInvocation() + .Use(openaiClient); + +ChatOptions chatOptions = new() +{ + Tools = [AIFunctionFactory.Create(GetWeather)] +}; + +await foreach (var message in client.CompleteStreamingAsync("Do I need an umbrella?", chatOptions)) +{ + Console.Write(message); +} + +[Description("Gets the weather")] +static string GetWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; +``` + +### Caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .Use(openaiClient); + +for (int i = 0; i < 3; i++) +{ + await foreach (var message in client.CompleteStreamingAsync("In less than 100 words, what is AI?")) + { + Console.Write(message); + } + + Console.WriteLine(); + Console.WriteLine(); +} +``` + +### Telemetry + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(openaiClient); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Telemetry, Caching, and Tool Calling + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; +using OpenTelemetry.Trace; + +// Configure telemetry +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Configure caching +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +// Configure tool calling +var chatOptions = new ChatOptions +{ + Tools = [AIFunctionFactory.Create(GetPersonAge)] +}; + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .UseFunctionInvocation() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(openaiClient); + +for (int i = 0; i < 3; i++) +{ + Console.WriteLine(await client.CompleteAsync("How much older is Alice than Bob?", chatOptions)); +} + +[Description("Gets the age of a person specified by name.")] +static int GetPersonAge(string personName) => + personName switch + { + "Alice" => 42, + "Bob" => 35, + _ => 26, + }; +``` + +### Text embedding generation + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IEmbeddingGenerator> generator = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsEmbeddingGenerator("text-embedding-3-small"); + +var embeddings = await generator.GenerateAsync("What is AI?"); + +Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +``` + +### Text embedding generation with caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IEmbeddingGenerator> openAIGenerator = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsEmbeddingGenerator("text-embedding-3-small"); + +IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(cache) + .Use(openAIGenerator); + +foreach (var prompt in new[] { "What is AI?", "What is .NET?", "What is AI?" }) +{ + var embeddings = await generator.GenerateAsync(prompt); + + Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +} +``` + +### Dependency Injection + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using OpenAI; + +// App Setup +var builder = Host.CreateApplicationBuilder(); +builder.Services.AddSingleton(new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))); +builder.Services.AddSingleton(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); +builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)); + +builder.Services.AddChatClient(b => b + .UseDistributedCache() + .UseLogging() + .Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +var app = builder.Build(); + +// Elsewhere in the app +var chatClient = app.Services.GetRequiredService(); +Console.WriteLine(await chatClient.CompleteAsync("What is AI?")); +``` + +### Minimal Web API + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API_KEY"])); + +builder.Services.AddChatClient(b => + b.Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +builder.Services.AddEmbeddingGenerator>(g => + g.Use(g.Services.GetRequiredService().AsEmbeddingGenerator("text-embedding-3-small"))); + +var app = builder.Build(); + +app.MapPost("/chat", async (IChatClient client, string message) => +{ + var response = await client.CompleteAsync(message); + return response.Message; +}); + +app.MapPost("/embedding", async (IEmbeddingGenerator> client, string message) => +{ + var response = await client.GenerateAsync(message); + return response[0].Vector; +}); + +app.Run(); +``` + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs new file mode 100644 index 00000000000..8128926f942 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Security.Cryptography; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides internal helpers for implementing caching services. +internal static class CachingHelpers +{ + /// Computes a default cache key for the specified parameters. + /// Specifies the type of the data being used to compute the key. + /// The data with which to compute the key. + /// The . + /// A string that will be used as a cache key. + public static string GetCacheKey(TValue value, JsonSerializerOptions serializerOptions) + => GetCacheKey(value, false, serializerOptions); + + /// Computes a default cache key for the specified parameters. + /// Specifies the type of the data being used to compute the key. + /// The data with which to compute the key. + /// Another data item that causes the key to vary. + /// The . + /// A string that will be used as a cache key. + public static string GetCacheKey(TValue value, bool flag, JsonSerializerOptions serializerOptions) + { + _ = Throw.IfNull(value); + _ = Throw.IfNull(serializerOptions); + serializerOptions.MakeReadOnly(); + + var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue))); + + if (flag && jsonKeyBytes.Length > 0) + { + // Make an arbitrary change to the hash input based on the flag + // The alternative would be including the flag in "value" in the + // first place, but that's likely to require an extra allocation + // or the inclusion of another type in the JsonSerializerContext. + // This is a micro-optimization we can change at any time. + jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]); + } + + // The complete JSON representation is excessively long for a cache key, duplicating much of the content + // from the value. So we use a hash of it as the default key. +#if NET8_0_OR_GREATER + Span hashData = stackalloc byte[SHA256.HashSizeInBytes]; + SHA256.HashData(jsonKeyBytes, hashData); + return Convert.ToHexString(hashData); +#else + using var sha256 = SHA256.Create(); + var hashData = sha256.ComputeHash(jsonKeyBytes); + return BitConverter.ToString(hashData).Replace("-", string.Empty); +#endif + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs new file mode 100644 index 00000000000..89a778cdd1b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -0,0 +1,155 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that caches the results of chat calls. +/// +public abstract class CachingChatClient : DelegatingChatClient +{ + /// Initializes a new instance of the class. + /// The underlying . + protected CachingChatClient(IChatClient innerClient) + : base(innerClient) + { + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + // We're only storing the final result, not the in-flight task, so that we can avoid caching failures + // or having problems when one of the callers cancels but others don't. This has the drawback that + // concurrent callers might trigger duplicate requests, but that's acceptable. + var cacheKey = GetCacheKey(false, chatMessages, options); + + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing) + { + return existing; + } + + var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); + return result; + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + var cacheKey = GetCacheKey(true, chatMessages, options); + if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) + { + foreach (var chunk in existingChunks) + { + yield return chunk; + } + } + else + { + var capturedItems = new List(); + StreamingChatCompletionUpdate? previousCoalescedCopy = null; + await foreach (var item in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + yield return item; + + // If this item is compatible with the previous one, we will coalesce them in the cache + var previous = capturedItems.Count > 0 ? capturedItems[capturedItems.Count - 1] : null; + if (item.ChoiceIndex == 0 + && item.Contents.Count == 1 + && item.Contents[0] is TextContent currentTextContent + && previous is { ChoiceIndex: 0 } + && previous.Role == item.Role + && previous.Contents is { Count: 1 } + && previous.Contents[0] is TextContent previousTextContent) + { + if (!ReferenceEquals(previous, previousCoalescedCopy)) + { + // We don't want to mutate any object that we also yield, since the recipient might + // not expect that. Instead make a copy we can safely mutate. + previousCoalescedCopy = new() + { + Role = previous.Role, + AuthorName = previous.AuthorName, + AdditionalProperties = previous.AdditionalProperties, + ChoiceIndex = previous.ChoiceIndex, + RawRepresentation = previous.RawRepresentation, + Contents = [new TextContent(previousTextContent.Text)] + }; + + // The last item we captured was before we knew it could be coalesced + // with this one, so replace it with the coalesced copy + capturedItems[capturedItems.Count - 1] = previousCoalescedCopy; + } + +#pragma warning disable S1643 // Strings should not be concatenated using '+' in a loop + ((TextContent)previousCoalescedCopy.Contents[0]).Text += currentTextContent.Text; +#pragma warning restore S1643 + } + else + { + capturedItems.Add(item); + } + } + + await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Computes a cache key for the specified call parameters. + /// + /// A flag to indicate if this is a streaming call. + /// The chat content. + /// The chat options to configure the request. + /// A string that will be used as a cache key. + protected abstract string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options); + + /// + /// Returns a previously cached , if available. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); + + /// + /// Returns a previously cached list of values, if available. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken); + + /// + /// Stores a in the underlying cache. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheAsync(string key, ChatCompletion value, CancellationToken cancellationToken); + + /// + /// Stores a list of values in the underlying cache. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs new file mode 100644 index 00000000000..d7934ba7809 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A builder for creating pipelines of . +public sealed class ChatClientBuilder +{ + /// The registered client factory instances. + private List>? _clientFactories; + + /// Initializes a new instance of the class. + /// The service provider to use for dependency injection. + public ChatClientBuilder(IServiceProvider? services = null) + { + Services = services ?? EmptyServiceProvider.Instance; + } + + /// Gets the associated with the builder instance. + public IServiceProvider Services { get; } + + /// Completes the pipeline by adding a final that represents the underlying backend. This is typically a client for an LLM service. + /// The inner client to use. + /// An instance of that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn. + public IChatClient Use(IChatClient innerClient) + { + var chatClient = Throw.IfNull(innerClient); + + // To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost. + if (_clientFactories is not null) + { + for (var i = _clientFactories.Count - 1; i >= 0; i--) + { + chatClient = _clientFactories[i](Services, chatClient) ?? + throw new InvalidOperationException( + $"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " + + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances."); + } + } + + return chatClient; + } + + /// Adds a factory for an intermediate chat client to the chat client pipeline. + /// The client factory function. + /// The updated instance. + public ChatClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + return Use((_, innerClient) => clientFactory(innerClient)); + } + + /// Adds a factory for an intermediate chat client to the chat client pipeline. + /// The client factory function. + /// The updated instance. + public ChatClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + (_clientFactories ??= []).Add(clientFactory); + return this; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs new file mode 100644 index 00000000000..246ac7f3689 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for registering with a . +public static class ChatClientBuilderServiceCollectionExtensions +{ + /// Adds a chat client to the . + /// The to which the client should be added. + /// The factory to use to construct the instance. + /// The collection. + /// The client is registered as a scoped service. + public static IServiceCollection AddChatClient( + this IServiceCollection services, + Func clientFactory) + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(clientFactory); + + return services.AddScoped(services => + clientFactory(new ChatClientBuilder(services))); + } + + /// Adds a chat client to the . + /// The to which the client should be added. + /// The key with which to associate the client. + /// The factory to use to construct the instance. + /// The collection. + /// The client is registered as a scoped service. + public static IServiceCollection AddKeyedChatClient( + this IServiceCollection services, + object serviceKey, + Func clientFactory) + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(serviceKey); + _ = Throw.IfNull(clientFactory); + + return services.AddKeyedScoped(serviceKey, (services, _) => + clientFactory(new ChatClientBuilder(services))); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs new file mode 100644 index 00000000000..5d16440a8fa --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods on that simplify working with structured output. +/// +public static partial class ChatClientStructuredOutputExtensions +{ + private const string UsesReflectionJsonSerializerMessage = + "This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications."; + + private static JsonSerializerOptions? _defaultJsonSerializerOptions; + + /// Sends chat messages to the model, requesting a response matching the type . + /// The . + /// The chat content to send. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + /// The type of structured output to request. + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + public static Task> CompleteAsync( + this IChatClient chatClient, + IList chatMessages, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, chatMessages, DefaultJsonSerializerOptions, options, useNativeJsonSchema, cancellationToken); + + /// Sends a user chat text message to the model, requesting a response matching the type . + /// The . + /// The text content for the chat message to send. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// The type of structured output to request. + [RequiresDynamicCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. " + + "Use System.Text.Json source generation for native AOT applications.")] + [RequiresUnreferencedCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. " + + "Use System.Text.Json source generation for native AOT applications.")] + public static Task> CompleteAsync( + this IChatClient chatClient, + string chatMessage, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken); + + /// Sends a user chat text message to the model, requesting a response matching the type . + /// The . + /// The text content for the chat message to send. + /// The JSON serialization options to use. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// The type of structured output to request. + public static Task> CompleteAsync( + this IChatClient chatClient, + string chatMessage, + JsonSerializerOptions serializerOptions, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken); + + /// Sends chat messages to the model, requesting a response matching the type . + /// The . + /// The chat content to send. + /// The JSON serialization options to use. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + /// The type of structured output to request. + public static async Task> CompleteAsync( + this IChatClient chatClient, + IList chatMessages, + JsonSerializerOptions serializerOptions, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class + { + _ = Throw.IfNull(chatClient); + _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(serializerOptions); + + serializerOptions.MakeReadOnly(); + + var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), new() + { + TreatNullObliviousAsNonNullable = true, + TransformSchemaNode = static (context, node) => + { + if (node is JsonObject obj) + { + if (obj.TryGetPropertyValue("enum", out _) + && !obj.TryGetPropertyValue("type", out _)) + { + obj.Insert(0, "type", "string"); + } + } + + return node; + }, + }); + schemaNode.Insert(0, "$schema", "https://json-schema.org/draft/2020-12/schema"); + schemaNode.Add("additionalProperties", false); + var schema = JsonSerializer.Serialize(schemaNode, JsonNodeContext.Default.JsonNode); + + ChatMessage? promptAugmentation = null; + options = (options ?? new()).Clone(); + + // Currently there's no way for the inner IChatClient to specify whether structured output + // is supported, so we always default to false. In the future, some mechanism of declaring + // capabilities may be added (e.g., on ChatClientMetadata). + if (useNativeJsonSchema.GetValueOrDefault(false)) + { + // When using native structured output, we don't add any additional prompt, because + // the LLM backend is meant to do whatever's needed to explain the schema to the LLM. + options.ResponseFormat = ChatResponseFormat.ForJsonSchema( + schema, + schemaName: typeof(T).Name, + schemaDescription: typeof(T).GetCustomAttribute()?.Description); + } + else + { + options.ResponseFormat = ChatResponseFormat.Json; + + // When not using native structured output, augment the chat messages with a schema prompt +#pragma warning disable SA1118 // Parameter should not span multiple lines + promptAugmentation = new ChatMessage(ChatRole.System, $$""" + Respond with a JSON value conforming to the following schema: + ``` + {{schema}} + ``` + """); +#pragma warning restore SA1118 // Parameter should not span multiple lines + + chatMessages.Add(promptAugmentation); + } + + try + { + var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return new ChatCompletion(result, serializerOptions); + } + finally + { + if (promptAugmentation is not null) + { + _ = chatMessages.Remove(promptAugmentation); + } + } + } + + private static JsonSerializerOptions DefaultJsonSerializerOptions + { + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + get => _defaultJsonSerializerOptions ?? GetOrCreateDefaultJsonSerializerOptions(); + } + + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + private static JsonSerializerOptions GetOrCreateDefaultJsonSerializerOptions() + { + var options = new JsonSerializerOptions(JsonSerializerDefaults.General) + { + Converters = { new JsonStringEnumConverter() }, + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + WriteIndented = true, + }; + return Interlocked.CompareExchange(ref _defaultJsonSerializerOptions, options, null) ?? options; + } + + [JsonSerializable(typeof(JsonNode))] + [JsonSourceGenerationOptions(WriteIndented = true)] + private sealed partial class JsonNodeContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs new file mode 100644 index 00000000000..344a01d2c22 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs @@ -0,0 +1,147 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the result of a chat completion request with structured output. +/// The type of value expected from the chat completion. +/// +/// Language models are not guaranteed to honor the requested schema. If the model's output is not +/// parseable as the expected type, then will return . +/// You can access the underlying JSON response on the property. +/// +public class ChatCompletion : ChatCompletion +{ + private static readonly JsonReaderOptions _allowMultipleValuesJsonReaderOptions = new JsonReaderOptions { AllowMultipleValues = true }; + private readonly JsonSerializerOptions _serializerOptions; + + private T? _deserializedResult; + private bool _hasDeserializedResult; + + /// Initializes a new instance of the class. + /// The unstructured that is being wrapped. + /// The to use when deserializing the result. + public ChatCompletion(ChatCompletion completion, JsonSerializerOptions serializerOptions) + : base(Throw.IfNull(completion).Choices) + { + _serializerOptions = Throw.IfNull(serializerOptions); + CompletionId = completion.CompletionId; + ModelId = completion.ModelId; + CreatedAt = completion.CreatedAt; + FinishReason = completion.FinishReason; + Usage = completion.Usage; + RawRepresentation = completion.RawRepresentation; + AdditionalProperties = completion.AdditionalProperties; + } + + /// + /// Gets the result of the chat completion as an instance of . + /// If the response did not contain JSON, or if deserialization fails, this property will throw. + /// To avoid exceptions, use instead. + /// + public T Result + { + get + { + var result = GetResultCore(out var failureReason); + return failureReason switch + { + FailureReason.ResultDidNotContainJson => throw new InvalidOperationException("The response did not contain text to be deserialized"), + FailureReason.DeserializationProducedNull => throw new InvalidOperationException("The deserialized response is null"), + _ => result!, + }; + } + } + + /// + /// Attempts to deserialize the result to produce an instance of . + /// + /// The result. + /// if the result was produced, otherwise . + public bool TryGetResult([NotNullWhen(true)] out T? result) + { + try + { + result = GetResultCore(out var failureReason); + return failureReason is null; + } +#pragma warning disable CA1031 // Do not catch general exception types + catch + { + result = default; + return false; + } +#pragma warning restore CA1031 // Do not catch general exception types + } + + private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo typeInfo) + { + // We need to deserialize only the first top-level object as a workaround for a common LLM backend + // issue. GPT 3.5 Turbo commonly returns multiple top-level objects after doing a function call. + // See https://community.openai.com/t/2-json-objects-returned-when-using-function-calling-and-json-mode/574348 + var utf8ByteLength = Encoding.UTF8.GetByteCount(json); + var buffer = ArrayPool.Shared.Rent(utf8ByteLength); + try + { + var utf8SpanLength = Encoding.UTF8.GetBytes(json, 0, json.Length, buffer, 0); + var utf8Span = new ReadOnlySpan(buffer, 0, utf8SpanLength); + var reader = new Utf8JsonReader(utf8Span, _allowMultipleValuesJsonReaderOptions); + return JsonSerializer.Deserialize(ref reader, typeInfo); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + private string? GetResultAsJson() + { + var choice = Choices.Count == 1 ? Choices[0] : null; + var content = choice?.Contents.Count == 1 ? choice.Contents[0] : null; + return (content as TextContent)?.Text; + } + + private T? GetResultCore(out FailureReason? failureReason) + { + if (_hasDeserializedResult) + { + failureReason = default; + return _deserializedResult; + } + + var json = GetResultAsJson(); + if (string.IsNullOrEmpty(json)) + { + failureReason = FailureReason.ResultDidNotContainJson; + return default; + } + + // If there's an exception here, we want it to propagate, since the Result property is meant to throw directly + var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo)_serializerOptions.GetTypeInfo(typeof(T))); + if (deserialized is null) + { + failureReason = FailureReason.DeserializationProducedNull; + return default; + } + + _deserializedResult = deserialized; + _hasDeserializedResult = true; + failureReason = default; + return deserialized; + } + + private enum FailureReason + { + ResultDidNotContainJson, + DeserializationProducedNull, + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs new file mode 100644 index 00000000000..a8a4b9269e2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that updates or replaces the used by the remainder of the pipeline. +/// +/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options +/// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide +/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example +/// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the +/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance +/// and mutating the clone, for example: +/// +/// options => +/// { +/// var newOptions = options?.Clone() ?? new(); +/// newOptions.MaxTokens = 1000; +/// return newOptions; +/// } +/// +/// +public sealed class ConfigureOptionsChatClient : DelegatingChatClient +{ + /// The callback delegate used to configure options. + private readonly Func _configureOptions; + + /// Initializes a new instance of the class with the specified callback. + /// The inner client. + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions) + : base(innerClient) + { + _configureOptions = Throw.IfNull(configureOptions); + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + return await base.CompleteAsync(chatMessages, _configureOptions(options), cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var update in base.CompleteStreamingAsync(chatMessages, _configureOptions(options), cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..12b903c0dac --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class ConfigureOptionsChatClientBuilderExtensions +{ + /// + /// Adds a callback that updates or replaces . This can be used to set default options. + /// + /// The . + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + /// The . + /// + /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options + /// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide + /// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example + /// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the + /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance + /// and mutating the clone, for example: + /// + /// options => + /// { + /// var newOptions = options?.Clone() ?? new(); + /// newOptions.MaxTokens = 1000; + /// return newOptions; + /// } + /// + /// + public static ChatClientBuilder UseChatOptions( + this ChatClientBuilder builder, Func configureOptions) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(configureOptions); + + return builder.Use(innerClient => new ConfigureOptionsChatClient(innerClient, configureOptions)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs new file mode 100644 index 00000000000..65c50c090bd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that caches the results of completion calls, storing them as JSON in an . +/// +public class DistributedCachingChatClient : CachingChatClient +{ + private readonly IDistributedCache _storage; + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used as the backing store for the cache. + public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache storage) + : base(innerClient) + { + _storage = Throw.IfNull(storage); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing cache data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return (ChatCompletion?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion))); + } + + return null; + } + + /// + protected override async Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return (IReadOnlyList?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); + } + + return null; + } + + /// + protected override async Task WriteCacheAsync(string key, ChatCompletion value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) + { + // While it might be desirable to include ChatOptions in the cache key, it's not always possible, + // since ChatOptions can contain types that are not guaranteed to be serializable or have a stable + // hashcode across multiple calls. So the default cache key is simply the JSON representation of + // the chat contents. Developers may subclass and override this to provide custom rules. + _jsonSerializerOptions.MakeReadOnly(); + return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..d465161e1e4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Extension methods for adding a to an pipeline. +/// +public static class DistributedCachingChatClientBuilderExtensions +{ + /// + /// Adds a as the next stage in the pipeline. + /// + /// The . + /// + /// An optional instance that will be used as the backing store for the cache. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The provided as . + public static ChatClientBuilder UseDistributedCache(this ChatClientBuilder builder, IDistributedCache? storage = null, Action? configure = null) + { + _ = Throw.IfNull(builder); + return builder.Use((services, innerClient) => + { + storage ??= services.GetRequiredService(); + var chatClient = new DistributedCachingChatClient(innerClient, storage); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs new file mode 100644 index 00000000000..c46d7f43156 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -0,0 +1,639 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that invokes functions defined on . +/// Include this in a chat pipeline to resolve function calls automatically. +/// +/// +/// When this client receives a in a chat completion, it responds +/// by calling the corresponding defined in , +/// producing a . +/// +public class FunctionInvokingChatClient : DelegatingChatClient +{ + /// Maximum number of roundtrips allowed to the inner client. + private int? _maximumIterationsPerRequest; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + public FunctionInvokingChatClient(IChatClient innerClient) + : base(innerClient) + { + } + + /// + /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. + /// + /// + /// + /// If the value is , then if a function call fails with an exception, the + /// underlying will be instructed to give a response without invoking + /// any further functions. + /// + /// + /// If the value is , the underlying will be allowed + /// to continue attempting function calls until is reached. + /// + /// + /// The default value is . + /// + /// + public bool RetryOnError { get; set; } + + /// + /// Gets or sets a value indicating whether detailed exception information should be included + /// in the chat history when calling the underlying . + /// + /// + /// + /// The default value is , meaning that only a generic error message will + /// be included in the chat history. This prevents the underlying language model from disclosing + /// raw exception details to the end user, since it does not receive that information. Even in this + /// case, the raw object is available to application code by inspecting + /// the property. + /// + /// + /// If set to , the full exception message will be added to the chat history + /// when calling the underlying . This can help it to bypass problems on + /// its own, for example by retrying the function call with different arguments. However it may + /// result in disclosing the raw exception information to external users, which may be a security + /// concern depending on the application scenario. + /// + /// + public bool DetailedErrors { get; set; } + + /// + /// Gets or sets a value indicating whether to allow concurrent invocation of functions. + /// + /// + /// + /// An individual response from the inner client may contain multiple function call requests. + /// By default, such function calls may be issued to execute concurrently with each other. Set + /// to false to disable such concurrent invocation and force + /// the functions to be invoked serially. + /// + /// + /// The default value is . + /// + /// + public bool ConcurrentInvocation { get; set; } = true; + + /// + /// Gets or sets a value indicating whether to keep intermediate messages in the chat history. + /// + /// + /// When the inner returns to the + /// , the adds + /// those messages to the list of messages, along with instances + /// it creates with the results of invoking the requested functions. The resulting augmented + /// list of messages is then passed to the inner client in order to send the results back. + /// By default, is , and those + /// messages will persist in the list provided to + /// and by the caller. Set + /// to to remove those messages prior to completing the operation. + /// + public bool KeepFunctionCallingMessages { get; set; } = true; + + /// + /// Gets or sets the maximum number of iterations per request. + /// + /// + /// + /// Each request to this may end up making + /// multiple requests to the inner client. Each time the inner client responds with + /// a function call request, this client may perform that invocation and send the results + /// back to the inner client in a new request. This property limits the number of times + /// such a roundtrip is performed. If null, there is no limit applied. If set, the value + /// must be at least one, as it includes the initial request. + /// + /// + /// The default value is . + /// + /// + public int? MaximumIterationsPerRequest + { + get => _maximumIterationsPerRequest; + set + { + if (value < 1) + { + Throw.ArgumentOutOfRangeException(nameof(value)); + } + + _maximumIterationsPerRequest = value; + } + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + ChatCompletion? response; + + HashSet? messagesToRemove = null; + HashSet? contentsToRemove = null; + try + { + for (int iteration = 0; ; iteration++) + { + // Make the call to the handler. + response = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (options is null + || options.Tools is not { Count: > 0 } + || response.Choices.Count == 0 + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + { + break; + } + + // If there's more than one choice, we don't know which one to add to chat history, or which + // of their function calls to process. This should not happen except if the developer has + // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer + // doesn't realize this and is wasting their budget requesting extra choices we'd never use. + if (response.Choices.Count > 1) + { + throw new InvalidOperationException($"Automatic function call invocation only accepts a single choice, but {response.Choices.Count} choices were received."); + } + + // Extract any function call contents on the first choice. If there are none, we're done. + // We don't have any way to express a preference to use a different choice, since this + // is a niche case especially with function calling. + FunctionCallContent[] functionCallContents = response.Message.Contents.OfType().ToArray(); + if (functionCallContents.Length == 0) + { + break; + } + + // Track all added messages in order to remove them, if requested. + if (!KeepFunctionCallingMessages) + { + messagesToRemove ??= []; + } + + // Add the original response message into the history and track the message for removal. + chatMessages.Add(response.Message); + if (messagesToRemove is not null) + { + if (functionCallContents.Length == response.Message.Contents.Count) + { + // The most common case is that the response message contains only function calling content. + // In that case, we can just track the whole message for removal. + _ = messagesToRemove.Add(response.Message); + } + else + { + // In the less likely case where some content is function calling and some isn't, we don't want to remove + // the non-function calling content by removing the whole message. So we track the content directly. + (contentsToRemove ??= []).UnionWith(functionCallContents); + } + } + + // Add the responses from the function calls into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (modeAndMessages.MessagesAdded is not null) + { + messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + } + + switch (modeAndMessages.Mode) + { + case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: + // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = ChatToolMode.Auto; + break; + + case ContinueMode.AllowOneMoreRoundtrip: + // The LLM gets one further chance to answer, but cannot use tools. + options = options.Clone(); + options.Tools = null; + break; + + case ContinueMode.Terminate: + // Bail immediately. + return response; + } + } + + return response!; + } + finally + { + RemoveMessagesAndContentFromList(messagesToRemove, contentsToRemove, chatMessages); + } + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + HashSet? messagesToRemove = null; + try + { + for (int iteration = 0; ; iteration++) + { + List? functionCallContents = null; + await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + // We're going to emit all StreamingChatMessage items upstream, even ones that represent + // function calls, because a given StreamingChatMessage can contain other content too. + yield return chunk; + + foreach (var item in chunk.Contents.OfType()) + { + functionCallContents ??= []; + functionCallContents.Add(item); + } + } + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (options is null + || options.Tools is not { Count: > 0 } + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) + || functionCallContents is not { Count: > 0 }) + { + break; + } + + // Track all added messages in order to remove them, if requested. + if (!KeepFunctionCallingMessages) + { + messagesToRemove ??= []; + } + + // Add a manufactured response message containing the function call contents to the chat history. + ChatMessage functionCallMessage = new(ChatRole.Assistant, [.. functionCallContents]); + chatMessages.Add(functionCallMessage); + _ = messagesToRemove?.Add(functionCallMessage); + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (modeAndMessages.MessagesAdded is not null) + { + messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + } + + // Decide how to proceed based on the result of the function calls. + switch (modeAndMessages.Mode) + { + case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: + // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = ChatToolMode.Auto; + break; + + case ContinueMode.AllowOneMoreRoundtrip: + // The LLM gets one further chance to answer, but cannot use tools. + options = options.Clone(); + options.Tools = null; + break; + + case ContinueMode.Terminate: + // Bail immediately. + yield break; + } + } + } + finally + { + RemoveMessagesAndContentFromList(messagesToRemove, contentToRemove: null, chatMessages); + } + } + + /// + /// Removes all of the messages in from + /// and all of the content in from the messages in . + /// + private static void RemoveMessagesAndContentFromList( + HashSet? messagesToRemove, + HashSet? contentToRemove, + IList messages) + { + Debug.Assert( + contentToRemove is null || messagesToRemove is not null, + "We should only be tracking content to remove if we're also tracking messages to remove."); + + if (messagesToRemove is not null) + { + for (int m = messages.Count - 1; m >= 0; m--) + { + ChatMessage message = messages[m]; + + if (contentToRemove is not null) + { + for (int c = message.Contents.Count - 1; c >= 0; c--) + { + if (contentToRemove.Contains(message.Contents[c])) + { + message.Contents.RemoveAt(c); + } + } + } + + if (messages.Count == 0 || messagesToRemove.Contains(messages[m])) + { + messages.RemoveAt(m); + } + } + } + } + + /// + /// Processes the function calls in the list. + /// + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing the functions to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( + IList chatMessages, ChatOptions options, IReadOnlyList functionCallContents, int iteration, CancellationToken cancellationToken) + { + // We must add a response for every tool call, regardless of whether we successfully executed it or not. + // If we successfully execute it, we'll add the result. If we don't, we'll add an error. + + int functionCount = functionCallContents.Count; + Debug.Assert(functionCount > 0, $"Expecteded {nameof(functionCount)} to be > 0, got {functionCount}."); + + // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. + if (functionCount == 1) + { + FunctionInvocationResult result = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[0], iteration, 0, 1, cancellationToken).ConfigureAwait(false); + IList added = AddResponseMessages(chatMessages, [result]); + return (result.ContinueMode, added); + } + else + { + FunctionInvocationResult[] results; + + if (ConcurrentInvocation) + { + // Schedule the invocation of every function. + results = await Task.WhenAll( + from i in Enumerable.Range(0, functionCount) + select Task.Run(() => ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken))).ConfigureAwait(false); + } + else + { + // Invoke each function serially. + results = new FunctionInvocationResult[functionCount]; + for (int i = 0; i < functionCount; i++) + { + results[i] = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken).ConfigureAwait(false); + } + } + + ContinueMode continueMode = ContinueMode.Continue; + IList added = AddResponseMessages(chatMessages, results); + foreach (FunctionInvocationResult fir in results) + { + if (fir.ContinueMode > continueMode) + { + continueMode = fir.ContinueMode; + } + } + + return (continueMode, added); + } + } + + /// Processes the function call described in . + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call content representing the function to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The 0-based index of the function being called out of total functions. + /// The number of function call requests made, of which this is one. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task ProcessFunctionCallAsync( + IList chatMessages, ChatOptions options, FunctionCallContent functionCallContent, + int iteration, int functionCallIndex, int totalFunctionCount, CancellationToken cancellationToken) + { + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. + AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Metadata.Name == functionCallContent.Name); + if (function is null) + { + return new(ContinueMode.Continue, FunctionStatus.NotFound, functionCallContent, result: null, exception: null); + } + + FunctionInvocationContext context = new(chatMessages, functionCallContent, function) + { + Iteration = iteration, + FunctionCallIndex = functionCallIndex, + FunctionCount = totalFunctionCount, + }; + + try + { + object? result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + return new( + context.Terminate ? ContinueMode.Terminate : ContinueMode.Continue, + FunctionStatus.CompletedSuccessfully, + functionCallContent, + result, + exception: null); + } + catch (Exception e) when (!cancellationToken.IsCancellationRequested) + { + return new( + RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. + FunctionStatus.Failed, + functionCallContent, + result: null, + exception: e); + } + } + + /// Represents the return value of , dictating how the loop should behave. + /// These values are ordered from least severe to most severe, and code explicitly depends on the ordering. + internal enum ContinueMode + { + /// Send back the responses and continue processing. + Continue = 0, + + /// Send back the response but without any tools. + AllowOneMoreRoundtrip = 1, + + /// Immediately exit the function calling loop. + Terminate = 2, + } + + /// Adds one or more response messages for function invocation results. + /// The chat to which to add the one or more response messages. + /// Information about the function call invocations and results. + /// A list of all chat messages added to . + protected virtual IList AddResponseMessages(IList chat, ReadOnlySpan results) + { + _ = Throw.IfNull(chat); + + var contents = new AIContent[results.Length]; + for (int i = 0; i < results.Length; i++) + { + contents[i] = CreateFunctionResultContent(results[i]); + } + + ChatMessage message = new(ChatRole.Tool, contents); + chat.Add(message); + return [message]; + + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + { + _ = Throw.IfNull(result); + + object? functionResult; + if (result.Status == FunctionStatus.CompletedSuccessfully) + { + functionResult = result.Result ?? "Success: Function completed."; + } + else + { + string message = result.Status switch + { + FunctionStatus.NotFound => "Error: Requested function not found.", + FunctionStatus.Failed => "Error: Function failed.", + _ => "Error: Unknown error.", + }; + + if (DetailedErrors && result.Exception is not null) + { + message = $"{message} Exception: {result.Exception.Message}"; + } + + functionResult = message; + } + + return new FunctionResultContent(result.CallContent.CallId, result.CallContent.Name, functionResult, result.Exception); + } + } + + /// Invokes the function asynchronously. + /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// + /// The to monitor for cancellation requests. The default is . + /// The result of the function invocation. This may be null if the function invocation returned null. + protected virtual Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + { + _ = Throw.IfNull(context); + + return context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken); + } + + /// Provides context for a function invocation. + public sealed class FunctionInvocationContext + { + /// Initializes a new instance of the class. + /// The chat contents associated with the operation that initiated this function call request. + /// The AI function to be invoked. + /// The function call content information associated with this invocation. + internal FunctionInvocationContext( + IList chatMessages, + FunctionCallContent functionCallContent, + AIFunction function) + { + Function = function; + CallContent = functionCallContent; + ChatMessages = chatMessages; + } + + /// Gets or sets the AI function to be invoked. + public AIFunction Function { get; set; } + + /// Gets or sets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; set; } + + /// Gets or sets the chat contents associated with the operation that initiated this function call request. + public IList ChatMessages { get; set; } + + /// Gets or sets the number of this iteration with the underlying client. + /// + /// The initial request to the client that passes along the chat contents provided to the + /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on. + /// + public int Iteration { get; set; } + + /// Gets or sets the index of the function call within the iteration. + /// + /// The response from the underlying client may include multiple function call requests. + /// This index indicates the position of the function call within the iteration. + /// + public int FunctionCallIndex { get; set; } + + /// Gets or sets the total number of function call requests within the iteration. + /// + /// The response from the underlying client may include multiple function call requests. + /// This count indicates how many there were. + /// + public int FunctionCount { get; set; } + + /// Gets or sets a value indicating whether to terminate the request. + /// + /// In response to a function call request, the function may be invoked, its result added to the chat contents, + /// and a new request issued to the wrapped client. If this property is set to true, that subsequent request + /// will not be issued and instead the loop immediately terminated rather than continuing until there are no + /// more function call requests in responses. + /// + public bool Terminate { get; set; } + } + + /// Provides information about the invocation of a function call. + public sealed class FunctionInvocationResult + { + internal FunctionInvocationResult(ContinueMode continueMode, FunctionStatus status, FunctionCallContent callContent, object? result, Exception? exception) + { + ContinueMode = continueMode; + Status = status; + CallContent = callContent; + Result = result; + Exception = exception; + } + + /// Gets status about how the function invocation completed. + public FunctionStatus Status { get; } + + /// Gets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; } + + /// Gets the result of the function call. + public object? Result { get; } + + /// Gets any exception the function call threw. + public Exception? Exception { get; } + + /// Gets an indication for how the caller should continue the processing loop. + internal ContinueMode ContinueMode { get; } + } + + /// Provides error codes for when errors occur as part of the function calling loop. + public enum FunctionStatus + { + /// The operation completed successfully. + CompletedSuccessfully, + + /// The requested function could not be found. + NotFound, + + /// The function call failed with an exception. + Failed, + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..15010b42068 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods for attaching a to a chat pipeline. +/// +public static class FunctionInvokingChatClientBuilderExtensions +{ + /// + /// Enables automatic function call invocation on the chat pipeline. + /// + /// This works by adding an instance of with default options. + /// The being used to build the chat pipeline. + /// An optional callback that can be used to configure the instance. + /// The supplied . + public static ChatClientBuilder UseFunctionInvocation(this ChatClientBuilder builder, Action? configure = null) + { + _ = Throw.IfNull(builder); + + return builder.Use(innerClient => + { + var chatClient = new FunctionInvokingChatClient(innerClient); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs new file mode 100644 index 00000000000..f0a9e8a0d75 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -0,0 +1,154 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable EA0000 // Use source generated logging methods for improved performance +#pragma warning disable CA2254 // Template should be a static expression + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that logs chat operations to an . +public class LoggingChatClient : DelegatingChatClient +{ + /// An instance used for all logging. + private readonly ILogger _logger; + + /// The to use for serialization of state written to the logger. + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used for all logging. + public LoggingChatClient(IChatClient innerClient, ILogger logger) + : base(innerClient) + { + _logger = Throw.IfNull(logger); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing logging data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + public override async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + LogStart(chatMessages, options); + try + { + var completion = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (completion, _jsonSerializerOptions), null, static (state, _) => + $"CompleteAsync completed: {JsonSerializer.Serialize(state.completion, state._jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion)))}"); + } + else + { + _logger.LogDebug("CompleteAsync completed."); + } + } + + return completion; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteAsync failed."); + throw; + } + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + LogStart(chatMessages, options); + + IAsyncEnumerator e; + try + { + e = base.CompleteStreamingAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteStreamingAsync failed."); + throw; + } + + try + { + StreamingChatCompletionUpdate? update = null; + while (true) + { + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + break; + } + + update = e.Current; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteStreamingAsync failed."); + throw; + } + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (update, _jsonSerializerOptions), null, static (state, _) => + $"CompleteStreamingAsync received update: {JsonSerializer.Serialize(state.update, state._jsonSerializerOptions.GetTypeInfo(typeof(StreamingChatCompletionUpdate)))}"); + } + else + { + _logger.LogDebug("CompleteStreamingAsync received update."); + } + } + + yield return update; + } + + _logger.LogDebug("CompleteStreamingAsync completed."); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + private void LogStart(IList chatMessages, ChatOptions? options, [CallerMemberName] string? methodName = null) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (methodName, chatMessages, options, this), null, static (state, _) => + $"{state.methodName} invoked: " + + $"Messages: {JsonSerializer.Serialize(state.chatMessages, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(IList)))}. " + + $"Options: {JsonSerializer.Serialize(state.options, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatOptions)))}. " + + $"Metadata: {JsonSerializer.Serialize(state.Item4.Metadata, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatClientMetadata)))}."); + } + else + { + _logger.LogDebug($"{methodName} invoked."); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..056ba5401fc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class LoggingChatClientBuilderExtensions +{ + /// Adds logging to the chat client pipeline. + /// The . + /// + /// An optional with which logging should be performed. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The . + public static ChatClientBuilder UseLogging( + this ChatClientBuilder builder, ILogger? logger = null, Action? configure = null) + { + _ = Throw.IfNull(builder); + + return builder.Use((services, innerClient) => + { + logger ??= services.GetRequiredService().CreateLogger(nameof(LoggingChatClient)); + var chatClient = new LoggingChatClient(innerClient, logger); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs new file mode 100644 index 00000000000..13e2d1229dd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -0,0 +1,509 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that implements the OpenTelemetry Semantic Conventions for Generative AI systems. +/// +/// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. +/// The specification is still experimental and subject to change; as such, the telemetry output by this client is also subject to change. +/// +public sealed class OpenTelemetryChatClient : DelegatingChatClient +{ + private readonly ActivitySource _activitySource; + private readonly Meter _meter; + + private readonly Histogram _tokenUsageHistogram; + private readonly Histogram _operationDurationHistogram; + + private readonly string? _modelId; + private readonly string? _modelProvider; + private readonly string? _endpointAddress; + private readonly int _endpointPort; + + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An optional source name that will be used on the telemetry data. + public OpenTelemetryChatClient(IChatClient innerClient, string? sourceName = null) + : base(innerClient) + { + Debug.Assert(innerClient is not null, "Should have been validated by the base ctor"); + + ChatClientMetadata metadata = innerClient!.Metadata; + _modelId = metadata.ModelId; + _modelProvider = metadata.ProviderName; + _endpointAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); + _endpointPort = metadata.ProviderUri?.Port ?? 0; + + string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; + _activitySource = new(name); + _meter = new(name); + + _tokenUsageHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.TokenUsage.Name, + OpenTelemetryConsts.TokensUnit, + OpenTelemetryConsts.GenAI.Client.TokenUsage.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.TokenUsage.ExplicitBucketBoundaries }); + + _operationDurationHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.OperationDuration.Name, + OpenTelemetryConsts.SecondsUnit, + OpenTelemetryConsts.GenAI.Client.OperationDuration.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries }); + + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when formatting chat data into telemetry strings. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + _activitySource.Dispose(); + _meter.Dispose(); + } + + base.Dispose(disposing); + } + + /// + /// Gets or sets a value indicating whether potentially sensitive information (e.g. prompts) should be included in telemetry. + /// + /// + /// The value is by default, meaning that telemetry will include metadata such as token counts but not the raw text of prompts or completions. + /// + public bool EnableSensitiveData { get; set; } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _jsonSerializerOptions.MakeReadOnly(); + + using Activity? activity = StartActivity(chatMessages, options); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + string? requestModelId = options?.ModelId ?? _modelId; + + ChatCompletion? response = null; + Exception? error = null; + try + { + response = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + SetCompletionResponse(activity, requestModelId, response, error, stopwatch); + } + + return response; + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _jsonSerializerOptions.MakeReadOnly(); + + using Activity? activity = StartActivity(chatMessages, options); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + string? requestModelId = options?.ModelId ?? _modelId; + + IAsyncEnumerable response; + try + { + response = base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } + catch (Exception ex) + { + SetCompletionResponse(activity, requestModelId, null, ex, stopwatch); + throw; + } + + var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); + List? streamedContents = activity is not null ? [] : null; + try + { + while (true) + { + StreamingChatCompletionUpdate update; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + + update = responseEnumerator.Current; + } + catch (Exception ex) + { + SetCompletionResponse(activity, requestModelId, null, ex, stopwatch); + throw; + } + + streamedContents?.Add(update); + yield return update; + } + } + finally + { + if (activity is not null) + { + UsageContent? usageContent = streamedContents?.SelectMany(c => c.Contents).OfType().LastOrDefault(); + SetCompletionResponse( + activity, + stopwatch, + requestModelId, + OrganizeStreamingContent(streamedContents), + streamedContents?.SelectMany(c => c.Contents).OfType(), + usage: usageContent?.Details); + } + + await responseEnumerator.DisposeAsync(); + } + } + + /// Gets a value indicating whether diagnostics are enabled. + private bool Enabled => _activitySource.HasListeners(); + + /// Convert chat history to a string aligned with the OpenAI format. + private static string ToOpenAIFormat(IEnumerable messages, JsonSerializerOptions serializerOptions) + { + var sb = new StringBuilder().Append('['); + + string messageSeparator = string.Empty; + foreach (var message in messages) + { + _ = sb.Append(messageSeparator); + messageSeparator = ", \n"; + + string text = string.Concat(message.Contents.OfType().Select(c => c.Text)); + _ = sb.Append("{\"role\": \"").Append(message.Role).Append("\", \"content\": ").Append(JsonSerializer.Serialize(text, serializerOptions.GetTypeInfo(typeof(string)))); + + if (message.Contents.OfType().Any()) + { + _ = sb.Append(", \"tool_calls\": ").Append('['); + + string messageItemSeparator = string.Empty; + foreach (var functionCall in message.Contents.OfType()) + { + _ = sb.Append(messageItemSeparator); + messageItemSeparator = ", \n"; + + _ = sb.Append("{\"id\": \"").Append(functionCall.CallId) + .Append("\", \"function\": {\"arguments\": ").Append(JsonSerializer.Serialize(functionCall.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary)))) + .Append(", \"name\": \"").Append(functionCall.Name) + .Append("\"}, \"type\": \"function\"}"); + } + + _ = sb.Append(']'); + } + + _ = sb.Append('}'); + } + + _ = sb.Append(']'); + return sb.ToString(); + } + + /// Organize streaming content by choice index. + private static Dictionary> OrganizeStreamingContent(IEnumerable? contents) + { + Dictionary> choices = []; + if (contents is null) + { + return choices; + } + + foreach (var content in contents) + { + if (!choices.TryGetValue(content.ChoiceIndex, out var choiceContents)) + { + choices[content.ChoiceIndex] = choiceContents = []; + } + + choiceContents.Add(content); + } + + return choices; + } + + /// Creates an activity for a chat completion request, or returns null if not enabled. + private Activity? StartActivity(IList chatMessages, ChatOptions? options) + { + Activity? activity = null; + if (Enabled) + { + string? modelId = options?.ModelId ?? _modelId; + + activity = _activitySource.StartActivity( + $"chat.completions {modelId}", + ActivityKind.Client, + default(ActivityContext), + [ + new(OpenTelemetryConsts.GenAI.Operation.Name, "chat"), + new(OpenTelemetryConsts.GenAI.Request.Model, modelId), + new(OpenTelemetryConsts.GenAI.System, _modelProvider), + ]); + + if (activity is not null) + { + if (_endpointAddress is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Server.Address, _endpointAddress) + .SetTag(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (options is not null) + { + if (options.FrequencyPenalty is float frequencyPenalty) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.FrequencyPenalty, frequencyPenalty); + } + + if (options.MaxOutputTokens is int maxTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.MaxTokens, maxTokens); + } + + if (options.PresencePenalty is float presencePenalty) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.PresencePenalty, presencePenalty); + } + + if (options.StopSequences is IList stopSequences) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.StopSequences, $"[{string.Join(", ", stopSequences.Select(s => $"\"{s}\""))}]"); + } + + if (options.Temperature is float temperature) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.Temperature, temperature); + } + + if (options.AdditionalProperties?.TryGetConvertedValue("top_k", out double topK) is true) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.TopK, topK); + } + + if (options.TopP is float top_p) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.TopP, top_p); + } + } + + if (EnableSensitiveData) + { + _ = activity.AddEvent(new ActivityEvent( + OpenTelemetryConsts.GenAI.Content.Prompt, + tags: new ActivityTagsCollection([new(OpenTelemetryConsts.GenAI.Prompt, ToOpenAIFormat(chatMessages, _jsonSerializerOptions))]))); + } + } + } + + return activity; + } + + /// Adds chat completion information to the activity. + private void SetCompletionResponse( + Activity? activity, + string? requestModelId, + ChatCompletion? completions, + Exception? error, + Stopwatch? stopwatch) + { + if (!Enabled) + { + return; + } + + if (_operationDurationHistogram.Enabled && stopwatch is not null) + { + TagList tags = default; + + AddMetricTags(ref tags, requestModelId, completions); + if (error is not null) + { + tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); + } + + _operationDurationHistogram.Record(stopwatch.Elapsed.TotalSeconds, tags); + } + + if (_tokenUsageHistogram.Enabled && completions?.Usage is { } usage) + { + if (usage.InputTokenCount is int inputTokens) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input"); + AddMetricTags(ref tags, requestModelId, completions); + _tokenUsageHistogram.Record(inputTokens); + } + + if (usage.OutputTokenCount is int outputTokens) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "output"); + AddMetricTags(ref tags, requestModelId, completions); + _tokenUsageHistogram.Record(outputTokens); + } + } + + if (activity is null) + { + return; + } + + if (error is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Error.Type, error.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, error.Message); + return; + } + + if (completions is not null) + { + if (completions.FinishReason is ChatFinishReason finishReason) + { +#pragma warning disable CA1308 // Normalize strings to uppercase + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.FinishReasons, $"[\"{finishReason.Value.ToLowerInvariant()}\"]"); +#pragma warning restore CA1308 + } + + if (!string.IsNullOrWhiteSpace(completions.CompletionId)) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Id, completions.CompletionId); + } + + if (completions.ModelId is not null) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Model, completions.ModelId); + } + + if (completions.Usage?.InputTokenCount is int inputTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.InputTokens, inputTokens); + } + + if (completions.Usage?.OutputTokenCount is int outputTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.OutputTokens, outputTokens); + } + + if (EnableSensitiveData) + { + _ = activity.AddEvent(new ActivityEvent( + OpenTelemetryConsts.GenAI.Content.Completion, + tags: new ActivityTagsCollection([new(OpenTelemetryConsts.GenAI.Completion, ToOpenAIFormat(completions.Choices, _jsonSerializerOptions))]))); + } + } + } + + /// Adds streaming chat completion information to the activity. + private void SetCompletionResponse( + Activity? activity, + Stopwatch? stopwatch, + string? requestModelId, + Dictionary> choices, + IEnumerable? toolCalls, + UsageDetails? usage) + { + if (activity is null || !Enabled || choices.Count == 0) + { + return; + } + + string? id = null; + ChatFinishReason? finishReason = null; + string? modelId = null; + List messages = new(choices.Count); + + foreach (var choice in choices) + { + ChatRole? role = null; + List items = []; + foreach (var update in choice.Value) + { + id ??= update.CompletionId; + role ??= update.Role; + finishReason ??= update.FinishReason; + foreach (AIContent content in update.Contents) + { + items.Add(content); + modelId ??= content.ModelId; + } + } + + messages.Add(new ChatMessage(role ?? ChatRole.Assistant, items)); + } + + if (toolCalls is not null && messages.FirstOrDefault()?.Contents is { } c) + { + foreach (var functionCall in toolCalls) + { + c.Add(functionCall); + } + } + + ChatCompletion completion = new(messages) + { + CompletionId = id, + FinishReason = finishReason, + ModelId = modelId, + Usage = usage, + }; + + SetCompletionResponse(activity, requestModelId, completion, error: null, stopwatch); + } + + private void AddMetricTags(ref TagList tags, string? requestModelId, ChatCompletion? completions) + { + tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, "chat"); + + if (requestModelId is not null) + { + tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModelId); + } + + tags.Add(OpenTelemetryConsts.GenAI.System, _modelProvider); + + if (_endpointAddress is string endpointAddress) + { + tags.Add(OpenTelemetryConsts.Server.Address, endpointAddress); + tags.Add(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (completions?.ModelId is string responseModel) + { + tags.Add(OpenTelemetryConsts.GenAI.Response.Model, responseModel); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..bf1ff4e9f0d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class OpenTelemetryChatClientBuilderExtensions +{ + /// + /// Adds OpenTelemetry support to the chat client pipeline, following the OpenTelemetry Semantic Conventions for Generative AI systems. + /// + /// + /// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. + /// The specification is still experimental and subject to change; as such, the telemetry output by this client is also subject to change. + /// + /// The . + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The . + public static ChatClientBuilder UseOpenTelemetry( + this ChatClientBuilder builder, string? sourceName = null, Action? configure = null) => + Throw.IfNull(builder).Use(innerClient => + { + var chatClient = new OpenTelemetryChatClient(innerClient, sourceName); + configure?.Invoke(chatClient); + return chatClient; + }); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs new file mode 100644 index 00000000000..8438d467eb6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that caches the results of embedding generation calls. +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public abstract class CachingEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + /// Initializes a new instance of the class. + /// The underlying . + protected CachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator) + : base(innerGenerator) + { + } + + /// + public override async Task> GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(values); + + // Optimize for the common-case of a single value in a list/array. + if (values is IList valuesList) + { + switch (valuesList.Count) + { + case 0: + return []; + + case 1: + // In the expected common case where we can cheaply tell there's only a single value and access it, + // we can avoid all the overhead of splitting the list and reassembling it. + var cacheKey = GetCacheKey(valuesList[0], options); + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding e) + { + return [e]; + } + else + { + var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); + if (generated.Count != 1) + { + throw new InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); + } + + await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); + return generated; + } + } + } + + // Some of the inputs may already be cached. Go through each, checking to see whether each individually is cached. + // Split those that are cached into one list and those that aren't into another. We retain their original positions + // so that we can reassemble the results in the correct order. + GeneratedEmbeddings results = []; + List<(int Index, string CacheKey, TInput Input)>? uncached = null; + foreach (TInput input in values) + { + // We're only storing the final result, not the in-flight task, so that we can avoid caching failures + // or having problems when one of the callers cancels but others don't. This has the drawback that + // concurrent callers might trigger duplicate requests, but that's acceptable. + var cacheKey = GetCacheKey(input, options); + + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding existing) + { + results.Add(existing); + } + else + { + (uncached ??= []).Add((results.Count, cacheKey, input)); + results.Add(null!); // temporary placeholder + } + } + + // If anything wasn't cached, we need to generate embeddings for those. + if (uncached is not null) + { + // Now make a single call to the wrapped generator to generate embeddings for all of the uncached inputs. + var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken).ConfigureAwait(false); + + // Store the resulting embeddings into the cache individually. + for (int i = 0; i < uncachedResults.Count; i++) + { + await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken).ConfigureAwait(false); + } + + // Fill in the gaps with the newly generated results. + for (int i = 0; i < uncachedResults.Count; i++) + { + results[uncached[i].Index] = uncachedResults[i]; + } + } + + Debug.Assert(results.All(e => e is not null), "Expected all values to be non-null"); + return results; + } + + /// + /// Computes a cache key for the specified call parameters. + /// + /// The for which an embedding is being requested. + /// The options to configure the request. + /// A string that will be used as a cache key. + protected abstract string GetCacheKey(TInput value, EmbeddingGenerationOptions? options); + + /// Returns a previously cached , if available. + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); + + /// Stores a in the underlying cache. + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheAsync(string key, TEmbedding value, CancellationToken cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs new file mode 100644 index 00000000000..932bb2f91b8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating embedding generator that caches the results of embedding generation calls, +/// storing them as JSON in an . +/// +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public class DistributedCachingEmbeddingGenerator : CachingEmbeddingGenerator + where TEmbedding : Embedding +{ + private readonly IDistributedCache _storage; + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// A instance that will be used as the backing store for the cache. + public DistributedCachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, IDistributedCache storage) + : base(innerGenerator) + { + _ = Throw.IfNull(storage); + _storage = storage; + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing cache data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set + { + _ = Throw.IfNull(value); + _jsonSerializerOptions = value; + } + } + + /// + protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return JsonSerializer.Deserialize(existingJson, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); + } + + return null; + } + + /// + protected override async Task WriteCacheAsync(string key, TEmbedding value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) + { + // While it might be desirable to include options in the cache key, it's not always possible, + // since options can contain types that are not guaranteed to be serializable or have a stable + // hashcode across multiple calls. So the default cache key is simply the JSON representation of + // the value. Developers may subclass and override this to provide custom rules. + return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..77aaa30e05d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Extension methods for adding a to an +/// pipeline. +/// +public static class DistributedCachingEmbeddingGeneratorBuilderExtensions +{ + /// + /// Adds a as the next stage in the pipeline. + /// + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The . + /// + /// An optional instance that will be used as the backing store for the cache. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The provided as . + public static EmbeddingGeneratorBuilder UseDistributedCache( + this EmbeddingGeneratorBuilder builder, + IDistributedCache? storage = null, + Action>? configure = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(builder); + return builder.Use((services, innerGenerator) => + { + storage ??= services.GetRequiredService(); + var result = new DistributedCachingEmbeddingGenerator(innerGenerator, storage); + configure?.Invoke(result); + return result; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs new file mode 100644 index 00000000000..96c4c92d4a9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A builder for creating pipelines of . +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public sealed class EmbeddingGeneratorBuilder + where TEmbedding : Embedding +{ + /// The registered client factory instances. + private List, IEmbeddingGenerator>>? _generatorFactories; + + /// Initializes a new instance of the class. + /// The service provider to use for dependency injection. + public EmbeddingGeneratorBuilder(IServiceProvider? services = null) + { + Services = services ?? EmptyServiceProvider.Instance; + } + + /// Gets the associated with the builder instance. + public IServiceProvider Services { get; } + + /// + /// Builds an instance of using the specified inner generator. + /// + /// The inner generator to use. + /// An instance of . + /// + /// If there are any factories registered with this builder, is used as a seed to + /// the last factory, and the result of each factory delegate is passed to the previously registered factory. + /// The final result is then returned from this call. + /// + public IEmbeddingGenerator Use(IEmbeddingGenerator innerGenerator) + { + var embeddingGenerator = Throw.IfNull(innerGenerator); + + // To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost. + if (_generatorFactories is not null) + { + for (var i = _generatorFactories.Count - 1; i >= 0; i--) + { + embeddingGenerator = _generatorFactories[i](Services, embeddingGenerator) ?? + throw new InvalidOperationException( + $"The {nameof(IEmbeddingGenerator)} entry at index {i} returned null. " + + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IEmbeddingGenerator)} instances."); + } + } + + return embeddingGenerator; + } + + /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. + /// The generator factory function. + /// The updated instance. + public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) + { + _ = Throw.IfNull(generatorFactory); + + return Use((_, innerGenerator) => generatorFactory(innerGenerator)); + } + + /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. + /// The generator factory function. + /// The updated instance. + public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) + { + _ = Throw.IfNull(generatorFactory); + + _generatorFactories ??= []; + _generatorFactories.Add(generatorFactory); + return this; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs new file mode 100644 index 00000000000..369de130e72 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for registering with a . +public static class EmbeddingGeneratorBuilderServiceCollectionExtensions +{ + /// Adds a embedding generator to the . + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The to which the generator should be added. + /// The factory to use to construct the instance. + /// The collection. + /// The generator is registered as a scoped service. + public static IServiceCollection AddEmbeddingGenerator( + this IServiceCollection services, + Func, IEmbeddingGenerator> generatorFactory) + where TEmbedding : Embedding + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(generatorFactory); + + return services.AddScoped(services => + generatorFactory(new EmbeddingGeneratorBuilder(services))); + } + + /// Adds an embedding generator to the . + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The to which the service should be added. + /// The key with which to associated the generator. + /// The factory to use to construct the instance. + /// The collection. + /// The generator is registered as a scoped service. + public static IServiceCollection AddKeyedEmbeddingGenerator( + this IServiceCollection services, + object serviceKey, + Func, IEmbeddingGenerator> generatorFactory) + where TEmbedding : Embedding + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(serviceKey); + _ = Throw.IfNull(generatorFactory); + + return services.AddKeyedScoped(serviceKey, (services, _) => + generatorFactory(new EmbeddingGeneratorBuilder(services))); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs new file mode 100644 index 00000000000..b7981de8129 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable EA0000 // Use source generated logging methods for improved performance + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that logs embedding generation operations to an . +/// Specifies the type of the input passed to the generator. +/// Specifies the type of the embedding instance produced by the generator. +public class LoggingEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + /// An instance used for all logging. + private readonly ILogger _logger; + + /// The to use for serialization of state written to the logger. + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used for all logging. + public LoggingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, ILogger logger) + : base(innerGenerator) + { + _logger = Throw.IfNull(logger); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing logging data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + public override async Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (values, options, this), null, static (state, _) => + "GenerateAsync invoked: " + + $"Values: {JsonSerializer.Serialize(state.values, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(IEnumerable)))}. " + + $"Options: {JsonSerializer.Serialize(state.options, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGenerationOptions)))}. " + + $"Metadata: {JsonSerializer.Serialize(state.Item3.Metadata, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGeneratorMetadata)))}."); + } + else + { + _logger.LogDebug("GenerateAsync invoked."); + } + } + + try + { + var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.LogDebug("GenerateAsync generated {Count} embedding(s).", embeddings.Count); + } + + return embeddings; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "GenerateAsync failed."); + throw; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..1335a3fd8d3 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class LoggingEmbeddingGeneratorBuilderExtensions +{ + /// Adds logging to the embedding generator pipeline. + /// Specifies the type of the input passed to the generator. + /// Specifies the type of the embedding instance produced by the generator. + /// The . + /// + /// An optional with which logging should be performed. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The . + public static EmbeddingGeneratorBuilder UseLogging( + this EmbeddingGeneratorBuilder builder, ILogger? logger = null, Action>? configure = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(builder); + + return builder.Use((services, innerGenerator) => + { + logger ??= services.GetRequiredService().CreateLogger(nameof(LoggingEmbeddingGenerator)); + var generator = new LoggingEmbeddingGenerator(innerGenerator, logger); + configure?.Invoke(generator); + return generator; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs new file mode 100644 index 00000000000..8105cc64bdf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -0,0 +1,239 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that implements the OpenTelemetry Semantic Conventions for Generative AI systems. +/// +/// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. +/// The specification is still experimental and subject to change; as such, the telemetry output by this generator is also subject to change. +/// +/// The type of input used to produce embeddings. +/// The type of embedding generated. +public sealed class OpenTelemetryEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + private readonly ActivitySource _activitySource; + private readonly Meter _meter; + + private readonly Histogram _tokenUsageHistogram; + private readonly Histogram _operationDurationHistogram; + + private readonly string? _modelId; + private readonly string? _modelProvider; + private readonly string? _endpointAddress; + private readonly int _endpointPort; + private readonly int? _dimensions; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , which is the next stage of the pipeline. + /// An optional source name that will be used on the telemetry data. + public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator innerGenerator, string? sourceName = null) + : base(innerGenerator) + { + Debug.Assert(innerGenerator is not null, "Should have been validated by the base ctor."); + + EmbeddingGeneratorMetadata metadata = innerGenerator!.Metadata; + _modelId = metadata.ModelId; + _modelProvider = metadata.ProviderName; + _endpointAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); + _endpointPort = metadata.ProviderUri?.Port ?? 0; + _dimensions = metadata.Dimensions; + + string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; + _activitySource = new(name); + _meter = new(name); + + _tokenUsageHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.TokenUsage.Name, + OpenTelemetryConsts.TokensUnit, + OpenTelemetryConsts.GenAI.Client.TokenUsage.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.TokenUsage.ExplicitBucketBoundaries }); + + _operationDurationHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.OperationDuration.Name, + OpenTelemetryConsts.SecondsUnit, + OpenTelemetryConsts.GenAI.Client.OperationDuration.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries }); + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + _activitySource.Dispose(); + _meter.Dispose(); + } + + base.Dispose(disposing); + } + + /// Gets a value indicating whether diagnostics are enabled. + private bool Enabled => _activitySource.HasListeners(); + + /// + public override async Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(values); + + using Activity? activity = StartActivity(); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + GeneratedEmbeddings? response = null; + Exception? error = null; + try + { + response = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + SetCompletionResponse(activity, response, error, stopwatch); + } + + return response; + } + + /// Creates an activity for an embedding generation request, or returns null if not enabled. + private Activity? StartActivity() + { + Activity? activity = null; + if (Enabled) + { + activity = _activitySource.StartActivity( + $"embedding {_modelId}", + ActivityKind.Client, + default(ActivityContext), + [ + new(OpenTelemetryConsts.GenAI.Operation.Name, "embedding"), + new(OpenTelemetryConsts.GenAI.Request.Model, _modelId), + new(OpenTelemetryConsts.GenAI.System, _modelProvider), + ]); + + if (activity is not null) + { + if (_endpointAddress is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Server.Address, _endpointAddress) + .SetTag(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (_dimensions is int dimensions) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.EmbeddingDimensions, dimensions); + } + } + } + + return activity; + } + + /// Adds embedding generation response information to the activity. + private void SetCompletionResponse( + Activity? activity, + GeneratedEmbeddings? embeddings, + Exception? error, + Stopwatch? stopwatch) + { + if (!Enabled) + { + return; + } + + int? inputTokens = null; + string? responseModelId = null; + if (embeddings is not null) + { + responseModelId = embeddings.FirstOrDefault()?.ModelId; + if (embeddings.Usage?.InputTokenCount is int i) + { + inputTokens = inputTokens.GetValueOrDefault() + i; + } + } + + if (_operationDurationHistogram.Enabled && stopwatch is not null) + { + TagList tags = default; + AddMetricTags(ref tags, responseModelId); + if (error is not null) + { + tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); + } + + _operationDurationHistogram.Record(stopwatch.Elapsed.TotalSeconds, tags); + } + + if (_tokenUsageHistogram.Enabled && inputTokens.HasValue) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input"); + AddMetricTags(ref tags, responseModelId); + + _tokenUsageHistogram.Record(inputTokens.Value); + } + + if (activity is null) + { + return; + } + + if (error is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Error.Type, error.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, error.Message); + return; + } + + if (inputTokens.HasValue) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.InputTokens, inputTokens); + } + + if (responseModelId is not null) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Model, responseModelId); + } + } + + private void AddMetricTags(ref TagList tags, string? responseModelId) + { + tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, "embedding"); + + if (_modelId is string requestModel) + { + tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModel); + } + + tags.Add(OpenTelemetryConsts.GenAI.System, _modelProvider); + + if (_endpointAddress is string endpointAddress) + { + tags.Add(OpenTelemetryConsts.Server.Address, endpointAddress); + tags.Add(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + // Assume all of the embeddings in the same batch used the same model + if (responseModelId is not null) + { + tags.Add(OpenTelemetryConsts.GenAI.Response.Model, responseModelId); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..ba60847ef93 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class OpenTelemetryEmbeddingGeneratorBuilderExtensions +{ + /// + /// Adds OpenTelemetry support to the embedding generator pipeline, following the OpenTelemetry Semantic Conventions for Generative AI systems. + /// + /// + /// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. + /// The specification is still experimental and subject to change; as such, the telemetry output by this generator is also subject to change. + /// + /// The type of input used to produce embeddings. + /// The type of embedding generated. + /// The . + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The . + public static EmbeddingGeneratorBuilder UseOpenTelemetry( + this EmbeddingGeneratorBuilder builder, string? sourceName = null, Action>? configure = null) + where TEmbedding : Embedding => + Throw.IfNull(builder).Use(innerGenerator => + { + var generator = new OpenTelemetryEmbeddingGenerator(innerGenerator, sourceName); + configure?.Invoke(generator); + return generator; + }); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs b/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs new file mode 100644 index 00000000000..5e3abc9fc0c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.Extensions.AI; + +/// Provides an implementation of that contains no services. +internal sealed class EmptyServiceProvider : IKeyedServiceProvider +{ + /// Gets a singleton instance of . + public static EmptyServiceProvider Instance { get; } = new(); + + /// + public object? GetService(Type serviceType) => null; + + /// + public object? GetKeyedService(Type serviceType, object? serviceKey) => null; + + /// + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => + GetKeyedService(serviceType, serviceKey) ?? + throw new InvalidOperationException($"No service for type '{serviceType}' and key '{serviceKey}' has been registered."); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs new file mode 100644 index 00000000000..25f239f8883 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using System.Threading; + +namespace Microsoft.Extensions.AI; + +/// Provides additional context to the invocation of an created by . +/// +/// A delegate or passed to methods may represent a method that has a parameter +/// of type . Whereas all other parameters are passed by name from the supplied collection of arguments, +/// a parameter is passed specially by the implementation, in order to pass relevant +/// context into the method's invocation. For example, any passed to the +/// method is available from the property. +/// +public class AIFunctionContext +{ + /// Initializes a new instance of the class. + public AIFunctionContext() + { + } + + /// Gets or sets a related to the operation. + public CancellationToken CancellationToken { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs new file mode 100644 index 00000000000..c562db8ca3a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -0,0 +1,490 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides factory methods for creating commonly-used implementations of . +public static +#if NET + partial +#endif + class AIFunctionFactory +{ + internal const string UsesReflectionJsonSerializerMessage = + "This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications."; + + /// Lazily-initialized default options instance. + private static AIFunctionFactoryCreateOptions? _defaultOptions; + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The created for invoking . + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + public static AIFunction Create(Delegate method) => Create(method, _defaultOptions ??= new()); + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method.Method, method.Target, options); + } + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The name to use for the . + /// The description to use for the . + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied Delegate.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied Delegate.")] + public static AIFunction Create(Delegate method, string? name, string? description = null) + => Create(method, (_defaultOptions ??= new()).SerializerOptions, name, description); + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The used to marshal function parameters. + /// The name to use for the . + /// The description to use for the . + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public static AIFunction Create(Delegate method, JsonSerializerOptions options, string? name = null, string? description = null) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method.Method, method.Target, new(options) { Name = name, Description = description }); + } + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public static AIFunction Create(MethodInfo method, object? target = null) + => Create(method, target, _defaultOptions ??= new()); + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method, target, options); + } + + private sealed +#if NET + partial +#endif + class ReflectionAIFunction : AIFunction + { + private readonly MethodInfo _method; + private readonly object? _target; + private readonly Func, AIFunctionContext?, object?>[] _parameterMarshalers; + private readonly Func> _returnMarshaler; + private readonly JsonTypeInfo? _returnTypeInfo; + private readonly bool _needsAIFunctionContext; + + /// + /// Initializes a new instance of the class for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// Function creation options. + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + + options.SerializerOptions.MakeReadOnly(); + + if (method.ContainsGenericParameters) + { + Throw.ArgumentException(nameof(method), "Open generic methods are not supported"); + } + + if (!method.IsStatic && target is null) + { + Throw.ArgumentNullException(nameof(target), "Target must not be null for an instance method."); + } + + _method = method; + _target = target; + + // Get the function name to use. + string? functionName = options.Name; + if (functionName is null) + { + functionName = SanitizeMetadataName(method.Name!); + + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + functionName.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + functionName.Length > AsyncSuffix.Length) + { + functionName = functionName.Substring(0, functionName.Length - AsyncSuffix.Length); + } + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) + { + return true; + } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; + } + } + + // Build up a list of AIParameterMetadata for the parameters we expect to be populated + // from arguments. Some arguments are populated specially, not from arguments, and thus + // we don't want to advertise their metadata. + List? parameterMetadata = options.Parameters is not null ? null : []; + + // Get marshaling delegates for parameters and build up the parameter metadata. + var parameters = method.GetParameters(); + _parameterMarshalers = new Func, AIFunctionContext?, object?>[parameters.Length]; + bool sawAIContextParameter = false; + for (int i = 0; i < parameters.Length; i++) + { + if (GetParameterMarshaler(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshalers[i]) is AIFunctionParameterMetadata parameterView) + { + parameterMetadata?.Add(parameterView); + } + } + + _needsAIFunctionContext = sawAIContextParameter; + + // Get the return type and a marshaling func for the return value. + Type returnType = GetReturnMarshaler(method, out _returnMarshaler); + _returnTypeInfo = returnType != typeof(void) ? options.SerializerOptions.GetTypeInfo(returnType) : null; + + Metadata = new AIFunctionMetadata(functionName) + { + Description = options.Description ?? method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty, + Parameters = options.Parameters ?? parameterMetadata!, + ReturnParameter = options.ReturnParameter ?? new() + { + ParameterType = returnType, + Description = method.ReturnParameter.GetCustomAttribute(inherit: true)?.Description, + Schema = FunctionCallHelpers.InferReturnParameterJsonSchema(returnType, options.SerializerOptions), + }, + AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance, + JsonSerializerOptions = options.SerializerOptions, + }; + } + + /// + public override AIFunctionMetadata Metadata { get; } + + /// + protected override async Task InvokeCoreAsync( + IEnumerable>? arguments, + CancellationToken cancellationToken) + { + var paramMarshalers = _parameterMarshalers; + object?[] args = paramMarshalers.Length != 0 ? new object?[paramMarshalers.Length] : []; + + IReadOnlyDictionary argDict = + arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance : + arguments as IReadOnlyDictionary ?? + arguments. +#if NET8_0_OR_GREATER + ToDictionary(); +#else + ToDictionary(kvp => kvp.Key, kvp => kvp.Value); +#endif + AIFunctionContext? context = _needsAIFunctionContext ? + new() { CancellationToken = cancellationToken } : + null; + + for (int i = 0; i < args.Length; i++) + { + args[i] = paramMarshalers[i](argDict, context); + } + + object? result = await _returnMarshaler(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false); + + switch (_returnTypeInfo) + { + case null: + Debug.Assert(Metadata.ReturnParameter.ParameterType == typeof(void), "The return parameter is not void."); + return null; + + case { Kind: JsonTypeInfoKind.None }: + // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. + return JsonSerializer.SerializeToElement(result, _returnTypeInfo); + + default: + { + // Serialize asynchronously to support potential IAsyncEnumerable responses. + using MemoryStream stream = new(); + await JsonSerializer.SerializeAsync(stream, result, _returnTypeInfo, cancellationToken).ConfigureAwait(false); + Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); + return JsonElement.ParseValue(ref reader); + } + } + } + + /// + /// Gets a delegate for handling the marshaling of a parameter. + /// + private static AIFunctionParameterMetadata? GetParameterMarshaler( + JsonSerializerOptions options, + ParameterInfo parameter, + ref bool sawAIFunctionContext, + out Func, AIFunctionContext?, object?> marshaler) + { + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); + } + + // Special-case an AIFunctionContext parameter. + if (parameter.ParameterType == typeof(AIFunctionContext)) + { + if (sawAIFunctionContext) + { + Throw.ArgumentException(nameof(parameter), $"Only one {nameof(AIFunctionContext)} parameter is permitted."); + } + + sawAIFunctionContext = true; + + marshaler = static (_, ctx) => + { + Debug.Assert(ctx is not null, "Expected a non-null context object."); + return ctx; + }; + return null; + } + + // Resolve the contract used to marshall the value from JSON -- can throw if not supported or not found. + Type parameterType = parameter.ParameterType; + JsonTypeInfo typeInfo = options.GetTypeInfo(parameterType); + + // Create a marshaler that simply looks up the parameter by name in the arguments dictionary. + marshaler = (IReadOnlyDictionary arguments, AIFunctionContext? _) => + { + // If the parameter has an argument specified in the dictionary, return that argument. + if (arguments.TryGetValue(parameter.Name, out object? value)) + { + return value switch + { + null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke + _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type + JsonElement element => JsonSerializer.Deserialize(element, typeInfo), + JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo), + JsonNode node => JsonSerializer.Deserialize(node, typeInfo), + _ => MarshallViaJsonRoundtrip(value), + }; + + object? MarshallViaJsonRoundtrip(object value) + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + string json = JsonSerializer.Serialize(value, options.GetTypeInfo(value.GetType())); + return JsonSerializer.Deserialize(json, typeInfo); + } + catch + { + // Eat any exceptions and fall back to the original value to force a cast exception later on. + return value; + } +#pragma warning restore CA1031 // Do not catch general exception types + } + } + + // There was no argument for the parameter. Try to use a default value. + if (parameter.HasDefaultValue) + { + return parameter.DefaultValue; + } + + // No default either. Leave it empty. + return null; + }; + + string? description = parameter.GetCustomAttribute(inherit: true)?.Description; + return new AIFunctionParameterMetadata(parameter.Name) + { + Description = description, + HasDefaultValue = parameter.HasDefaultValue, + DefaultValue = parameter.HasDefaultValue ? parameter.DefaultValue : null, + IsRequired = !parameter.IsOptional, + ParameterType = parameter.ParameterType, + Schema = FunctionCallHelpers.InferParameterJsonSchema( + parameter.ParameterType, + parameter.Name, + description, + parameter.HasDefaultValue, + parameter.DefaultValue, + options) + }; + } + + /// + /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. + /// + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + private static Type GetReturnMarshaler(MethodInfo method, out Func> marshaler) + { + // Handle each known return type for the method + Type returnType = method.ReturnType; + + // Task + if (returnType == typeof(Task)) + { + marshaler = async static result => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + return typeof(void); + } + + // ValueTask + if (returnType == typeof(ValueTask)) + { + marshaler = async static result => + { + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + return typeof(void); + } + + if (returnType.IsGenericType) + { + // Task + if (returnType.GetGenericTypeDefinition() == typeof(Task<>) && + returnType.GetProperty(nameof(Task.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo taskResultGetter) + { + marshaler = async result => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return ReflectionInvoke(taskResultGetter, result, null); + }; + return taskResultGetter.ReturnType; + } + + // ValueTask + if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>) && + returnType.GetMethod(nameof(ValueTask.AsTask), BindingFlags.Public | BindingFlags.Instance) is MethodInfo valueTaskAsTask && + valueTaskAsTask.ReturnType.GetProperty(nameof(ValueTask.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo asTaskResultGetter) + { + marshaler = async result => + { + var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!; + await task.ConfigureAwait(false); + return ReflectionInvoke(asTaskResultGetter, task, null); + }; + return asTaskResultGetter.ReturnType; + } + } + + // For everything else, just use the result as-is. + marshaler = result => new ValueTask(result); + return returnType; + + // Throws an exception if a result is found to be null unexpectedly + static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); + } + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + return null; + } +#endif + } + + /// + /// Remove characters from method name that are valid in metadata but shouldn't be used in a method name. + /// This is primarily intended to remove characters emitted by for compiler-generated method name mangling. + /// + private static string SanitizeMetadataName(string methodName) => + InvalidNameCharsRegex().Replace(methodName, "_"); + + /// Regex that flags any character other than ASCII digits or letters or the underscore. +#if NET + [GeneratedRegex("[^0-9A-Za-z_]")] + private static partial Regex InvalidNameCharsRegex(); +#else + private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; + private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); +#endif + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs new file mode 100644 index 00000000000..8e0db9b4813 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Options that can be provided when creating an from a method. +/// +public sealed class AIFunctionFactoryCreateOptions +{ + /// + /// Initializes a new instance of the class with default serializer options. + /// + [RequiresUnreferencedCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)] + public AIFunctionFactoryCreateOptions() + : this(JsonSerializerOptions.Default) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The JSON serialization options used to marshal .NET types. + public AIFunctionFactoryCreateOptions(JsonSerializerOptions serializerOptions) + { + SerializerOptions = Throw.IfNull(serializerOptions); + } + + /// Gets the used to marshal .NET values being passed to the underlying delegate. + public JsonSerializerOptions SerializerOptions { get; } + + /// Gets or sets the name to use for the function. + /// + /// If , it will default to one derived from the method represented by the passed or . + /// + public string? Name { get; set; } + + /// Gets or sets the description to use for the function. + /// + /// If , it will default to one derived from the passed or , if possible + /// (e.g. via a on the method). + /// + public string? Description { get; set; } + + /// Gets or sets metadata for the parameters of the function. + /// + /// If , it will default to metadata derived from the passed or . + /// + public IReadOnlyList? Parameters { get; set; } + + /// Gets or sets metadata for function's return parameter. + /// + /// If , it will default to one derived from the passed or . + /// + public AIFunctionReturnParameterMetadata? ReturnParameter { get; set; } + + /// + /// Gets or sets additional values that will be stored on the resulting property. + /// + /// + /// This can be used to provide arbitrary information about the function. + /// + public IReadOnlyDictionary? AdditionalProperties { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs new file mode 100644 index 00000000000..06317f570a2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Extensions.AI; + +/// Provides cached options around JSON serialization to be used by the project. +internal static partial class JsonDefaults +{ + /// Gets the singleton to use for serialization-related operations. + public static JsonSerializerOptions Options { get; } = CreateDefaultOptions(); + + /// Creates the default to use for serialization-related operations. + private static JsonSerializerOptions CreateDefaultOptions() + { + // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, + // and we want to be flexible in terms of what can be put into the various collections in the object model. + // Otherwise, use the source-generated options to enable Native AOT. + + if (JsonSerializer.IsReflectionEnabledByDefault) + { + // Keep in sync with the JsonSourceGenerationOptions on JsonContext below. + var options = new JsonSerializerOptions(JsonSerializerDefaults.Web) + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, +#pragma warning disable IL3050, IL2026 // only used when reflection-based serialization is enabled + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), +#pragma warning restore IL3050, IL2026 + }; + + options.MakeReadOnly(); + return options; + } + else + { + return JsonContext.Default.Options; + } + } + + // Keep in sync with CreateDefaultOptions above. + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] + [JsonSerializable(typeof(IList))] + [JsonSerializable(typeof(ChatOptions))] + [JsonSerializable(typeof(EmbeddingGenerationOptions))] + [JsonSerializable(typeof(ChatClientMetadata))] + [JsonSerializable(typeof(EmbeddingGeneratorMetadata))] + [JsonSerializable(typeof(ChatCompletion))] + [JsonSerializable(typeof(StreamingChatCompletionUpdate))] + [JsonSerializable(typeof(IReadOnlyList))] + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(IEnumerable))] + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(long))] + [JsonSerializable(typeof(float))] + [JsonSerializable(typeof(double))] + [JsonSerializable(typeof(bool))] + [JsonSerializable(typeof(TimeSpan))] + [JsonSerializable(typeof(DateTimeOffset))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] +#if NET + [JsonSerializable(typeof(Embedding))] +#endif + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(AIContent))] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj new file mode 100644 index 00000000000..39b33458d0c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -0,0 +1,47 @@ + + + + Microsoft.Extensions.AI + Utilities for working with generative AI components. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 + true + + + + + $(NoWarn);IL2026 + + + + true + true + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs b/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs new file mode 100644 index 00000000000..31e61101a13 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +#pragma warning disable S3218 // Inner class members should not shadow outer class "static" or type members +#pragma warning disable CA1716 // Identifiers should not match keywords +#pragma warning disable S4041 // Type names should not match namespaces + +/// Provides constants used by various telemetry services. +internal static class OpenTelemetryConsts +{ + public const string DefaultSourceName = "Experimental.Microsoft.Extensions.AI"; + + public const string SecondsUnit = "s"; + public const string TokensUnit = "token"; + + public static class Error + { + public const string Type = "error.type"; + } + + public static class GenAI + { + public const string Completion = "gen_ai.completion"; + public const string Prompt = "gen_ai.prompt"; + public const string System = "gen_ai.system"; + + public static class Client + { + public static class OperationDuration + { + public const string Description = "Measures the duration of a GenAI operation"; + public const string Name = "gen_ai.client.operation.duration"; + public static readonly double[] ExplicitBucketBoundaries = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28, 2.56, 5.12, 10.24, 20.48, 40.96, 81.92]; + } + + public static class TokenUsage + { + public const string Description = "Measures number of input and output tokens used"; + public const string Name = "gen_ai.client.token.usage"; + public static readonly int[] ExplicitBucketBoundaries = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384, 65_536, 262_144, 1_048_576, 4_194_304, 16_777_216, 67_108_864]; + } + } + + public static class Content + { + public const string Completion = "gen_ai.content.completion"; + public const string Prompt = "gen_ai.content.prompt"; + } + + public static class Operation + { + public const string Name = "gen_ai.operation.name"; + } + + public static class Request + { + public const string EmbeddingDimensions = "gen_ai.request.embedding.dimensions"; + public const string FrequencyPenalty = "gen_ai.request.frequency_penalty"; + public const string Model = "gen_ai.request.model"; + public const string MaxTokens = "gen_ai.request.max_tokens"; + public const string PresencePenalty = "gen_ai.request.presence_penalty"; + public const string StopSequences = "gen_ai.request.stop_sequences"; + public const string Temperature = "gen_ai.request.temperature"; + public const string TopK = "gen_ai.request.top_k"; + public const string TopP = "gen_ai.request.top_p"; + } + + public static class Response + { + public const string FinishReasons = "gen_ai.response.finish_reasons"; + public const string Id = "gen_ai.response.id"; + public const string InputTokens = "gen_ai.response.input_tokens"; + public const string Model = "gen_ai.response.model"; + public const string OutputTokens = "gen_ai.response.output_tokens"; + } + + public static class Token + { + public const string Type = "gen_ai.token.type"; + } + } + + public static class Server + { + public const string Address = "server.address"; + public const string Port = "server.port"; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/README.md b/src/Libraries/Microsoft.Extensions.AI/README.md new file mode 100644 index 00000000000..ef092749200 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/README.md @@ -0,0 +1,27 @@ +# Microsoft.Extensions.AI + +Provides utilities for working with generative AI components. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +Please refer to the [README](https://www.nuget.org/packages/Microsoft.Extensions.AI.Abstractions/#readme-body-tab) for the [Microsoft.Extensions.AI.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.AI.Abstractions) package. + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Shared/CollectionExtensions/CollectionExtensions.cs b/src/Shared/CollectionExtensions/CollectionExtensions.cs new file mode 100644 index 00000000000..33196e6e771 --- /dev/null +++ b/src/Shared/CollectionExtensions/CollectionExtensions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; + +#pragma warning disable S108 // Nested blocks of code should not be left empty +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable SA1501 // Statement should not be on a single line + +#pragma warning disable CA1716 +namespace Microsoft.Shared.Collections; +#pragma warning restore CA1716 + +/// +/// Utilities to augment the basic collection types. +/// +#if !SHARED_PROJECT +[ExcludeFromCodeCoverage] +#endif + +internal static class CollectionExtensions +{ + /// Attempts to extract a typed value from the dictionary. + /// The dictionary to query. + /// The key to locate. + /// The value retrieved from the dictionary, if found; otherwise, default. + /// True if the value was found and converted to the requested type; otherwise, false. + /// + /// If a value is found for the key in the dictionary, but the value is not of the requested type but is + /// an object, the method will attempt to convert the object to the requested type. + /// is employed because these methods are primarily intended for use with primitives. + /// + public static bool TryGetConvertedValue(this IReadOnlyDictionary? input, string key, [NotNullWhen(true)] out T? value) + { + object? valueObject = null; + _ = input?.TryGetValue(key, out valueObject); + return TryConvertValue(valueObject, out value); + } + + private static bool TryConvertValue(object? obj, [NotNullWhen(true)] out T? value) + { + switch (obj) + { + case T t: + // The object is already of the requested type. Return it. + value = t; + return true; + + case IConvertible: + // The object is convertible; try to convert it to the requested type. Unfortunately, there's no + // convenient way to do this that avoids exceptions and that doesn't involve a ton of boilerplate, + // so we only try when the source object is at least an IConvertible, which is what ChangeType uses. + try + { + value = (T)Convert.ChangeType(obj, typeof(T), CultureInfo.InvariantCulture); + return true; + } + catch (ArgumentException) { } + catch (InvalidCastException) { } + catch (FormatException) { } + catch (OverflowException) { } + break; + } + + // Unable to convert the object to the requested type. Fail. + value = default; + return false; + } +} diff --git a/src/Shared/CollectionExtensions/README.md b/src/Shared/CollectionExtensions/README.md new file mode 100644 index 00000000000..a732b7c36d4 --- /dev/null +++ b/src/Shared/CollectionExtensions/README.md @@ -0,0 +1,11 @@ +# Collection Extensions + +`TryGetTypedValue` performs a ``TryGetValue` on a dictionary and then attempts to cast the value to the specified type. If the value is not of the specified type, false is returned. + +To use this in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/Shared/NumericExtensions/README.md b/src/Shared/NumericExtensions/README.md index bcb2d9a7cba..c93835acd3b 100644 --- a/src/Shared/NumericExtensions/README.md +++ b/src/Shared/NumericExtensions/README.md @@ -6,6 +6,6 @@ To use this in your project, add the following to your `.csproj` file: ```xml - true + true ``` diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs new file mode 100644 index 00000000000..e71b2f431e8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AdditionalPropertiesDictionaryTests +{ + [Fact] + public void Constructor_Roundtrips() + { + AdditionalPropertiesDictionary d = new(); + Assert.Empty(d); + + d = new(new Dictionary { ["key1"] = "value1" }); + Assert.Single(d); + + d = new((IEnumerable>)new Dictionary { ["key1"] = "value1", ["key2"] = "value2" }); + Assert.Equal(2, d.Count); + } + + [Fact] + public void Comparer_OrdinalIgnoreCase() + { + AdditionalPropertiesDictionary d = new() + { + ["key1"] = "value1", + ["KEY1"] = "value2", + ["key2"] = "value3", + ["key3"] = "value4", + ["KeY3"] = "value5", + }; + + Assert.Equal(3, d.Count); + + Assert.Equal("value2", d["key1"]); + Assert.Equal("value2", d["kEY1"]); + + Assert.Equal("value3", d["key2"]); + Assert.Equal("value3", d["KEY2"]); + + Assert.Equal("value5", d["Key3"]); + Assert.Equal("value5", d["KEy3"]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs new file mode 100644 index 00000000000..2c54a6f0865 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.Extensions.AI; + +internal static class AssertExtensions +{ + /// + /// Asserts that the two function call parameters are equal, up to JSON equivalence. + /// + public static void EqualFunctionCallParameters( + IDictionary? expected, + IDictionary? actual, + JsonSerializerOptions? options = null) + { + if (expected is null || actual is null) + { + Assert.Equal(expected, actual); + return; + } + + foreach (var expectedEntry in expected) + { + if (!actual.TryGetValue(expectedEntry.Key, out object? actualValue)) + { + throw new XunitException($"Expected parameter '{expectedEntry.Key}' not found in actual value."); + } + + AreJsonEquivalentValues(expectedEntry.Value, actualValue, options, propertyName: expectedEntry.Key); + } + + if (expected.Count != actual.Count) + { + var extraParameters = actual + .Where(e => !expected.ContainsKey(e.Key)) + .Select(e => $"'{e.Key}'") + .First(); + + throw new XunitException($"Actual value contains additional parameters {string.Join(", ", extraParameters)} not found in expected value."); + } + } + + /// + /// Asserts that the two function call results are equal, up to JSON equivalence. + /// + public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null) + => AreJsonEquivalentValues(expected, actual, options); + + private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null) + { + options ??= JsonSerializerOptions.Default; + JsonElement expectedElement = NormalizeToElement(expected, options); + JsonElement actualElement = NormalizeToElement(actual, options); + if (!JsonElement.DeepEquals(expectedElement, actualElement)) + { + string message = propertyName is null + ? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}" + : $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"; + + throw new XunitException(message); + } + + static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options) + => value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs new file mode 100644 index 00000000000..274021988e1 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; + +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CapturingLogger : ILogger +{ + private readonly Stack _scopes = new(); + private readonly List _entries = []; + private readonly LogLevel _enabledLevel; + + public CapturingLogger(LogLevel enabledLevel = LogLevel.Trace) + { + _enabledLevel = enabledLevel; + } + + public IReadOnlyList Entries => _entries; + + public IDisposable? BeginScope(TState state) + where TState : notnull + { + var scope = new LoggerScope(this); + _scopes.Push(scope); + return scope; + } + + public bool IsEnabled(LogLevel logLevel) => logLevel >= _enabledLevel; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + var message = formatter(state, exception); + lock (_entries) + { + _entries.Add(new LogEntry(logLevel, eventId, state, exception, message)); + } + } + + private sealed class LoggerScope(CapturingLogger owner) : IDisposable + { + public void Dispose() => owner.EndScope(this); + } + + private void EndScope(LoggerScope loggerScope) + { + if (_scopes.Peek() != loggerScope) + { + throw new InvalidOperationException("Logger scopes out of order"); + } + + _scopes.Pop(); + } + + public record LogEntry(LogLevel Level, EventId EventId, object? State, Exception? Exception, string Message); +} + +internal sealed class CapturingLoggerProvider : ILoggerProvider +{ + public CapturingLogger Logger { get; } = new(); + + public ILogger CreateLogger(string categoryName) => Logger; + + void IDisposable.Dispose() + { + // nop + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs new file mode 100644 index 00000000000..68f5ad12245 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientExtensionsTests +{ + [Fact] + public void CompleteAsync_InvalidArgs_Throws() + { + Assert.Throws("client", () => + { + _ = ChatClientExtensions.CompleteAsync(null!, "hello"); + }); + + Assert.Throws("chatMessage", () => + { + _ = ChatClientExtensions.CompleteAsync(new TestChatClient(), null!); + }); + } + + [Fact] + public void CompleteStreamingAsync_InvalidArgs_Throws() + { + Assert.Throws("client", () => + { + _ = ChatClientExtensions.CompleteStreamingAsync(null!, "hello"); + }); + + Assert.Throws("chatMessage", () => + { + _ = ChatClientExtensions.CompleteStreamingAsync(new TestChatClient(), null!); + }); + } + + [Fact] + public async Task CompleteAsync_CreatesTextMessageAsync() + { + var expectedResponse = new ChatCompletion([new ChatMessage()]); + var expectedOptions = new ChatOptions(); + using var cts = new CancellationTokenSource(); + + using TestChatClient client = new() + { + CompleteAsyncCallback = (chatMessages, options, cancellationToken) => + { + ChatMessage m = Assert.Single(chatMessages); + Assert.Equal(ChatRole.User, m.Role); + Assert.Equal("hello", m.Text); + + Assert.Same(expectedOptions, options); + + Assert.Equal(cts.Token, cancellationToken); + + return Task.FromResult(expectedResponse); + }, + }; + + ChatCompletion response = await client.CompleteAsync("hello", expectedOptions, cts.Token); + + Assert.Same(expectedResponse, response); + } + + [Fact] + public async Task CompleteStreamingAsync_CreatesTextMessageAsync() + { + var expectedOptions = new ChatOptions(); + using var cts = new CancellationTokenSource(); + + using TestChatClient client = new() + { + CompleteStreamingAsyncCallback = (chatMessages, options, cancellationToken) => + { + ChatMessage m = Assert.Single(chatMessages); + Assert.Equal(ChatRole.User, m.Role); + Assert.Equal("hello", m.Text); + + Assert.Same(expectedOptions, options); + + Assert.Equal(cts.Token, cancellationToken); + + return YieldAsync([new StreamingChatCompletionUpdate { Text = "world" }]); + }, + }; + + int count = 0; + await foreach (var update in client.CompleteStreamingAsync("hello", expectedOptions, cts.Token)) + { + Assert.Equal(0, count); + Assert.Equal("world", update.Text); + count++; + } + + Assert.Equal(1, count); + } + + private static async IAsyncEnumerable YieldAsync(params StreamingChatCompletionUpdate[] updates) + { + await Task.Yield(); + foreach (var update in updates) + { + yield return update; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs new file mode 100644 index 00000000000..43e24e61f8e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientMetadataTests +{ + [Fact] + public void Constructor_NullValues_AllowedAndRoundtrip() + { + ChatClientMetadata metadata = new(null, null, null); + Assert.Null(metadata.ProviderName); + Assert.Null(metadata.ProviderUri); + Assert.Null(metadata.ModelId); + } + + [Fact] + public void Constructor_Value_Roundtrips() + { + var uri = new Uri("https://example.com"); + ChatClientMetadata metadata = new("providerName", uri, "theModel"); + Assert.Equal("providerName", metadata.ProviderName); + Assert.Same(uri, metadata.ProviderUri); + Assert.Equal("theModel", metadata.ModelId); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs new file mode 100644 index 00000000000..a695e686f6e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatCompletionTests +{ + [Fact] + public void Constructor_InvalidArgs_Throws() + { + Assert.Throws("message", () => new ChatCompletion((ChatMessage)null!)); + Assert.Throws("choices", () => new ChatCompletion((IList)null!)); + } + + [Fact] + public void Constructor_Message_Roundtrips() + { + ChatMessage message = new(); + + ChatCompletion completion = new(message); + Assert.Same(message, completion.Message); + Assert.Same(message, Assert.Single(completion.Choices)); + } + + [Fact] + public void Constructor_Choices_Roundtrips() + { + List messages = + [ + new ChatMessage(), + new ChatMessage(), + new ChatMessage(), + ]; + + ChatCompletion completion = new(messages); + Assert.Same(messages, completion.Choices); + Assert.Equal(3, messages.Count); + } + + [Fact] + public void Message_EmptyChoices_Throws() + { + ChatCompletion completion = new([]); + + Assert.Empty(completion.Choices); + Assert.Throws(() => completion.Message); + } + + [Fact] + public void Message_SingleChoice_Returned() + { + ChatMessage message = new(); + ChatCompletion completion = new([message]); + + Assert.Same(message, completion.Message); + Assert.Same(message, completion.Choices[0]); + } + + [Fact] + public void Message_MultipleChoices_ReturnsFirst() + { + ChatMessage first = new(); + ChatCompletion completion = new([ + first, + new ChatMessage(), + ]); + + Assert.Same(first, completion.Message); + Assert.Same(first, completion.Choices[0]); + } + + [Fact] + public void Choices_SetNull_Throws() + { + ChatCompletion completion = new([]); + Assert.Throws("value", () => completion.Choices = null!); + } + + [Fact] + public void Properties_Roundtrip() + { + ChatCompletion completion = new([]); + + Assert.Null(completion.CompletionId); + completion.CompletionId = "id"; + Assert.Equal("id", completion.CompletionId); + + Assert.Null(completion.ModelId); + completion.ModelId = "modelId"; + Assert.Equal("modelId", completion.ModelId); + + Assert.Null(completion.CreatedAt); + completion.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), completion.CreatedAt); + + Assert.Null(completion.FinishReason); + completion.FinishReason = ChatFinishReason.ContentFilter; + Assert.Equal(ChatFinishReason.ContentFilter, completion.FinishReason); + + Assert.Null(completion.Usage); + UsageDetails usage = new(); + completion.Usage = usage; + Assert.Same(usage, completion.Usage); + + Assert.Null(completion.RawRepresentation); + object raw = new(); + completion.RawRepresentation = raw; + Assert.Same(raw, completion.RawRepresentation); + + Assert.Null(completion.AdditionalProperties); + AdditionalPropertiesDictionary additionalProps = []; + completion.AdditionalProperties = additionalProps; + Assert.Same(additionalProps, completion.AdditionalProperties); + + List newChoices = [new ChatMessage(), new ChatMessage()]; + completion.Choices = newChoices; + Assert.Same(newChoices, completion.Choices); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatCompletion original = new( + [ + new ChatMessage(ChatRole.Assistant, "Choice1"), + new ChatMessage(ChatRole.Assistant, "Choice2"), + new ChatMessage(ChatRole.Assistant, "Choice3"), + new ChatMessage(ChatRole.Assistant, "Choice4"), + ]) + { + CompletionId = "id", + ModelId = "modelId", + CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + FinishReason = ChatFinishReason.ContentFilter, + Usage = new UsageDetails(), + RawRepresentation = new(), + AdditionalProperties = new() { ["key"] = "value" }, + }; + + string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.ChatCompletion); + + ChatCompletion? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatCompletion); + + Assert.NotNull(result); + Assert.Equal(4, result.Choices.Count); + + for (int i = 0; i < original.Choices.Count; i++) + { + Assert.Equal(ChatRole.Assistant, result.Choices[i].Role); + Assert.Equal($"Choice{i + 1}", result.Choices[i].Text); + } + + Assert.Equal("id", result.CompletionId); + Assert.Equal("modelId", result.ModelId); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); + Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); + Assert.NotNull(result.Usage); + + Assert.NotNull(result.AdditionalProperties); + Assert.Single(result.AdditionalProperties); + Assert.True(result.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs new file mode 100644 index 00000000000..0318a77b47b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatFinishReasonTests +{ + [Fact] + public void Constructor_Value_Roundtrips() + { + Assert.Equal("abc", new ChatFinishReason("abc").Value); + } + + [Fact] + public void Constructor_NullOrWhiteSpace_Throws() + { + Assert.Throws(() => new ChatFinishReason(null!)); + Assert.Throws(() => new ChatFinishReason(" ")); + } + + [Fact] + public void Equality_UsesOrdinalIgnoreCaseComparison() + { + Assert.True(new ChatFinishReason("abc").Equals(new ChatFinishReason("ABC"))); + Assert.True(new ChatFinishReason("abc").Equals((object)new ChatFinishReason("ABC"))); + Assert.True(new ChatFinishReason("abc") == new ChatFinishReason("ABC")); + Assert.Equal(new ChatFinishReason("abc").GetHashCode(), new ChatFinishReason("ABC").GetHashCode()); + Assert.False(new ChatFinishReason("abc") != new ChatFinishReason("ABC")); + + Assert.False(new ChatFinishReason("abc").Equals(new ChatFinishReason("def"))); + Assert.False(new ChatFinishReason("abc").Equals((object)new ChatFinishReason("def"))); + Assert.False(new ChatFinishReason("abc").Equals(null)); + Assert.False(new ChatFinishReason("abc").Equals("abc")); + Assert.False(new ChatFinishReason("abc") == new ChatFinishReason("def")); + Assert.True(new ChatFinishReason("abc") != new ChatFinishReason("def")); + Assert.NotEqual(new ChatFinishReason("abc").GetHashCode(), new ChatFinishReason("def").GetHashCode()); // not guaranteed due to possible hash code collisions + } + + [Fact] + public void Singletons_UseKnownValues() + { + Assert.Equal("stop", ChatFinishReason.Stop.Value); + Assert.Equal("length", ChatFinishReason.Length.Value); + Assert.Equal("tool_calls", ChatFinishReason.ToolCalls.Value); + Assert.Equal("content_filter", ChatFinishReason.ContentFilter.Value); + } + + [Fact] + public void Value_NormalizesToStopped() + { + Assert.Equal("test", new ChatFinishReason("test").Value); + Assert.Equal("test", new ChatFinishReason("test").ToString()); + + Assert.Equal("TEST", new ChatFinishReason("TEST").Value); + Assert.Equal("TEST", new ChatFinishReason("TEST").ToString()); + + Assert.Equal("stop", default(ChatFinishReason).Value); + Assert.Equal("stop", default(ChatFinishReason).ToString()); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatFinishReason role = new("abc"); + string? json = JsonSerializer.Serialize(role, TestJsonSerializerContext.Default.ChatFinishReason); + Assert.Equal("\"abc\"", json); + + ChatFinishReason? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatFinishReason); + Assert.Equal(role, result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs new file mode 100644 index 00000000000..dbef5f4088b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -0,0 +1,382 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatMessageTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + ChatMessage message = new(); + Assert.Null(message.AuthorName); + Assert.Empty(message.Contents); + Assert.Equal(ChatRole.User, message.Role); + Assert.Null(message.Text); + Assert.NotNull(message.Contents); + Assert.Same(message.Contents, message.Contents); + Assert.Empty(message.Contents); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + Assert.Equal(string.Empty, message.ToString()); + } + + [Theory] + [InlineData(null)] + [InlineData("text")] + public void Constructor_RoleString_PropsRoundtrip(string? text) + { + ChatMessage message = new(ChatRole.Assistant, text); + + Assert.Equal(ChatRole.Assistant, message.Role); + + Assert.Same(message.Contents, message.Contents); + if (text is null) + { + Assert.Empty(message.Contents); + } + else + { + Assert.Single(message.Contents); + TextContent tc = Assert.IsType(message.Contents[0]); + Assert.Equal(text, tc.Text); + } + + Assert.Null(message.AuthorName); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + Assert.Equal(text ?? string.Empty, message.ToString()); + } + + [Fact] + public void Constructor_RoleList_InvalidArgs_Throws() + { + Assert.Throws("contents", () => new ChatMessage(ChatRole.User, (IList)null!)); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + public void Constructor_RoleList_PropsRoundtrip(int messageCount) + { + List content = []; + for (int i = 0; i < messageCount; i++) + { + content.Add(new TextContent($"text-{i}")); + } + + ChatMessage message = new(ChatRole.System, content); + + Assert.Equal(ChatRole.System, message.Role); + + Assert.Same(message.Contents, message.Contents); + if (messageCount == 0) + { + Assert.Empty(message.Contents); + Assert.Null(message.Text); + } + else + { + Assert.Equal(messageCount, message.Contents.Count); + for (int i = 0; i < messageCount; i++) + { + TextContent tc = Assert.IsType(message.Contents[i]); + Assert.Equal($"text-{i}", tc.Text); + } + + Assert.Equal("text-0", message.Text); + Assert.Equal("text-0", message.ToString()); + } + + Assert.Null(message.AuthorName); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" \r\n\t\v ")] + public void AuthorName_InvalidArg_UsesNull(string? authorName) + { + ChatMessage message = new() + { + AuthorName = authorName + }; + Assert.Null(message.AuthorName); + + message.AuthorName = "author"; + Assert.Equal("author", message.AuthorName); + + message.AuthorName = authorName; + Assert.Null(message.AuthorName); + } + + [Fact] + public void Text_GetSet_UsesFirstTextContent() + { + ChatMessage message = new(ChatRole.User, + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new TextContent("text-1"), + new TextContent("text-2"), + new FunctionResultContent(new FunctionCallContent("callId1", "fc2"), "result"), + ]); + + TextContent textContent = Assert.IsType(message.Contents[3]); + Assert.Equal("text-1", textContent.Text); + Assert.Equal("text-1", message.Text); + Assert.Equal("text-1", message.ToString()); + + message.Text = "text-3"; + Assert.Equal("text-3", message.Text); + Assert.Equal("text-3", message.Text); + Assert.Same(textContent, message.Contents[3]); + Assert.Equal("text-3", message.ToString()); + } + + [Fact] + public void Text_Set_AddsTextMessageToEmptyList() + { + ChatMessage message = new(ChatRole.User, []); + Assert.Empty(message.Contents); + + message.Text = "text-1"; + Assert.Equal("text-1", message.Text); + + Assert.Single(message.Contents); + TextContent textContent = Assert.IsType(message.Contents[0]); + Assert.Equal("text-1", textContent.Text); + } + + [Fact] + public void Text_Set_AddsTextMessageToListWithNoText() + { + ChatMessage message = new(ChatRole.User, + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + ]); + Assert.Equal(3, message.Contents.Count); + + message.Text = "text-1"; + Assert.Equal("text-1", message.Text); + Assert.Equal(4, message.Contents.Count); + + message.Text = "text-2"; + Assert.Equal("text-2", message.Text); + Assert.Equal(4, message.Contents.Count); + + message.Contents.RemoveAt(3); + Assert.Equal(3, message.Contents.Count); + + message.Text = "text-3"; + Assert.Equal("text-3", message.Text); + Assert.Equal(4, message.Contents.Count); + } + + [Fact] + public void Contents_InitializesToList() + { + // This is an implementation detail, but if this test starts failing, we need to ensure + // tests are in place for whatever possibly-custom implementation of IList is being used. + Assert.IsType>(new ChatMessage().Contents); + } + + [Fact] + public void Contents_Roundtrips() + { + ChatMessage message = new(); + Assert.Empty(message.Contents); + + List contents = []; + message.Contents = contents; + + Assert.Same(contents, message.Contents); + + message.Contents = contents; + Assert.Same(contents, message.Contents); + + message.Contents = null; + Assert.NotNull(message.Contents); + Assert.NotSame(contents, message.Contents); + Assert.Empty(message.Contents); + } + + [Fact] + public void RawRepresentation_Roundtrips() + { + ChatMessage message = new(); + Assert.Null(message.RawRepresentation); + + object raw = new(); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + + message.RawRepresentation = null; + Assert.Null(message.RawRepresentation); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + } + + [Fact] + public void AdditionalProperties_Roundtrips() + { + ChatMessage message = new(); + Assert.Null(message.RawRepresentation); + + AdditionalPropertiesDictionary props = []; + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + + message.AdditionalProperties = null; + Assert.Null(message.AdditionalProperties); + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + } + + [Fact] + public void ItCanBeSerializeAndDeserialized() + { + // Arrange + IList items = + [ + new TextContent("content-1") + { + ModelId = "model-1", + AdditionalProperties = new() { ["metadata-key-1"] = "metadata-value-1" } + }, + new ImageContent(new Uri("https://fake-random-test-host:123"), "mime-type/2") + { + ModelId = "model-2", + AdditionalProperties = new() { ["metadata-key-2"] = "metadata-value-2" } + }, + new DataContent(new BinaryData(new[] { 1, 2, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/3") + { + ModelId = "model-3", + AdditionalProperties = new() { ["metadata-key-3"] = "metadata-value-3" } + }, + new AudioContent(new BinaryData(new[] { 3, 2, 1 }, options: TestJsonSerializerContext.Default.Options), "mime-type/4") + { + ModelId = "model-4", + AdditionalProperties = new() { ["metadata-key-4"] = "metadata-value-4" } + }, + new ImageContent(new BinaryData(new[] { 2, 1, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/5") + { + ModelId = "model-5", + AdditionalProperties = new() { ["metadata-key-5"] = "metadata-value-5" } + }, + new TextContent("content-6") + { + ModelId = "model-6", + AdditionalProperties = new() { ["metadata-key-6"] = "metadata-value-6" } + }, + new FunctionCallContent("function-id", "plugin-name-function-name", new Dictionary { ["parameter"] = "argument" }), + new FunctionResultContent(new FunctionCallContent("function-id", "plugin-name-function-name"), "function-result"), + ]; + + // Act + var chatMessageJson = JsonSerializer.Serialize(new ChatMessage(ChatRole.User, contents: items) + { + Text = "content-1-override", // Override the content of the first text content item that has the "content-1" content + AuthorName = "Fred", + AdditionalProperties = new() { ["message-metadata-key-1"] = "message-metadata-value-1" }, + }, TestJsonSerializerContext.Default.Options); + + var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson, TestJsonSerializerContext.Default.Options)!; + + // Assert + Assert.Equal("Fred", deserializedMessage.AuthorName); + Assert.Equal("user", deserializedMessage.Role.Value); + Assert.NotNull(deserializedMessage.AdditionalProperties); + Assert.Single(deserializedMessage.AdditionalProperties); + Assert.Equal("message-metadata-value-1", deserializedMessage.AdditionalProperties["message-metadata-key-1"]?.ToString()); + + Assert.NotNull(deserializedMessage.Contents); + Assert.Equal(items.Count, deserializedMessage.Contents.Count); + + var textContent = deserializedMessage.Contents[0] as TextContent; + Assert.NotNull(textContent); + Assert.Equal("content-1-override", textContent.Text); + Assert.Equal("model-1", textContent.ModelId); + Assert.NotNull(textContent.AdditionalProperties); + Assert.Single(textContent.AdditionalProperties); + Assert.Equal("metadata-value-1", textContent.AdditionalProperties["metadata-key-1"]?.ToString()); + + var imageContent = deserializedMessage.Contents[1] as ImageContent; + Assert.NotNull(imageContent); + Assert.Equal("https://fake-random-test-host:123/", imageContent.Uri); + Assert.Equal("model-2", imageContent.ModelId); + Assert.Equal("mime-type/2", imageContent.MediaType); + Assert.NotNull(imageContent.AdditionalProperties); + Assert.Single(imageContent.AdditionalProperties); + Assert.Equal("metadata-value-2", imageContent.AdditionalProperties["metadata-key-2"]?.ToString()); + + var dataContent = deserializedMessage.Contents[2] as DataContent; + Assert.NotNull(dataContent); + Assert.True(dataContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-3", dataContent.ModelId); + Assert.Equal("mime-type/3", dataContent.MediaType); + Assert.NotNull(dataContent.AdditionalProperties); + Assert.Single(dataContent.AdditionalProperties); + Assert.Equal("metadata-value-3", dataContent.AdditionalProperties["metadata-key-3"]?.ToString()); + + var audioContent = deserializedMessage.Contents[3] as AudioContent; + Assert.NotNull(audioContent); + Assert.True(audioContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 3, 2, 1 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-4", audioContent.ModelId); + Assert.Equal("mime-type/4", audioContent.MediaType); + Assert.NotNull(audioContent.AdditionalProperties); + Assert.Single(audioContent.AdditionalProperties); + Assert.Equal("metadata-value-4", audioContent.AdditionalProperties["metadata-key-4"]?.ToString()); + + imageContent = deserializedMessage.Contents[4] as ImageContent; + Assert.NotNull(imageContent); + Assert.True(imageContent.Data?.Span.SequenceEqual(new BinaryData(new[] { 2, 1, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-5", imageContent.ModelId); + Assert.Equal("mime-type/5", imageContent.MediaType); + Assert.NotNull(imageContent.AdditionalProperties); + Assert.Single(imageContent.AdditionalProperties); + Assert.Equal("metadata-value-5", imageContent.AdditionalProperties["metadata-key-5"]?.ToString()); + + textContent = deserializedMessage.Contents[5] as TextContent; + Assert.NotNull(textContent); + Assert.Equal("content-6", textContent.Text); + Assert.Equal("model-6", textContent.ModelId); + Assert.NotNull(textContent.AdditionalProperties); + Assert.Single(textContent.AdditionalProperties); + Assert.Equal("metadata-value-6", textContent.AdditionalProperties["metadata-key-6"]?.ToString()); + + var functionCallContent = deserializedMessage.Contents[6] as FunctionCallContent; + Assert.NotNull(functionCallContent); + Assert.Equal("plugin-name-function-name", functionCallContent.Name); + Assert.Equal("function-id", functionCallContent.CallId); + Assert.NotNull(functionCallContent.Arguments); + Assert.Single(functionCallContent.Arguments); + Assert.Equal("argument", functionCallContent.Arguments["parameter"]?.ToString()); + + var functionResultContent = deserializedMessage.Contents[7] as FunctionResultContent; + Assert.NotNull(functionResultContent); + Assert.Equal("function-result", functionResultContent.Result?.ToString()); + Assert.Equal("function-id", functionResultContent.CallId); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs new file mode 100644 index 00000000000..2e769ff6d7e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatOptionsTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + ChatOptions options = new(); + Assert.Null(options.Temperature); + Assert.Null(options.MaxOutputTokens); + Assert.Null(options.TopP); + Assert.Null(options.FrequencyPenalty); + Assert.Null(options.PresencePenalty); + Assert.Null(options.ResponseFormat); + Assert.Null(options.ModelId); + Assert.Null(options.StopSequences); + Assert.Same(ChatToolMode.Auto, options.ToolMode); + Assert.Null(options.Tools); + Assert.Null(options.AdditionalProperties); + + ChatOptions clone = options.Clone(); + Assert.Null(clone.Temperature); + Assert.Null(clone.MaxOutputTokens); + Assert.Null(clone.TopP); + Assert.Null(clone.FrequencyPenalty); + Assert.Null(clone.PresencePenalty); + Assert.Null(clone.ResponseFormat); + Assert.Null(clone.ModelId); + Assert.Null(clone.StopSequences); + Assert.Same(ChatToolMode.Auto, clone.ToolMode); + Assert.Null(clone.Tools); + Assert.Null(clone.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrip() + { + ChatOptions options = new(); + + List stopSequences = + [ + "stop1", + "stop2", + ]; + + List tools = + [ + AIFunctionFactory.Create(() => 42), + AIFunctionFactory.Create(() => 43), + ]; + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.Temperature = 0.1f; + options.MaxOutputTokens = 2; + options.TopP = 0.3f; + options.FrequencyPenalty = 0.4f; + options.PresencePenalty = 0.5f; + options.ResponseFormat = ChatResponseFormat.Json; + options.ModelId = "modelId"; + options.StopSequences = stopSequences; + options.ToolMode = ChatToolMode.RequireAny; + options.Tools = tools; + options.AdditionalProperties = additionalProps; + + Assert.Equal(0.1f, options.Temperature); + Assert.Equal(2, options.MaxOutputTokens); + Assert.Equal(0.3f, options.TopP); + Assert.Equal(0.4f, options.FrequencyPenalty); + Assert.Equal(0.5f, options.PresencePenalty); + Assert.Same(ChatResponseFormat.Json, options.ResponseFormat); + Assert.Equal("modelId", options.ModelId); + Assert.Same(stopSequences, options.StopSequences); + Assert.Same(ChatToolMode.RequireAny, options.ToolMode); + Assert.Same(tools, options.Tools); + Assert.Same(additionalProps, options.AdditionalProperties); + + ChatOptions clone = options.Clone(); + Assert.Equal(0.1f, clone.Temperature); + Assert.Equal(2, clone.MaxOutputTokens); + Assert.Equal(0.3f, clone.TopP); + Assert.Equal(0.4f, clone.FrequencyPenalty); + Assert.Equal(0.5f, clone.PresencePenalty); + Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat); + Assert.Equal("modelId", clone.ModelId); + Assert.Equal(stopSequences, clone.StopSequences); + Assert.Same(ChatToolMode.RequireAny, clone.ToolMode); + Assert.Equal(tools, clone.Tools); + Assert.Equal(additionalProps, clone.AdditionalProperties); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatOptions options = new(); + + List stopSequences = + [ + "stop1", + "stop2", + ]; + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.Temperature = 0.1f; + options.MaxOutputTokens = 2; + options.TopP = 0.3f; + options.FrequencyPenalty = 0.4f; + options.PresencePenalty = 0.5f; + options.ResponseFormat = ChatResponseFormat.Json; + options.ModelId = "modelId"; + options.StopSequences = stopSequences; + options.ToolMode = ChatToolMode.RequireAny; + options.Tools = + [ + AIFunctionFactory.Create(() => 42), + AIFunctionFactory.Create(() => 43), + ]; + options.AdditionalProperties = additionalProps; + + string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.ChatOptions); + + ChatOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatOptions); + Assert.NotNull(deserialized); + + Assert.Equal(0.1f, deserialized.Temperature); + Assert.Equal(2, deserialized.MaxOutputTokens); + Assert.Equal(0.3f, deserialized.TopP); + Assert.Equal(0.4f, deserialized.FrequencyPenalty); + Assert.Equal(0.5f, deserialized.PresencePenalty); + Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat); + Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat); + Assert.Equal("modelId", deserialized.ModelId); + Assert.NotSame(stopSequences, deserialized.StopSequences); + Assert.Equal(stopSequences, deserialized.StopSequences); + Assert.Equal(ChatToolMode.RequireAny, deserialized.ToolMode); + Assert.Null(deserialized.Tools); + + Assert.NotNull(deserialized.AdditionalProperties); + Assert.Single(deserialized.AdditionalProperties); + Assert.True(deserialized.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs new file mode 100644 index 00000000000..f4a63f34e05 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatResponseFormatTests +{ + [Fact] + public void Singletons_Idempotent() + { + Assert.Same(ChatResponseFormat.Text, ChatResponseFormat.Text); + Assert.Same(ChatResponseFormat.Json, ChatResponseFormat.Json); + } + + [Fact] + public void Constructor_InvalidArgs_Throws() + { + Assert.Throws(() => new ChatResponseFormatJson(null, "name")); + Assert.Throws(() => new ChatResponseFormatJson(null, null, "description")); + Assert.Throws(() => new ChatResponseFormatJson(null, "name", "description")); + } + + [Fact] + public void Constructor_PropsDefaulted() + { + ChatResponseFormatJson f = new(null); + Assert.Null(f.Schema); + Assert.Null(f.SchemaName); + Assert.Null(f.SchemaDescription); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + ChatResponseFormatJson f = new("{}", "name", "description"); + Assert.Equal("{}", f.Schema); + Assert.Equal("name", f.SchemaName); + Assert.Equal("description", f.SchemaDescription); + } + + [Fact] + public void Equality_ComparersProduceExpectedResults() + { + Assert.True(ChatResponseFormat.Text == ChatResponseFormat.Text); + Assert.True(ChatResponseFormat.Text.Equals(ChatResponseFormat.Text)); + Assert.Equal(ChatResponseFormat.Text.GetHashCode(), ChatResponseFormat.Text.GetHashCode()); + Assert.False(ChatResponseFormat.Text.Equals(ChatResponseFormat.Json)); + Assert.False(ChatResponseFormat.Text.Equals(new ChatResponseFormatJson(null))); + Assert.False(ChatResponseFormat.Text.Equals(new ChatResponseFormatJson("{}"))); + + Assert.True(ChatResponseFormat.Json == ChatResponseFormat.Json); + Assert.True(ChatResponseFormat.Json.Equals(ChatResponseFormat.Json)); + Assert.False(ChatResponseFormat.Json.Equals(ChatResponseFormat.Text)); + Assert.False(ChatResponseFormat.Json.Equals(new ChatResponseFormatJson("{}"))); + + Assert.True(ChatResponseFormat.Json.Equals(new ChatResponseFormatJson(null))); + Assert.Equal(ChatResponseFormat.Json.GetHashCode(), new ChatResponseFormatJson(null).GetHashCode()); + + Assert.True(new ChatResponseFormatJson("{}").Equals(new ChatResponseFormatJson("{}"))); + Assert.Equal(new ChatResponseFormatJson("{}").GetHashCode(), new ChatResponseFormatJson("{}").GetHashCode()); + + Assert.False(new ChatResponseFormatJson("""{ "prop": 42 }""").Equals(new ChatResponseFormatJson("""{ "prop": 43 }"""))); + Assert.NotEqual(new ChatResponseFormatJson("""{ "prop": 42 }""").GetHashCode(), new ChatResponseFormatJson("""{ "prop": 43 }""").GetHashCode()); // technically not guaranteed + + Assert.False(new ChatResponseFormatJson("""{ "prop": 42 }""").Equals(new ChatResponseFormatJson("""{ "PROP": 42 }"""))); + Assert.NotEqual(new ChatResponseFormatJson("""{ "prop": 42 }""").GetHashCode(), new ChatResponseFormatJson("""{ "PROP": 42 }""").GetHashCode()); // technically not guaranteed + + Assert.True(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name", "description"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name", "description2"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name2", "description"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name2", "description2"))); + + Assert.Equal(new ChatResponseFormatJson("{}", "name", "description").GetHashCode(), new ChatResponseFormatJson("{}", "name", "description").GetHashCode()); + } + + [Fact] + public void Serialization_TextRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.Text, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"text"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.Text, result); + } + + [Fact] + public void Serialization_JsonRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.Json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"json"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.Json, result); + } + + [Fact] + public void Serialization_ForJsonSchemaRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.ForJsonSchema("[1,2,3]", "name", "description"), TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"json","schema":"[1,2,3]","schemaName":"name","schemaDescription":"description"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.ForJsonSchema("[1,2,3]", "name", "description"), result); + Assert.Equal("[1,2,3]", (result as ChatResponseFormatJson)?.Schema); + Assert.Equal("name", (result as ChatResponseFormatJson)?.SchemaName); + Assert.Equal("description", (result as ChatResponseFormatJson)?.SchemaDescription); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs new file mode 100644 index 00000000000..7761aa2fdc3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatRoleTests +{ + [Fact] + public void Constructor_Value_Roundtrips() + { + Assert.Equal("abc", new ChatRole("abc").Value); + } + + [Fact] + public void Constructor_NullOrWhiteSpace_Throws() + { + Assert.Throws(() => new ChatRole(null!)); + Assert.Throws(() => new ChatRole(" ")); + } + + [Fact] + public void Equality_UsesOrdinalIgnoreCaseComparison() + { + Assert.True(new ChatRole("abc").Equals(new ChatRole("ABC"))); + Assert.True(new ChatRole("abc").Equals((object)new ChatRole("ABC"))); + Assert.True(new ChatRole("abc") == new ChatRole("ABC")); + Assert.False(new ChatRole("abc") != new ChatRole("ABC")); + + Assert.False(new ChatRole("abc").Equals(new ChatRole("def"))); + Assert.False(new ChatRole("abc").Equals((object)new ChatRole("def"))); + Assert.False(new ChatRole("abc").Equals(null)); + Assert.False(new ChatRole("abc").Equals("abc")); + Assert.False(new ChatRole("abc") == new ChatRole("def")); + Assert.True(new ChatRole("abc") != new ChatRole("def")); + + Assert.Equal(new ChatRole("abc").GetHashCode(), new ChatRole("abc").GetHashCode()); + Assert.Equal(new ChatRole("abc").GetHashCode(), new ChatRole("ABC").GetHashCode()); + Assert.NotEqual(new ChatRole("abc").GetHashCode(), new ChatRole("def").GetHashCode()); // not guaranteed + } + + [Fact] + public void Singletons_UseKnownValues() + { + Assert.Equal("assistant", ChatRole.Assistant.Value); + Assert.Equal("system", ChatRole.System.Value); + Assert.Equal("tool", ChatRole.Tool.Value); + Assert.Equal("user", ChatRole.User.Value); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatRole role = new("abc"); + string? json = JsonSerializer.Serialize(role, TestJsonSerializerContext.Default.ChatRole); + Assert.Equal("\"abc\"", json); + + ChatRole? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatRole); + Assert.Equal(role, result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs new file mode 100644 index 00000000000..7cdda8ef975 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatToolModeTests +{ + [Fact] + public void Singletons_Idempotent() + { + Assert.Same(ChatToolMode.Auto, ChatToolMode.Auto); + Assert.Same(ChatToolMode.RequireAny, ChatToolMode.RequireAny); + } + + [Fact] + public void Equality_ComparersProduceExpectedResults() + { + Assert.True(ChatToolMode.Auto == ChatToolMode.Auto); + Assert.True(ChatToolMode.Auto.Equals(ChatToolMode.Auto)); + Assert.False(ChatToolMode.Auto.Equals(ChatToolMode.RequireAny)); + Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode(null))); + Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode("func"))); + Assert.Equal(ChatToolMode.Auto.GetHashCode(), ChatToolMode.Auto.GetHashCode()); + + Assert.True(ChatToolMode.RequireAny == ChatToolMode.RequireAny); + Assert.True(ChatToolMode.RequireAny.Equals(ChatToolMode.RequireAny)); + Assert.False(ChatToolMode.RequireAny.Equals(ChatToolMode.Auto)); + Assert.False(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode("func"))); + + Assert.True(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode(null))); + Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), new RequiredChatToolMode(null).GetHashCode()); + Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), ChatToolMode.RequireAny.GetHashCode()); + + Assert.True(new RequiredChatToolMode("func").Equals(new RequiredChatToolMode("func"))); + Assert.Equal(new RequiredChatToolMode("func").GetHashCode(), new RequiredChatToolMode("func").GetHashCode()); + + Assert.False(new RequiredChatToolMode("func1").Equals(new RequiredChatToolMode("func2"))); + Assert.NotEqual(new RequiredChatToolMode("func1").GetHashCode(), new RequiredChatToolMode("func2").GetHashCode()); // technically not guaranteed + + Assert.False(new RequiredChatToolMode("func1").Equals(new RequiredChatToolMode("FUNC1"))); + Assert.NotEqual(new RequiredChatToolMode("func1").GetHashCode(), new RequiredChatToolMode("FUNC1").GetHashCode()); // technically not guaranteed + } + + [Fact] + public void Serialization_AutoRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.Auto, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"auto"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.Auto, result); + } + + [Fact] + public void Serialization_RequireAnyRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.RequireAny, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"required"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.RequireAny, result); + } + + [Fact] + public void Serialization_RequireSpecificRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.RequireSpecific("myFunc"), TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"required","requiredFunctionName":"myFunc"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.RequireSpecific("myFunc"), result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs new file mode 100644 index 00000000000..51c82c7dcb7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingChatClientTests +{ + [Fact] + public void RequiresInnerChatClient() + { + Assert.Throws(() => new NoOpDelegatingChatClient(null!)); + } + + [Fact] + public void MetadataDefaultsToInnerClient() + { + using var inner = new TestChatClient(); + using var delegating = new NoOpDelegatingChatClient(inner); + + Assert.Same(inner.Metadata, delegating.Metadata); + } + + [Fact] + public async Task ChatAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedChatContents = new List(); + var expectedChatOptions = new ChatOptions(); + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + var expectedCompletion = new ChatCompletion([]); + using var inner = new TestChatClient + { + CompleteAsyncCallback = (chatContents, options, cancellationToken) => + { + Assert.Same(expectedChatContents, chatContents); + Assert.Same(expectedChatOptions, options); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var resultTask = delegating.CompleteAsync(expectedChatContents, expectedChatOptions, expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedCompletion); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedCompletion, await resultTask); + } + + [Fact] + public async Task ChatStreamingAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedChatContents = new List(); + var expectedChatOptions = new ChatOptions(); + var expectedCancellationToken = CancellationToken.None; + StreamingChatCompletionUpdate[] expectedResults = + [ + new() { Role = ChatRole.User, Text = "Message 1" }, + new() { Role = ChatRole.User, Text = "Message 2" } + ]; + + using var inner = new TestChatClient + { + CompleteStreamingAsyncCallback = (chatContents, options, cancellationToken) => + { + Assert.Same(expectedChatContents, chatContents); + Assert.Same(expectedChatOptions, options); + Assert.Equal(expectedCancellationToken, cancellationToken); + return YieldAsync(expectedResults); + } + }; + + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var resultAsyncEnumerable = delegating.CompleteStreamingAsync(expectedChatContents, expectedChatOptions, expectedCancellationToken); + + // Assert + var enumerator = resultAsyncEnumerable.GetAsyncEnumerator(); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedResults[0], enumerator.Current); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedResults[1], enumerator.Current); + Assert.False(await enumerator.MoveNextAsync()); + } + + [Fact] + public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() + { + // Arrange + using var inner = new TestChatClient(); + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var client = delegating.GetService(); + + // Assert + Assert.Same(delegating, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfKeyIsNotNull() + { + // Arrange + var expectedParam = new object(); + var expectedKey = new object(); + using var expectedResult = new TestChatClient(); + using var inner = new TestChatClient + { + GetServiceCallback = (_, _) => expectedResult + }; + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var client = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfNotCompatibleWithRequest() + { + // Arrange + var expectedParam = new object(); + var expectedResult = TimeZoneInfo.Local; + var expectedKey = new object(); + using var inner = new TestChatClient + { + GetServiceCallback = (type, key) => type == expectedResult.GetType() && key == expectedKey + ? expectedResult + : throw new InvalidOperationException("Unexpected call") + }; + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var tzi = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, tzi); + } + + private static async IAsyncEnumerable YieldAsync(IEnumerable input) + { + await Task.Yield(); + foreach (var item in input) + { + yield return item; + } + } + + private sealed class NoOpDelegatingChatClient(IChatClient innerClient) + : DelegatingChatClient(innerClient); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs new file mode 100644 index 00000000000..988727b1159 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs @@ -0,0 +1,220 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class StreamingChatCompletionUpdateTests +{ + [Fact] + public void Constructor_PropsDefaulted() + { + StreamingChatCompletionUpdate update = new(); + Assert.Null(update.AuthorName); + Assert.Null(update.Role); + Assert.Null(update.Text); + Assert.Empty(update.Contents); + Assert.Null(update.RawRepresentation); + Assert.Null(update.AdditionalProperties); + Assert.Null(update.CompletionId); + Assert.Null(update.CreatedAt); + Assert.Null(update.FinishReason); + Assert.Equal(0, update.ChoiceIndex); + Assert.Equal(string.Empty, update.ToString()); + } + + [Fact] + public void Properties_Roundtrip() + { + StreamingChatCompletionUpdate update = new(); + + Assert.Null(update.AuthorName); + update.AuthorName = "author"; + Assert.Equal("author", update.AuthorName); + + Assert.Null(update.Role); + update.Role = ChatRole.Assistant; + Assert.Equal(ChatRole.Assistant, update.Role); + + Assert.Empty(update.Contents); + update.Contents.Add(new TextContent("text")); + Assert.Single(update.Contents); + Assert.Equal("text", update.Text); + Assert.Same(update.Contents, update.Contents); + IList newList = [new TextContent("text")]; + update.Contents = newList; + Assert.Same(newList, update.Contents); + update.Contents = null; + Assert.NotNull(update.Contents); + Assert.Empty(update.Contents); + + Assert.Null(update.Text); + update.Text = "text"; + Assert.Equal("text", update.Text); + + Assert.Null(update.RawRepresentation); + object raw = new(); + update.RawRepresentation = raw; + Assert.Same(raw, update.RawRepresentation); + + Assert.Null(update.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { ["key"] = "value" }; + update.AdditionalProperties = props; + Assert.Same(props, update.AdditionalProperties); + + Assert.Null(update.CompletionId); + update.CompletionId = "id"; + Assert.Equal("id", update.CompletionId); + + Assert.Null(update.CreatedAt); + update.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), update.CreatedAt); + + Assert.Equal(0, update.ChoiceIndex); + update.ChoiceIndex = 42; + Assert.Equal(42, update.ChoiceIndex); + + Assert.Null(update.FinishReason); + update.FinishReason = ChatFinishReason.ContentFilter; + Assert.Equal(ChatFinishReason.ContentFilter, update.FinishReason); + } + + [Fact] + public void Text_GetSet_UsesFirstTextContent() + { + StreamingChatCompletionUpdate update = new() + { + Role = ChatRole.User, + Contents = + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new TextContent("text-1"), + new TextContent("text-2"), + new FunctionResultContent(new FunctionCallContent("callId1", "fc2"), "result"), + ], + }; + + TextContent textContent = Assert.IsType(update.Contents[3]); + Assert.Equal("text-1", textContent.Text); + Assert.Equal("text-1", update.Text); + Assert.Equal("text-1", update.ToString()); + + update.Text = "text-3"; + Assert.Equal("text-3", update.Text); + Assert.Equal("text-3", update.Text); + Assert.Same(textContent, update.Contents[3]); + Assert.Equal("text-3", update.ToString()); + } + + [Fact] + public void Text_Set_AddsTextMessageToEmptyList() + { + StreamingChatCompletionUpdate update = new() + { + Role = ChatRole.User, + }; + Assert.Empty(update.Contents); + + update.Text = "text-1"; + Assert.Equal("text-1", update.Text); + + Assert.Single(update.Contents); + TextContent textContent = Assert.IsType(update.Contents[0]); + Assert.Equal("text-1", textContent.Text); + } + + [Fact] + public void Text_Set_AddsTextMessageToListWithNoText() + { + StreamingChatCompletionUpdate update = new() + { + Contents = + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + ] + }; + Assert.Equal(3, update.Contents.Count); + + update.Text = "text-1"; + Assert.Equal("text-1", update.Text); + Assert.Equal(4, update.Contents.Count); + + update.Text = "text-2"; + Assert.Equal("text-2", update.Text); + Assert.Equal(4, update.Contents.Count); + + update.Contents.RemoveAt(3); + Assert.Equal(3, update.Contents.Count); + + update.Text = "text-3"; + Assert.Equal("text-3", update.Text); + Assert.Equal(4, update.Contents.Count); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + StreamingChatCompletionUpdate original = new() + { + AuthorName = "author", + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("text-1"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new DataContent("data"u8.ToArray()), + new TextContent("text-2"), + ], + RawRepresentation = new object(), + CompletionId = "id", + CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + FinishReason = ChatFinishReason.ContentFilter, + AdditionalProperties = new() { ["key"] = "value" }, + ChoiceIndex = 42, + }; + + string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.StreamingChatCompletionUpdate); + + StreamingChatCompletionUpdate? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.StreamingChatCompletionUpdate); + + Assert.NotNull(result); + Assert.Equal(5, result.Contents.Count); + + Assert.IsType(result.Contents[0]); + Assert.Equal("text-1", ((TextContent)result.Contents[0]).Text); + + Assert.IsType(result.Contents[1]); + Assert.Equal("http://localhost/image", ((ImageContent)result.Contents[1]).Uri); + + Assert.IsType(result.Contents[2]); + Assert.Equal("fc1", ((FunctionCallContent)result.Contents[2]).Name); + + Assert.IsType(result.Contents[3]); + Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data?.ToArray()); + + Assert.IsType(result.Contents[4]); + Assert.Equal("text-2", ((TextContent)result.Contents[4]).Text); + + Assert.Equal("author", result.AuthorName); + Assert.Equal(ChatRole.Assistant, result.Role); + Assert.Equal("id", result.CompletionId); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); + Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); + Assert.Equal(42, result.ChoiceIndex); + + Assert.NotNull(result.AdditionalProperties); + Assert.Single(result.AdditionalProperties); + Assert.True(result.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs new file mode 100644 index 00000000000..ece02f017bb --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + DerivedAIContent c = new(); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + DerivedAIContent c = new(); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + } + + private sealed class DerivedAIContent : AIContent; +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs new file mode 100644 index 00000000000..7aff849e8a1 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class AudioContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs new file mode 100644 index 00000000000..18aae8c0497 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class DataContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs new file mode 100644 index 00000000000..ea3017cf7ea --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs @@ -0,0 +1,249 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Reflection; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public abstract class DataContentTests + where T : DataContent +{ + private static T Create(params object?[] args) + { + try + { + return (T)Activator.CreateInstance(typeof(T), args)!; + } + catch (TargetInvocationException e) + { + throw e.InnerException!; + } + } + + public T CreateDataContent(Uri uri, string? mediaType = null) => Create(uri, mediaType)!; + +#pragma warning disable S3997 // String URI overloads should call "System.Uri" overloads + public T CreateDataContent(string uriString, string? mediaType = null) => Create(uriString, mediaType)!; +#pragma warning restore S3997 + + public T CreateDataContent(ReadOnlyMemory data, string? mediaType = null) => Create(data, mediaType)!; + + [Theory] + + // Invalid URI + [InlineData("", typeof(ArgumentException))] + [InlineData("invalid", typeof(UriFormatException))] + + // Format errors + [InlineData("data", typeof(UriFormatException))] // data missing colon + [InlineData("data:", typeof(UriFormatException))] // data missing comma + [InlineData("data:something,", typeof(UriFormatException))] // mime type without subtype + [InlineData("data:something;else,data", typeof(UriFormatException))] // mime type without subtype + [InlineData("data:type/subtype;;parameter=value;else,", typeof(UriFormatException))] // parameter without value + [InlineData("data:type/subtype;parameter=va=lue;else,", typeof(UriFormatException))] // parameter with multiple = + [InlineData("data:type/subtype;=value;else,", typeof(UriFormatException))] // empty parameter name + [InlineData("data:image/j/peg;base64,/9j/4AAQSkZJRgABAgAAZABkAAD", typeof(UriFormatException))] // multiple slashes in media type + + // Base64 Validation Errors + [InlineData("data:text;base64,something!", typeof(UriFormatException))] // Invalid base64 due to invalid character '!' + [InlineData("data:text/plain;base64,U29tZQ==\t", typeof(UriFormatException))] // Invalid base64 due to tab character + [InlineData("data:text/plain;base64,U29tZQ==\r", typeof(UriFormatException))] // Invalid base64 due to carriage return character + [InlineData("data:text/plain;base64,U29tZQ==\n", typeof(UriFormatException))] // Invalid base64 due to line feed character + [InlineData("data:text/plain;base64,U29t\r\nZQ==", typeof(UriFormatException))] // Invalid base64 due to carriage return and line feed characters + [InlineData("data:text/plain;base64,U29", typeof(UriFormatException))] // Invalid base64 due to missing padding + [InlineData("data:text/plain;base64,U29tZQ", typeof(UriFormatException))] // Invalid base64 due to missing padding + [InlineData("data:text/plain;base64,U29tZQ=", typeof(UriFormatException))] // Invalid base64 due to missing padding + public void Ctor_InvalidUri_Throws(string path, Type exception) + { + Assert.Throws(exception, () => CreateDataContent(path)); + } + + [Theory] + [InlineData("type")] + [InlineData("type//subtype")] + [InlineData("type/subtype/")] + [InlineData("type/subtype;key=")] + [InlineData("type/subtype;=value")] + [InlineData("type/subtype;key=value;another=")] + public void Ctor_InvalidMediaType_Throws(string mediaType) + { + Assert.Throws(() => CreateDataContent("http://localhost/test", mediaType)); + } + + [Theory] + [InlineData("type/subtype")] + [InlineData("type/subtype;key=value")] + [InlineData("type/subtype;key=value;another=value")] + [InlineData("type/subtype;key=value;another=value;yet_another=value")] + public void Ctor_ValidMediaType_Roundtrips(string mediaType) + { + T content = CreateDataContent("http://localhost/test", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent("data:,", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent("data:text/plain,", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(new Uri("data:text/plain,"), mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(new byte[] { 0, 1, 2 }, mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(content.Uri); + Assert.Equal(mediaType, content.MediaType); + } + + [Fact] + public void Ctor_NoMediaType_Roundtrips() + { + T content; + + foreach (string url in new[] { "http://localhost/test", "about:something", "file://c:\\path" }) + { + content = CreateDataContent(url); + Assert.Equal(url, content.Uri); + Assert.Null(content.MediaType); + Assert.Null(content.Data); + } + + content = CreateDataContent("data:,something"); + Assert.Equal("data:,something", content.Uri); + Assert.Null(content.MediaType); + Assert.Equal("something"u8.ToArray(), content.Data!.Value.ToArray()); + + content = CreateDataContent("data:,Hello+%3C%3E"); + Assert.Equal("data:,Hello+%3C%3E", content.Uri); + Assert.Null(content.MediaType); + Assert.Equal("Hello <>"u8.ToArray(), content.Data!.Value.ToArray()); + } + + [Fact] + public void Serialize_MatchesExpectedJson() + { + Assert.Equal( + """{"uri":"data:,"}""", + JsonSerializer.Serialize(CreateDataContent("data:,"), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"http://localhost/"}""", + JsonSerializer.Serialize(CreateDataContent(new Uri("http://localhost/")), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + JsonSerializer.Serialize(CreateDataContent( + uriString: "data:application/octet-stream;base64,AQIDBA=="), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + JsonSerializer.Serialize(CreateDataContent( + new ReadOnlyMemory([0x01, 0x02, 0x03, 0x04]), "application/octet-stream"), + TestJsonSerializerContext.Default.Options)); + } + + [Theory] + [InlineData("{}")] + [InlineData("""{ "mediaType":"text/plain" }""")] + public void Deserialize_MissingUriString_Throws(string json) + { + Assert.Throws(() => JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options)!); + } + + [Fact] + public void Deserialize_MatchesExpectedData() + { + // Data + MimeType only + var content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"data:;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; + + Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); + Assert.NotNull(content.Data); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data!.Value.ToArray()); + Assert.Equal("application/octet-stream", content.MediaType); + Assert.True(content.ContainsData); + + // Uri referenced content-only + content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"http://localhost/"}""", TestJsonSerializerContext.Default.Options)!; + + Assert.Null(content.Data); + Assert.Equal("http://localhost/", content.Uri); + Assert.Equal("application/octet-stream", content.MediaType); + Assert.False(content.ContainsData); + + // Using extra metadata + content = JsonSerializer.Deserialize(""" + { + "uri": "data:;base64,AQIDBA==", + "modelId": "gpt-4", + "additionalProperties": + { + "key": "value" + }, + "mediaType": "text/plain" + } + """, TestJsonSerializerContext.Default.Options)!; + + Assert.Equal("data:text/plain;base64,AQIDBA==", content.Uri); + Assert.NotNull(content.Data); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data!.Value.ToArray()); + Assert.Equal("text/plain", content.MediaType); + Assert.True(content.ContainsData); + Assert.Equal("gpt-4", content.ModelId); + Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString()); + } + + [Theory] + [InlineData( + """{"uri": "data:;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( // Does not support non-readable content + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", "unexpected": true}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( // Uri comes before mimetype + """{"mediaType": "text/plain", "uri": "http://localhost/" }""", + """{"uri":"http://localhost/","mediaType":"text/plain"}""")] + public void Serialize_Deserialize_Roundtrips(string serialized, string expectedToString) + { + var content = JsonSerializer.Deserialize(serialized, TestJsonSerializerContext.Default.Options)!; + var reSerialization = JsonSerializer.Serialize(content, TestJsonSerializerContext.Default.Options); + Assert.Equal(expectedToString, reSerialization); + } + + [Theory] + [InlineData("application/json")] + [InlineData("application/octet-stream")] + [InlineData("application/pdf")] + [InlineData("application/xml")] + [InlineData("audio/mpeg")] + [InlineData("audio/ogg")] + [InlineData("audio/wav")] + [InlineData("image/apng")] + [InlineData("image/avif")] + [InlineData("image/bmp")] + [InlineData("image/gif")] + [InlineData("image/jpeg")] + [InlineData("image/png")] + [InlineData("image/svg+xml")] + [InlineData("image/tiff")] + [InlineData("image/webp")] + [InlineData("text/css")] + [InlineData("text/csv")] + [InlineData("text/html")] + [InlineData("text/javascript")] + [InlineData("text/plain")] + [InlineData("text/plain;charset=UTF-8")] + [InlineData("text/xml")] + [InlineData("custom/mediatypethatdoesntexists")] + public void MediaType_Roundtrips(string mediaType) + { + DataContent c = new("data:,", mediaType); + Assert.Equal(mediaType, c.MediaType); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs new file mode 100644 index 00000000000..791bb4cc0e7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs @@ -0,0 +1,302 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +#if NET +using System.Runtime.ExceptionServices; +#endif +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionCallContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + FunctionCallContent c = new("callId1", "name"); + + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + Assert.Equal("name", c.Name); + + Assert.Null(c.Arguments); + Assert.Null(c.Exception); + } + + [Fact] + public void Constructor_ArgumentsRoundtrip() + { + Dictionary args = []; + + FunctionCallContent c = new("id", "name", args); + + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Equal("name", c.Name); + Assert.Equal("id", c.CallId); + Assert.Same(args, c.Arguments); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + FunctionCallContent c = new("callId1", "name"); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + c.CallId = "id"; + Assert.Equal("id", c.CallId); + + Assert.Null(c.Arguments); + AdditionalPropertiesDictionary args = new() { { "key", "value" } }; + c.Arguments = args; + Assert.Same(args, c.Arguments); + + Assert.Null(c.Exception); + Exception e = new(); + c.Exception = e; + Assert.Same(e, c.Exception); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializableWithException() + { + // Arrange + var ex = new InvalidOperationException("hello", new NullReferenceException("bye")); +#if NET + ExceptionDispatchInfo.SetRemoteStackTrace(ex, "stack trace"); +#endif + var sut = new FunctionCallContent("callId1", "functionName") { Exception = ex }; + + // Act + var json = JsonSerializer.SerializeToNode(sut, TestJsonSerializerContext.Default.Options); + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + JsonObject jsonEx = Assert.IsType(json!["exception"]); + Assert.Equal(4, jsonEx.Count); + Assert.Equal("System.InvalidOperationException", (string?)jsonEx["className"]); + Assert.Equal("hello", (string?)jsonEx["message"]); +#if NET + Assert.StartsWith("stack trace", (string?)jsonEx["stackTraceString"]); +#endif + JsonObject jsonExInner = Assert.IsType(jsonEx["innerException"]); + Assert.Equal(4, jsonExInner.Count); + Assert.Equal("System.NullReferenceException", (string?)jsonExInner["className"]); + Assert.Equal("bye", (string?)jsonExInner["message"]); + Assert.Null(jsonExInner["innerException"]); + Assert.Null(jsonExInner["stackTraceString"]); + + Assert.NotNull(deserializedSut); + Assert.IsType(deserializedSut.Exception); + Assert.Equal("hello", deserializedSut.Exception.Message); +#if NET + Assert.StartsWith("stack trace", deserializedSut.Exception.StackTrace); +#endif + + Assert.IsType(deserializedSut.Exception.InnerException); + Assert.Equal("bye", deserializedSut.Exception.InnerException.Message); + Assert.Null(deserializedSut.Exception.InnerException.StackTrace); + Assert.Null(deserializedSut.Exception.InnerException.InnerException); + } + + [Fact] + public async Task AIFunctionFactory_ObjectValues_Converted() + { + Dictionary arguments = new() + { + ["a"] = new DayOfWeek[] { DayOfWeek.Monday, DayOfWeek.Tuesday, DayOfWeek.Wednesday }, + ["b"] = 123.4M, + ["c"] = "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + ["d"] = new ReadOnlyDictionary((new Dictionary + { + ["p1"] = "42", + ["p2"] = "43", + })), + }; + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task AIFunctionFactory_JsonElementValues_ValuesDeserialized() + { + Dictionary arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!; + Assert.All(arguments.Values, v => Assert.IsType(v)); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public void AIFunctionFactory_WhenTypesUnknownByContext_Throws() + { + var ex = Assert.Throws(() => AIFunctionFactory.Create((CustomType arg) => { }, TestJsonSerializerContext.Default.Options)); + Assert.Contains("JsonTypeInfo metadata", ex.Message); + Assert.Contains(nameof(CustomType), ex.Message); + + ex = Assert.Throws(() => AIFunctionFactory.Create(() => new CustomType(), TestJsonSerializerContext.Default.Options)); + Assert.Contains("JsonTypeInfo metadata", ex.Message); + Assert.Contains(nameof(CustomType), ex.Message); + } + + [Fact] + public async Task AIFunctionFactory_JsonDocumentValues_ValuesDeserialized() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task AIFunctionFactory_JsonNodeValues_ValuesDeserialized() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task TypelessAIFunction_JsonDocumentValues_AcceptsArguments() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + [Fact] + public async Task TypelessAIFunction_JsonElementValues_AcceptsArguments() + { + Dictionary arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!; + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + [Fact] + public async Task TypelessAIFunction_JsonNodeValues_AcceptsArguments() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + private sealed class CustomType; + + private sealed class NetTypelessAIFunction : AIFunction + { + public static NetTypelessAIFunction Instance { get; } = new NetTypelessAIFunction(); + + public override AIFunctionMetadata Metadata => new("NetTypeless") + { + Description = "AIFunction with parameters that lack .NET types", + Parameters = + [ + new AIFunctionParameterMetadata("a"), + new AIFunctionParameterMetadata("b"), + new AIFunctionParameterMetadata("c"), + new AIFunctionParameterMetadata("d"), + new AIFunctionParameterMetadata("e"), + new AIFunctionParameterMetadata("f"), + ] + }; + + protected override Task InvokeCoreAsync(IEnumerable>? arguments, CancellationToken cancellationToken) => + Task.FromResult(arguments); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs new file mode 100644 index 00000000000..a24120ca9a9 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs @@ -0,0 +1,120 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionResultContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + FunctionResultContent c = new("callId1", "functionName"); + Assert.Equal("callId1", c.CallId); + Assert.Equal("functionName", c.Name); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Null(c.Result); + Assert.Null(c.Exception); + } + + [Fact] + public void Constructor_String_PropsRoundtrip() + { + Exception e = new(); + + FunctionResultContent c = new("id", "name", "result", e); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal("name", c.Name); + Assert.Equal("id", c.CallId); + Assert.Equal("result", c.Result); + Assert.Same(e, c.Exception); + } + + [Fact] + public void Constructor_FunctionCallContent_PropsRoundtrip() + { + Exception e = new(); + + FunctionResultContent c = new(new FunctionCallContent("id", "name"), "result", e); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal("id", c.CallId); + Assert.Equal("result", c.Result); + Assert.Same(e, c.Exception); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + FunctionResultContent c = new("callId1", "functionName"); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + c.CallId = "id"; + Assert.Equal("id", c.CallId); + + Assert.Null(c.Result); + c.Result = "result"; + Assert.Equal("result", c.Result); + + Assert.Null(c.Exception); + Exception e = new(); + c.Exception = e; + Assert.Same(e, c.Exception); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializable() + { + // Arrange + var sut = new FunctionResultContent(new FunctionCallContent("id", "p1-f1"), "result"); + + // Act + var json = JsonSerializer.Serialize(sut, TestJsonSerializerContext.Default.Options); + + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.NotNull(deserializedSut); + Assert.Equal(sut.Name, deserializedSut.Name); + Assert.Equal(sut.CallId, deserializedSut.CallId); + Assert.Equal(sut.Result, deserializedSut.Result?.ToString()); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializableWithException() + { + // Arrange + var sut = new FunctionResultContent("callId1", "functionName") { Exception = new InvalidOperationException("hello") }; + + // Act + var json = JsonSerializer.Serialize(sut, TestJsonSerializerContext.Default.Options); + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.NotNull(deserializedSut); + Assert.IsType(deserializedSut.Exception); + Assert.Contains("hello", deserializedSut.Exception.Message); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs new file mode 100644 index 00000000000..7b088e3ebf3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class ImageContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs new file mode 100644 index 00000000000..d1ba5e83bc9 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class TextContentTests +{ + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData("text")] + public void Constructor_String_PropsDefault(string? text) + { + TextContent c = new(text); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal(text, c.Text); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + TextContent c = new(null); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Null(c.Text); + c.Text = "text"; + Assert.Equal("text", c.Text); + Assert.Equal("text", c.ToString()); + + c.Text = null; + Assert.Null(c.Text); + Assert.Equal(string.Empty, c.ToString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs new file mode 100644 index 00000000000..109bdc8120e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class UsageContentTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("details", () => new UsageContent(null!)); + } + + [Fact] + public void Constructor_Parameterless_PropsDefault() + { + UsageContent c = new(); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.NotNull(c.Details); + Assert.Same(c.Details, c.Details); + Assert.Null(c.Details.InputTokenCount); + Assert.Null(c.Details.OutputTokenCount); + Assert.Null(c.Details.TotalTokenCount); + Assert.Null(c.Details.AdditionalProperties); + } + + [Fact] + public void Constructor_UsageDetails_PropsRoundtrip() + { + UsageDetails details = new(); + + UsageContent c = new(details); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Same(details, c.Details); + + UsageDetails details2 = new(); + c.Details = details2; + Assert.Same(details2, c.Details); + } + + [Fact] + public void Details_SetNull_Throws() + { + UsageContent c = new(); + + UsageDetails d = c.Details; + Assert.NotNull(d); + + Assert.Throws("value", () => c.Details = null!); + + Assert.Same(d, c.Details); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..91640e62f4f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingEmbeddingGeneratorTests +{ + [Fact] + public void RequiresInnerService() + { + Assert.Throws(() => new NoOpDelegatingEmbeddingGenerator(null!)); + } + + [Fact] + public void MetadataDefaultsToInnerService() + { + using var inner = new TestEmbeddingGenerator(); + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + Assert.Same(inner.Metadata, delegating.Metadata); + } + + [Fact] + public async Task GenerateEmbeddingsDefaultsToInnerServiceAsync() + { + // Arrange + var expectedInput = new List(); + using var cts = new CancellationTokenSource(); + var expectedCancellationToken = cts.Token; + var expectedResult = new TaskCompletionSource>>(); + var expectedEmbedding = new GeneratedEmbeddings>([new(new float[] { 1.0f, 2.0f, 3.0f })]); + using var inner = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (input, options, cancellationToken) => + { + Assert.Same(expectedInput, input); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var resultTask = delegating.GenerateAsync(expectedInput, options: null, expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedEmbedding); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedEmbedding, await resultTask); + } + + [Fact] + public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() + { + // Arrange + using var inner = new TestEmbeddingGenerator(); + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService>>(); + + // Assert + Assert.Same(delegating, service); + } + + [Fact] + public void GetServiceDelegatesToInnerIfKeyIsNotNull() + { + // Arrange + var expectedParam = new object(); + var expectedKey = new object(); + using var expectedResult = new TestEmbeddingGenerator(); + using var inner = new TestEmbeddingGenerator + { + GetServiceCallback = (_, _) => expectedResult + }; + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService>>(expectedKey); + + // Assert + Assert.Same(expectedResult, service); + } + + [Fact] + public void GetServiceDelegatesToInnerIfNotCompatibleWithRequest() + { + // Arrange + var expectedParam = new object(); + var expectedResult = TimeZoneInfo.Local; + var expectedKey = new object(); + using var inner = new TestEmbeddingGenerator + { + GetServiceCallback = (type, key) => type == expectedResult.GetType() && key == expectedKey + ? expectedResult + : throw new InvalidOperationException("Unexpected call") + }; + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, service); + } + + private sealed class NoOpDelegatingEmbeddingGenerator(IEmbeddingGenerator> innerGenerator) : + DelegatingEmbeddingGenerator>(innerGenerator); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs new file mode 100644 index 00000000000..e9dd45959c7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGenerationOptionsTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + EmbeddingGenerationOptions options = new(); + Assert.Null(options.ModelId); + Assert.Null(options.AdditionalProperties); + + EmbeddingGenerationOptions clone = options.Clone(); + Assert.Null(clone.ModelId); + Assert.Null(clone.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrip() + { + EmbeddingGenerationOptions options = new(); + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.ModelId = "modelId"; + options.AdditionalProperties = additionalProps; + + Assert.Equal("modelId", options.ModelId); + Assert.Same(additionalProps, options.AdditionalProperties); + + EmbeddingGenerationOptions clone = options.Clone(); + Assert.Equal("modelId", clone.ModelId); + Assert.Equal(additionalProps, clone.AdditionalProperties); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + EmbeddingGenerationOptions options = new(); + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.ModelId = "model"; + options.AdditionalProperties = additionalProps; + + string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.EmbeddingGenerationOptions); + + EmbeddingGenerationOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.EmbeddingGenerationOptions); + Assert.NotNull(deserialized); + + Assert.Equal("model", deserialized.ModelId); + + Assert.NotNull(deserialized.AdditionalProperties); + Assert.Single(deserialized.AdditionalProperties); + Assert.True(deserialized.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs new file mode 100644 index 00000000000..827ed04c712 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorExtensionsTests +{ + [Fact] + public async Task GenerateAsync_InvalidArgs_ThrowsAsync() + { + await Assert.ThrowsAsync("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello")); + } + + [Fact] + public async Task GenerateAsync_ReturnsSingleEmbeddingAsync() + { + Embedding result = new(new float[] { 1f, 2f, 3f }); + + using TestEmbeddingGenerator service = new() + { + GenerateAsyncCallback = (values, options, cancellationToken) => + Task.FromResult>>([result]) + }; + + Assert.Same(result, (await service.GenerateAsync("hello"))[0]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs new file mode 100644 index 00000000000..b3cd0d59abb --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorMetadataTests +{ + [Fact] + public void Constructor_NullValues_AllowedAndRoundtrip() + { + EmbeddingGeneratorMetadata metadata = new(null, null, null, null); + Assert.Null(metadata.ProviderName); + Assert.Null(metadata.ProviderUri); + Assert.Null(metadata.ModelId); + Assert.Null(metadata.Dimensions); + } + + [Fact] + public void Constructor_Value_Roundtrips() + { + var uri = new Uri("https://example.com"); + EmbeddingGeneratorMetadata metadata = new("providerName", uri, "theModel", 42); + Assert.Equal("providerName", metadata.ProviderName); + Assert.Same(uri, metadata.ProviderUri); + Assert.Equal("theModel", metadata.ModelId); + Assert.Equal(42, metadata.Dimensions); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs new file mode 100644 index 00000000000..45fcce8ba63 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingTests +{ + [Fact] + public void Embedding_Ctor_Roundtrips() + { + float[] floats = [1f, 2f, 3f]; + UsageDetails usage = new(); + AdditionalPropertiesDictionary props = []; + var createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z"); + const string Model = "text-embedding-3-small"; + + Embedding e = new(floats) + { + CreatedAt = createdAt, + ModelId = Model, + AdditionalProperties = props, + }; + + Assert.Equal(floats, e.Vector.ToArray()); + Assert.Equal(Model, e.ModelId); + Assert.Same(props, e.AdditionalProperties); + Assert.Equal(createdAt, e.CreatedAt); + + Assert.True(MemoryMarshal.TryGetArray(e.Vector, out ArraySegment array)); + Assert.Same(floats, array.Array); + } + +#if NET + [Fact] + public void Embedding_Half_SerializationRoundtrips() + { + Half[] halfs = [(Half)1f, (Half)2f, (Half)3f]; + Embedding e = new(halfs); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"halves","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } +#endif + + [Fact] + public void Embedding_Single_SerializationRoundtrips() + { + float[] floats = [1f, 2f, 3f]; + Embedding e = new(floats); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"floats","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } + + [Fact] + public void Embedding_Double_SerializationRoundtrips() + { + double[] floats = [1f, 2f, 3f]; + Embedding e = new(floats); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"doubles","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs new file mode 100644 index 00000000000..4ebd9465ca8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs @@ -0,0 +1,246 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +#pragma warning disable xUnit2013 // Do not use equality check to check for collection size. +#pragma warning disable xUnit2017 // Do not use Contains() to check if a value exists in a collection + +namespace Microsoft.Extensions.AI; + +public class GeneratedEmbeddingsTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("embeddings", () => new GeneratedEmbeddings>(null!)); + Assert.Throws("capacity", () => new GeneratedEmbeddings>(-1)); + } + + [Fact] + public void Ctor_ValidArgs_NoExceptions() + { + GeneratedEmbeddings>[] instances = + [ + [], + new(0), + new(42), + new([]) + ]; + + foreach (var instance in instances) + { + Assert.Empty(instance); + + Assert.False(((ICollection>)instance).IsReadOnly); + Assert.Equal(0, instance.Count); + + Assert.False(instance.Contains(new Embedding(new float[] { 1, 2, 3 }))); + Assert.False(instance.Contains(null!)); + + Assert.Equal(-1, instance.IndexOf(new Embedding(new float[] { 1, 2, 3 }))); + Assert.Equal(-1, instance.IndexOf(null!)); + + instance.CopyTo(Array.Empty>(), 0); + + Assert.Throws(() => instance[0]); + Assert.Throws(() => instance[-1]); + } + } + + [Fact] + public void Ctor_RoundtripsEnumerable() + { + List> embeddings = + [ + new(new float[] { 1, 2, 3 }), + new(new float[] { 4, 5, 6 }), + ]; + + var generatedEmbeddings = new GeneratedEmbeddings>(embeddings); + + Assert.Equal(embeddings, generatedEmbeddings); + Assert.Equal(2, generatedEmbeddings.Count); + + Assert.Same(embeddings[0], generatedEmbeddings[0]); + Assert.Same(embeddings[1], generatedEmbeddings[1]); + + Assert.Equal(0, generatedEmbeddings.IndexOf(embeddings[0])); + Assert.Equal(1, generatedEmbeddings.IndexOf(embeddings[1])); + + Assert.True(generatedEmbeddings.Contains(embeddings[0])); + Assert.True(generatedEmbeddings.Contains(embeddings[1])); + + Assert.False(generatedEmbeddings.Contains(null!)); + Assert.Equal(-1, generatedEmbeddings.IndexOf(null!)); + + Assert.Throws(() => generatedEmbeddings[-1]); + Assert.Throws(() => generatedEmbeddings[2]); + + Assert.True(embeddings.SequenceEqual(generatedEmbeddings)); + + var e = new Embedding(new float[] { 7, 8, 9 }); + generatedEmbeddings.Add(e); + Assert.Equal(3, generatedEmbeddings.Count); + Assert.Same(e, generatedEmbeddings[2]); + } + + [Fact] + public void Properties_Roundtrip() + { + GeneratedEmbeddings> embeddings = []; + + Assert.Null(embeddings.Usage); + + UsageDetails usage = new(); + embeddings.Usage = usage; + Assert.Same(usage, embeddings.Usage); + embeddings.Usage = null; + Assert.Null(embeddings.Usage); + + Assert.Null(embeddings.AdditionalProperties); + AdditionalPropertiesDictionary props = []; + embeddings.AdditionalProperties = props; + Assert.Same(props, embeddings.AdditionalProperties); + embeddings.AdditionalProperties = null; + Assert.Null(embeddings.AdditionalProperties); + } + + [Fact] + public void Add() + { + GeneratedEmbeddings> embeddings = []; + var e = new Embedding(new float[] { 1, 2, 3 }); + + embeddings.Add(e); + Assert.Equal(1, embeddings.Count); + Assert.Same(e, embeddings[0]); + } + + [Fact] + public void AddRange() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + + Assert.Equal(2, embeddings.Count); + Assert.Same(e1, embeddings[0]); + Assert.Same(e2, embeddings[1]); + } + + [Fact] + public void Clear() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + embeddings.Clear(); + Assert.Equal(0, embeddings.Count); + Assert.Empty(embeddings); + } + + [Fact] + public void Remove() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + Assert.True(embeddings.Remove(e1)); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + Assert.False(embeddings.Remove(e1)); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + Assert.True(embeddings.Remove(e2)); + Assert.Equal(0, embeddings.Count); + } + + [Fact] + public void RemoveAt() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + embeddings.RemoveAt(0); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + embeddings.RemoveAt(0); + Assert.Equal(0, embeddings.Count); + } + + [Fact] + public void Insert() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + var e3 = new Embedding(new float[] { 7, 8, 9 }); + embeddings.Insert(1, e3); + Assert.Equal(3, embeddings.Count); + Assert.Same(e3, embeddings[1]); + Assert.Same(e2, embeddings[2]); + } + + [Fact] + public void Indexer() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + var e3 = new Embedding(new float[] { 7, 8, 9 }); + embeddings[1] = e3; + Assert.Equal(2, embeddings.Count); + Assert.Same(e1, embeddings[0]); + Assert.Same(e3, embeddings[1]); + } + + [Fact] + public void Indexer_InvalidIndex_Throws() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + Assert.Throws(() => embeddings[-1]); + Assert.Throws(() => embeddings[2]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs new file mode 100644 index 00000000000..a1aa48bd115 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionMetadataTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("name", () => new AIFunctionMetadata((string)null!)); + Assert.Throws("name", () => new AIFunctionMetadata(" \t ")); + Assert.Throws("metadata", () => new AIFunctionMetadata((AIFunctionMetadata)null!)); + } + + [Fact] + public void Constructor_String_PropsDefaulted() + { + AIFunctionMetadata f = new("name"); + Assert.Equal("name", f.Name); + Assert.Empty(f.Description); + Assert.Empty(f.Parameters); + + Assert.NotNull(f.ReturnParameter); + Assert.Null(f.ReturnParameter.Schema); + Assert.Null(f.ReturnParameter.ParameterType); + Assert.Null(f.ReturnParameter.Description); + + Assert.NotNull(f.AdditionalProperties); + Assert.Empty(f.AdditionalProperties); + Assert.Same(f.AdditionalProperties, new AIFunctionMetadata("name2").AdditionalProperties); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionMetadata f1 = new("name") + { + Description = "description", + Parameters = [new AIFunctionParameterMetadata("param")], + ReturnParameter = new AIFunctionReturnParameterMetadata(), + AdditionalProperties = new Dictionary { { "key", "value" } }, + }; + + AIFunctionMetadata f2 = new(f1); + Assert.Equal(f1.Name, f2.Name); + Assert.Equal(f1.Description, f2.Description); + Assert.Same(f1.Parameters, f2.Parameters); + Assert.Same(f1.ReturnParameter, f2.ReturnParameter); + Assert.Same(f1.AdditionalProperties, f2.AdditionalProperties); + } + + [Fact] + public void Props_InvalidArg_Throws() + { + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { Parameters = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { ReturnParameter = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { AdditionalProperties = null! }); + } + + [Fact] + public void Description_NullNormalizedToEmpty() + { + AIFunctionMetadata f = new("name") { Description = null }; + Assert.Equal("", f.Description); + } + + [Fact] + public void GetParameter_EmptyCollection_ReturnsNull() + { + Assert.Null(new AIFunctionMetadata("name").GetParameter("test")); + } + + [Fact] + public void GetParameter_ByName_ReturnsParameter() + { + AIFunctionMetadata f = new("name") + { + Parameters = + [ + new AIFunctionParameterMetadata("param0"), + new AIFunctionParameterMetadata("param1"), + new AIFunctionParameterMetadata("param2"), + ] + }; + + Assert.Same(f.Parameters[0], f.GetParameter("param0")); + Assert.Same(f.Parameters[1], f.GetParameter("param1")); + Assert.Same(f.Parameters[2], f.GetParameter("param2")); + Assert.Null(f.GetParameter("param3")); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs new file mode 100644 index 00000000000..23c33ecf07a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionParameterMetadataTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("name", () => new AIFunctionParameterMetadata((string)null!)); + Assert.Throws("name", () => new AIFunctionParameterMetadata(" ")); + Assert.Throws("metadata", () => new AIFunctionParameterMetadata((AIFunctionParameterMetadata)null!)); + } + + [Fact] + public void Constructor_String_PropsDefaulted() + { + AIFunctionParameterMetadata p = new("name"); + Assert.Equal("name", p.Name); + Assert.Null(p.Description); + Assert.Null(p.DefaultValue); + Assert.False(p.IsRequired); + Assert.Null(p.ParameterType); + Assert.Null(p.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionParameterMetadata p1 = new("name") + { + Description = "description", + HasDefaultValue = true, + DefaultValue = 42, + IsRequired = true, + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionParameterMetadata p2 = new(p1); + + Assert.Equal(p1.Name, p2.Name); + Assert.Equal(p1.Description, p2.Description); + Assert.Equal(p1.DefaultValue, p2.DefaultValue); + Assert.Equal(p1.IsRequired, p2.IsRequired); + Assert.Equal(p1.ParameterType, p2.ParameterType); + Assert.Equal(p1.Schema, p2.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagatedAndOverwritten() + { + AIFunctionParameterMetadata p1 = new("name") + { + Description = "description", + HasDefaultValue = true, + DefaultValue = 42, + IsRequired = true, + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionParameterMetadata p2 = new(p1) + { + Description = "description2", + HasDefaultValue = true, + DefaultValue = 43, + IsRequired = false, + ParameterType = typeof(long), + Schema = JsonDocument.Parse("""{"type":"number"}"""), + }; + + Assert.Equal("description2", p2.Description); + Assert.True(p2.HasDefaultValue); + Assert.Equal(43, p2.DefaultValue); + Assert.False(p2.IsRequired); + Assert.Equal(typeof(long), p2.ParameterType); + } + + [Fact] + public void Props_InvalidArg_Throws() + { + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = "\r\n\t " }); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs new file mode 100644 index 00000000000..bb5bbeec03a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionReturnParameterMetadataTests +{ + [Fact] + public void Constructor_PropsDefaulted() + { + AIFunctionReturnParameterMetadata p = new(); + Assert.Null(p.Description); + Assert.Null(p.ParameterType); + Assert.Null(p.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionReturnParameterMetadata p1 = new() + { + Description = "description", + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionReturnParameterMetadata p2 = new(p1); + Assert.Equal(p1.Description, p2.Description); + Assert.Equal(p1.ParameterType, p2.ParameterType); + Assert.Equal(p1.Schema, p2.Schema); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs new file mode 100644 index 00000000000..df143e8b97e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionTests +{ + [Fact] + public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync() + { + DerivedAIFunction f = new(); + + using CancellationTokenSource cts = new(); + var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + + Assert.NotNull(result1.Item1); + Assert.Empty(result1.Item1); + Assert.Equal(cts.Token, result1.Item2); + + var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + Assert.Same(result1.Item1, result2.Item1); + } + + [Fact] + public void ToString_ReturnsName() + { + DerivedAIFunction f = new(); + Assert.Equal("name", f.ToString()); + } + + private sealed class DerivedAIFunction : AIFunction + { + public override AIFunctionMetadata Metadata => new("name"); + + protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) + { + Assert.NotNull(arguments); + return Task.FromResult((arguments, cancellationToken)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj new file mode 100644 index 00000000000..0d4d5fbfa96 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj @@ -0,0 +1,24 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.Abstractions. + + + + $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003 + true + + + + true + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs new file mode 100644 index 00000000000..55f4c486483 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +public sealed class TestChatClient : IChatClient +{ + public IServiceProvider? Services { get; set; } + + public ChatClientMetadata Metadata { get; set; } = new(); + + public Func, ChatOptions?, CancellationToken, Task>? CompleteAsyncCallback { get; set; } + + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? CompleteStreamingAsyncCallback { get; set; } + + public Func? GetServiceCallback { get; set; } + + public Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => CompleteAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + + public IAsyncEnumerable CompleteStreamingAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + + public TService? GetService(object? key = null) + where TService : class + => (TService?)GetServiceCallback!(typeof(TService), key); + + void IDisposable.Dispose() + { + // No resources need disposing. + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs new file mode 100644 index 00000000000..83680a2be10 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +public sealed class TestEmbeddingGenerator : IEmbeddingGenerator> +{ + public EmbeddingGeneratorMetadata Metadata { get; } = new(); + + public Func, EmbeddingGenerationOptions?, CancellationToken, Task>>>? GenerateAsyncCallback { get; set; } + + public Func? GetServiceCallback { get; set; } + + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + => GenerateAsyncCallback!.Invoke(values, options, cancellationToken); + + public TService? GetService(object? key = null) + where TService : class + => (TService?)GetServiceCallback!(typeof(TService), key); + + void IDisposable.Dispose() + { + // No resources to dispose + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..5a3e966c17b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + UseStringEnumConverter = true)] +[JsonSerializable(typeof(ChatCompletion))] +[JsonSerializable(typeof(StreamingChatCompletionUpdate))] +[JsonSerializable(typeof(ChatOptions))] +[JsonSerializable(typeof(EmbeddingGenerationOptions))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(int[]))] // Used in ChatMessageContentTests +[JsonSerializable(typeof(Embedding))] // Used in EmbeddingTests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(ReadOnlyDictionary))] // Used in Content tests +[JsonSerializable(typeof(DayOfWeek[]))] // Used in Content tests +[JsonSerializable(typeof(Guid))] // Used in Content tests +[JsonSerializable(typeof(decimal))] // Used in Content tests +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs new file mode 100644 index 00000000000..29aef62fd77 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Microsoft.TestUtilities; + +namespace Microsoft.Extensions.AI; + +public class AzureAIInferenceChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetChatCompletionsClient() + ?.AsChatClient(Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_CHAT_MODEL") ?? "gpt-4o-mini"); + + public override Task CompleteStreamingAsync_UsageDataAvailable() => + throw new SkipTestException("Azure.AI.Inference library doesn't currently surface streaming usage data."); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs new file mode 100644 index 00000000000..fd4bd11a96f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -0,0 +1,536 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Azure.Core.Pipeline; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class AzureAIInferenceChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("chatCompletionsClient", () => new AzureAIInferenceChatClient(null!, "model")); + + ChatCompletionsClient client = new(new("http://somewhere"), new AzureKeyCredential("key")); + Assert.Throws("modelId", () => new AzureAIInferenceChatClient(client, " ")); + } + + [Fact] + public void AsChatClient_InvalidArgs_Throws() + { + Assert.Throws("chatCompletionsClient", () => ((ChatCompletionsClient)null!).AsChatClient("model")); + + ChatCompletionsClient client = new(new("http://somewhere"), new AzureKeyCredential("key")); + Assert.Throws("modelId", () => client.AsChatClient(" ")); + } + + [Fact] + public void AsChatClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + ChatCompletionsClient client = new(endpoint, new AzureKeyCredential("key")); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("AzureAIInference", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + ChatCompletionsClient client = new(new("http://localhost"), new AzureKeyCredential("key")); + IChatClient chatClient = client.AsChatClient("model"); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(chatClient, chatClient.GetService()); + + Assert.Same(client, chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(client, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + {"messages":[{"content":[{"text":"hello","type":"text"}],"role":"user"}],"max_tokens":10,"temperature":0.5,"model":"gpt-4o-mini"} + """; + + const string Output = """ + { + "id": "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + "object": "chat.completion", + "created": 1727888631, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 9, + "total_tokens": 17, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.CompletionId); + Assert.Equal("Hello! How can I assist you today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(8, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.OutputTokenCount); + Assert.Equal(17, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + {"messages":[{"content":[{"text":"hello","type":"text"}],"role":"user"}],"max_tokens":20,"temperature":0.5,"stream":true,"model":"gpt-4o-mini"} + """; + + const string Output = """ + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":8,"completion_tokens":9,"total_tokens":17,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_889_370); + Assert.Equal(12, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.Equal(i < 10 ? 1 : 0, updates[i].Contents.Count); + Assert.Equal(i < 10 ? null : ChatFinishReason.Stop, updates[i].FinishReason); + } + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "content": "You are a really nice friend.", + "role": "system" + }, + { + "content": [ + { + "text": "hello!", + "type": "text" + } + ], + "role": "user" + }, + { + "content": "hi, how are you?", + "role": "assistant" + }, + { + "content": [ + { + "text": "i\u0027m good. how are you?", + "type": "text" + } + ], + "role": "user" + } + ], + "temperature": 0.25, + "stop": [ + "great" + ], + "presence_penalty": 0.5, + "frequency_penalty": 0.75, + "model": "gpt-4o-mini", + "seed": 42 + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", + "object": "chat.completion", + "created": 1727894187, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I’m doing well, thank you! What’s on your mind today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 42, + "completion_tokens": 15, + "total_tokens": 57, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List messages = + [ + new(ChatRole.System, "You are a really nice friend."), + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42L }, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(42, response.Usage.InputTokenCount); + Assert.Equal(15, response.Usage.OutputTokenCount); + Assert.Equal(57, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "content": [ + { + "text": "How old is Alice?", + "type": "text" + } + ], + "role": "user" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "required": ["personName"], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + } + }, + "type": "function" + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADydKhrSKEBWJ8gy0KCIU74rN3Hmk", + "object": "chat.completion", + "created": 1727894702, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_8qbINM045wlmKZt9bVJgwAym", + "type": "function", + "function": { + "name": "GetPersonAge", + "arguments": "{\"personName\":\"Alice\"}" + } + } + ], + "refusal": null + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 61, + "completion_tokens": 16, + "total_tokens": 77, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); + Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(61, response.Usage.InputTokenCount); + Assert.Equal(16, response.Usage.OutputTokenCount); + Assert.Equal(77, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + [Fact] + public async Task FunctionCallContent_Streaming() + { + const string Input = """ + { + "messages": [ + { + "content": [ + { + "text": "How old is Alice?", + "type": "text" + } + ], + "role": "user" + } + ], + "stream": true, + "model": "gpt-4o-mini", + "tools": [ + { + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "required": ["personName"], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + } + }, + "type": "function" + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_F9ZaqPWo69u0urxAhVt8meDW","type":"function","function":{"name":"GetPersonAge","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"person"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Alice"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":61,"completion_tokens":16,"total_tokens":77,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + })) + { + updates.Add(update); + } + + Assert.Equal("", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_895_263); + Assert.Equal(10, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.Equal(i < 7 ? null : ChatFinishReason.ToolCalls, updates[i].FinishReason); + } + + FunctionCallContent fcc = Assert.IsType(Assert.Single(updates[updates.Count - 1].Contents)); + Assert.Equal("call_F9ZaqPWo69u0urxAhVt8meDW", fcc.CallId); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + private static IChatClient CreateChatClient(HttpClient httpClient, string modelId) => + new ChatCompletionsClient( + new("http://somewhere"), + new AzureKeyCredential("key"), + new ChatCompletionsClientOptions { Transport = new HttpClientTransport(httpClient) }) + .AsChatClient(modelId); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..4c4086e1157 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Azure; +using Azure.AI.Inference; + +namespace Microsoft.Extensions.AI; + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets an to use for testing, or null if the associated tests should be disabled. + public static ChatCompletionsClient? GetChatCompletionsClient() + { + string? apiKey = + Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_APIKEY") ?? + Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + + if (apiKey is not null) + { + string? endpoint = + Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_ENDPOINT") ?? + "https://api.openai.com/v1"; + + return new(new Uri(endpoint), new AzureKeyCredential(apiKey)); + } + + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj new file mode 100644 index 00000000000..d992413109b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj @@ -0,0 +1,22 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.AzureAIInference + + + + true + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs new file mode 100644 index 00000000000..f538d1476b0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +internal sealed class BinaryEmbedding : Embedding +{ + public BinaryEmbedding(ReadOnlyMemory bits) + { + Bits = bits; + } + + public ReadOnlyMemory Bits { get; } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs new file mode 100644 index 00000000000..c2aaa0d086d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CallCountingChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) +{ + private int _callCount; + + public int CallCount => _callCount; + + public override Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + + public override IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } +} + +internal static class CallCountingChatClientBuilderExtensions +{ + public static ChatClientBuilder UseCallCounting(this ChatClientBuilder builder) => + builder.Use(innerClient => new CallCountingChatClient(innerClient)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs new file mode 100644 index 00000000000..2930f94b6db --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CallCountingEmbeddingGenerator(IEmbeddingGenerator> innerGenerator) + : DelegatingEmbeddingGenerator>(innerGenerator) +{ + private int _callCount; + + public int CallCount => _callCount; + + public override Task>> GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.GenerateAsync(values, options, cancellationToken); + } +} + +internal static class CallCountingEmbeddingGeneratorBuilderExtensions +{ + public static EmbeddingGeneratorBuilder> UseCallCounting( + this EmbeddingGeneratorBuilder> builder) => + builder.Use(innerGenerator => new CallCountingEmbeddingGenerator(innerGenerator)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs new file mode 100644 index 00000000000..50257544430 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -0,0 +1,650 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.TestUtilities; +using OpenTelemetry.Trace; +using Xunit; + +#pragma warning disable CA2000 // Dispose objects before losing scope +#pragma warning disable CA2214 // Do not call overridable methods in constructors + +namespace Microsoft.Extensions.AI; + +public abstract class ChatClientIntegrationTests : IDisposable +{ + private readonly IChatClient? _chatClient; + + protected ChatClientIntegrationTests() + { + _chatClient = CreateChatClient(); + } + + public void Dispose() + { + _chatClient?.Dispose(); + GC.SuppressFinalize(this); + } + + protected abstract IChatClient? CreateChatClient(); + + [ConditionalFact] + public virtual async Task CompleteAsync_SingleRequestMessage() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync("What's the biggest animal?"); + + Assert.Contains("whale", response.Message.Text, StringComparison.OrdinalIgnoreCase); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_MultipleRequestMessages() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync( + [ + new(ChatRole.User, "Pick a city, any city"), + new(ChatRole.Assistant, "Seattle"), + new(ChatRole.User, "And another one"), + new(ChatRole.Assistant, "Jakarta"), + new(ChatRole.User, "What continent are they each in?"), + ]); + + Assert.Single(response.Choices); + Assert.Contains("America", response.Message.Text); + Assert.Contains("Asia", response.Message.Text); + } + + [ConditionalFact] + public virtual async Task CompleteStreamingAsync_SingleStreamingResponseChoice() + { + SkipIfNotEnabled(); + + IList chatHistory = + [ + new(ChatRole.User, "Quote, word for word, Neil Armstrong's famous words.") + ]; + + StringBuilder sb = new(); + await foreach (var chunk in _chatClient.CompleteStreamingAsync(chatHistory)) + { + sb.Append(chunk.Text); + } + + string responseText = sb.ToString(); + Assert.Contains("one small step", responseText, StringComparison.OrdinalIgnoreCase); + Assert.Contains("one giant leap", responseText, StringComparison.OrdinalIgnoreCase); + + // The input list is left unaugmented. + Assert.Single(chatHistory); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_UsageDataAvailable() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync("Explain in 10 words how AI works"); + + Assert.Single(response.Choices); + Assert.True(response.Usage?.InputTokenCount > 1); + Assert.True(response.Usage?.OutputTokenCount > 1); + Assert.Equal(response.Usage?.InputTokenCount + response.Usage?.OutputTokenCount, response.Usage?.TotalTokenCount); + } + + [ConditionalFact] + public virtual async Task CompleteStreamingAsync_UsageDataAvailable() + { + SkipIfNotEnabled(); + + var response = _chatClient.CompleteStreamingAsync("Explain in 10 words how AI works"); + + List chunks = []; + await foreach (var chunk in response) + { + chunks.Add(chunk); + } + + Assert.True(chunks.Count > 1); + + UsageContent usage = chunks.SelectMany(c => c.Contents).OfType().Single(); + Assert.True(usage.Details.InputTokenCount > 1); + Assert.True(usage.Details.OutputTokenCount > 1); + Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + int secretNumber = 42; + + var response = await chatClient.CompleteAsync("What is the current secret number?", new() + { + Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] + }); + + Assert.Single(response.Choices); + Assert.Contains(secretNumber.ToString(), response.Message.Text); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_NonStreaming() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = await chatClient.CompleteAsync("What is the result of SecretComputation on 42 and 84?", new() + { + Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] + }); + + Assert.Single(response.Choices); + Assert.Contains("3528", response.Message.Text); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_Streaming() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = chatClient.CompleteStreamingAsync("What is the result of SecretComputation on 42 and 84?", new() + { + Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] + }); + + StringBuilder sb = new(); + await foreach (var chunk in response) + { + sb.Append(chunk.Text); + } + + Assert.Contains("3528", sb.ToString()); + } + + protected virtual bool SupportsParallelFunctionCalling => true; + + [ConditionalFact] + public virtual async Task FunctionInvocation_SupportsMultipleParallelRequests() + { + SkipIfNotEnabled(); + if (!SupportsParallelFunctionCalling) + { + throw new SkipTestException("Parallel function calling is not supported by this chat client"); + } + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + // The service/model isn't guaranteed to request two calls to GetPersonAge in the same turn, but it's common that it will. + var response = await chatClient.CompleteAsync("How much older is Elsa than Anna? Return the age difference as a single number.", new() + { + Tools = [AIFunctionFactory.Create((string personName) => + { + return personName switch + { + "Elsa" => 21, + "Anna" => 18, + _ => 30, + }; + }, "GetPersonAge")] + }); + + Assert.True( + Regex.IsMatch(response.Message.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), + $"Doesn't contain three: {response.Message.Text}"); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_RequireAny() + { + SkipIfNotEnabled(); + + int callCount = 0; + var tool = AIFunctionFactory.Create(() => + { + callCount++; + return 123; + }, "GetSecretNumber"); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = await chatClient.CompleteAsync("Are birds real?", new() + { + Tools = [tool], + ToolMode = ChatToolMode.RequireAny, + }); + + Assert.Single(response.Choices); + Assert.True(callCount >= 1); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_RequireSpecific() + { + SkipIfNotEnabled(); + + bool shieldsUp = false; + var getSecretNumberTool = AIFunctionFactory.Create(() => 123, "GetSecretNumber"); + var shieldsUpTool = AIFunctionFactory.Create(() => shieldsUp = true, "ShieldsUp"); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + // Even though the user doesn't ask for the shields to be activated, verify that the tool is invoked + var response = await chatClient.CompleteAsync("What's the current secret number?", new() + { + Tools = [getSecretNumberTool, shieldsUpTool], + ToolMode = ChatToolMode.RequireSpecific(shieldsUpTool.Metadata.Name), + }); + + Assert.True(shieldsUp); + } + + [ConditionalFact] + public virtual async Task Caching_OutputVariesWithoutCaching() + { + SkipIfNotEnabled(); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + var firstResponse = await _chatClient.CompleteAsync([message]); + Assert.Single(firstResponse.Choices); + + var secondResponse = await _chatClient.CompleteAsync([message]); + Assert.NotEqual(firstResponse.Message.Text, secondResponse.Message.Text); + } + + [ConditionalFact] + public virtual async Task Caching_SamePromptResultsInCacheHit_NonStreaming() + { + SkipIfNotEnabled(); + + using var chatClient = new DistributedCachingChatClient( + _chatClient, + new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + var firstResponse = await chatClient.CompleteAsync([message]); + Assert.Single(firstResponse.Choices); + + // No matter what it said before, we should see identical output due to caching + for (int i = 0; i < 3; i++) + { + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(firstResponse.Message.Text, secondResponse.Message.Text); + } + + // ... but if the conversation differs, we should see different output + message.Text += "!"; + var thirdResponse = await chatClient.CompleteAsync([message]); + Assert.NotEqual(firstResponse.Message.Text, thirdResponse.Message.Text); + } + + [ConditionalFact] + public virtual async Task Caching_SamePromptResultsInCacheHit_Streaming() + { + SkipIfNotEnabled(); + + using var chatClient = new DistributedCachingChatClient( + _chatClient, + new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + StringBuilder orig = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + orig.Append(update.Text); + } + + // No matter what it said before, we should see identical output due to caching + for (int i = 0; i < 3; i++) + { + StringBuilder second = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + second.Append(update.Text); + } + + Assert.Equal(orig.ToString(), second.ToString()); + } + + // ... but if the conversation differs, we should see different output + message.Text += "!"; + StringBuilder third = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + third.Append(update.Text); + } + + Assert.NotEqual(orig.ToString(), third.ToString()); + } + + [ConditionalFact] + public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls() + { + SkipIfNotEnabled(); + + int functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return $"{100 + functionCallCount} degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseFunctionInvocation() + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("101", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(2, llmCallCount!.CallCount); + + // Second call doesn't execute the function or call the LLM, but rather just returns the cached result + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() + { + SkipIfNotEnabled(); + + // This means that if the function call produces the same result, we can avoid calling the LLM + // whereas if the function call produces a different result, we do call the LLM + + var functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return "58 degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseFunctionInvocation() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("58", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + + // Second time, the calls to the LLM don't happen, but the function is called again + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(2, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync() + { + SkipIfNotEnabled(); + + // This means that if the function call produces the same result, we can avoid calling the LLM + // whereas if the function call produces a different result, we do call the LLM + + var functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return $"{80 + functionCallCount} degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseFunctionInvocation() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("81", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + + // Second time, the first call to the LLM don't happen, but the function is called again, + // and since its output now differs, we no longer hit the cache so the second LLM call does happen + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Contains("82", secondResponse.Message.Text); + Assert.Equal(2, functionCallCount); + Assert.Equal(3, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Logging_LogsCalls_NonStreaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new LoggingChatClient(CreateChatClient()!, logger); + + await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]); + + Assert.Collection(logger.Entries, + entry => Assert.Contains("What\\u0027s the biggest animal?", entry.Message), + entry => Assert.Contains("whale", entry.Message)); + } + + [ConditionalFact] + public virtual async Task Logging_LogsCalls_Streaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new LoggingChatClient(CreateChatClient()!, logger); + + await foreach (var update in chatClient.CompleteStreamingAsync("What's the biggest animal?")) + { + // Do nothing with the updates + } + + Assert.Contains(logger.Entries, e => e.Message.Contains("What\\u0027s the biggest animal?")); + Assert.Contains(logger.Entries, e => e.Message.Contains("whale")); + } + + [ConditionalFact] + public virtual async Task Logging_LogsFunctionCalls_NonStreaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new FunctionInvokingChatClient( + new LoggingChatClient(CreateChatClient()!, logger)); + + int secretNumber = 42; + await chatClient.CompleteAsync( + "What is the current secret number?", + new ChatOptions { Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] }); + + Assert.Collection(logger.Entries, + entry => Assert.Contains("What is the current secret number?", entry.Message), + entry => Assert.Contains("\"name\":\"GetSecretNumber\"", entry.Message), + entry => Assert.Contains($"\"result\":{secretNumber}", entry.Message), + entry => Assert.Contains(secretNumber.ToString(), entry.Message)); + } + + [ConditionalFact] + public virtual async Task Logging_LogsFunctionCalls_Streaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new FunctionInvokingChatClient( + new LoggingChatClient(CreateChatClient()!, logger)); + + int secretNumber = 42; + await foreach (var update in chatClient.CompleteStreamingAsync( + "What is the current secret number?", + new ChatOptions { Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] })) + { + // Do nothing with the updates + } + + Assert.Contains(logger.Entries, e => e.Message.Contains("What is the current secret number?")); + Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\":\"GetSecretNumber\"")); + Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\":{secretNumber}")); + } + + [ConditionalFact] + public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics() + { + SkipIfNotEnabled(); + + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => { instance.EnableSensitiveData = true; }) + .Use(CreateChatClient()!); + + var response = await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]); + + var activity = Assert.Single(activities); + Assert.StartsWith("chat.completions", activity.DisplayName); + Assert.StartsWith("http", (string)activity.GetTagItem("server.address")!); + Assert.Equal(chatClient.Metadata.ProviderUri?.Port, (int)activity.GetTagItem("server.port")!); + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.input_tokens")!); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.output_tokens")!); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutput() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + Who is described in the following sentence? + Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + """); + + Assert.Equal("Jimbo Smith", response.Result.FullName); + Assert.Equal(35, response.Result.AgeInYears); + Assert.Contains("Cardiff", response.Result.HomeTown); + Assert.Equal(JobType.Programmer, response.Result.Job); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutput_WithFunctions() + { + SkipIfNotEnabled(); + + var expectedPerson = new Person + { + FullName = "Jimbo Smith", + AgeInYears = 35, + HomeTown = "Cardiff", + Job = JobType.Programmer, + }; + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + var response = await chatClient.CompleteAsync( + "Who is person with ID 123?", new ChatOptions + { + Tools = [AIFunctionFactory.Create((int personId) => + { + Assert.Equal(123, personId); + return expectedPerson; + }, "GetPersonById")] + }); + + Assert.NotSame(expectedPerson, response.Result); + Assert.Equal(expectedPerson.FullName, response.Result.FullName); + Assert.Equal(expectedPerson.AgeInYears, response.Result.AgeInYears); + Assert.Equal(expectedPerson.HomeTown, response.Result.HomeTown); + Assert.Equal(expectedPerson.Job, response.Result.Job); + } + + private class Person + { +#pragma warning disable S1144, S3459 // Unassigned members should be removed + public string? FullName { get; set; } + public int AgeInYears { get; set; } + public string? HomeTown { get; set; } + public JobType Job { get; set; } +#pragma warning restore S1144, S3459 // Unused private types or members should be removed + } + + private enum JobType + { + Surgeon, + PopStar, + Programmer, + Unknown, + } + + [MemberNotNull(nameof(_chatClient))] + protected void SkipIfNotEnabled() + { + if (_chatClient is null) + { + throw new SkipTestException("Client is not enabled."); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..252427836e8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +#if NET +using System.Numerics.Tensors; +#endif +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.TestUtilities; +using OpenTelemetry.Trace; +using Xunit; + +#pragma warning disable CA2214 // Do not call overridable methods in constructors +#pragma warning disable S3967 // Multidimensional arrays should not be used + +namespace Microsoft.Extensions.AI; + +public abstract class EmbeddingGeneratorIntegrationTests : IDisposable +{ + private readonly IEmbeddingGenerator>? _embeddingGenerator; + + protected EmbeddingGeneratorIntegrationTests() + { + _embeddingGenerator = CreateEmbeddingGenerator(); + } + + public void Dispose() + { + _embeddingGenerator?.Dispose(); + GC.SuppressFinalize(this); + } + + protected abstract IEmbeddingGenerator>? CreateEmbeddingGenerator(); + + [ConditionalFact] + public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully() + { + SkipIfNotEnabled(); + + var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET"); + + Assert.NotNull(embeddings.Usage); + Assert.NotNull(embeddings.Usage.InputTokenCount); + Assert.NotNull(embeddings.Usage.TotalTokenCount); + Assert.Single(embeddings); + Assert.Equal(_embeddingGenerator.Metadata.ModelId, embeddings[0].ModelId); + Assert.NotEmpty(embeddings[0].Vector.ToArray()); + } + + [ConditionalFact] + public virtual async Task GenerateEmbeddings_CreatesEmbeddingsSuccessfully() + { + SkipIfNotEnabled(); + + var embeddings = await _embeddingGenerator.GenerateAsync([ + "Red", + "White", + "Blue", + ]); + + Assert.Equal(3, embeddings.Count); + Assert.NotNull(embeddings.Usage); + Assert.NotNull(embeddings.Usage.InputTokenCount); + Assert.NotNull(embeddings.Usage.TotalTokenCount); + Assert.All(embeddings, embedding => + { + Assert.Equal(_embeddingGenerator.Metadata.ModelId, embedding.ModelId); + Assert.NotEmpty(embedding.Vector.ToArray()); + }); + } + + [ConditionalFact] + public virtual async Task Caching_SameOutputsForSameInput() + { + SkipIfNotEnabled(); + + using var generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateEmbeddingGenerator()!); + + string input = "Red, White, and Blue"; + var embedding1 = await generator.GenerateAsync(input); + var embedding2 = await generator.GenerateAsync(input); + var embedding3 = await generator.GenerateAsync(input + "... and Green"); + var embedding4 = await generator.GenerateAsync(input); + + var callCounter = generator.GetService(); + Assert.NotNull(callCounter); + + Assert.Equal(2, callCounter.CallCount); + } + + [ConditionalFact] + public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics() + { + SkipIfNotEnabled(); + + string sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + var embeddingGenerator = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry(sourceName) + .Use(CreateEmbeddingGenerator()!); + + _ = await embeddingGenerator.GenerateAsync("Hello, world!"); + + Assert.Single(activities); + var activity = activities.Single(); + Assert.StartsWith("embedding", activity.DisplayName); + Assert.StartsWith("http", (string)activity.GetTagItem("server.address")!); + Assert.Equal(embeddingGenerator.Metadata.ProviderUri?.Port, (int)activity.GetTagItem("server.port")!); + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.input_tokens")!); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + +#if NET + [ConditionalFact] + public async Task Quantization_Binary_EmbeddingsCompareSuccessfully() + { + SkipIfNotEnabled(); + + using IEmbeddingGenerator generator = + new QuantizationEmbeddingGenerator( + CreateEmbeddingGenerator()!); + + var embeddings = await generator.GenerateAsync(["dog", "cat", "fork", "spoon"]); + Assert.Equal(4, embeddings.Count); + + long[,] distances = new long[embeddings.Count, embeddings.Count]; + for (int i = 0; i < embeddings.Count; i++) + { + for (int j = 0; j < embeddings.Count; j++) + { + distances[i, j] = TensorPrimitives.HammingBitDistance(embeddings[i].Bits.Span, embeddings[j].Bits.Span); + } + } + + for (int i = 0; i < embeddings.Count; i++) + { + Assert.Equal(0, distances[i, i]); + } + + Assert.True(distances[0, 1] < distances[0, 2]); + Assert.True(distances[0, 1] < distances[0, 3]); + Assert.True(distances[0, 1] < distances[1, 2]); + Assert.True(distances[0, 1] < distances[1, 3]); + + Assert.True(distances[2, 3] < distances[0, 2]); + Assert.True(distances[2, 3] < distances[0, 3]); + Assert.True(distances[2, 3] < distances[1, 2]); + Assert.True(distances[2, 3] < distances[1, 3]); + } + + [ConditionalFact] + public async Task Quantization_Half_EmbeddingsCompareSuccessfully() + { + SkipIfNotEnabled(); + + using IEmbeddingGenerator> generator = + new QuantizationEmbeddingGenerator( + CreateEmbeddingGenerator()!); + + var embeddings = await generator.GenerateAsync(["dog", "cat", "fork", "spoon"]); + Assert.Equal(4, embeddings.Count); + + var distances = new Half[embeddings.Count, embeddings.Count]; + for (int i = 0; i < embeddings.Count; i++) + { + for (int j = 0; j < embeddings.Count; j++) + { + distances[i, j] = TensorPrimitives.CosineSimilarity(embeddings[i].Vector.Span, embeddings[j].Vector.Span); + } + } + + for (int i = 0; i < embeddings.Count; i++) + { + Assert.Equal(1.0, (double)distances[i, i], 0.001); + } + + Assert.True(distances[0, 1] > distances[0, 2]); + Assert.True(distances[0, 1] > distances[0, 3]); + Assert.True(distances[0, 1] > distances[1, 2]); + Assert.True(distances[0, 1] > distances[1, 3]); + + Assert.True(distances[2, 3] > distances[0, 2]); + Assert.True(distances[2, 3] > distances[0, 3]); + Assert.True(distances[2, 3] > distances[1, 2]); + Assert.True(distances[2, 3] > distances[1, 3]); + } +#endif + + [MemberNotNull(nameof(_embeddingGenerator))] + protected void SkipIfNotEnabled() + { + if (_embeddingGenerator is null) + { + throw new SkipTestException("Generator is not enabled."); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj new file mode 100644 index 00000000000..e38ccd3268b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj @@ -0,0 +1,37 @@ + + + Microsoft.Extensions.AI + Opt-in integration tests for Microsoft.Extensions.AI. + + + + $(NoWarn);CA1063;CA1861;SA1130;VSTHRD003 + true + + + + true + true + true + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs new file mode 100644 index 00000000000..150c984ff86 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -0,0 +1,228 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable S1144 // Unused private types or members should be removed +#pragma warning disable S3459 // Unassigned members should be removed + +namespace Microsoft.Extensions.AI; + +// This isn't a feature we're planning to ship, but demonstrates how custom clients can +// layer in non-trivial functionality. In this case we're able to upgrade non-function-calling models +// to behaving as if they do support function calling. +// +// In practice: +// - For llama3:8b or mistral:7b, this works fairly reliably, at least when it only needs to +// make a single function call with a constrained set of args. +// - For smaller models like phi3:mini, it works only on a more occasional basis (e.g., if there's +// only one function defined, and it takes no arguments, but is very hit-and-miss beyond that). + +internal sealed class PromptBasedFunctionCallingChatClient(IChatClient innerClient) + : DelegatingChatClient(innerClient) +{ + private const string MessageIntro = "You are an AI model with function calling capabilities. Call one or more functions if they are relevant to the user's query."; + + private static readonly JsonSerializerOptions _jsonOptions = new(JsonSerializerDefaults.Web) + { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + // Our goal is to convert tools into a prompt describing them, then to detect tool calls in the + // response and convert those into FunctionCallContent. + if (options?.Tools is { Count: > 0 }) + { + AddOrUpdateToolPrompt(chatMessages, options.Tools); + options = options.Clone(); + options.Tools = null; + + options.StopSequences ??= []; + if (!options.StopSequences.Contains("")) + { + options.StopSequences.Add(""); + } + + // Since the point of this client is to avoid relying on the underlying model having + // native tool call support, we have to replace any "tool" or "toolcall" messages with + // "user" or "assistant" ones. + foreach (var message in chatMessages) + { + for (var itemIndex = 0; itemIndex < message.Contents.Count; itemIndex++) + { + if (message.Contents[itemIndex] is FunctionResultContent frc) + { + var toolCallResultJson = JsonSerializer.Serialize(new ToolCallResult { Id = frc.CallId, Result = frc.Result }, _jsonOptions); + message.Role = ChatRole.User; + message.Contents[itemIndex] = new TextContent( + $"{toolCallResultJson}"); + } + else if (message.Contents[itemIndex] is FunctionCallContent fcc) + { + var toolCallJson = JsonSerializer.Serialize(new { fcc.CallId, fcc.Name, fcc.Arguments }, _jsonOptions); + message.Role = ChatRole.Assistant; + message.Contents[itemIndex] = new TextContent( + $"{toolCallJson}"); + } + } + } + } + + var result = await base.CompleteAsync(chatMessages, options, cancellationToken); + + if (result.Choices.FirstOrDefault()?.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos + && startPos >= 0) + { + var message = result.Choices.First(); + var contentItem = message.Contents.SingleOrDefault(); + content = content.Substring(startPos); + + foreach (var toolCallJson in content.Split([""], StringSplitOptions.None)) + { + var toolCall = toolCallJson.Trim(); + if (toolCall.Length == 0) + { + continue; + } + + var endPos = toolCall.IndexOf(" 0) + { + toolCall = toolCall.Substring(0, endPos); + try + { + var toolCallParsed = JsonSerializer.Deserialize(toolCall, _jsonOptions); + if (!string.IsNullOrEmpty(toolCallParsed?.Name)) + { + if (toolCallParsed!.Arguments is not null) + { + ParseArguments(toolCallParsed.Arguments); + } + + var id = Guid.NewGuid().ToString().Substring(0, 6); + message.Contents.Add(new FunctionCallContent(id, toolCallParsed.Name!, toolCallParsed.Arguments is { } args ? new ReadOnlyDictionary(args) : null)); + + if (contentItem is not null) + { + message.Contents.Remove(contentItem); + } + } + } + catch (JsonException) + { + // Ignore invalid tool calls + } + } + } + } + + return result; + } + + private static void ParseArguments(IDictionary arguments) + { + // This is a simple implementation. A more robust answer is to use other schema information given by + // the AIFunction here, as for example is done in OpenAIChatClient. + foreach (var kvp in arguments.ToArray()) + { + if (kvp.Value is JsonElement jsonElement) + { + arguments[kvp.Key] = jsonElement.ValueKind switch + { + JsonValueKind.String => jsonElement.GetString(), + JsonValueKind.Number => jsonElement.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => jsonElement.ToString() + }; + } + } + } + + private static void AddOrUpdateToolPrompt(IList chatMessages, IList tools) + { + var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text?.StartsWith(MessageIntro, StringComparison.Ordinal) is true); + if (existingToolPrompt is null) + { + existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null); + chatMessages.Insert(0, existingToolPrompt); + } + + var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType().Select(ToToolDescriptor), _jsonOptions); + existingToolPrompt.Text = $$""" + {{MessageIntro}} + + For each function call, return a JSON object with the function name and arguments within XML tags + as follows: + + {"name": "tool_name", "arguments": { argname1: argval1, argname2: argval2, ... } } + + Note that the contents of MUST be a valid JSON object, with no other text. + + Once you receive the result as a JSON object within XML tags, use it to + answer the user's question without repeating the same tool call. + + Here are the available tools: + {{toolDescriptorsJson}} + """; + } + + private static ToolDescriptor ToToolDescriptor(AIFunction tool) => new() + { + Name = tool.Metadata.Name, + Description = tool.Metadata.Description, + Arguments = tool.Metadata.Parameters.ToDictionary( + p => p.Name, + p => new ToolParameterDescriptor + { + Type = p.ParameterType?.Name, + Description = p.Description, + Enum = p.ParameterType?.IsEnum == true ? Enum.GetNames(p.ParameterType) : null, + Required = p.IsRequired, + }), + }; + + private sealed class ToolDescriptor + { + public string? Name { get; set; } + public string? Description { get; set; } + public IDictionary? Arguments { get; set; } + } + + private sealed class ToolParameterDescriptor + { + public string? Type { get; set; } + public string? Description { get; set; } + public bool? Required { get; set; } + public string[]? Enum { get; set; } + } + + private sealed class ToolCall + { + public string? Id { get; set; } + public string? Name { get; set; } + public IDictionary? Arguments { get; set; } + } + + private sealed class ToolCallResult + { + public string? Id { get; set; } + public object? Result { get; set; } + } +} + +public static class PromptBasedFunctionCallingChatClientExtensions +{ + public static ChatClientBuilder UsePromptBasedFunctionCalling(this ChatClientBuilder builder) + => builder.Use(innerClient => new PromptBasedFunctionCallingChatClient(innerClient)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs new file mode 100644 index 00000000000..90032f16434 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +#if NET +using System.Numerics.Tensors; +#endif +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +internal sealed class QuantizationEmbeddingGenerator : + IEmbeddingGenerator +#if NET + , IEmbeddingGenerator> +#endif +{ + private readonly IEmbeddingGenerator> _floatService; + + public QuantizationEmbeddingGenerator(IEmbeddingGenerator> floatService) + { + _floatService = floatService; + } + + public EmbeddingGeneratorMetadata Metadata => _floatService.Metadata; + + void IDisposable.Dispose() => _floatService.Dispose(); + + public TService? GetService(object? key = null) + where TService : class => + key is null && this is TService ? (TService?)(object)this : + _floatService.GetService(key); + + async Task> IEmbeddingGenerator.GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + var embeddings = await _floatService.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + return new(from e in embeddings select QuantizeToBinary(e)) + { + Usage = embeddings.Usage, + AdditionalProperties = embeddings.AdditionalProperties, + }; + } + + private static BinaryEmbedding QuantizeToBinary(Embedding embedding) + { + ReadOnlySpan vector = embedding.Vector.Span; + + var result = new byte[(int)Math.Ceiling(vector.Length / 8.0)]; + for (int i = 0; i < vector.Length; i++) + { + if (vector[i] > 0) + { + result[i / 8] |= (byte)(1 << (i % 8)); + } + } + + return new(result) + { + CreatedAt = embedding.CreatedAt, + ModelId = embedding.ModelId, + AdditionalProperties = embedding.AdditionalProperties, + }; + } + +#if NET + async Task>> IEmbeddingGenerator>.GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + var embeddings = await _floatService.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + return new(from e in embeddings select QuantizeToHalf(e)) + { + Usage = embeddings.Usage, + AdditionalProperties = embeddings.AdditionalProperties, + }; + } + + private static Embedding QuantizeToHalf(Embedding embedding) + { + ReadOnlySpan vector = embedding.Vector.Span; + var result = new Half[vector.Length]; + TensorPrimitives.ConvertToHalf(vector, result); + return new(result) + { + CreatedAt = embedding.CreatedAt, + ModelId = embedding.ModelId, + AdditionalProperties = embedding.AdditionalProperties, + }; + } +#endif +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs new file mode 100644 index 00000000000..0c436f7ccb5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs @@ -0,0 +1,201 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; +using Microsoft.Shared.Diagnostics; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + +namespace Microsoft.Extensions.AI; + +/// Provides an example of a custom for reducing chat message lists. +public class ReducingChatClientTests +{ + private static readonly Tokenizer _gpt4oTokenizer = TiktokenTokenizer.CreateForModel("gpt-4o"); + + [Fact] + public async Task Reduction_LimitsMessagesBasedOnTokenLimit() + { + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Equal(2, messages.Count); + Assert.Collection(messages, + m => Assert.StartsWith("Golden retrievers are quite active", m.Text, StringComparison.Ordinal), + m => Assert.StartsWith("Are they good with kids?", m.Text, StringComparison.Ordinal)); + return Task.FromResult(new ChatCompletion([])); + } + }; + + using var client = new ChatClientBuilder() + .UseChatReducer(new TokenCountingChatReducer(_gpt4oTokenizer, 40)) + .Use(innerClient); + + List messages = + [ + new ChatMessage(ChatRole.User, "Hi there! Can you tell me about golden retrievers?"), + new ChatMessage(ChatRole.Assistant, "Of course! Golden retrievers are known for their friendly and tolerant attitudes. They're great family pets and are very intelligent and easy to train."), + new ChatMessage(ChatRole.User, "What kind of exercise do they need?"), + new ChatMessage(ChatRole.Assistant, "Golden retrievers are quite active and need regular exercise. Daily walks, playtime, and activities like fetching or swimming are great for them."), + new ChatMessage(ChatRole.User, "Are they good with kids?"), + ]; + + await client.CompleteAsync(messages); + + Assert.Equal(5, messages.Count); + } +} + +/// Provides an example of a chat client for reducing the size of a message list. +public sealed class ReducingChatClient : DelegatingChatClient +{ + private readonly IChatReducer _reducer; + private readonly bool _inPlace; + + /// Initializes a new instance of the class. + /// The inner client. + /// The reducer to be used by this instance. + /// + /// true if the should perform any modifications directly on the supplied list of messages; + /// false if it should instead create a new list when reduction is necessary. + /// + public ReducingChatClient(IChatClient innerClient, IChatReducer reducer, bool inPlace = false) + : base(innerClient) + { + _reducer = Throw.IfNull(reducer); + _inPlace = inPlace; + } + + /// + public override async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + + return await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + /// Runs the reducer and gets the chat message list to forward to the inner client. + private async Task> GetChatMessagesToPropagate(IList chatMessages, CancellationToken cancellationToken) => + await _reducer.ReduceAsync(chatMessages, _inPlace, cancellationToken).ConfigureAwait(false) ?? + chatMessages; +} + +/// Represents a reducer capable of shrinking the size of a list of chat messages. +public interface IChatReducer +{ + /// Reduces the size of a list of chat messages. + /// The messages. + /// true if the reducer should modify the provided list; false if a new list should be returned. + /// The to monitor for cancellation requests. The default is . + /// The new list of messages, or null if no reduction need be performed or was true. + Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken); +} + +/// Provides extensions for configuring instances. +public static class ReducingChatClientExtensions +{ + public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer, bool inPlace = false) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(reducer); + + return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer, inPlace)); + } +} + +/// An that culls the oldest messages once a certain token threshold is reached. +public sealed class TokenCountingChatReducer : IChatReducer +{ + private readonly Tokenizer _tokenizer; + private readonly int _tokenLimit; + + public TokenCountingChatReducer(Tokenizer tokenizer, int tokenLimit) + { + _tokenizer = Throw.IfNull(tokenizer); + _tokenLimit = Throw.IfLessThan(tokenLimit, 1); + } + + public async Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken) + { + _ = Throw.IfNull(chatMessages); + + if (chatMessages.Count > 1) + { + int totalCount = CountTokens(chatMessages[chatMessages.Count - 1]); + + if (inPlace) + { + for (int i = chatMessages.Count - 2; i >= 0; i--) + { + totalCount += CountTokens(chatMessages[i]); + if (totalCount > _tokenLimit) + { + if (chatMessages is List list) + { + list.RemoveRange(0, i + 1); + } + else + { + for (int j = i; j >= 0; j--) + { + chatMessages.RemoveAt(j); + } + } + + break; + } + } + } + else + { + for (int i = chatMessages.Count - 2; i >= 0; i--) + { + totalCount += CountTokens(chatMessages[i]); + if (totalCount > _tokenLimit) + { + return chatMessages.Skip(i + 1).ToList(); + } + } + } + } + + return null; + } + + private int CountTokens(ChatMessage message) + { + int sum = 0; + foreach (AIContent content in message.Contents) + { + if ((content as TextContent)?.Text is string text) + { + sum += _tokenizer.CountTokens(text); + } + } + + return sum; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs new file mode 100644 index 00000000000..14ba68feb7a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +/// +/// An that checks the request body against an expected one +/// and sends back an expected response. +/// +public sealed class VerbatimHttpHandler(string expectedInput, string sentOutput) : HttpMessageHandler +{ + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Assert.NotNull(request.Content); + + string? input = await request.Content +#if NET + .ReadAsStringAsync(cancellationToken).ConfigureAwait(false); +#else + .ReadAsStringAsync().ConfigureAwait(false); +#endif + + Assert.NotNull(input); + Assert.Equal(RemoveWhiteSpace(expectedInput), RemoveWhiteSpace(input)); + + return new() { Content = new StringContent(sentOutput) }; + } + + public static string? RemoveWhiteSpace(string? text) => + text is null ? null : + Regex.Replace(text, @"\s*", string.Empty); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..d25d750ce37 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +#pragma warning disable S125 // Sections of code should not be commented out + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets a to use for testing, or null if the associated tests should be disabled. + public static Uri? GetOllamaUri() + { + // return new Uri("http://localhost:11434"); + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj new file mode 100644 index 00000000000..5db789e3b6b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj @@ -0,0 +1,22 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.Ollama + + + + true + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs new file mode 100644 index 00000000000..891378c0e86 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.TestUtilities; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ? + new OllamaChatClient(endpoint, "llama3.1") : + null; + + public override Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_Streaming() => + throw new SkipTestException("Ollama does not currently support function invocation with streaming."); + + public override Task Logging_LogsFunctionCalls_Streaming() => + throw new SkipTestException("Ollama does not currently support function invocation with streaming."); + + public override Task FunctionInvocation_RequireAny() => + throw new SkipTestException("Ollama does not currently support requiring function invocation."); + + public override Task FunctionInvocation_RequireSpecific() => + throw new SkipTestException("Ollama does not currently support requiring function invocation."); + + [ConditionalFact] + public async Task PromptBasedFunctionCalling_NoArgs() + { + SkipIfNotEnabled(); + + using var chatClient = new ChatClientBuilder() + .UseFunctionInvocation() + .UsePromptBasedFunctionCalling() + .Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient)) + .Use(CreateChatClient()!); + + var secretNumber = 42; + var response = await chatClient.CompleteAsync("What is the current secret number? Answer with digits only.", new ChatOptions + { + ModelId = "llama3:8b", + Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")], + Temperature = 0, + AdditionalProperties = new() { ["seed"] = 0L }, + }); + + Assert.Single(response.Choices); + Assert.Contains(secretNumber.ToString(), response.Message.Text); + } + + [ConditionalFact] + public async Task PromptBasedFunctionCalling_WithArgs() + { + SkipIfNotEnabled(); + + using var chatClient = new ChatClientBuilder() + .UseFunctionInvocation() + .UsePromptBasedFunctionCalling() + .Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient)) + .Use(CreateChatClient()!); + + var stockPriceTool = AIFunctionFactory.Create([Description("Returns the stock price for a given ticker symbol")] ( + [Description("The ticker symbol")] string symbol, + [Description("The currency code such as USD or JPY")] string currency) => + { + Assert.Equal("MSFT", symbol); + Assert.Equal("GBP", currency); + return 999; + }, "GetStockPrice"); + + var didCallIrrelevantTool = false; + var irrelevantTool = AIFunctionFactory.Create(() => { didCallIrrelevantTool = true; return 123; }, "GetSecretNumber"); + + var response = await chatClient.CompleteAsync("What's the stock price for Microsoft in British pounds?", new ChatOptions + { + Tools = [stockPriceTool, irrelevantTool], + Temperature = 0, + AdditionalProperties = new() { ["seed"] = 0L }, + }); + + Assert.Single(response.Choices); + Assert.Contains("999", response.Message.Text); + Assert.False(didCallIrrelevantTool); + } + + private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) + { + public override Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Assert.Null(options?.Tools); + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs new file mode 100644 index 00000000000..b09947337ed --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -0,0 +1,464 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OllamaChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("endpoint", () => new OllamaChatClient(null!)); + Assert.Throws("modelId", () => new OllamaChatClient(new("http://localhost"), " ")); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + using OllamaChatClient client = new(new("http://localhost")); + + Assert.Same(client, client.GetService()); + Assert.Same(client, client.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(client); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(client, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public void AsChatClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + using IChatClient chatClient = new OllamaChatClient(endpoint, model); + Assert.Equal("ollama", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + { + "model":"llama3.1", + "messages":[{"role":"user","content":"hello"}], + "stream":false, + "options":{"num_predict":10,"temperature":0.5} + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T15:46:10.5248793Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today? Is there something" + }, + "done_reason": "length", + "done": true, + "total_duration": 22186844400, + "load_duration": 17947219100, + "prompt_eval_count": 11, + "prompt_eval_duration": 1953805000, + "eval_count": 10, + "eval_duration": 2277274000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using OllamaChatClient client = new(new("http://localhost:11434"), "llama3.1", httpClient); + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("Hello! How are you today? Is there something", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T15:46:10.5248793Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Length, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(11, response.Usage.InputTokenCount); + Assert.Equal(10, response.Usage.OutputTokenCount); + Assert.Equal(21, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + { + "model":"llama3.1", + "messages":[{"role":"user","content":"hello"}], + "stream":true, + "options":{"num_predict":20,"temperature":0.5} + } + """; + + const string Output = """ + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.4965315Z","message":{"role":"assistant","content":"Hello"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.763058Z","message":{"role":"assistant","content":"!"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.9751134Z","message":{"role":"assistant","content":" How"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.1788125Z","message":{"role":"assistant","content":" are"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.3883171Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.5912498Z","message":{"role":"assistant","content":" today"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.7968039Z","message":{"role":"assistant","content":"?"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.0034152Z","message":{"role":"assistant","content":" Is"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.1931196Z","message":{"role":"assistant","content":" there"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.3827484Z","message":{"role":"assistant","content":" something"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.5659027Z","message":{"role":"assistant","content":" I"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.7488871Z","message":{"role":"assistant","content":" can"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.9339881Z","message":{"role":"assistant","content":" help"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.1201564Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.303447Z","message":{"role":"assistant","content":" with"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.4964909Z","message":{"role":"assistant","content":" or"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.6837816Z","message":{"role":"assistant","content":" would"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.8723142Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.064613Z","message":{"role":"assistant","content":" like"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.2504498Z","message":{"role":"assistant","content":" to"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.2514508Z","message":{"role":"assistant","content":""},"done_reason":"length", "done":true,"total_duration":11912402900,"load_duration":6824559200,"prompt_eval_count":11,"prompt_eval_duration":1329601000,"eval_count":20,"eval_duration":3754262000} + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal(21, updates.Count); + + DateTimeOffset[] createdAts = Regex.Matches(Output, @"2024.*?Z").Cast().Select(m => DateTimeOffset.Parse(m.Value)).ToArray(); + + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal(i < updates.Count - 1 ? 1 : 2, updates[i].Contents.Count); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.All(updates[i].Contents, u => Assert.Equal("llama3.1", u.ModelId)); + Assert.Equal(createdAts[i], updates[i].CreatedAt); + Assert.Equal(i < updates.Count - 1 ? null : ChatFinishReason.Length, updates[i].FinishReason); + } + + Assert.Equal("Hello! How are you today? Is there something I can help you with or would you like to", string.Concat(updates.Select(u => u.Text))); + Assert.Equal(2, updates[updates.Count - 1].Contents.Count); + Assert.IsType(updates[updates.Count - 1].Contents[0]); + UsageContent usage = Assert.IsType(updates[updates.Count - 1].Contents[1]); + Assert.Equal(11, usage.Details.InputTokenCount); + Assert.Equal(20, usage.Details.OutputTokenCount); + Assert.Equal(31, usage.Details.TotalTokenCount); + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "hello!" + }, + { + "role": "assistant", + "content": "hi, how are you?" + }, + { + "role": "user", + "content": "i\u0027m good. how are you?" + } + ], + "stream": false, + "options": { + "frequency_penalty": 0.75, + "presence_penalty": 0.5, + "seed": 42, + "stop": ["great"], + "temperature": 0.25 + } + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T17:18:46.308987Z", + "message": { + "role": "assistant", + "content": "I'm just a computer program, so I don't have feelings or emotions like humans do, but I'm functioning properly and ready to help with any questions or tasks you may have! How about we chat about something in particular or just shoot the breeze? Your choice!" + }, + "done_reason": "stop", + "done": true, + "total_duration": 23229369000, + "load_duration": 7724086300, + "prompt_eval_count": 36, + "prompt_eval_duration": 4245660000, + "eval_count": 55, + "eval_duration": 11256470000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), httpClient: httpClient); + + List messages = + [ + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + ModelId = "llama3.1", + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42 }, + }); + Assert.NotNull(response); + + Assert.Equal( + VerbatimHttpHandler.RemoveWhiteSpace(""" + I'm just a computer program, so I don't have feelings or emotions like humans do, + but I'm functioning properly and ready to help with any questions or tasks you may have! + How about we chat about something in particular or just shoot the breeze ? Your choice! + """), + VerbatimHttpHandler.RemoveWhiteSpace(response.Message.Text)); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T17:18:46.308987Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(36, response.Usage.InputTokenCount); + Assert.Equal(55, response.Usage.OutputTokenCount); + Assert.Equal(91, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + }, + "required": ["personName"] + } + } + } + ] + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T18:48:30.2669578Z", + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "GetPersonAge", + "arguments": { + "personName": "Alice" + } + } + } + ] + }, + "done_reason": "stop", + "done": true, + "total_duration": 27351311300, + "load_duration": 8041538400, + "prompt_eval_count": 170, + "prompt_eval_duration": 16078776000, + "eval_count": 19, + "eval_duration": 3227962000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler) { Timeout = Timeout.InfiniteTimeSpan }; + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient) + { + ToolCallJsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T18:48:30.2669578Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(170, response.Usage.InputTokenCount); + Assert.Equal(19, response.Usage.OutputTokenCount); + Assert.Equal(189, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + [Fact] + public async Task FunctionResultContent_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + }, + { + "role": "assistant", + "content": "{\u0022call_id\u0022:\u0022abcd1234\u0022,\u0022name\u0022:\u0022GetPersonAge\u0022,\u0022arguments\u0022:{\u0022personName\u0022:\u0022Alice\u0022}}" + }, + { + "role": "tool", + "content": "{\u0022call_id\u0022:\u0022abcd1234\u0022,\u0022result\u0022:42}" + } + ], + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + }, + "required": ["personName"] + } + } + } + ] + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T20:57:20.157266Z", + "message": { + "role": "assistant", + "content": "Alice is 42 years old." + }, + "done_reason": "stop", + "done": true, + "total_duration": 20320666000, + "load_duration": 8159642600, + "prompt_eval_count": 106, + "prompt_eval_duration": 10846727000, + "eval_count": 8, + "eval_duration": 1307842000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler) { Timeout = Timeout.InfiniteTimeSpan }; + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient) + { + ToolCallJsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var response = await client.CompleteAsync( + [ + new(ChatRole.User, "How old is Alice?"), + new(ChatRole.Assistant, [new FunctionCallContent("abcd1234", "GetPersonAge", new Dictionary { ["personName"] = "Alice" })]), + new(ChatRole.Tool, [new FunctionResultContent("abcd1234", "GetPersonAge", 42)]), + ], + new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Equal("Alice is 42 years old.", response.Message.Text); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T20:57:20.157266Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(106, response.Usage.InputTokenCount); + Assert.Equal(8, response.Usage.OutputTokenCount); + Assert.Equal(114, response.Usage.TotalTokenCount); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..4333cbde636 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OllamaEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegrationTests +{ + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ? + new OllamaEmbeddingGenerator(endpoint, "all-minilm") : + null; +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..205398c9a1c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OllamaEmbeddingGeneratorTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("endpoint", () => new OllamaEmbeddingGenerator(null!)); + Assert.Throws("modelId", () => new OllamaEmbeddingGenerator(new("http://localhost"), " ")); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + using OllamaEmbeddingGenerator generator = new(new("http://localhost")); + + Assert.Same(generator, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(generator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(generator, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public void AsEmbeddingGenerator_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + using IEmbeddingGenerator> chatClient = new OllamaEmbeddingGenerator(endpoint, model); + Assert.Equal("ollama", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public async Task GetEmbeddingsAsync_ExpectedRequestResponse() + { + const string Input = """ + {"model":"all-minilm","input":["hello, world!","red, white, blue"]} + """; + + const string Output = """ + { + "model":"all-minilm", + "embeddings":[ + [-0.038159743,0.032830726,-0.005602915,0.014363416,-0.04031945,-0.11662117,0.031710647,0.0019634133,-0.042558126,0.02925818,0.04254404,0.032178584,0.029820565,0.010947956,-0.05383333,-0.05031401,-0.023460664,0.010746779,-0.13776828,0.003972192,0.029283607,0.06673441,-0.015434976,0.048401773,-0.088160664,-0.012700827,0.04134059,0.0408592,-0.050058633,-0.058048956,0.048720006,0.068883754,0.0588242,0.008813041,-0.016036017,0.08514798,-0.07813561,-0.07740018,0.020856613,0.016228318,0.032506905,-0.053466275,-0.06220645,-0.024293836,0.0073994277,0.02410873,0.006477103,0.051144805,0.072868116,0.03460658,-0.0547553,-0.05937917,-0.007205277,0.020145971,0.035794333,0.005588114,0.010732389,-0.052755248,0.01006711,-0.008716047,-0.062840104,0.038445882,-0.013913384,0.07341423,0.09004691,-0.07995187,-0.016410379,0.044806693,-0.06886798,-0.03302609,-0.015488586,0.0112944925,0.03645402,0.06637969,-0.054364193,0.008732196,0.012049053,-0.038111813,0.006928739,0.05113517,0.07739711,-0.12295967,0.016389083,0.049567502,0.03162499,-0.039604694,0.0016613991,0.009564599,-0.03268798,-0.033994347,-0.13328508,0.0072719813,-0.010261588,0.038570367,-0.093384996,-0.041716397,0.069951184,-0.02632818,-0.149702,0.13445856,0.037486482,0.052814852,0.045044158,0.018727085,0.05445453,0.01727433,-0.032474063,0.046129994,-0.046679277,-0.03058037,-0.0181755,-0.048695795,0.033057086,-0.0038555008,0.050006237,-0.05828653,-0.010029618,0.01062073,-0.040105496,-0.0015263702,0.060846698,-0.04557025,0.049251337,0.026121102,0.019804202,-0.0016694543,0.059516467,-6.525171e-33,0.06351319,0.0030810465,0.028928237,0.17336167,0.0029677018,0.027755935,-0.09513812,-0.031182382,0.026697554,-0.0107956175,0.023849761,0.02378595,-0.03121345,0.049473017,-0.02506533,0.101713106,-0.079133175,-0.0032418896,0.04290832,0.094838716,-0.06652884,0.0062877694,0.02221229,0.0700068,-0.007469806,-0.0017550732,0.027011596,-0.075321496,0.114022695,0.0085597,-0.023766534,-0.04693697,0.014437173,0.01987886,-0.0046902793,0.0013660098,-0.034307938,-0.054156985,-0.09417741,-0.028919358,-0.018871028,0.04574328,0.047602862,-0.0031305805,-0.033291575,-0.0135114025,0.051019657,0.031115327,0.015239397,0.05413997,-0.085031144,0.013366392,-0.04757861,0.07102588,-0.013105953,-0.0023799809,0.050322797,-0.041649505,-0.014187793,0.0324716,0.005401626,0.091307014,0.0044665188,-0.018263677,-0.015284639,-0.04634121,0.038754962,0.014709013,0.052040145,0.0017918312,-0.014979437,0.027103048,0.03117813,0.023749126,-0.004567645,0.03617759,0.06680814,-0.001835277,0.021281,-0.057563916,0.019137124,0.031450257,-0.018432263,-0.040860977,0.10391725,0.011970765,-0.014854915,-0.10521159,-0.012288272,-0.00041675335,-0.09510029,0.058300544,0.042590536,-0.025064372,-0.09454636,4.0064686e-33,0.13224861,0.0053342036,-0.033114634,-0.09096768,-0.031561732,-0.03395822,-0.07202013,0.12591493,-0.08332582,0.052816514,0.001065021,0.022002738,0.1040207,0.013038866,0.04092958,0.018689224,0.1142518,0.024801003,0.014596161,0.006195551,-0.011214642,-0.035760444,-0.037979998,0.011274433,-0.051305123,0.007884909,0.06734877,0.0033462204,-0.09284879,0.037033774,-0.022331867,0.039951596,-0.030730229,-0.011403805,-0.014458028,0.024968812,-0.097553216,-0.03536226,-0.037567392,-0.010149212,-0.06387594,0.025570663,0.02060328,0.037549157,-0.104355134,-0.02837097,-0.052078977,0.0128349,-0.05123587,-0.029060647,-0.09632806,-0.042301137,0.067175224,-0.030890828,-0.010358077,0.027408795,-0.028092034,0.010337195,0.04303845,0.022324203,0.00797792,0.056084383,0.040727936,0.092925824,0.01653155,-0.053750493,0.00046004262,0.050728552,0.04253214,-0.029197674,0.00926312,-0.010662153,-0.037244495,0.002277273,-0.030296732,0.07459592,0.002572513,-0.017561244,0.0028881067,0.03841156,0.007247727,0.045637112,0.039992437,0.014227117,-0.014297474,0.05854321,0.03632371,0.05527864,-0.02007574,-0.08043163,-0.030238612,-0.014929122,0.022335418,0.011954643,-0.06906099,-1.8807288e-8,-0.07850291,0.046684187,-0.023935271,0.063510746,0.024001691,0.0014455577,-0.09078209,-0.066868275,-0.0801402,0.005480386,0.053663295,0.10483363,-0.066864185,0.015531167,0.06711155,0.07081655,-0.031996343,0.020819444,-0.021926524,-0.0073062326,-0.010652819,0.0041180425,0.033138428,-0.0789938,0.03876969,-0.075220205,-0.015715994,0.0059789424,0.005140016,-0.06150612,0.041992374,0.09544083,-0.043187104,0.014401576,-0.10615426,-0.027936764,0.011047429,0.069572434,0.06690283,-0.074798405,-0.07852024,0.04276141,-0.034642085,-0.106051244,-0.03581038,0.051521253,0.06865896,-0.04999753,0.0154549,-0.06452052,-0.07598782,0.02603005,0.074413665,-0.012398757,0.13330704,0.07475513,0.051348723,0.02098748,-0.02679416,0.08896129,0.039944872,-0.041040305,0.031930625,0.018114654], + [0.007228383,-0.021804843,-0.07494023,-0.021707121,-0.021184582,0.09326986,0.10764054,-0.01918113,0.007439991,0.01367952,-0.034187328,-0.044076536,0.016042138,0.007507193,-0.016432272,0.025345335,0.010598066,-0.03832474,-0.14418823,-0.033625234,0.013156937,-0.0048872638,-0.08534306,-0.00003228713,-0.08900276,-0.00008128615,0.010332802,0.053303026,-0.050233904,-0.0879366,-0.064243905,-0.017168961,0.1284308,-0.015268303,-0.049664143,-0.07491954,0.021887481,0.015997978,-0.07967111,0.08744341,-0.039261423,-0.09904984,0.02936398,0.042995434,0.057036504,0.09063012,0.0000012311281,0.06120768,-0.050825767,-0.014443322,0.02879051,-0.002343813,-0.10176559,0.104563184,0.031316753,0.08251861,-0.041213628,-0.0217945,0.0649965,-0.011131547,0.018417398,-0.014460508,-0.05108664,0.11330918,0.01863208,0.006442521,-0.039408617,-0.03609412,-0.009156692,-0.0031261789,-0.010928502,-0.021108521,0.037411734,0.012443921,0.018142054,-0.0362644,0.058286663,-0.02733258,-0.052172586,-0.08320095,-0.07089281,-0.0970049,-0.048587535,0.055343032,0.048351917,0.06892102,-0.039993215,0.06344781,-0.084417015,0.003692423,-0.059397053,0.08186814,0.0029228176,-0.010551637,-0.058019258,0.092128515,0.06862907,-0.06558893,0.021121018,0.079212844,0.09616225,0.0045106052,0.039712362,-0.053576704,0.035097837,-0.04251009,-0.013761404,0.011582285,0.02387105,0.009042205,0.054141942,-0.051263757,-0.07984356,-0.020198742,-0.051623948,-0.0013434993,-0.05825417,-0.0026240738,0.0050159167,-0.06320204,0.07872169,-0.04051374,0.04671058,-0.05804034,-0.07103668,-0.07507343,0.015222599,-3.0948323e-33,0.0076309564,-0.06283016,0.024291662,0.12532257,0.013917241,0.04869009,-0.037988827,-0.035241846,-0.041410565,-0.033772282,0.018835608,0.081035286,-0.049912665,0.044602085,0.030495265,-0.009206943,0.027668765,0.011651487,-0.10254086,0.054472663,-0.06514106,0.12192646,0.048823033,-0.015688669,0.010323047,-0.02821445,-0.030832449,-0.035029083,-0.010604268,0.0014445938,0.08670387,0.01997448,0.0101131955,0.036524937,-0.033489946,-0.026745271,-0.04709222,0.015197909,0.018787097,-0.009976326,-0.0016434817,-0.024719588,-0.09179337,0.09343157,0.029579962,-0.015174558,0.071250066,0.010549244,0.010716396,0.05435638,-0.06391847,-0.031383075,0.007916095,0.012391228,-0.012053197,-0.017409964,0.013742709,0.0594159,-0.033767693,0.04505938,-0.0017214329,0.12797962,0.03223919,-0.054756388,0.025249248,-0.02273578,-0.04701282,-0.018718086,0.009820931,-0.06267794,-0.012644738,0.0068301614,0.093209736,-0.027372226,-0.09436381,0.003861504,0.054960024,-0.058553983,-0.042971537,-0.008994571,-0.08225824,-0.013560626,-0.01880568,0.0995795,-0.040887516,-0.0036491079,-0.010253542,-0.031025425,-0.006957114,-0.038943008,-0.090270124,-0.031345647,0.029613726,-0.099465184,-0.07469079,7.844707e-34,0.024241973,0.03597121,-0.049776066,0.05084303,0.006059542,-0.020719761,0.019962702,0.092246406,0.069408394,0.062306542,0.013837189,0.054749023,0.05090263,0.04100415,-0.02573441,0.09535842,0.036858294,0.059478357,0.0070162765,0.038462427,-0.053635903,0.05912332,-0.037887845,-0.0012995935,-0.068758026,0.0671618,0.029407106,-0.061569903,-0.07481879,-0.01849014,0.014240046,-0.08064838,0.028351007,0.08456427,0.016858438,0.02053254,0.06171099,-0.028964644,-0.047633287,0.08802184,0.0017116248,0.019451816,0.03419083,0.07152118,-0.027244413,-0.04888475,-0.10314279,0.07628554,-0.045991484,-0.023299307,-0.021448445,0.04111079,-0.036342163,-0.010670482,0.01950527,-0.0648448,-0.033299454,0.05782628,0.030278979,0.079154804,-0.03679649,0.031728156,-0.034912236,0.08817754,0.059208114,-0.02319613,-0.027045371,-0.018559752,-0.051946763,-0.010635224,0.048839167,-0.043925915,-0.028300019,-0.0039419765,0.044211324,-0.067469835,-0.027534118,0.005051618,-0.034172326,0.080007285,-0.01931061,-0.005759926,0.08765162,0.08372951,-0.093784876,0.011837292,0.019019455,0.047941882,0.05504541,-0.12475821,0.012822803,0.12833545,0.08005919,0.019278418,-0.025834465,-1.9763878e-8,0.05211108,0.024891146,-0.0015623684,0.0040500895,0.015101377,-0.0031462535,0.014759316,-0.041329216,-0.029255627,0.048599463,0.062482737,0.018376771,-0.066601776,0.014752581,0.07968402,-0.015090815,-0.12100162,-0.0014005995,0.0134423375,-0.0065814927,-0.01188529,-0.01107086,-0.059613306,0.030120188,0.0418596,-0.009260598,0.028435009,0.024893047,0.031339604,0.09501834,0.027570697,0.0636991,-0.056108754,-0.0329521,-0.114633024,-0.00981398,-0.060992315,0.027551433,0.0069592255,-0.059862003,0.0008075791,0.001507554,-0.028574942,-0.011227367,0.0056030746,-0.041190825,-0.09364463,-0.04459479,-0.055058934,-0.029972456,-0.028642913,-0.015199684,0.007875299,-0.034083385,0.02143902,-0.017395096,0.027429376,0.013198211,0.005065835,0.037760753,0.08974973,0.07598824,0.0050444477,0.014734193] + ], + "total_duration":375551700, + "load_duration":354411900, + "prompt_eval_count":9 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = new OllamaEmbeddingGenerator(new("http://localhost:11434"), "all-minilm", httpClient); + + var response = await generator.GenerateAsync([ + "hello, world!", + "red, white, blue", + ]); + Assert.NotNull(response); + Assert.Equal(2, response.Count); + + Assert.NotNull(response.Usage); + Assert.Equal(9, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.TotalTokenCount); + + foreach (Embedding e in response) + { + Assert.Equal("all-minilm", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(384, e.Vector.Length); + Assert.Contains(e.Vector.ToArray(), f => !f.Equals(0)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..49560a9c451 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSerializable(typeof(string))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(IDictionary))] +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..da60e62061f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using Azure.AI.OpenAI; +using OpenAI; + +namespace Microsoft.Extensions.AI; + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets an to use for testing, or null if the associated tests should be disabled. + public static OpenAIClient? GetOpenAIClient() + { + string? apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + + if (apiKey is not null) + { + if (string.Equals(Environment.GetEnvironmentVariable("OPENAI_MODE"), "AzureOpenAI", StringComparison.OrdinalIgnoreCase)) + { + var endpoint = Environment.GetEnvironmentVariable("OPENAI_ENDPOINT") + ?? throw new InvalidOperationException("To use AzureOpenAI, set a value for OPENAI_ENDPOINT"); + return new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)); + } + else + { + return new OpenAIClient(apiKey); + } + } + + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj new file mode 100644 index 00000000000..0ef40e12df3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj @@ -0,0 +1,26 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.OpenAI + + + + true + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs new file mode 100644 index 00000000000..c82e1abc860 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OpenAIChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetOpenAIClient() + ?.AsChatClient(Environment.GetEnvironmentVariable("OPENAI_CHAT_MODEL") ?? "gpt-4o-mini"); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs new file mode 100644 index 00000000000..f19a19f3ce8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -0,0 +1,608 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using OpenAI; +using OpenAI.Chat; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OpenAIChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => new OpenAIChatClient(null!, "model")); + Assert.Throws("chatClient", () => new OpenAIChatClient(null!)); + + OpenAIClient openAIClient = new("key"); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, null!)); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, "")); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, " ")); + } + + [Fact] + public void AsChatClient_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => ((OpenAIClient)null!).AsChatClient("model")); + Assert.Throws("chatClient", () => ((ChatClient)null!).AsChatClient()); + + OpenAIClient client = new("key"); + Assert.Throws("modelId", () => client.AsChatClient(null!)); + Assert.Throws("modelId", () => client.AsChatClient(" ")); + } + + [Fact] + public void AsChatClient_OpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + OpenAIClient client = new(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("openai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + + chatClient = client.GetChatClient(model).AsChatClient(); + Assert.Equal("openai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void AsChatClient_AzureOpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + AzureOpenAIClient client = new(endpoint, new ApiKeyCredential("key")); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("azureopenai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + + chatClient = client.GetChatClient(model).AsChatClient(); + Assert.Equal("azureopenai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient() + { + OpenAIClient openAIClient = new(new ApiKeyCredential("key")); + IChatClient chatClient = openAIClient.AsChatClient("model"); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(chatClient, chatClient.GetService()); + + Assert.Same(openAIClient, chatClient.GetService()); + + Assert.NotNull(chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public void GetService_ChatClient_SuccessfullyReturnsUnderlyingClient() + { + ChatClient openAIClient = new OpenAIClient(new ApiKeyCredential("key")).GetChatClient("model"); + IChatClient chatClient = openAIClient.AsChatClient(); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(openAIClient, chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":10,"temperature":0.5} + """; + + const string Output = """ + { + "id": "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + "object": "chat.completion", + "created": 1727888631, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 9, + "total_tokens": 17, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.CompletionId); + Assert.Equal("Hello! How can I assist you today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(8, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.OutputTokenCount); + Assert.Equal(17, response.Usage.TotalTokenCount); + Assert.NotNull(response.Usage.AdditionalProperties); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":20,"stream":true,"stream_options":{"include_usage":true},"temperature":0.5} + """; + + const string Output = """ + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":8,"completion_tokens":9,"total_tokens":17,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_889_370); + Assert.Equal(12, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.NotNull(updates[i].AdditionalProperties); + Assert.Equal("fp_f85bea6784", updates[i].AdditionalProperties![nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + Assert.Equal(i == 10 ? 0 : 1, updates[i].Contents.Count); + Assert.Equal(i < 10 ? null : ChatFinishReason.Stop, updates[i].FinishReason); + } + + UsageContent usage = updates.SelectMany(u => u.Contents).OfType().Single(); + Assert.Equal(8, usage.Details.InputTokenCount); + Assert.Equal(9, usage.Details.OutputTokenCount); + Assert.Equal(17, usage.Details.TotalTokenCount); + Assert.NotNull(usage.Details.AdditionalProperties); + Assert.Equal(new Dictionary { [nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)] = 0 }, usage.Details.AdditionalProperties[nameof(ChatTokenUsage.OutputTokenDetails)]); + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "role": "system", + "content": "You are a really nice friend." + }, + { + "role": "user", + "content": "hello!" + }, + { + "role": "assistant", + "content": "hi, how are you?" + }, + { + "role": "user", + "content": "i\u0027m good. how are you?" + } + ], + "model": "gpt-4o-mini", + "frequency_penalty": 0.75, + "presence_penalty": 0.5, + "seed":42, + "stop": [ + "great" + ], + "temperature": 0.25 + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", + "object": "chat.completion", + "created": 1727894187, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I’m doing well, thank you! What’s on your mind today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 42, + "completion_tokens": 15, + "total_tokens": 57, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List messages = + [ + new(ChatRole.System, "You are a really nice friend."), + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42 }, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(42, response.Usage.InputTokenCount); + Assert.Equal(15, response.Usage.OutputTokenCount); + Assert.Equal(57, response.Usage.TotalTokenCount); + Assert.NotNull(response.Usage.AdditionalProperties); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the age of the specified person.", + "name": "GetPersonAge", + "parameters": { + "type": "object", + "required": [ + "personName" + ], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + }, + "strict": false + } + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADydKhrSKEBWJ8gy0KCIU74rN3Hmk", + "object": "chat.completion", + "created": 1727894702, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_8qbINM045wlmKZt9bVJgwAym", + "type": "function", + "function": { + "name": "GetPersonAge", + "arguments": "{\"personName\":\"Alice\"}" + } + } + ], + "refusal": null + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 61, + "completion_tokens": 16, + "total_tokens": 77, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); + Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(61, response.Usage.InputTokenCount); + Assert.Equal(16, response.Usage.OutputTokenCount); + Assert.Equal(77, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task FunctionCallContent_Streaming() + { + const string Input = """ + { + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "model": "gpt-4o-mini", + "stream": true, + "stream_options": { + "include_usage": true + }, + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the age of the specified person.", + "name": "GetPersonAge", + "parameters": { + "type": "object", + "required": [ + "personName" + ], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + }, + "strict": false + } + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_F9ZaqPWo69u0urxAhVt8meDW","type":"function","function":{"name":"GetPersonAge","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"person"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Alice"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":61,"completion_tokens":16,"total_tokens":77,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + })) + { + updates.Add(update); + } + + Assert.Equal("", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_895_263); + Assert.Equal(10, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.NotNull(updates[i].AdditionalProperties); + Assert.Equal("fp_f85bea6784", updates[i].AdditionalProperties![nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + Assert.Equal(i < 7 ? null : ChatFinishReason.ToolCalls, updates[i].FinishReason); + } + + FunctionCallContent fcc = Assert.IsType(Assert.Single(updates[updates.Count - 1].Contents)); + Assert.Equal("call_F9ZaqPWo69u0urxAhVt8meDW", fcc.CallId); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + + UsageContent usage = updates.SelectMany(u => u.Contents).OfType().Single(); + Assert.Equal(61, usage.Details.InputTokenCount); + Assert.Equal(16, usage.Details.OutputTokenCount); + Assert.Equal(77, usage.Details.TotalTokenCount); + Assert.NotNull(usage.Details.AdditionalProperties); + Assert.Equal(new Dictionary { [nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)] = 0 }, usage.Details.AdditionalProperties[nameof(ChatTokenUsage.OutputTokenDetails)]); + } + + private static IChatClient CreateChatClient(HttpClient httpClient, string modelId) => + new OpenAIClient(new ApiKeyCredential("apikey"), new OpenAIClientOptions { Transport = new HttpClientPipelineTransport(httpClient) }) + .AsChatClient(modelId); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..38283e2687b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OpenAIEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegrationTests +{ + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOpenAIClient() + ?.AsEmbeddingGenerator(Environment.GetEnvironmentVariable("OPENAI_EMBEDDING_MODEL") ?? "text-embedding-3-small"); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..d08cf295a4b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs @@ -0,0 +1,187 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net.Http; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using OpenAI; +using OpenAI.Embeddings; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OpenAIEmbeddingGeneratorTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => new OpenAIEmbeddingGenerator(null!, "model")); + Assert.Throws("embeddingClient", () => new OpenAIEmbeddingGenerator(null!)); + + OpenAIClient openAIClient = new("key"); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, null!)); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, "")); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, " ")); + } + + [Fact] + public void AsEmbeddingGenerator_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => ((OpenAIClient)null!).AsEmbeddingGenerator("model")); + Assert.Throws("embeddingClient", () => ((EmbeddingClient)null!).AsEmbeddingGenerator()); + + OpenAIClient client = new("key"); + Assert.Throws("modelId", () => client.AsEmbeddingGenerator(null!)); + Assert.Throws("modelId", () => client.AsEmbeddingGenerator(" ")); + } + + [Fact] + public void AsEmbeddingGenerator_OpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + OpenAIClient client = new(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IEmbeddingGenerator> embeddingGenerator = client.AsEmbeddingGenerator(model); + Assert.Equal("openai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + + embeddingGenerator = client.GetEmbeddingClient(model).AsEmbeddingGenerator(); + Assert.Equal("openai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + } + + [Fact] + public void AsEmbeddingGenerator_AzureOpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + AzureOpenAIClient client = new(endpoint, new ApiKeyCredential("key")); + + IEmbeddingGenerator> embeddingGenerator = client.AsEmbeddingGenerator(model); + Assert.Equal("azureopenai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + + embeddingGenerator = client.GetEmbeddingClient(model).AsEmbeddingGenerator(); + Assert.Equal("azureopenai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + } + + [Fact] + public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient() + { + OpenAIClient openAIClient = new(new ApiKeyCredential("key")); + IEmbeddingGenerator> embeddingGenerator = openAIClient.AsEmbeddingGenerator("model"); + + Assert.Same(embeddingGenerator, embeddingGenerator.GetService>>()); + Assert.Same(embeddingGenerator, embeddingGenerator.GetService()); + + Assert.Same(openAIClient, embeddingGenerator.GetService()); + + Assert.NotNull(embeddingGenerator.GetService()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(embeddingGenerator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public void GetService_EmbeddingClient_SuccessfullyReturnsUnderlyingClient() + { + EmbeddingClient openAIClient = new OpenAIClient(new ApiKeyCredential("key")).GetEmbeddingClient("model"); + IEmbeddingGenerator> embeddingGenerator = openAIClient.AsEmbeddingGenerator(); + + Assert.Same(embeddingGenerator, embeddingGenerator.GetService>>()); + Assert.Same(openAIClient, embeddingGenerator.GetService()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(embeddingGenerator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public async Task GetEmbeddingsAsync_ExpectedRequestResponse() + { + const string Input = """ + {"input":["hello, world!","red, white, blue"],"model":"text-embedding-3-small","encoding_format":"base64"} + """; + + const string Output = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": "qjH+vMcj07wP1+U7kbwjOv4cwLyL3iy9DkgpvCkBQD0bthW98o6SvMMwmTrQRQa9r7b1uy4tuLzssJs7jZspPe0JG70KJy89ae4fPNLUwjytoHk9BX/1OlXCfTzc07M8JAMIPU7cibsUJiC8pTNGPWUbJztfwW69oNwOPQIQ+rwm60M7oAfOvDMAsTxb+fM77WIaPIverDqcu5S84f+rvFyr8rxqoB686/4cPVnj9ztLHw29mJqaPAhH8Lz/db86qga/PGhnYD1WST28YgWru1AdRTz/db899PIPPBzBE720ie47ujymPbh/Kb0scLs8V1Q7PGIFqzwVMR48xp+UOhNGYTxfwW67CaDvvOeEI7tgc228uQNoPXrLBztd2TI9HRqTvLuVJbytoPm8YVMsOvi6irzweJY7/WpBvI5NKL040ym95ccmPAfj8rxJCZG9bsGYvJkpVzszp7G8wOxcu6/ZN7xXrTo7Q90YvGTtZjz/SgA8RWxVPL/hXjynl8O8ZzGjvHK0Uj0dRVI954QjvaqKfTxmUeS8Abf6O0RhV7tr+R098rnRPAju8DtoiiK95SCmvGV0pjwQMOW9wJPdPPutxDxYivi8NLKvPI3pKj3UDYE9Fg5cvQsyrTz+HEC9uuMmPMEaHbzJ4E8778YXvVDERb2cFBS9tsIsPLU7bT3+R/+8b55WPLhRaTzsgls9Nb2tuhNG4btlzSW9Y7cpvO1iGr0lh0a8u8BkvadJQj24f6k9J51CvbAPdbwCEHq8CicvvIKROr0ESbg7GMvYPE6OCLxS2sG7/WrBPOzbWj3uP1i9TVXKPPJg0rtp7h87TSqLPCmowLxrfdy8XbbwPG06WT33jEo9uxlkvcQN17tAmVy8h72yPEdMFLz4Ewo7BPs2va35eLynScI8WpV2PENW2bwQBSa9lSufu32+wTwl4MU8vohfvRyT07ylCIe8dHHPPPg+ST0Ooag8EsIiO9F7w7ylM0Y7dfgOPADaPLwX7hq7iG8xPDW9Lb1Q8oU98twTPYDUvTomwIQ8akcfvUhXkj3mK6Q8syXxvAMb+DwfMI87bsGYPGUbJ71GHtS8XbbwvFQ+P70f14+7Uq+CPSXgxbvHfFK9icgwPQsEbbwm60O9EpRiPDjTKb3uFJm7p/BCPazDuzxh+iy8Xj2wvBqrl71a7nU9guq5PYNDOb1X2Pk8raD5u+bSpLsMD2u7C9ktPVS6gDzyjhI9vl2gPNO0AT0/vJ68XQTyvMMCWbubYhU9rzK3vLhRaToSlOK6qYIAvQAovrsa1la8CEdwPKOkCT1jEKm8Y7epvOv+HLsoJII704ZBPXbVTDubjVQ8aRnfOvspBr2imYs8MDi2vPFVVDxSrwK9hac2PYverLyxGnO9nqNQvfVLD71UEP+8tDDvurN+8Lzkbqc6tsKsu5WvXTtDKxo72b03PdDshryvXfY81JE/vLYbLL2Fp7Y7JbUGPEQ2GLyagla7fAxDPaVhhrxu7Ne7wzAZPOxXHDx5nUe9s35wPHcOizx1fM26FTGePAsEbbzzQBE9zCQMPW6TWDygucy8zPZLPM2oSjzfmy48EF4lvUttDj3NL4q8WIp4PRoEFzxKFA89uKpou9H3BDvK6009a33cPLq15rzv8VY9AQX8O1gxebzjCqo7EeJjPaA1DrxoZ2C65tIkvS0iOjxln2W8o0sKPMPXGb3Ak908cxhQvR8wDzzN1gq8DnNovMZGFbwUJiA9moJWPBl9VzkVA148TrlHO/nFCL1f7y68xe2VPIROtzvCJRu88YMUvaUzRj1qR5+7e6jFPGyrHL3/SgC9GMtYPJcT27yqMX688YOUO32+QT18iAS9cdeUPFbN+zvlx6a83d6xOzQLL7sZJNi8mSnXOuqan7uqin09CievvPw0hLyuq/c866Udu4T1t7wBXnu7zQFKvE5gyDxhUyw8qzx8vIrTLr0Kq+26TgdJPWmVoDzOiIk8aDwhPVug9Lq6iie9iSEwvOKxqjwMiyy7E59gPepMnjth+iw9ntGQOyDijbw76SW9i96sO7qKJ7ybYhU8R/6Su+GmLLzsgtu7inovPRG3pLwZUpi7YzvoucrAjjwOSKm8uuOmvLbt67wKUu68XCc0vbd0Kz0LXWy8lHmgPAAoPjxRpAS99oHMvOlBoDprUh09teLtOxoEl7z0mRA89tpLvVQQ/zyjdkk9ZZ/lvHLikrw76SW82LI5vXyIBLzVnL06NyGrPPXPzTta7nW8FTEePSVcB73FGFU9SFcSPbzL4rtXrbo84lirvcd8Urw9/yG9+63EvPdhCz2rPPw8PPQjvbXibbuo+0C8oWtLPWVG5juL3qw71Zw9PMUY1Tk3yKu8WWq3vLnYKL25A+i8zH2LvMW/1bxDr1g8Cqvtu3pPRr0FrbU8vVKiO0LSGj1b+fM7Why2ux1FUjwhv0s89lYNPUbFVLzJ4M88t/hpvdpvNj0EzfY7gC29u0HyW7yv2Tc8dSPOvNhZurzrpR28jUIqPM0vijxyDdK8iBYyvZ0fkrxalXa9JeBFPO/GF71dBHK8X8FuPKnY/jpQmQY9S5jNPGBz7TrpQaA87/FWvUHyWzwCEPq78HiWOhfuGr0ltYY9I/iJPamCgLwLBO28jZupu38ivzuIbzG8Cfnuu0dMlLypKQG7BzxyvR5QULwCEHo8k8ehPUXoFjzPvka9MDi2vPsphjwjfMi854QjvcW/VbzO4Yg7Li04vL/h3jsaL9a5iG8xuybrwzz3YYu8Gw8VvVGkBD1UugA99MRPuCjLArzvxhc8XICzPFyrcr0gDU296h7eu8jV0TxNKos8lSufuqT9CD1oDmE8sqGyu2PiaLz6osY5YjBqPBAFJrwIlfG8PlihOBE74zzzQJG8r112vJPHobyrPPw7YawrPb5doLqtzrk7qHcCPVIoQzz5l0i81UM+vFd/eryaVxc9xA3XO/6YgbweJZG7W840PF0Ecj19ZUI8x1GTOtb1vDyDnLg8yxkOvOywGz0kqgg8fTqDvKlUQL3Bnlu992ELvZPHobybCZa82LK5vf2NgzwnnUK8YMzsPKOkiTxDr9g6la/duz3/IbusR/q8lmFcvFbN+zztCRu95nklPVKBwjwEJnY6V9j5PPK50bz6okY7R6UTPPnFiDwCafk8N8grO/gTCr1iiWm8AhB6vHHXlLyV3Z08vtZgPMDsXDsck9O7mdBXvRLCojzkbqe8XxpuvDSyLzu0MO87cxhQvd3eMbxtDxo9JKqIvB8CT72zrDC7s37wPHvWhbuXQZs8UlYDu7ef6rzsV5y8IkYLvUo/Tjz+R/88PrGgujSyrzxsBJy8P7yeO7f46byfKpA8cFDVPLygIzsdGpO77LCbvLSJ7rtgzOy7sA91O0hXkrwhO408XKvyvMUYVT2mPsQ8d+DKu9lkuLy+iF89xZSWPJFjpDwIlfE8bC9bPBE7Y7z/+f08W6B0PAc8crhmquO7RvOUPDybJLwlXAe9cuKSvMPXGbxK5s48sZY0O+4UmT1/Ij+8oNyOvPIH07tNKos8yTnPO2RpKDwRO+O7vl2gvKSvB7xGmpW7nD9TPZpXFzyXQRs9InHKurhR6bwb4VS8iiwuO3pPxrxeD3A8CfluO//OPr0MaOq8r112vAwP6zynHgM9T+cHPJuNVLzLRE07EmkjvWHX6rzBGh285G4nPe6Y17sCafm8//n9PJkpVzv9P4K7IWbMPCtlvTxHKVK8JNXHO/uCBblAFZ48xyPTvGaqY7wXlRs9EDDlPHcOizyNQiq9W3W1O7iq6LxwqdQ69MRPvSJGC7n3CIy8HOxSvSjLAryU0p87QJncvEoUjzsi7Qu9U4xAOwn5brzfm668Wu71uu002rw/Y588o6SJPFfY+Tyfg4+8u5WlPMDBnTzVnD08ljadu3sBxbzfm668n4OPO9VDvrz0mZC8kFimPNiyOT134Mo8vquhvDA4Njyjz0i7zVpJu1rudbwmksQ794xKuhN0ITz/zj68Vvu7unBQ1bv8NAS97FecOyxwOzs1ZC68AIG9PKLyCryvtvU8ntEQPBkkWD2xwfO7QfLbOhqIVTykVog7lSufvKOkiTwpqEA9/RFCvKxHejx3tYu74woqPMS0VzoMtuu8ViZ7PL8PH72+L2C81JE/vN3eMTwoywK9z5OHOx4lkTwGBrW8c5QRu4khMDyvBPc8nR8SvdlkuLw0si+9S8aNvCkBwLsXwFo7Od4nPbo8pryp2P68GfkYPKpfvjrsV5w6zuEIvbHB8zxnMSM9C9mtu1nj97zjYym8XFJzPAiVcTyNm6m7X5YvPJ8qED1l+OS8WTx3vGKJ6bt+F0G9jk2oPAR0dzwIR/A8umdlvNLUwjzI1dE7yuvNvBdnW7zdhTI9xkaVPCVcB70Mtus7G7aVPDchK7xuwRi8oDWOu/SZkLxOuUe8c5QRPLBo9Dz/+f07zS+KvNBFBr1n2CO8TKNLO4ZZNbym5US5HsyRvGi1YTwxnDO71vW8PM3WCr3E4he816e7O7QFML2asBa8jZspPSVcBzvjvCi9ZGmoPHV8zbyyobK830KvOgw9q7xzZtG7R6WTPMpnjzxj4mg8mrAWPS+GN7xoZ2C8tsKsOVMIAj1fli89Zc0lO00qCzz+R/87XKvyvLxy4zy52Cg9YjBqvW9F1zybjVS8mwmWvLvA5DymugU9DOQrPJWvXbvT38C8TrnHvLbt67sgiQ49e32GPPTETzv7goW7cKnUOoOcuLpG85S8CoCuO7ef6rkaqxe90tTCPJ8qkDvuuxk8FFFfPK9ddrtAbh08roC4PAnOrztV8D08jemquwR09ziL3iy7xkaVumVG5rygNQ69CfnuPGBzbTyE9Tc9Z9ijPK8yNzxgoa084woqu1F2RLwN76m7hrI0vf7xgLwaXRY6JmeFO68ytzrrpR29XbZwPYI4uzvkFai8qHcCPRCJ5DxKFI+7dHHPPE65xzxvnta8BPs2vWaq4zwrvjy8tDDvvEq7D7076SU9q+N8PAsyLTxb+XM9xZQWPP7ufzxsXZu6BEk4vGXNJbwBXvu8xA3XO8lcEbuuJzk8GEeavGnun7sMPSs9ITsNu1yr8roj+Ik8To6IvKjQgbwIwzG8wqlZvDfIK7xln2W8B+Pyu1HPw7sBjDs9Ba01PGSU57w/Yx867FecPFdUu7w2b6w7X5avvA8l57ypKQE9oGBNPeyC27vGytM828i1PP9KAD2/4V68eZ1HvDHqtDvR94Q6UwgCPLMlcbz+w0C8HwJPu/I1k7yZ/pe8aLXhPHYDDT28oKO8p2wEvdVDvrxh+qy8WDF5vJBYpjpaR3U8vgQhPNItwrsJoG88UaQEu3e1C7yagtY6HOzSOw9+5ryYTBk9q+N8POMKqrwoywI9DLZrPCN8SDxYivi8b3MXPf/OvruvBHc8M6exvA3vKbxz7RA8Fdieu4rTrrwFVDa8Vvu7PF0Ecjs6N6e8BzzyPP/Ovrv2rww9t59qvEoUDz3HUZO7UJkGPRigmbz/+X28qjH+u3jACbxlzaW7DA9rvFLawbwLBO2547yoO1t1NTr1pI68Vs37PAI+Ojx8s8O8xnHUvPg+yTwLBO26ybUQPfUoTTw76SU8i96sPKWMRbwUqt46pj7EPGX4ZL3ILtG8AV77vM0BSjzKZ488CByxvIWnNjyIFrI83CwzPN2FsjzHUZO8rzK3O+iPIbyGCzQ98NGVuxpdlrxhrKs8hQC2vFWXvjsCaXm8oRJMPHyIBLz+HMA8W/nzvHkZCb0pqMC87m0YPCu+vDsM5Ks8VnR8vG0Pmrt0yk48y3KNvKcegzwGMXS9xZQWPDYWrTxxAtQ7IWZMPU4Hybw89CO8/eaCPPMSUTxuk9i8WAY6vGfYozsQMGW8Li24vI+mJzxKFI88HwJPPFru9btRz8O6L9+2u29F1zwC5bq7RGHXvMtyjbr5bIm7V626uxsPlTv1KE29UB3FPMwkDDupggC8SQkRvH4XQT1cJ7Q8nvzPvKsRvTu9+SI8JbUGuiP4iTx460i99JkQPNF7Qz26Dma8u+4kvHO/0LyzfvA8EIlkPUPdmLpmUWS8uxnku8f4E72ruL27BzxyvKeXwz1plSC8gpG6vEQ2mLvtYho91Zy9vLvA5DtnXGK7sZY0uyu+PLwXlZu8GquXvE2uSb0ezBG8wn6au470KD1Abh28YMzsvPQdT7xKP867Xg/wO81aSb0IarK7SY1PO5EKJTsMi6y8cH4VvcXtlbwdGhM8xTsXPQvZLbxgzOw7Pf8hPRsPlbzDMJm8ZGmoPM1aSb0HEbO8PPQjvX5wwDwQXiW9wlDaO7SJ7jxFE9a8FTEePG5omTvPkwc8vtZgux9bzrmwD3W8U2EBPAVUNj0hlIw7comTPAEF/DvKwI68YKGtPJ78Tz1boHQ9sOS1vHiSSTlVG307HsyRPHEwFDxQmQY8CaBvvB0aE70PfuY8+neHvHOUET3ssBu7+tCGPJl3WDx4wAk9d1yMPOqanzwGBjW8ZialPB7MEby1O+07J0RDu4yQq7xpGV88ZXQmPc3WCruRCqU8Xbbwu+0JG7kXGVq8SY1PvKblxDv/oH68r7Z1OynWgDklh0a8E/hfPBCJZL31/Y08sD21vA9+Zjy6DmY82WQ4PAJp+TxHTJQ8JKoIvUBunbwgDc26BzxyvVUb/bz+w8A8Wu51u8guUbyHZLM8Iu0LvJqCVj3nhKO96kwevVDyBb3UDYG79zNLO7KhMj1IgtE83NOzO0f+krw89CM9z5OHuz+OXj2TxyE8wOzcPP91v7zUZgA8DyVnvILqOTzn3aI8j/+mO8xPyzt1UQ48+R4IvQnOrzt1I067QtKau9vINb1+7AE8sA/1uy7UOLzpQSC8dqoNPSnWgDsJoO+8ANo8vfDRlbwefpC89wgMPI1CKrrYsrm78mBSvFFLBb1Pa0a8s1MxPHbVzLw+WCG9kbyjvNt6tLwfMA+8HwLPvGO3qTyyobK8DcFpPInIsLwXGdq7nBSUPGdc4ryTx6G8T+eHPBxolDvIqhK8rqv3u1fY+Tz3M0s9qNCBO/GDlL2N6Sq9XKtyPFMIgrw0Cy+7Y7epPLJzcrz/+X28la/du8MC2bwTn+C5YSXsvDneJzz/SoC8H9ePvHMY0Lx0nw+9lSsfvS3Jujz/SgC94rEqvQwP67zd3rE83NOzPKvj/DyYmpo8h2SzvF8abjye0ZC8vSRivCKfijs/vJ48NAuvvFIoQzzFGFU9dtVMPa2g+TtpGd88Uv2DO3kZiTwA2rw79f2Nu1ugdDx0nw+8di7MvIrTrjz08g+8j6anvGH6LLxQ8oW8LBc8Pf0/Ajxl+OQ8SQkRPYrTrrzyNRM8GquXu9ItQjz1Sw87C9mtuxXYnrwDl7m87Y1ZO2ChrbyhQIy4EsIiPWpHHz0inwo7teJtPJ0fEroHPPK7fp4APV/B7rwwODa8L4Y3OiaSxLsBBfw7RI8XvP5H/zxVlz68n1VPvEBuHbwTzSA8fOEDvV49sDs2b6y8mf6XPMVm1jvjvCg8ETvjPEQ2GLxK5s47Q92YuxOfYLyod4K8EDDlPHAlFj1zGFC8pWGGPE65R7wBMzy8nJjSvLoO5rwwkbU7Eu3hvLOsMDyyobI6YHNtPKs8fLzXp7s6AV57PV49MLsVMR68+4KFPIkhMLxeaG87mXdYulyAMzzQRQY9ljadu3YDDby7GWS7phOFPEJ5mzq6tea6Eu1hPJjzmTz+R388di5MvJn+F7wi7Qs8K768PFnj9zu5MSi8Gl2WvJfomzxHd1O8vw8fvONjqbxuaBk980ARPSNRiTwLMi272Fk6vDGcs7z60Ia8vX1hOzvppbuKLK48jZspvZkpV7pWJns7G7YVPdPfwLyruL08FFHfu7ZprbwT+N84+1TFPGpHn7y9JOI8xe2Vu08SR7zs29o8/RFCPCbAhDzfQi89OpCmvL194boeJZE8kQqlvES6VjrzEtE7eGeKu2kZX71rfdw8D6wmu6Y+xLzJXJE8DnPovJrbVbvkFai8KX0Bvfr7RbuXbNq8Gw+VPRCJ5LyA1D28uQPoPLygo7xENpi8/RHCvEOv2DwRtyS9o0uKPNshNbvmeSU8IyPJvCedQjy7GWQ8Wkf1vGKJ6bztYho8vHLju5cT2zzKZw+88jWTvFb7uznYCzm8" + }, + { + "object": "embedding", + "index": 1, + "embedding": "eyfbu150UDkC6hQ9ip9oPG7jWDw3AOm8DQlcvFiY5Lt3Z6W8BLPPOV0uOz3FlQk8h5AYvH6Aobv0z/E8nOQRvHI8H7rQA+s8F6X9vPplyDzuZ1u8T2cTvAUeoDt0v0Q9/xx5vOhqlT1EgXu8zfQavTK0CDxRxX08v3MIPAY29bzIpFm8bGAzvQkkazxCciu8mjyxvIK0rDx6mzC7Eqg3O8H2rTz9vo482RNiPUYRB7xaQMU80h8hu8kPqrtyPB+8dvxUvfplSD21bJY8oQ8YPZbCEDvxegw9bTJzvYNlEj0h2q+9mw5xPQ5P8TyWwpA7rmvvO2Go27xw2tO6luNqO2pEfTztTwa7KnbRvAbw37vkEU89uKAhPGfvF7u6I8c8DPGGvB1gjzxU2K48+oqDPLCo/zsskoc8PUclvXCUvjzOpQC9qxaKO1iY5LyT9XS9ZNzmvI74Lr03azk93CYTvFJVCTzd+FK8lwgmvcMzPr00q4O9k46FvEx5HbyIqO083xSJvC7PFzy/lOK7HPW+PF2ikDxeAHu9QnIrvSz59rl/UmG8ZNzmu2b4nD3V31Y5aXK9O/2+jrxljUw8y9jkPGuvTTxX5/48u44XPXFFpDwAiEm8lcuVvX6h+zwe7Lm8SUUSPHmkNTu9Eb08cP8OvYgcw7xU2C49Wm4FPeV8H72AA8c7eH/6vBI0Yj3L2GQ8/0G0PHg5ZTvHjAS9fNhAPcE8wzws2By6RWAhvWTcZjz+1uM8H1eKvHdnJT0TWR29KcVrPdu7wrvMQzW9VhW/Ozo09LvFtuM8OlmvPO5GAT3eHY68zTqwvIhiWLs1w1i9sGJqPaurOb0s2Jy8Z++XOwAU9Lggb988vnyNvVfGpLypKBS8IouVO60NBb26r/G6w+0ovbVslrz+kE68MQOjOxdf6DvoRdo8Z4RHPCvhIT3e7009P4Q1PQ0JXDyD8Ty8/ZnTuhu4Lj3X1lG9sVnlvMxDNb3wySY9cUWkPNZKJ73qyP+8rS7fPNhBojwpxes8kt0fPM7rlbwYEE68zoBFvdrExzsMzEu9BflkvF0uu7zNFfW8UyfJPPSJ3LrEBf68+6JYvef/xDpAe7C8f5h2vPqKA7xUTAS9eDllPVK8eL0+GeW7654gPQuGNr3/+x69YajbPAehRTyc5BE8pfQIPMGwGL2QoA87iGJYPYXoN7s4sc69f1JhPdYEkjxgkIa6uxpCvHtMljtYvR88uCzMPBeEo7wm1/U8GBDOvBkHybwyG3i7aeaSvQzMyzy3e2a9xZUJvVSSmTu7SII8x4yEPKAYHTxUTIQ8lcsVO5x5QT3VDRe963llO4K0rLqI1i07DX0xvQv6CznrniA9nL9WPTvl2Tw6WS+8NcPYvEL+VbzZfrK9NDcuO4wBNL0jXVW980PHvNZKJz1Oti09StG8vIZTiDwu8PE8zP0fO9340juv1j890vFgvMFqAz2kHui7PNxUPQehxTzjGlQ9vcunPL+U4jyfrUw8R+NGPHQF2jtSdmO8mYtLvF50ULyT1Bo9ONaJPC1kx7woznC83xQJvUdv8byEXA29keaku6Qe6Ly+fA29kKAPOxLuzLxjxJG9JnCGur58jTws2Jy8CkmmO3pVm7uwqH87Eu7Mu/SJXL0IUis9MFI9vGnmEr1Oti09Z+8XvH1DkbwcaZS8NDcuvT0BkLyPNT89Haakuza607wv5+w81KLGO80VdT3MiUq8J4hbPHHRzrwr4aG8PSJqvJOOBT3t2zC8eBgLvXchkLymOp66y9jkPDdG/jw2ulO983GHPDvl2Tt+Ooy9NwDpOzZ0Pr3xegw7bhGZvEpd57s5YjS9Gk1evIbfMjxBwcW8NnQ+PMlVPzxR6ji9M8zdPImHk7wQsby8u0gCPXtMFr22YxE9Wm4FPaXPzbygGJ093bK9OuYtBTxyXfk8iYeTvNH65byk/Q29QO+FvKbGyLxCcqs9nL/WvPtcQ72XTjs8kt2fuhaNKDxqRH08KX9WPbmXnDtXDDo96GoVPVw3QL0eeGS8ayOjvAIL7zywQZC9at0NvUMjET1Q8707eTDgvIio7Tv60Jg87kYBOw50LLx7BgE96qclPUXsSz0nQkY5aDUtvQF/RD1bZQC73fjSPHgYCzyPNT+9q315vbMvhjsvodc8tEdbPGcQ8jz8U768cYs5PIwBtL38x5M9PtPPvIex8jzfFIk9vsIivLsaQj2/uZ072y8YvSV5C7uoA9k8JA67PO5nWzvS8eC8av7nuxSWrbybpwE9f5h2vG3sXTmoA1k9sjiLvTBSPbxc8Sq9UpuePB+dHz2/cwg9BWS1vCrqJr2M3Pg86LAqPS/GEj3oRdq8GiyEvACISbuiJ+28FFAYuzBSvTzwDzy8K5uMvE5wmDpd6CW6dkJqPGlyvTwF2Iq9f1JhPSHarzwDdr88JXkLu4ADxzx5pDW7zqUAvdAoJj24wXs8doj/PH46jD2/2vc893fSuyxtTL0YnPg7IWbaPOiwqrxLDk27ZxDyPBpymbwW0z08M/odPTufRL1AVvU849Q+vBGDfD3JDyq6Z6kCPL9OzTz0rpe8FtM9vaDqXLx+W2Y7jHWJPGXT4TwJ3lW9M4bIPPCDkTwoZwE9XH1VOmksqLxLPI08cNrTvCyz4bz+Srm8kiO1vDP6nbvIpNk8MrSIvPe95zoTWR29SYsnPYC9MT2F6De93qm4PCbX9bqqhv47yky6PENE67x/DEw8JdYAvUdvcbywh6W8//ueO8fSmTyjTCi9yky6O/qr3TzvGEE8wqcTPeDmSDyuJVo8ip/ou1HqOLxOtq28y5LPuxk1Cb0Ddr+7c+2EvKQeaL1SVQk8XS47PGTcZjwdpiQ8uFqMO0QaDD1XxqS8mLmLuuSFJDz1xmy8PvgKvJAHf7yC+kE8VapuvetYC7tHCAI8oidtPOiwqjyoSW68xCo5vfzobTzz2HY88/0xPNkT4rty9om8RexLu9SiRrsVaG081gSSO5IjtTsOLpc72sTHPGCQBj0QJRI9BCclPI1sBDzCyO07QHuwvOYthTz4tGK5QHuwvWfvFz2CQNc8PviKPO8YwTuQoA89fjoMPBnBs7zGZ8m8uiPHvMdeRLx+gKE8keaku0wziDzZWfe8I4KQPJ0qpzs4sc47dyEQPEQaDDzVmcE8//uePJcIJjztTwa9ogaTOftcwztU2K48opvCuyz5drzqM1C7iYcTvfDJJjxXxiQ9o0wovO1PBrwqvGa7dSoVPbI4izvnuS88zzGrPH3POzzHXkQ9PSJqOXCUPryW4+o8ELE8PNZKp7z+Sjm8foChPPIGtzyTaUq8JA47vBiceDw3a7m6jWyEOmksKDwH59q5GMo4veALBL0SqDe7IaxvvBD3Ubxn7xc9+dkdPSBOBTxHCAI8mYvLOydCxjw5HB88zTqwvJXs77w9AZA9CxvmvIeQGL2rffm8JXkLPKqGfjyoSe464d1DPPd3UrpO/EK8qxYKvUuCojwhZlq8EPfRPKaAs7xKF9K85i0FvEYRhzyPNT88m6cBvdSiRjxnqQI9uOY2vcBFSLx4OeW7BxUbPCz59rt+W2Y7SWZsPGzUCLzE5KM7sIclvIdr3buoSW47AK0EPImHE7wgToU8IdovO7FZ5bxbzO+8uMF7PGayB7z6ioO8zzErPEcIgrxSm568FJYtvNf7jDyrffm8KaQRPcoGpTwleQu8EWKiPHPthLz44qI8pEOjvWh7QjzpPNU8lcuVPHCUPr3n/8Q8bNQIu0WmNr1Erzs95VfkPCeIW7vT0Aa7656gudH65bxw/w49ZrKHPHsn27sIUiu8mEU2vdUNF7wBf8Q809CGPFtlgDo1fcO85i2FPEcIAjwL+os653OavOu1AL2EN9K8H52fPKzoybuMdYk8T2cTO8lVPzyK5X07iNYtvD74ijzT0IY8RIF7vLLENbyZi8s8KwJ8vAne1TvGZ8k71gSSumJZwTybp4G8656gPG8IFL27SAI9arjSvKVbeDxljcy83fjSuxu4Lr2DZRK9G0TZvLFZ5bxR6ji8NPEYPbI4izyAvTE9riVaPCCUGrw0Ny48f1LhuzIb+DolBTY8UH9ou/4EpLyAvTG9CFIrvCBOBTlkIvy8WJhkvHIXZLkf47Q8GQfJvBpNXr1pcr07c8jJO2nmkrxOcJi8sy8GuzjWibu2Pta8WQO1PFPhs7z7XEO8pEMjvb9OzTz4bs08EWKiu0YyYbzeHQ695D+PPKVbeDzvGEG9B6HFO0uCojws+Xa7JQW2OpRgRbxjCqc8Sw7NPDTxmLwjXVW8sRNQvFPhszzM/Z88rVMavZPUGj06WS+8JpHgO3etursdx369uZccvKplJDws+Xa8fzGHPB1gj7yqZaQ887ecPBNZHbzoi2+7NwDpPMxDtbzfWh49H+O0PO+kaztI2kE8/xz5PImHE73fNWO8T60ovIPxPDvR2Yu8XH3VvMcYr7wfnR+9fUORPIdr3Tyn6wO9nkL8vM2uhTzGIbS66u26vE2/MrxFYKE8iwo5vLSNcLy+wiK9GTUJPK10dLzrniC8qkBpvPxTPrwzQLO8illTvFi9H7yMATS7ayOjO14Ae7z19Cy87dswPKbGyDzujJa93EdtPdsB2LYT5Ue9RhEHPKurubxm+By9+mVIvIy7HrxZj987yOpuvUdv8TvgCwS8TDMIO9xsqLsL+gs8BWS1PFRMBD1yXXm86GoVvK+QqjxRXg46TZHyu2ayhzx7TJa8uKAhPLyFkjsV3MI7niGiPGNQvDxgkIa887ccPUmLJ7yZsIa8KDnBvHgYi7yMR0m82ukCvRuK7junUvO8aeYSPXtt8LqXCKa84kgUPd5jIzxlRze93xQJPNNcMT2v1j889GiCPKRkfbxz7YQ8b06pO8cYL7xg9/U8yQ+qPGlyvbzfNWO8vZ3nPBGD/DtB5gC7yKRZPPTPcbz6q928bleuPI74rrzVDRe9CQORvMmb1Dzv0qs8DBLhu4dr3bta1fQ8aeYSvRD3UTugpMe8CxvmPP9BNDzHjAQ742DpOzXD2Dz4bk28c1T0Onxka7zEBf48uiNHvGayBz1pcj29NcPYvDnu3jz5kwg9WkBFvL58jTx/mHY8wTzDPDZ0Pru/uZ08PQGQPOFRmby4oKE8JktLPIx1iTsppBG9dyGQvHfzT7wzhki44KAzPSOCkDzv0iu8lGBFO2VHNzyKxKM72EEiPYtQzryT9fQ8UDnTPEx5nTzuZ9s8QO8FvG8IlDx7J9s6MUk4O9k4nbx7TBa7G7iuvCzYHDocr6k8/7UJPY2ymTwVIlg8KjC8OvSuFz2iJ+28cCBpvE0qAzw41ok7sgrLvPjiojyG37K6lwimvKcxGTwRHI28y5LPO/mTiDx82MC5VJIZPWkH7TwPusG8YhOsvH1DkbzUx4E8TQXIvO+ka7zKwI+8w+2oPNLxYLzxegy9zEM1PDo0dDxIINc8FdxCO46E2TwPRmw9+ooDvMmb1LwBf0S8CQMRvEXsS7zPvdU80qvLPLfvO7wbuK68iBzDO0cpXL2WndU7dXCqvOTLubytLl88LokCvZj/IDw0q4M8G7guvNkTYrq5UQe7vcunvIrEI7xuERm9RexLvAdbsDwLQCE7uVEHPYjWrbuM3Pi8g2WSO3R5L7x4XiC8vKZsu9Sixros+fa8UH/ouxxpFL3wyaa72sRHu2YZ9zuiJ2274o4pOjkcnzyagka7za4FvYrEozwCMCo7cJQ+vfqKAzzJ4em8fNhAPUB7sLylz80833v4vOU2ir1ty4M8UV4OPXQF2jyu30S9EjRivBVo7TwXX2g70ANrvEJyq7wQJRK99jE9O7c10brUxwE9SUUSPS4VLbzBsJg7FHHyPMz9n7latJo8bleuvBpN3jsF+WS8Ye7wO4nNKL0TWZ08iRM+vOn2v7sB8xm9jY3ePJ/zYbkLG+a7ZvicvGxgM73L2OS761iLPKcxmTrX+ww8J0JGu1MnyTtJZuw7pIm4PJbCED29V1K9PFCqPLBBkLxhYka8hXTiPEB7MDzrniA7h5CYvIR9ZzzARcg7TZHyu4sKOb1in9Y7nL9WO6gD2TxSduO8UaQjPQO81Lxw/w69KwL8O4FJ3D2XTju8SE6XPGDWGz0K1VC8YhMsvObCtDyndy49BCclu68cVbxemYu8sGLqOksOzTzj1L47ISBFvLly4Ttk3Oa8RhGHNwzxBj0v5+y7ogaTPA+6QbxiE6w8ubj2PDixzrstZEe9jbKZPPd30rwqMDw8TQXIPFurlTxx0c68jLsePfSJ3LuXTru8yeHpu6Ewcjx5D4a8BvBfvN8Uibs9R6W8lsIQvaEw8rvVUyw8SJQsPebCNDwu8PE8GMo4OxAlkjwJmMA8KaQRvdYlbDwNNxy9ouHXPDffDrxwZv46AK0EPJqCRrpWz6k8/0E0POAs3rxmsoe7zTqwO5mLyzyP7ym7wTzDvFB/aLx5D4a7doj/O67fxDtsO/g7uq9xvMWViTtC/tU7PhnlvIEogjxxRSQ9SJSsPIJA1zyBKAI9ockCPYC9MbxBTXC83xSJvPFVUb1n75c8uiNHOxdf6Drt27A8/FM+vJOvXz3a6QI8UaQjuvqKgzyOhNm831oevF+xYLxjCic8sn6gPDdrOTs3Rv66cP+Ou5785rycBew8J0JGPJOOBbw9Imq8q335O3MOX7xemQs8PtNPPE1L3Tx5dnU4A+EPPLrdsTzfFIm7LJIHPB4yz7zbAdi8FWjtu1h3Cj0oznA8kv55PKgDWbxIINc8xdsePa8cVbzmlHQ8IJSavAgMlrx4XiA8z3dAu2PEET3xm+a75//EvK2Zr7xbqxU8zP2fvOSFJD1xRSS7k44FvPzHkzz5+ne8+tAYvd5jIz1GMuE8yxSAO3KCNDyRuOS8wzO+vObCNDwzQLO7isQjva1TGrz6ioM79GgCPF66Zbx1KpW8qW6pu4RcDTzcJhO9SJQsO5G45LsAiMm8lRErvJqCxjzQbju7w3nTuTclpDywqP88ysCPvAF/xLxfa0u88cChPBjKODyaPLE8k69fvGFiRrvuRgG9ATmvvJEsOr21+EC9KX/WOrmXnDwDAuo8yky6PI1sBDvztxy8PviKPKInbbzbdS276mGQO2Kf1rwn/DC8ZrIHPBRxcj0z+h264d1DPdG0ULxvTqm5bDt4vToTmjuGJcg7tmMRO9YEEr3oJAC9THmdPKn607vcJhM8Zj6yvHR5r7ywYmq83fjSO5mLyzshIEU8EWKiuu9eVjw75dk7fzGHvNl+sjwJJOs8YllBPAtheztz7QQ92lDyvDEDozzEKrk7KnZRvG8pbjsdYI+7yky6OfWAVzzjYGk7NX3DOzrNhDyeIaI8joTZvFcMOryYRba8G7iuu893QDw9RyW7za6FvDUJ7rva6YK9D7rBPD1o/zxCLJa65TaKvHsGAT2g6ly8+tCYu+wqy7xeAHu8vZ1nPBv+QzwfVwo8CMYAvM+91TzKTDq8Ueo4u2uvzTsBf8Q8p+uDvKofDz12tj+8wP+yOlkDtTwYyji6ZdPhPGv14rwqdtE8YPf1vLIKy7yFLs28ouFXvO1PBj15pDU83xQJPdfWUTz8x5O64kgUPBQKA72eIaK6A3a/OyzYnLoYnPg4XMNqPdxsqLsKSaY7pfSIvBoshLupKJS8G0TZOu/SqzzFcE47cvaJPA19Mb14dQC8sVllvJmwhjycBey8cvaJOmSWUbvRtFC8WtX0O2r+57twIGm8yeFpvFuG2rzCyO08PUelPK5rbzouFS29uCxMPQAUdDqtma88wqeTu5gge7zH8/O7l067PJdOO7uKxCO8/xx5vKt9+TztTwa8OhOaO+Q/Dzw33w49CZhAvSubjDydttG8IdovPIADR7stHrI7ATmvvOAs3rzL2OQ69K4XvNccZ7zlV2S8c+0EPfNDxzydKqc6LLPhO8YhtDyJhxM9H1eKOaNMKLtOcBg9HPU+PTsrbzvT0Ia8BG26PB2mpDp7TJa8wP8yPVvM77t0ea86eTBgvFurFT1C/tW7CkkmvKOSPT2aPDG9lGDFPAhSq7u5UYc8l5TQPFh3ijz9vg68lGBFO4/vKTxViZS7eQ8GPTNAs7xmsoe8o0yoPJfaZbwlvyA8IazvO0XsS717TJY8flvmOgHFWbyWnVW8mdFgvJbCkDynDF68" + } + ], + "model": "text-embedding-3-small", + "usage": { + "prompt_tokens": 9, + "total_tokens": 9 + } + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = new OpenAIClient(new ApiKeyCredential("apikey"), new OpenAIClientOptions + { + Transport = new HttpClientPipelineTransport(httpClient), + }).AsEmbeddingGenerator("text-embedding-3-small"); + + var response = await generator.GenerateAsync([ + "hello, world!", + "red, white, blue", + ]); + Assert.NotNull(response); + Assert.Equal(2, response.Count); + + Assert.NotNull(response.Usage); + Assert.Equal(9, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.TotalTokenCount); + + foreach (Embedding e in response) + { + Assert.Equal("text-embedding-3-small", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(1536, e.Vector.Length); + Assert.Contains(e.Vector.ToArray(), f => !f.Equals(0)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs new file mode 100644 index 00000000000..ba1c85d700a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientBuilderTest +{ + [Fact] + public void PassesServiceProviderToFactories() + { + var expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + using TestChatClient expectedResult = new(); + var builder = new ChatClientBuilder(expectedServiceProvider); + + builder.Use((serviceProvider, innerClient) => + { + Assert.Same(expectedServiceProvider, serviceProvider); + return expectedResult; + }); + + using TestChatClient innerClient = new(); + Assert.Equal(expectedResult, builder.Use(innerClient: innerClient)); + } + + [Fact] + public void BuildsPipelineInOrderAdded() + { + // Arrange + using TestChatClient expectedInnerClient = new(); + var builder = new ChatClientBuilder(); + + builder.Use(next => new InnerClientCapturingChatClient("First", next)); + builder.Use(next => new InnerClientCapturingChatClient("Second", next)); + builder.Use(next => new InnerClientCapturingChatClient("Third", next)); + + // Act + var first = (InnerClientCapturingChatClient)builder.Use(expectedInnerClient); + + // Assert + Assert.Equal("First", first.Name); + var second = (InnerClientCapturingChatClient)first.InnerClient; + Assert.Equal("Second", second.Name); + var third = (InnerClientCapturingChatClient)second.InnerClient; + Assert.Equal("Third", third.Name); + Assert.Same(expectedInnerClient, third.InnerClient); + } + + [Fact] + public void DoesNotAcceptNullInnerService() + { + Assert.Throws(() => new ChatClientBuilder().Use((IChatClient)null!)); + } + + [Fact] + public void DoesNotAcceptNullFactories() + { + ChatClientBuilder builder = new(); + Assert.Throws(() => builder.Use((Func)null!)); + Assert.Throws(() => builder.Use((Func)null!)); + } + + [Fact] + public void DoesNotAllowFactoriesToReturnNull() + { + ChatClientBuilder builder = new(); + builder.Use(_ => null!); + var ex = Assert.Throws(() => builder.Use(new TestChatClient())); + Assert.Contains("entry at index 0", ex.Message); + } + + private sealed class InnerClientCapturingChatClient(string name, IChatClient innerClient) : DelegatingChatClient(innerClient) + { +#pragma warning disable S3604 // False positive: Member initializer values should not be redundant + public string Name { get; } = name; +#pragma warning restore S3604 + public new IChatClient InnerClient => base.InnerClient; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs new file mode 100644 index 00000000000..0e776b4fee5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -0,0 +1,256 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text.Json; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientStructuredOutputExtensionsTests +{ + [Fact] + public async Task SuccessUsage() + { + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]) + { + CompletionId = "test", + CreatedAt = DateTimeOffset.UtcNow, + ModelId = "someModel", + RawRepresentation = new object(), + Usage = new(), + }; + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Null(responseFormat.Schema); + Assert.Null(responseFormat.SchemaName); + Assert.Null(responseFormat.SchemaDescription); + + // The inner client receives a trailing "system" message with the schema instruction + Assert.Collection(messages, + message => Assert.Equal("Hello", message.Text), + message => + { + Assert.Equal(ChatRole.System, message.Role); + Assert.Contains("Respond with a JSON value", message.Text); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); + foreach (Species v in Enum.GetValues(typeof(Species))) + { + Assert.Contains(v.ToString(), message.Text); // All enum values are described as strings + } + }); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + Assert.Equal(expectedCompletion.CompletionId, response.CompletionId); + Assert.Equal(expectedCompletion.CreatedAt, response.CreatedAt); + Assert.Equal(expectedCompletion.ModelId, response.ModelId); + Assert.Same(expectedCompletion.RawRepresentation, response.RawRepresentation); + Assert.Same(expectedCompletion.Usage, response.Usage); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // Doesn't mutate history (or at least, reverts any changes) + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + + [Fact] + public async Task FailureUsage_InvalidJson() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, "This is not valid JSON")]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Contains("invalid", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task FailureUsage_NullJson() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, "null")]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Equal("The deserialized response is null", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task FailureUsage_NoJsonInResponse() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, [new ImageContent("https://example.com")])]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Equal("The response did not contain text to be deserialized", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task CanUseNativeStructuredOutput() + { + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Equal(nameof(Animal), responseFormat.SchemaName); + Assert.Equal("Some test description", responseFormat.SchemaDescription); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", responseFormat.Schema); + foreach (Species v in Enum.GetValues(typeof(Species))) + { + Assert.Contains(v.ToString(), responseFormat.Schema); // All enum values are described as strings + } + + // The chat history isn't mutated any further, since native structured output is used instead of a prompt + Assert.Equal("Hello", Assert.Single(messages).Text); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory, useNativeJsonSchema: true); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // History remains unmutated + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + + [Fact] + public async Task CanSpecifyCustomJsonSerializationOptions() + { + var jso = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, jso))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Collection(messages, + message => Assert.Equal("Hello", message.Text), + message => + { + Assert.Equal(ChatRole.System, message.Role); + Assert.Contains("Respond with a JSON value", message.Text); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); + Assert.DoesNotContain(nameof(Animal.FullName), message.Text); // The JSO uses snake_case + Assert.Contains("full_name", message.Text); // The JSO uses snake_case + Assert.DoesNotContain(nameof(Species.Tiger), message.Text); // The JSO doesn't use enum-to-string conversion + }); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory, jso); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + } + + [Fact] + public async Task HandlesBackendReturningMultipleObjects() + { + // A very common failure mode for GPT 3.5 Turbo is that instead of returning a single top-level JSON object, + // it may return multiple, particularly when function calling is involved. + // See https://community.openai.com/t/2-json-objects-returned-when-using-function-calling-and-json-mode/574348 + // Fortunately we can work around this without breaking any cases of valid output. + + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var resultDuplicatedJson = JsonSerializer.Serialize(expectedResult) + Environment.NewLine + JsonSerializer.Serialize(expectedResult); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatCompletion([new ChatMessage(ChatRole.Assistant, resultDuplicatedJson)])); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + } + + [Description("Some test description")] + private class Animal + { + public int Id { get; set; } + public string? FullName { get; set; } + public Species Species { get; set; } + } + + private enum Species + { + Bear, + Tiger, + Walrus, + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs new file mode 100644 index 00000000000..a27761c99ec --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ConfigureOptionsChatClientTests +{ + [Fact] + public void ConfigureOptionsChatClient_InvalidArgs_Throws() + { + Assert.Throws("innerClient", () => new ConfigureOptionsChatClient(null!, _ => new ChatOptions())); + Assert.Throws("configureOptions", () => new ConfigureOptionsChatClient(new TestChatClient(), null!)); + } + + [Fact] + public void UseChatOptions_InvalidArgs_Throws() + { + var builder = new ChatClientBuilder(); + Assert.Throws("configureOptions", () => builder.UseChatOptions(null!)); + } + + [Fact] + public async Task ConfigureOptions_ReturnedInstancePassedToNextClient() + { + ChatOptions providedOptions = new(); + ChatOptions returnedOptions = new(); + ChatCompletion expectedCompletion = new(Array.Empty()); + var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray(); + using CancellationTokenSource cts = new(); + + using IChatClient innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Same(returnedOptions, options); + Assert.Equal(cts.Token, cancellationToken); + return Task.FromResult(expectedCompletion); + }, + + CompleteStreamingAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Same(returnedOptions, options); + Assert.Equal(cts.Token, cancellationToken); + return YieldUpdates(expectedUpdates); + }, + }; + + using var client = new ChatClientBuilder() + .UseChatOptions(options => + { + Assert.Same(providedOptions, options); + return returnedOptions; + }) + .Use(innerClient); + + var completion = await client.CompleteAsync(Array.Empty(), providedOptions, cts.Token); + Assert.Same(expectedCompletion, completion); + + int i = 0; + await using var e = client.CompleteStreamingAsync(Array.Empty(), providedOptions, cts.Token).GetAsyncEnumerator(); + while (i < expectedUpdates.Length) + { + Assert.True(await e.MoveNextAsync()); + Assert.Same(expectedUpdates[i++], e.Current); + } + + Assert.False(await e.MoveNextAsync()); + + static async IAsyncEnumerable YieldUpdates(StreamingChatCompletionUpdate[] updates) + { + foreach (var update in updates) + { + await Task.Yield(); + yield return update; + } + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs new file mode 100644 index 00000000000..650a8fdd162 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSerializable(typeof(DistributedCachingChatClientTest.CustomAIContent1))] +[JsonSerializable(typeof(DistributedCachingChatClientTest.CustomAIContent2))] +internal sealed partial class CustomAIContentJsonContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs new file mode 100644 index 00000000000..9bbfbea98c3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DependencyInjectionPatterns +{ + private IServiceCollection ServiceCollection { get; } = new ServiceCollection(); + + [Fact] + public void CanRegisterScopedUsingGenericType() + { + // Arrange/Act + ServiceCollection.AddChatClient(builder => builder + .UseScopedMiddleware() + .Use(new TestChatClient())); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance1Copy = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer *AND* inner client + var outer1 = Assert.IsType(instance1); + var outer2 = Assert.IsType(instance2); + var inner1 = Assert.IsType(((ScopedChatClient)instance1).InnerClient); + var inner2 = Assert.IsType(((ScopedChatClient)instance2).InnerClient); + + Assert.NotSame(outer1.Services, outer2.Services); + Assert.NotSame(instance1, instance2); + Assert.NotSame(inner1, inner2); + Assert.Same(instance1, instance1Copy); // From the same scope + } + + [Fact] + public void CanRegisterScopedUsingFactory() + { + // Arrange/Act + ServiceCollection.AddChatClient(builder => + { + builder.UseScopedMiddleware(); + return builder.Use(new TestChatClient { Services = builder.Services }); + }); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer *AND* inner client + var outer1 = Assert.IsType(instance1); + var outer2 = Assert.IsType(instance2); + var inner1 = Assert.IsType(((ScopedChatClient)instance1).InnerClient); + var inner2 = Assert.IsType(((ScopedChatClient)instance2).InnerClient); + + Assert.Same(outer1.Services, inner1.Services); + Assert.Same(outer2.Services, inner2.Services); + Assert.NotSame(outer1.Services, outer2.Services); + } + + [Fact] + public void CanRegisterScopedUsingSharedInstance() + { + // Arrange/Act + using var singleton = new TestChatClient(); + ServiceCollection.AddChatClient(builder => + { + builder.UseScopedMiddleware(); + return builder.Use(singleton); + }); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer instance, but the same inner client + Assert.IsType(instance1); + Assert.IsType(instance2); + Assert.Same(singleton, ((ScopedChatClient)instance1).InnerClient); + Assert.Same(singleton, ((ScopedChatClient)instance2).InnerClient); + } + + public class ScopedChatClient(IServiceProvider services, IChatClient inner) : DelegatingChatClient(inner) + { + public new IChatClient InnerClient => base.InnerClient; + public IServiceProvider Services => services; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs new file mode 100644 index 00000000000..35ced372eb2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -0,0 +1,703 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DistributedCachingChatClientTest +{ + private readonly TestInMemoryCacheStorage _storage = new(); + + [Fact] + public async Task CachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var expectedCompletion = new ChatCompletion([ + new(new ChatRole("fakeRole"), "This is some content") + { + AdditionalProperties = new() { ["a"] = "b" }, + Contents = [new FunctionCallContent("someCallId", "functionName", new Dictionary + { + ["arg1"] = "value1", + ["arg2"] = 123, + ["arg3"] = 123.4, + ["arg4"] = true, + ["arg5"] = false, + ["arg6"] = null + })] + } + ]) + { + CompletionId = "someId", + Usage = new() + { + InputTokenCount = 123, + OutputTokenCount = 456, + TotalTokenCount = 99999, + }, + CreatedAt = DateTimeOffset.UtcNow, + ModelId = "someModel", + AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 123 } + }; + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + return Task.FromResult(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + Assert.Same(expectedCompletion, result1); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + AssertCompletionsEqual(expectedCompletion, result2); + + // Act/Assert 2: Cache misses do not return cached results + await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some modified input")]); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task AllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async delegate + { + innerCallCount++; + await completionTcs.Task; + return new ChatCompletion([new(ChatRole.Assistant, "Hello")]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + var result2 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert 1 + Assert.Equal(2, innerCallCount); + Assert.False(result1.IsCompleted); + Assert.False(result2.IsCompleted); + completionTcs.SetResult(true); + Assert.Equal("Hello", (await result1).Message.Text); + Assert.Equal("Hello", (await result2).Message.Text); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + Assert.Equal(2, innerCallCount); + Assert.Equal("Hello", (await result3).Message.Text); + } + + [Fact] + public async Task DoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + throw new InvalidTimeZoneException("some failure"); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var input = new ChatMessage(ChatRole.User, "abc"); + var ex1 = await Assert.ThrowsAsync(() => outer.CompleteAsync([input])); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var ex2 = await Assert.ThrowsAsync(() => outer.CompleteAsync([input])); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task DoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var resolutionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async delegate + { + innerCallCount++; + if (innerCallCount == 1) + { + await resolutionTcs.Task; + } + + return new ChatCompletion([new(ChatRole.Assistant, "A good result")]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // First call gets cancelled + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteAsync([input]); + Assert.False(result1.IsCompleted); + Assert.Equal(1, innerCallCount); + resolutionTcs.SetCanceled(); + await Assert.ThrowsAsync(() => result1); + Assert.True(result1.IsCanceled); + + // Act/Assert: Second call can succeed + var result2 = await outer.CompleteAsync([input]); + Assert.Equal(2, innerCallCount); + Assert.Equal("A good result", result2.Message.Text); + } + + [Fact] + public async Task StreamingCachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + List expectedCompletion = + [ + new() + { + Role = new ChatRole("fakeRole1"), + ChoiceIndex = 3, + AdditionalProperties = new() { ["a"] = "b" }, + Contents = [new TextContent("Chunk1")] + }, + new() + { + Role = new ChatRole("fakeRole2"), + Text = "Chunk2", + Contents = + [ + new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" }), + new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 }), + ] + } + ]; + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Make the initial request and do a quick sanity check + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await AssertCompletionsEqualAsync(expectedCompletion, result1); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + await AssertCompletionsEqualAsync(expectedCompletion, result2); + + // Act/Assert 2: Cache misses do not return cached results + await ToListAsync(outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some modified input")])); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingCoalescesConsecutiveTextChunksAsync() + { + // Arrange + List expectedCompletion = + [ + new() { Role = ChatRole.Assistant, Text = "This" }, + new() { Role = ChatRole.Assistant, Text = " becomes one chunk" }, + new() { Role = ChatRole.Assistant, Contents = [new FunctionCallContent("callId1", "separator")] }, + new() { Role = ChatRole.Assistant, Text = "... and this" }, + new() { Role = ChatRole.Assistant, Text = " becomes another" }, + new() { Role = ChatRole.Assistant, Text = " one." }, + ]; + + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate { return ToAsyncEnumerableAsync(expectedCompletion); } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await ToListAsync(result1); + + // Act + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Collection(await ToListAsync(result2), + c => Assert.Equal("This becomes one chunk", c.Text), + c => Assert.IsType(Assert.Single(c.Contents)), + c => Assert.Equal("... and this becomes another one.", c.Text)); + } + + [Fact] + public async Task StreamingAllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + List expectedCompletion = + [ + new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + new() { Role = ChatRole.System, Text = "Chunk 2" }, + ]; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(completionTcs.Task, expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert 1 + Assert.NotSame(result1, result2); + var result1Assertion = AssertCompletionsEqualAsync(expectedCompletion, result1); + var result2Assertion = AssertCompletionsEqualAsync(expectedCompletion, result2); + Assert.False(result1Assertion.IsCompleted); + Assert.False(result2Assertion.IsCompleted); + completionTcs.SetResult(true); + await result1Assertion; + await result2Assertion; + Assert.Equal(2, innerCallCount); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await AssertCompletionsEqualAsync(expectedCompletion, result3); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingDoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(Task.CompletedTask, + [ + () => new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + () => throw new InvalidTimeZoneException("some failure"), + ]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteStreamingAsync([input]); + var ex1 = await Assert.ThrowsAsync(() => ToListAsync(result1)); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = outer.CompleteStreamingAsync([input]); + var ex2 = await Assert.ThrowsAsync(() => ToListAsync(result2)); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingDoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync( + innerCallCount == 1 ? completionTcs.Task : Task.CompletedTask, + [() => new() { Role = ChatRole.Assistant, Text = "A good result" }]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // First call gets cancelled + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteStreamingAsync([input]); + var result1Assertion = ToListAsync(result1); + Assert.False(result1Assertion.IsCompleted); + completionTcs.SetCanceled(); + await Assert.ThrowsAsync(() => result1Assertion); + Assert.True(result1Assertion.IsCanceled); + Assert.Equal(1, innerCallCount); + + // Act/Assert: Second call can succeed + var result2 = await ToListAsync(outer.CompleteStreamingAsync([input])); + Assert.Equal("A good result", result2[0].Text); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task CacheKeyDoesNotVaryByChatOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async (_, options, _) => + { + innerCallCount++; + await Task.Yield(); + return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act: Call with two different ChatOptions + var result1 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 1" } } + }); + var result2 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 2" } } + }); + + // Assert: Same result + Assert.Equal(1, innerCallCount); + Assert.Equal("value 1", result1.Message.Text); + Assert.Equal("value 1", result2.Message.Text); + } + + [Fact] + public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async (_, options, _) => + { + innerCallCount++; + await Task.Yield(); + return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + } + }; + using var outer = new CachingChatClientWithCustomKey(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act: Call with two different ChatOptions + var result1 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 1" } } + }); + var result2 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 2" } } + }); + + // Assert: Different results + Assert.Equal(2, innerCallCount); + Assert.Equal("value 1", result1.Message.Text); + Assert.Equal("value 2", result2.Message.Text); + } + + [Fact] + public async Task CanCacheCustomContentTypesAsync() + { + // Arrange + var expectedCompletion = new ChatCompletion([ + new(new ChatRole("fakeRole"), + [ + new CustomAIContent1("Hello", DateTime.Now), + new CustomAIContent2("Goodbye", 42), + ]) + ]); + + var serializerOptions = new JsonSerializerOptions(TestJsonSerializerContext.Default.Options); + serializerOptions.TypeInfoResolver = serializerOptions.TypeInfoResolver!.WithAddedModifier(typeInfo => + { + if (typeInfo.Type == typeof(AIContent)) + { + foreach (var t in new Type[] { typeof(CustomAIContent1), typeof(CustomAIContent2) }) + { + typeInfo.PolymorphismOptions!.DerivedTypes.Add(new JsonDerivedType(t, t.Name)); + } + } + }); + serializerOptions.TypeInfoResolverChain.Add(CustomAIContentJsonContext.Default); + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + return Task.FromResult(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = serializerOptions + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + AssertCompletionsEqual(expectedCompletion, result1); + + // Act + var result2 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + AssertCompletionsEqual(expectedCompletion, result2); + Assert.NotSame(result2.Message.Contents[0], expectedCompletion.Message.Contents[0]); + Assert.NotSame(result2.Message.Contents[1], expectedCompletion.Message.Contents[1]); + } + + [Fact] + public async Task CanResolveIDistributedCacheFromDI() + { + // Arrange + var services = new ServiceCollection() + .AddSingleton(_storage) + .BuildServiceProvider(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + return Task.FromResult(new ChatCompletion([ + new(ChatRole.Assistant, [new TextContent("Hey")])])); + } + }; + using var outer = new ChatClientBuilder(services) + .UseDistributedCache(configure: options => + { + options.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(testClient); + + // Act: Make a request that should populate the cache + Assert.Empty(_storage.Keys); + var result = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.NotNull(result); + Assert.Single(_storage.Keys); + } + + private static async Task> ToListAsync(IAsyncEnumerable values) + { + var result = new List(); + await foreach (var v in values) + { + result.Add(v); + } + + return result; + } + + private static IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) + => ToAsyncEnumerableAsync(Task.CompletedTask, values); + + private static IAsyncEnumerable ToAsyncEnumerableAsync(Task preTask, IEnumerable valueFactories) + => ToAsyncEnumerableAsync(preTask, valueFactories.Select>(v => () => v)); + + private static async IAsyncEnumerable ToAsyncEnumerableAsync(Task preTask, IEnumerable> values) + { + await preTask; + + foreach (var value in values) + { + await Task.Yield(); + yield return value(); + } + } + + private static void AssertCompletionsEqual(ChatCompletion expected, ChatCompletion actual) + { + Assert.Equal(expected.CompletionId, actual.CompletionId); + Assert.Equal(expected.Usage?.InputTokenCount, actual.Usage?.InputTokenCount); + Assert.Equal(expected.Usage?.OutputTokenCount, actual.Usage?.OutputTokenCount); + Assert.Equal(expected.Usage?.TotalTokenCount, actual.Usage?.TotalTokenCount); + Assert.Equal(expected.CreatedAt, actual.CreatedAt); + Assert.Equal(expected.ModelId, actual.ModelId); + Assert.Equal( + JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); + Assert.Equal(expected.Choices.Count, actual.Choices.Count); + + for (var i = 0; i < expected.Choices.Count; i++) + { + Assert.IsType(expected.Choices[i].GetType(), actual.Choices[i]); + Assert.Equal(expected.Choices[i].Role, actual.Choices[i].Role); + Assert.Equal(expected.Choices[i].Text, actual.Choices[i].Text); + Assert.Equal(expected.Choices[i].Contents.Count, actual.Choices[i].Contents.Count); + + for (var itemIndex = 0; itemIndex < expected.Choices[i].Contents.Count; itemIndex++) + { + var expectedItem = expected.Choices[i].Contents[itemIndex]; + var actualItem = actual.Choices[i].Contents[itemIndex]; + Assert.Equal(expectedItem.ModelId, actualItem.ModelId); + Assert.IsType(expectedItem.GetType(), actualItem); + + if (expectedItem is FunctionCallContent expectedFcc) + { + var actualFcc = (FunctionCallContent)actualItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); + } + } + } + } + + private static async Task AssertCompletionsEqualAsync(IReadOnlyList expected, IAsyncEnumerable actual) + { + var actualEnumerator = actual.GetAsyncEnumerator(); + + foreach (var expectedItem in expected) + { + Assert.True(await actualEnumerator.MoveNextAsync()); + + var actualItem = actualEnumerator.Current; + Assert.Equal(expectedItem.Text, actualItem.Text); + Assert.Equal(expectedItem.ChoiceIndex, actualItem.ChoiceIndex); + Assert.Equal(expectedItem.Role, actualItem.Role); + Assert.Equal(expectedItem.Contents.Count, actualItem.Contents.Count); + + for (var itemIndex = 0; itemIndex < expectedItem.Contents.Count; itemIndex++) + { + var expectedItemItem = expectedItem.Contents[itemIndex]; + var actualItemItem = actualItem.Contents[itemIndex]; + Assert.IsType(expectedItemItem.GetType(), actualItemItem); + + if (expectedItemItem is FunctionCallContent expectedFcc) + { + var actualFcc = (FunctionCallContent)actualItemItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); + } + else if (expectedItemItem is UsageContent expectedUsage) + { + var actualUsage = (UsageContent)actualItemItem; + Assert.Equal(expectedUsage.Details.InputTokenCount, actualUsage.Details.InputTokenCount); + Assert.Equal(expectedUsage.Details.OutputTokenCount, actualUsage.Details.OutputTokenCount); + Assert.Equal(expectedUsage.Details.TotalTokenCount, actualUsage.Details.TotalTokenCount); + } + } + } + + Assert.False(await actualEnumerator.MoveNextAsync()); + } + + private sealed class CachingChatClientWithCustomKey(IChatClient innerClient, IDistributedCache storage) + : DistributedCachingChatClient(innerClient, storage) + { + protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) + { + var baseKey = base.GetCacheKey(streaming, chatMessages, options); + return baseKey + options?.AdditionalProperties?["someKey"]?.ToString(); + } + } + + public class CustomAIContent1(string text, DateTime date) : AIContent + { + public string Text => text; + public DateTime Date => date; + } + + public class CustomAIContent2(string text, int number) : AIContent + { + public string Text => text; + public int Number => number; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs new file mode 100644 index 00000000000..8ad0c6d7944 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -0,0 +1,352 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionInvokingChatClientTests +{ + [Fact] + public async Task SupportsSingleFunctionCallPerRequestAsync() + { + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentInvocation) + { + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create((int i) => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + ] + }; + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func1"), + new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 34 } }), + new FunctionCallContent("callId3", "Func2", arguments: new Dictionary { { "i", 56 } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func1", result: "Result 1"), + new FunctionResultContent("callId2", "Func2", result: "Result 2: 34"), + new FunctionResultContent("callId3", "Func2", result: "Result 2: 56"), + ]), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId4", "Func2", arguments: new Dictionary { { "i", 78 } }), + new FunctionCallContent("callId5", "Func1")]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId4", "Func2", result: "Result 2: 78"), + new FunctionResultContent("callId5", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation })); + } + + [Fact] + public async Task ParallelFunctionCallsInvokedConcurrentlyByDefaultAsync() + { + using var barrier = new Barrier(2); + + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create((string arg) => + { + barrier.SignalAndWait(); + return arg + arg; + }, "Func"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), + new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func", result: "hellohello"), + new FunctionResultContent("callId2", "Func", result: "worldworld"), + ]), + new ChatMessage(ChatRole.Assistant, "done"), + ]); + } + + [Fact] + public async Task ConcurrentInvocationOfParallelCallsCanBeDisabledAsync() + { + int activeCount = 0; + + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create(async (string arg) => + { + Interlocked.Increment(ref activeCount); + await Task.Delay(100); + Assert.Equal(1, activeCount); + Interlocked.Decrement(ref activeCount); + return arg + arg; + }, "Func"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), + new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func", result: "hellohello"), + new FunctionResultContent("callId2", "Func", result: "worldworld"), + ]), + new ChatMessage(ChatRole.Assistant, "done"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = false })); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunctionCallingMessages) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + +#pragma warning disable SA1118 // Parameter should not span multiple lines + var finalChat = await InvokeAndAssertAsync( + options, + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], + expected: keepFunctionCallingMessages ? + null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "world") + ], + configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); +#pragma warning restore SA1118 + + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + +#pragma warning disable SA1118 // Parameter should not span multiple lines + var finalChat = await InvokeAndAssertAsync(options, + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], + expected: keepFunctionCallingMessages ? + null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Assistant, "more"), + new ChatMessage(ChatRole.Assistant, "world"), + ], + configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); +#pragma warning restore SA1118 + + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedErrors) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(string () => throw new InvalidOperationException("Oh no!"), "Func1"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors })); + } + + [Fact] + public async Task RejectsMultipleChoicesAsync() + { + var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); + var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); + + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = async (chatContents, options, cancellationToken) => + { + await Task.Yield(); + + return new ChatCompletion( + [ + new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), + new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), + ]); + } + }; + + IChatClient service = new ChatClientBuilder().UseFunctionInvocation().Use(innerClient); + + List chat = [new ChatMessage(ChatRole.User, "hello")]; + var ex = await Assert.ThrowsAsync( + () => service.CompleteAsync(chat, new ChatOptions { Tools = [func1, func2] })); + + Assert.Contains("only accepts a single choice", ex.Message); + Assert.Single(chat); // It didn't add anything to the chat history + } + + private static async Task> InvokeAndAssertAsync( + ChatOptions options, + List plan, + List? expected = null, + Func? configurePipeline = null) + { + Assert.NotEmpty(plan); + + configurePipeline ??= static b => b.UseFunctionInvocation(); + + using CancellationTokenSource cts = new(); + List chat = [plan[0]]; + int i = 0; + + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = async (contents, actualOptions, actualCancellationToken) => + { + Assert.Same(chat, contents); + Assert.Equal(cts.Token, actualCancellationToken); + + await Task.Yield(); + + return new ChatCompletion([plan[contents.Count]]); + } + }; + + IChatClient service = configurePipeline(new ChatClientBuilder()).Use(innerClient); + + var result = await service.CompleteAsync(chat, options, cts.Token); + chat.Add(result.Message); + + expected ??= plan; + Assert.NotNull(result); + Assert.Equal(expected.Count, chat.Count); + for (; i < expected.Count; i++) + { + var expectedMessage = expected[i]; + var chatMessage = chat[i]; + + Assert.Equal(expectedMessage.Role, chatMessage.Role); + Assert.Equal(expectedMessage.Text, chatMessage.Text); + Assert.Equal(expectedMessage.GetType(), chatMessage.GetType()); + + Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count); + for (int j = 0; j < expectedMessage.Contents.Count; j++) + { + var expectedItem = expectedMessage.Contents[j]; + var chatItem = chatMessage.Contents[j]; + + Assert.Equal(expectedItem.GetType(), chatItem.GetType()); + Assert.Equal(expectedItem.ToString(), chatItem.ToString()); + if (expectedItem is FunctionCallContent expectedFunctionCall) + { + var chatFunctionCall = (FunctionCallContent)chatItem; + Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name); + AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments); + } + else if (expectedItem is FunctionResultContent expectedFunctionResult) + { + var chatFunctionResult = (FunctionResultContent)chatItem; + AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result); + } + } + } + + return chat; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs new file mode 100644 index 00000000000..feb91ac925e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class LoggingChatClientTests +{ + [Fact] + public void LoggingChatClient_InvalidArgs_Throws() + { + Assert.Throws("innerClient", () => new LoggingChatClient(null!, NullLogger.Instance)); + Assert.Throws("logger", () => new LoggingChatClient(new TestChatClient(), null!)); + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level) + { + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using IChatClient innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatCompletion([new(ChatRole.Assistant, "blue whale")])); + }, + }; + + using IChatClient client = new ChatClientBuilder(services) + .UseLogging() + .Use(innerClient); + + await client.CompleteAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions { FrequencyPenalty = 3.0f }); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteAsync invoked:") && entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteAsync completed:") && entry.Message.Contains("blue whale"))); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteAsync invoked.") && !entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteAsync completed.") && !entry.Message.Contains("blue whale"))); + } + else + { + Assert.Empty(clp.Logger.Entries); + } + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteStreamAsync_LogsStartUpdateCompletion(LogLevel level) + { + CapturingLogger logger = new(level); + + using IChatClient innerClient = new TestChatClient + { + CompleteStreamingAsyncCallback = (messages, options, cancellationToken) => GetUpdatesAsync() + }; + + static async IAsyncEnumerable GetUpdatesAsync() + { + await Task.Yield(); + yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "blue " }; + yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "whale" }; + } + + using IChatClient client = new ChatClientBuilder() + .UseLogging(logger) + .Use(innerClient); + + await foreach (var update in client.CompleteStreamingAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions { FrequencyPenalty = 3.0f })) + { + // nop + } + + if (level is LogLevel.Trace) + { + Assert.Collection(logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync invoked:") && entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update:") && entry.Message.Contains("blue")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update:") && entry.Message.Contains("whale")), + entry => Assert.Contains("CompleteStreamingAsync completed.", entry.Message)); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync invoked.") && !entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update.") && !entry.Message.Contains("blue")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update.") && !entry.Message.Contains("whale")), + entry => Assert.Contains("CompleteStreamingAsync completed.", entry.Message)); + } + else + { + Assert.Empty(logger.Entries); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs new file mode 100644 index 00000000000..d0056b21b91 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs @@ -0,0 +1,220 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using OpenTelemetry.Trace; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class OpenTelemetryChatClientTests +{ + [Fact] + public async Task ExpectedInformationLogged_NonStreaming_Async() + { + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + using var innerClient = new TestChatClient + { + Metadata = new("testservice", new Uri("http://localhost:12345/something"), "amazingmodel"), + CompleteAsyncCallback = async (messages, options, cancellationToken) => + { + await Task.Yield(); + return new ChatCompletion([new ChatMessage(ChatRole.Assistant, "blue whale")]) + { + CompletionId = "id123", + FinishReason = ChatFinishReason.Stop, + Usage = new UsageDetails + { + InputTokenCount = 10, + OutputTokenCount = 20, + TotalTokenCount = 42, + }, + }; + } + }; + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => + { + instance.EnableSensitiveData = true; + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(innerClient); + + await chatClient.CompleteAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions + { + FrequencyPenalty = 3.0f, + MaxOutputTokens = 123, + ModelId = "replacementmodel", + TopP = 4.0f, + PresencePenalty = 5.0f, + ResponseFormat = ChatResponseFormat.Json, + Temperature = 6.0f, + StopSequences = ["hello", "world"], + AdditionalProperties = new() { ["top_k"] = 7.0f }, + }); + + var activity = Assert.Single(activities); + + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + + Assert.Equal("http://localhost:12345/something", activity.GetTagItem("server.address")); + Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); + + Assert.Equal("chat.completions replacementmodel", activity.DisplayName); + Assert.Equal("testservice", activity.GetTagItem("gen_ai.system")); + + Assert.Equal("replacementmodel", activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(3.0f, activity.GetTagItem("gen_ai.request.frequency_penalty")); + Assert.Equal(4.0f, activity.GetTagItem("gen_ai.request.top_p")); + Assert.Equal(5.0f, activity.GetTagItem("gen_ai.request.presence_penalty")); + Assert.Equal(6.0f, activity.GetTagItem("gen_ai.request.temperature")); + Assert.Equal(7.0, activity.GetTagItem("gen_ai.request.top_k")); + Assert.Equal(123, activity.GetTagItem("gen_ai.request.max_tokens")); + Assert.Equal("""["hello", "world"]""", activity.GetTagItem("gen_ai.request.stop_sequences")); + + Assert.Equal("id123", activity.GetTagItem("gen_ai.response.id")); + Assert.Equal("""["stop"]""", activity.GetTagItem("gen_ai.response.finish_reasons")); + Assert.Equal(10, activity.GetTagItem("gen_ai.response.input_tokens")); + Assert.Equal(20, activity.GetTagItem("gen_ai.response.output_tokens")); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + + [Fact] + public async Task ExpectedInformationLogged_Streaming_Async() + { + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + async static IAsyncEnumerable CallbackAsync( + IList messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = "blue ", + CompletionId = "id123", + }; + await Task.Yield(); + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = "whale", + FinishReason = ChatFinishReason.Stop, + }; + yield return new StreamingChatCompletionUpdate + { + Contents = [new UsageContent(new() + { + InputTokenCount = 10, + OutputTokenCount = 20, + TotalTokenCount = 42, + })], + }; + } + + using var innerClient = new TestChatClient + { + Metadata = new("testservice", new Uri("http://localhost:12345/something"), "amazingmodel"), + CompleteStreamingAsyncCallback = CallbackAsync, + }; + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => + { + instance.EnableSensitiveData = true; + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(innerClient); + + await foreach (var update in chatClient.CompleteStreamingAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions + { + FrequencyPenalty = 3.0f, + MaxOutputTokens = 123, + ModelId = "replacementmodel", + TopP = 4.0f, + PresencePenalty = 5.0f, + ResponseFormat = ChatResponseFormat.Json, + Temperature = 6.0f, + StopSequences = ["hello", "world"], + AdditionalProperties = new() { ["top_k"] = 7.0 }, + })) + { + // Drain the stream. + } + + var activity = Assert.Single(activities); + + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + + Assert.Equal("http://localhost:12345/something", activity.GetTagItem("server.address")); + Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); + + Assert.Equal("chat.completions replacementmodel", activity.DisplayName); + Assert.Equal("testservice", activity.GetTagItem("gen_ai.system")); + + Assert.Equal("replacementmodel", activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(3.0f, activity.GetTagItem("gen_ai.request.frequency_penalty")); + Assert.Equal(4.0f, activity.GetTagItem("gen_ai.request.top_p")); + Assert.Equal(5.0f, activity.GetTagItem("gen_ai.request.presence_penalty")); + Assert.Equal(6.0f, activity.GetTagItem("gen_ai.request.temperature")); + Assert.Equal(7.0, activity.GetTagItem("gen_ai.request.top_k")); + Assert.Equal(123, activity.GetTagItem("gen_ai.request.max_tokens")); + Assert.Equal("""["hello", "world"]""", activity.GetTagItem("gen_ai.request.stop_sequences")); + + Assert.Equal("id123", activity.GetTagItem("gen_ai.response.id")); + Assert.Equal("""["stop"]""", activity.GetTagItem("gen_ai.response.finish_reasons")); + Assert.Equal(10, activity.GetTagItem("gen_ai.response.input_tokens")); + Assert.Equal(20, activity.GetTagItem("gen_ai.response.output_tokens")); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs new file mode 100644 index 00000000000..d9ad92dc266 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public static class ScopedChatClientExtensions +{ + public static ChatClientBuilder UseScopedMiddleware(this ChatClientBuilder builder) + => builder.Use((services, inner) + => new DependencyInjectionPatterns.ScopedChatClient(services, inner)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs new file mode 100644 index 00000000000..2b4370222c6 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -0,0 +1,348 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DistributedCachingEmbeddingGeneratorTest +{ + private readonly TestInMemoryCacheStorage _storage = new(); + private readonly Embedding _expectedEmbedding = new(new float[] { 1.0f, 2.0f, 3.0f }) + { + CreatedAt = DateTimeOffset.Parse("2024-08-01T00:00:00Z"), + ModelId = "someModel", + AdditionalProperties = new() { ["a"] = "b" }, + }; + + [Fact] + public async Task CachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var innerCallCount = 0; + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + innerCallCount++; + return Task.FromResult>>([_expectedEmbedding]); + }, + }; + using var outer = new DistributedCachingEmbeddingGenerator>(testGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.GenerateAsync("abc"); + Assert.Single(result1); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = await outer.GenerateAsync("abc"); + + // Assert + Assert.Single(result2); + Assert.Equal(1, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + + // Act/Assert 2: Cache misses do not return cached results + await outer.GenerateAsync(["def"]); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task SupportsPartiallyCachedBatchesAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var innerCallCount = 0; + Embedding[] expected = Enumerable.Range(0, 10).Select(i => + new Embedding(new[] { 1.0f, 2.0f, 3.0f }) + { + CreatedAt = DateTimeOffset.Parse("2024-08-01T00:00:00Z") + TimeSpan.FromHours(i), + ModelId = $"someModel{i}", + AdditionalProperties = new() { [$"a{i}"] = $"b{i}" }, + }).ToArray(); + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + innerCallCount++; + Assert.Equal(innerCallCount == 1 ? 4 : 6, values.Count()); + return Task.FromResult>>(new(values.Select(i => expected[int.Parse(i)]))); + }, + }; + using var outer = new DistributedCachingEmbeddingGenerator>(testGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Make initial requests for some of the values + var results = await outer.GenerateAsync(["0", "4", "5", "8"]); + Assert.Equal(1, innerCallCount); + Assert.Equal(4, results.Count); + AssertEmbeddingsEqual(expected[0], results[0]); + AssertEmbeddingsEqual(expected[4], results[1]); + AssertEmbeddingsEqual(expected[5], results[2]); + AssertEmbeddingsEqual(expected[8], results[3]); + + // Act/Assert + results = await outer.GenerateAsync(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]); + Assert.Equal(2, innerCallCount); + for (int i = 0; i < 10; i++) + { + AssertEmbeddingsEqual(expected[i], results[i]); + } + + results = await outer.GenerateAsync(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]); + Assert.Equal(2, innerCallCount); + for (int i = 0; i < 10; i++) + { + AssertEmbeddingsEqual(expected[i], results[i]); + } + } + + [Fact] + public async Task AllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await completionTcs.Task; + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.GenerateAsync("abc"); + var result2 = outer.GenerateAsync("abc"); + + // Assert 1 + Assert.Equal(2, innerCallCount); + Assert.False(result1.IsCompleted); + Assert.False(result2.IsCompleted); + completionTcs.SetResult(true); + AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); + AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = await outer.GenerateAsync("abc"); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); + } + + [Fact] + public async Task DoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (value, options, cancellationToken) => + { + innerCallCount++; + throw new InvalidTimeZoneException("some failure"); + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var ex1 = await Assert.ThrowsAsync(() => outer.GenerateAsync("abc")); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var ex2 = await Assert.ThrowsAsync(() => outer.GenerateAsync("abc")); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task DoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var resolutionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + if (innerCallCount == 1) + { + await resolutionTcs.Task; + } + + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // First call gets cancelled + var result1 = outer.GenerateAsync("abc"); + Assert.False(result1.IsCompleted); + Assert.Equal(1, innerCallCount); + resolutionTcs.SetCanceled(); + await Assert.ThrowsAnyAsync(() => result1); + Assert.True(result1.IsCanceled); + + // Act/Assert: Second call can succeed + var result2 = await outer.GenerateAsync("abc"); + Assert.Single(result2); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await Task.Yield(); + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act: Call with two different options + var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 1" } + }); + var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 2" } + }); + + // Assert: Same result + Assert.Single(result1); + Assert.Single(result2); + Assert.Equal(1, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await Task.Yield(); + return [_expectedEmbedding]; + } + }; + using var outer = new CachingEmbeddingGeneratorWithCustomKey(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act: Call with two different options + var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 1" } + }); + var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 2" } + }); + + // Assert: Different results + Assert.Single(result1); + Assert.Single(result2); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task CanResolveIDistributedCacheFromDI() + { + // Arrange + var services = new ServiceCollection() + .AddSingleton(_storage) + .BuildServiceProvider(); + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + return Task.FromResult>>([_expectedEmbedding]); + }, + }; + using var outer = new EmbeddingGeneratorBuilder>(services) + .UseDistributedCache(configure: instance => + { + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(testGenerator); + + // Act: Make a request that should populate the cache + Assert.Empty(_storage.Keys); + var result = await outer.GenerateAsync("abc"); + + // Assert + Assert.NotNull(result); + Assert.Single(_storage.Keys); + } + + private static void AssertEmbeddingsEqual(Embedding expected, Embedding actual) + { + Assert.Equal(expected.CreatedAt, actual.CreatedAt); + Assert.Equal(expected.ModelId, actual.ModelId); + Assert.Equal(expected.Vector.ToArray(), actual.Vector.ToArray()); + Assert.Equal( + JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); + } + + private sealed class CachingEmbeddingGeneratorWithCustomKey(IEmbeddingGenerator> innerGenerator, IDistributedCache storage) + : DistributedCachingEmbeddingGenerator>(innerGenerator, storage) + { + protected override string GetCacheKey(string value, EmbeddingGenerationOptions? options) => + base.GetCacheKey(value, options) + options?.AdditionalProperties?["someKey"]?.ToString(); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs new file mode 100644 index 00000000000..357168c3b65 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorBuilderTests +{ + [Fact] + public void PassesServiceProviderToFactories() + { + var expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + using var expectedResult = new TestEmbeddingGenerator(); + var builder = new EmbeddingGeneratorBuilder>(expectedServiceProvider); + + builder.Use((serviceProvider, innerClient) => + { + Assert.Same(expectedServiceProvider, serviceProvider); + return expectedResult; + }); + + using var innerGenerator = new TestEmbeddingGenerator(); + Assert.Equal(expectedResult, builder.Use(innerGenerator)); + } + + [Fact] + public void BuildsPipelineInOrderAdded() + { + // Arrange + using var expectedInnerService = new TestEmbeddingGenerator(); + var builder = new EmbeddingGeneratorBuilder>(); + + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("First", next)); + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("Second", next)); + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("Third", next)); + + // Act + var first = (InnerServiceCapturingEmbeddingGenerator)builder.Use(expectedInnerService); + + // Assert + Assert.Equal("First", first.Name); + var second = (InnerServiceCapturingEmbeddingGenerator)first.InnerGenerator; + Assert.Equal("Second", second.Name); + var third = (InnerServiceCapturingEmbeddingGenerator)second.InnerGenerator; + Assert.Equal("Third", third.Name); + Assert.Same(expectedInnerService, third.InnerGenerator); + } + + [Fact] + public void DoesNotAcceptNullInnerService() + { + Assert.Throws(() => new EmbeddingGeneratorBuilder>().Use((IEmbeddingGenerator>)null!)); + } + + [Fact] + public void DoesNotAcceptNullFactories() + { + var builder = new EmbeddingGeneratorBuilder>(); + Assert.Throws(() => builder.Use((Func>, IEmbeddingGenerator>>)null!)); + Assert.Throws(() => builder.Use((Func>, IEmbeddingGenerator>>)null!)); + } + + [Fact] + public void DoesNotAllowFactoriesToReturnNull() + { + var builder = new EmbeddingGeneratorBuilder>(); + builder.Use(_ => null!); + var ex = Assert.Throws(() => builder.Use(new TestEmbeddingGenerator())); + Assert.Contains("entry at index 0", ex.Message); + } + + private sealed class InnerServiceCapturingEmbeddingGenerator(string name, IEmbeddingGenerator> innerGenerator) : + DelegatingEmbeddingGenerator>(innerGenerator) + { +#pragma warning disable S3604 // False positive: Member initializer values should not be redundant + public string Name { get; } = name; +#pragma warning restore S3604 + public new IEmbeddingGenerator> InnerGenerator => base.InnerGenerator; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..e231e8995fe --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class LoggingEmbeddingGeneratorTests +{ + [Fact] + public void LoggingEmbeddingGenerator_InvalidArgs_Throws() + { + Assert.Throws("innerGenerator", () => new LoggingEmbeddingGenerator>(null!, NullLogger.Instance)); + Assert.Throws("logger", () => new LoggingEmbeddingGenerator>(new TestEmbeddingGenerator(), null!)); + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level) + { + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using IEmbeddingGenerator> innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + return Task.FromResult(new GeneratedEmbeddings>([new Embedding(new float[] { 1f, 2f, 3f })])); + }, + }; + + using IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>(services) + .UseLogging() + .Use(innerGenerator); + + await generator.GenerateAsync("Blue whale"); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("GenerateAsync invoked:") && entry.Message.Contains("Blue whale")), + entry => Assert.Contains("GenerateAsync generated 1 embedding(s).", entry.Message)); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("GenerateAsync invoked.") && !entry.Message.Contains("Blue whale")), + entry => Assert.Contains("GenerateAsync generated 1 embedding(s).", entry.Message)); + } + else + { + Assert.Empty(clp.Logger.Entries); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs new file mode 100644 index 00000000000..41ed51cd2a2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -0,0 +1,186 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionFactoryTest +{ + [Fact] + public void InvalidArguments_Throw() + { + Delegate nullDelegate = null!; + Assert.Throws(() => AIFunctionFactory.Create(nullDelegate)); + Assert.Throws(() => AIFunctionFactory.Create((MethodInfo)null!)); + Assert.Throws(() => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, null)); + Assert.Throws(() => AIFunctionFactory.Create(typeof(List<>).GetMethod("Add")!, new List())); + } + + [Fact] + public async Task Parameters_MappedByName_Async() + { + AIFunction func; + + func = AIFunctionFactory.Create((string a) => a + " " + a); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")])); + + func = AIFunctionFactory.Create((string a, string b) => b + " " + a); + AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")])); + + func = AIFunctionFactory.Create((int a, long b) => a + b); + AssertExtensions.EqualFunctionCallResults(3L, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + } + + [Fact] + public async Task Parameters_DefaultValuesAreUsedButOverridable_Async() + { + AIFunction func = AIFunctionFactory.Create((string a = "test") => a + " " + a); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync()); + AssertExtensions.EqualFunctionCallResults("hello hello", await func.InvokeAsync([new KeyValuePair("a", "hello")])); + } + + [Fact] + public async Task Parameters_AIFunctionContextMappedByType_Async() + { + using var cts = new CancellationTokenSource(); + CancellationToken written; + AIFunction func; + + // As the only parameter + written = default; + func = AIFunctionFactory.Create((AIFunctionContext ctx) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(cancellationToken: cts.Token)); + Assert.Equal(cts.Token, written); + + // As the last + written = default; + func = AIFunctionFactory.Create((int somethingFirst, AIFunctionContext ctx) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new Dictionary { ["somethingFirst"] = 1, ["ctx"] = new AIFunctionContext() }, cts.Token)); + Assert.Equal(cts.Token, written); + + // As the first + written = default; + func = AIFunctionFactory.Create((AIFunctionContext ctx, int somethingAfter = 0) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(cancellationToken: cts.Token)); + Assert.Equal(cts.Token, written); + } + + [Fact] + public async Task Returns_AsyncReturnTypesSupported_Async() + { + AIFunction func; + + func = AIFunctionFactory.Create(Task (string a) => Task.FromResult(a + " " + a)); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")])); + + func = AIFunctionFactory.Create(ValueTask (string a, string b) => new ValueTask(b + " " + a)); + AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")])); + + long result = 0; + func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + Assert.Equal(3, result); + + result = 0; + func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + Assert.Equal(3, result); + + func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count)); + AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync([new("count", 5)])); + + static async IAsyncEnumerable SimpleIAsyncEnumerable(int count) + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + yield return i; + } + } + + func = AIFunctionFactory.Create(() => (IAsyncEnumerable)new ThrowingAsyncEnumerable()); + await Assert.ThrowsAsync(() => func.InvokeAsync()); + } + + private sealed class ThrowingAsyncEnumerable : IAsyncEnumerable + { +#pragma warning disable S3717 // Track use of "NotImplementedException" + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => throw new NotImplementedException(); +#pragma warning restore S3717 // Track use of "NotImplementedException" + } + + [Fact] + public void Metadata_DerivedFromLambda() + { + AIFunction func; + + func = AIFunctionFactory.Create(() => "test"); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Empty(func.Metadata.Description); + Assert.Empty(func.Metadata.Parameters); + Assert.Equal(typeof(string), func.Metadata.ReturnParameter.ParameterType); + + func = AIFunctionFactory.Create((string a) => a + " " + a); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Empty(func.Metadata.Description); + Assert.Single(func.Metadata.Parameters); + + func = AIFunctionFactory.Create( + [Description("This is a test function")] ([Description("This is A")] string a, [Description("This is B")] string b) => b + " " + a); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Equal("This is a test function", func.Metadata.Description); + Assert.Collection(func.Metadata.Parameters, + p => Assert.Equal("This is A", p.Description), + p => Assert.Equal("This is B", p.Description)); + } + + [Fact] + public void AIFunctionFactoryCreateOptions_ValuesPropagateToAIFunction() + { + IReadOnlyList parameterMetadata = [new AIFunctionParameterMetadata("a")]; + AIFunctionReturnParameterMetadata returnParameterMetadata = new() { ParameterType = typeof(string) }; + IReadOnlyDictionary metadata = new Dictionary { ["a"] = "b" }; + + var options = new AIFunctionFactoryCreateOptions + { + Name = "test name", + Description = "test description", + Parameters = parameterMetadata, + ReturnParameter = returnParameterMetadata, + AdditionalProperties = metadata, + }; + + Assert.Equal("test name", options.Name); + Assert.Equal("test description", options.Description); + Assert.Same(parameterMetadata, options.Parameters); + Assert.Same(returnParameterMetadata, options.ReturnParameter); + Assert.Same(metadata, options.AdditionalProperties); + + AIFunction func = AIFunctionFactory.Create(() => { }, options); + + Assert.Equal("test name", func.Metadata.Name); + Assert.Equal("test description", func.Metadata.Description); + Assert.Equal(parameterMetadata, func.Metadata.Parameters); + Assert.Equal(returnParameterMetadata, func.Metadata.ReturnParameter); + Assert.Equal(metadata, func.Metadata.AdditionalProperties); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj new file mode 100644 index 00000000000..b3d5e8048f5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj @@ -0,0 +1,32 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI. + + + + $(NoWarn);CA1063;CA1861;SA1130;VSTHRD003 + true + + + + true + + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs new file mode 100644 index 00000000000..8ab2cd0cbb0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; + +namespace Microsoft.Extensions.AI; + +internal sealed class TestInMemoryCacheStorage : IDistributedCache +{ + private readonly ConcurrentDictionary _storage = new(); + + public ICollection Keys => _storage.Keys; + + public byte[]? Get(string key) + => _storage.TryGetValue(key, out var value) ? value : null; + + public Task GetAsync(string key, CancellationToken token = default) + => Task.FromResult(Get(key)); + + public void Refresh(string key) + { + // In memory, nothing to refresh + } + + public Task RefreshAsync(string key, CancellationToken token = default) + => Task.CompletedTask; + + public void Remove(string key) + => _storage.TryRemove(key, out _); + + public Task RemoveAsync(string key, CancellationToken token = default) + { + Remove(key); + return Task.CompletedTask; + } + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + _storage[key] = value; + } + + public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + { + Set(key, value, options); + return Task.CompletedTask; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..e376da86dad --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +// These types are directly serialized by DistributedCachingChatClient +[JsonSerializable(typeof(ChatCompletion))] +[JsonSerializable(typeof(IList))] +[JsonSerializable(typeof(IReadOnlyList))] + +// These types are specific to the tests in this project +[JsonSerializable(typeof(bool))] +[JsonSerializable(typeof(double))] +[JsonSerializable(typeof(JsonElement))] +[JsonSerializable(typeof(Embedding))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(DayOfWeek[]))] +[JsonSerializable(typeof(Guid))] +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs b/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs index a27876703e7..e007d95860a 100644 --- a/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs +++ b/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs @@ -24,6 +24,7 @@ protected override IXunitTestCase CreateTestCase(ITestFrameworkDiscoveryOptions var skipReason = testMethod.EvaluateSkipConditions(); return skipReason != null ? new SkippedTestCase(skipReason, _diagnosticMessageSink, discoveryOptions.MethodDisplayOrDefault(), TestMethodDisplayOptions.None, testMethod) - : base.CreateTestCase(discoveryOptions, testMethod, factAttribute); + : new SkippedFactTestCase(DiagnosticMessageSink, discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), testMethod); // Test case skippable at runtime. } } diff --git a/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs b/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs index 846038f8786..b1e53b8ed77 100644 --- a/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs +++ b/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs @@ -3,7 +3,6 @@ // Borrowed from https://github.com/dotnet/aspnetcore/blob/95ed45c67/src/Testing/src/xunit/ -using System; using System.Collections.Generic; using Xunit.Abstractions; using Xunit.Sdk; diff --git a/test/TestUtilities/XUnit/SkipTestException.cs b/test/TestUtilities/XUnit/SkipTestException.cs new file mode 100644 index 00000000000..70f7d53c7d8 --- /dev/null +++ b/test/TestUtilities/XUnit/SkipTestException.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Borrowed from https://github.com/dotnet/aspnetcore/blob/95ed45c67/src/Testing/src/xunit/ + +using System; + +namespace Microsoft.TestUtilities; + +public class SkipTestException : Exception +{ + public SkipTestException(string reason) + : base(reason) + { + } +} diff --git a/test/TestUtilities/XUnit/SkippedFactTestCase.cs b/test/TestUtilities/XUnit/SkippedFactTestCase.cs new file mode 100644 index 00000000000..79ace15ea6e --- /dev/null +++ b/test/TestUtilities/XUnit/SkippedFactTestCase.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Microsoft.TestUtilities; + +public class SkippedFactTestCase : XunitTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes", error: true)] + public SkippedFactTestCase() + { + } + + public SkippedFactTestCase( + IMessageSink diagnosticMessageSink, TestMethodDisplay defaultMethodDisplay, TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod, object[]? testMethodArguments = null) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod, testMethodArguments) + { + } + + public override async Task RunAsync(IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + { + using SkippedTestMessageBus skipMessageBus = new(messageBus); + var result = await base.RunAsync(diagnosticMessageSink, skipMessageBus, constructorArguments, aggregator, cancellationTokenSource); + if (skipMessageBus.SkippedTestCount > 0) + { + result.Failed -= skipMessageBus.SkippedTestCount; + result.Skipped += skipMessageBus.SkippedTestCount; + } + + return result; + } +} diff --git a/test/TestUtilities/XUnit/SkippedTestMessageBus.cs b/test/TestUtilities/XUnit/SkippedTestMessageBus.cs new file mode 100644 index 00000000000..230586852b8 --- /dev/null +++ b/test/TestUtilities/XUnit/SkippedTestMessageBus.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Microsoft.TestUtilities; + +/// Implements message bus to communicate tests skipped via SkipTestException. +public sealed class SkippedTestMessageBus : IMessageBus +{ + private readonly IMessageBus _innerBus; + + public SkippedTestMessageBus(IMessageBus innerBus) + { + _innerBus = innerBus; + } + + public int SkippedTestCount { get; private set; } + + public void Dispose() + { + // nothing to dispose + } + + public bool QueueMessage(IMessageSinkMessage message) + { + var testFailed = message as ITestFailed; + + if (testFailed != null) + { + var exceptionType = testFailed.ExceptionTypes.FirstOrDefault(); + if (exceptionType == typeof(SkipTestException).FullName) + { + SkippedTestCount++; + return _innerBus.QueueMessage(new TestSkipped(testFailed.Test, testFailed.Messages.FirstOrDefault())); + } + } + + // Nothing we care about, send it on its way + return _innerBus.QueueMessage(message); + } +}