From 92ae19e87a17affb8f96edddf9d7d4b7ad6ce690 Mon Sep 17 00:00:00 2001 From: Jose Perez Rodriguez Date: Wed, 2 Oct 2024 23:25:23 +0000 Subject: [PATCH 1/3] Merged PR 43283: Using RC2 aspnetcore and runtime versions Getting ready for our 9.0-preview9 release ---- #### AI description (iteration 1) #### PR Classification Dependency update #### PR Summary This pull request updates the project to use RC2 versions of ASP.NET Core and runtime dependencies. - Updated dependency versions in `/eng/Version.Details.xml` and `/eng/Versions.props` to RC2. - Removed code coverage stage from `azure-pipelines.yml`. - Added setup for private feeds credentials in `/eng/pipelines/templates/BuildAndTest.yml`. - Disabled NU1507 warning in `Directory.Build.props`. - Modified `NuGet.config` to include `dotnet9-internal` feed and removed package source mappings. --- Directory.Build.props | 5 + NuGet.config | 36 +--- azure-pipelines.yml | 48 +---- eng/Version.Details.xml | 246 +++++++++++------------ eng/Versions.props | 82 ++++---- eng/pipelines/templates/BuildAndTest.yml | 18 ++ global.json | 4 +- 7 files changed, 192 insertions(+), 247 deletions(-) diff --git a/Directory.Build.props b/Directory.Build.props index fef9d781511..19dbc088120 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -35,6 +35,11 @@ $(NetCoreTargetFrameworks) + + + $(NoWarn);NU1507 + + false latest diff --git a/NuGet.config b/NuGet.config index f91233ccab5..aba191afbbc 100644 --- a/NuGet.config +++ b/NuGet.config @@ -3,6 +3,7 @@ + @@ -15,41 +16,8 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 211058cf56a..f674e637cea 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -143,7 +143,7 @@ extends: parameters: enableMicrobuild: true enableTelemetry: true - enableSourceIndex: true + enableSourceIndex: false runAsPublic: ${{ variables['runAsPublic'] }} # Publish build logs enablePublishBuildArtifacts: true @@ -220,51 +220,6 @@ extends: isWindows: false warnAsError: 0 - # ---------------------------------------------------------------- - # This stage performs quality gates enforcements - # ---------------------------------------------------------------- - - stage: codecoverage - displayName: CodeCoverage - dependsOn: - - build - condition: and(succeeded('build'), ne(variables['SkipQualityGates'], 'true')) - variables: - - template: /eng/common/templates-official/variables/pool-providers.yml@self - jobs: - - template: /eng/common/templates-official/jobs/jobs.yml@self - parameters: - enableMicrobuild: true - enableTelemetry: true - runAsPublic: ${{ variables['runAsPublic'] }} - workspace: - clean: all - - # ---------------------------------------------------------------- - # This stage downloads the code coverage reports from the build jobs, - # merges those and validates the combined test coverage. - # ---------------------------------------------------------------- - jobs: - - job: CodeCoverageReport - timeoutInMinutes: 180 - - pool: - name: NetCore1ESPool-Internal - image: 1es-mariner-2 - os: linux - - preSteps: - - checkout: self - clean: true - persistCredentials: true - fetchDepth: 1 - - steps: - - script: $(Build.SourcesDirectory)/build.sh --ci --restore - displayName: Init toolset - - - template: /eng/pipelines/templates/VerifyCoverageReport.yml - - # ---------------------------------------------------------------- # This stage only performs a build treating warnings as errors # to detect any kind of code style violations @@ -320,7 +275,6 @@ extends: parameters: validateDependsOn: - build - - codecoverage - correctness publishingInfraVersion: 3 enableSymbolValidation: false diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index e4dcd0226ff..e2498d1c564 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,172 +1,172 @@ - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/runtime - 0d44aea3696bab80b11a12c6bdfdbf8de9c4e815 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 - - https://github.com/dotnet/aspnetcore - 9a34a6e3c7975f41300bd2550a089a85810cafd1 + + https://dev.azure.com/dnceng/internal/_git/dotnet-aspnetcore + c70204ae3c91d2b48fa6d9b92b62265f368421b4 diff --git a/eng/Versions.props b/eng/Versions.props index e5209b4f90b..186e160151a 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -28,48 +28,48 @@ --> - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 - 9.0.0-rtm.24476.4 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 - 9.0.0-rtm.24477.5 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + 9.0.0-rc.2.24474.3 + true + $(NoWarn);LA0003 + + - + diff --git a/eng/MSBuild/Shared.props b/eng/MSBuild/Shared.props index a68b0e4298f..7c5ac8424e0 100644 --- a/eng/MSBuild/Shared.props +++ b/eng/MSBuild/Shared.props @@ -1,4 +1,8 @@ + + + + diff --git a/eng/Version.Details.xml b/eng/Version.Details.xml index e2498d1c564..3b65583f912 100644 --- a/eng/Version.Details.xml +++ b/eng/Version.Details.xml @@ -1,5 +1,9 @@ + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 990ebf52fc408ca45929fd176d2740675a67fab8 @@ -112,6 +116,18 @@ https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 990ebf52fc408ca45929fd176d2740675a67fab8 + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + + + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime + 990ebf52fc408ca45929fd176d2740675a67fab8 + https://dev.azure.com/dnceng/internal/_git/dotnet-runtime 990ebf52fc408ca45929fd176d2740675a67fab8 diff --git a/eng/Versions.props b/eng/Versions.props index 186e160151a..68ddd98d120 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -28,6 +28,7 @@ --> + 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 @@ -55,11 +56,14 @@ 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 9.0.0-rc.2.24473.5 + 9.0.0-rc.2.24473.5 9.0.0-rc.2.24474.3 9.0.0-rc.2.24474.3 diff --git a/eng/packages/General.props b/eng/packages/General.props index b7e3259930f..ce9c0579971 100644 --- a/eng/packages/General.props +++ b/eng/packages/General.props @@ -1,7 +1,9 @@ + + @@ -33,6 +35,7 @@ + @@ -47,9 +50,12 @@ + + + diff --git a/eng/packages/TestOnly.props b/eng/packages/TestOnly.props index 9e9fefae39d..2bde3b34e05 100644 --- a/eng/packages/TestOnly.props +++ b/eng/packages/TestOnly.props @@ -2,17 +2,21 @@ + + + + diff --git a/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs b/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs new file mode 100644 index 00000000000..b979931673c --- /dev/null +++ b/src/LegacySupport/CompilerFeatureRequiredAttribute/CompilerFeatureRequiredAttribute.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable SA1623 // Property summary documentation should match accessors + +namespace System.Runtime.CompilerServices; + +/// +/// Indicates that compiler support for a particular feature is required for the location where this attribute is applied. +/// +[AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] +internal sealed class CompilerFeatureRequiredAttribute : Attribute +{ + public CompilerFeatureRequiredAttribute(string featureName) + { + FeatureName = featureName; + } + + /// + /// The name of the compiler feature. + /// + public string FeatureName { get; } + + /// + /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . + /// + public bool IsOptional { get; init; } + + /// + /// The used for the ref structs C# feature. + /// + public const string RefStructs = nameof(RefStructs); + + /// + /// The used for the required members C# feature. + /// + public const string RequiredMembers = nameof(RequiredMembers); +} diff --git a/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md b/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md new file mode 100644 index 00000000000..c30799eef0b --- /dev/null +++ b/src/LegacySupport/CompilerFeatureRequiredAttribute/README.md @@ -0,0 +1,9 @@ +Enables use of C# required members on older frameworks. + +To use this source in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/LegacySupport/RequiredMemberAttribute/README.md b/src/LegacySupport/RequiredMemberAttribute/README.md new file mode 100644 index 00000000000..da8c9bc98ce --- /dev/null +++ b/src/LegacySupport/RequiredMemberAttribute/README.md @@ -0,0 +1,9 @@ +Enables use of C# required members on older frameworks. + +To use this source in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs b/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs new file mode 100644 index 00000000000..a83785b9655 --- /dev/null +++ b/src/LegacySupport/RequiredMemberAttribute/RequiredMemberAttribute.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel; + +namespace System.Runtime.CompilerServices; + +/// Specifies that a type has required members or that a member is required. +[AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] +[EditorBrowsable(EditorBrowsableState.Never)] +internal sealed class RequiredMemberAttribute : Attribute; diff --git a/src/LegacySupport/TrimAttributes/RequiresDynamicCodeAttribute.cs b/src/LegacySupport/TrimAttributes/RequiresDynamicCodeAttribute.cs new file mode 100644 index 00000000000..072701f1a46 --- /dev/null +++ b/src/LegacySupport/TrimAttributes/RequiresDynamicCodeAttribute.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable SA1101 +#pragma warning disable SA1116 +#pragma warning disable SA1117 +#pragma warning disable SA1512 +#pragma warning disable SA1623 +#pragma warning disable SA1642 +#pragma warning disable S3903 +#pragma warning disable S3996 + +namespace System.Diagnostics.CodeAnalysis; + +/// +/// Indicates that the specified method requires the ability to generate new code at runtime, +/// for example through . +/// +/// +/// This allows tools to understand which methods are unsafe to call when compiling ahead of time. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class RequiresDynamicCodeAttribute : Attribute +{ + /// + /// Initializes a new instance of the class + /// with the specified message. + /// + /// + /// A message that contains information about the usage of dynamic code. + /// + public RequiresDynamicCodeAttribute(string message) + { + Message = message; + } + + /// + /// Gets a message that contains information about the usage of dynamic code. + /// + public string Message { get; } + + /// + /// Gets or sets an optional URL that contains more information about the method, + /// why it requires dynamic code, and what options a consumer has to deal with it. + /// + public string? Url { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AITool.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AITool.cs new file mode 100644 index 00000000000..0cdcd60e63e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AITool.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +/// Represents a tool that may be specified to an AI service. +public class AITool +{ + /// Initializes a new instance of the class. + protected AITool() + { + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs new file mode 100644 index 00000000000..5ffc76260d9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/AdditionalPropertiesDictionary.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.Extensions.AI; + +/// Provides a dictionary used as the AdditionalProperties dictionary on Microsoft.Extensions.AI objects. +public sealed class AdditionalPropertiesDictionary : IDictionary, IReadOnlyDictionary +{ + /// The underlying dictionary. + private readonly Dictionary _dictionary; + + /// Initializes a new instance of the class. + public AdditionalPropertiesDictionary() + { + _dictionary = new(StringComparer.OrdinalIgnoreCase); + } + + /// Initializes a new instance of the class. + public AdditionalPropertiesDictionary(IDictionary dictionary) + { + _dictionary = new(dictionary, StringComparer.OrdinalIgnoreCase); + } + + /// Initializes a new instance of the class. + public AdditionalPropertiesDictionary(IEnumerable> collection) + { +#if NET + _dictionary = new(collection, StringComparer.OrdinalIgnoreCase); +#else + _dictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var item in collection) + { + _dictionary.Add(item.Key, item.Value); + } +#endif + } + + /// + public object? this[string key] + { + get => _dictionary[key]; + set => _dictionary[key] = value; + } + + /// + public ICollection Keys => _dictionary.Keys; + + /// + public ICollection Values => _dictionary.Values; + + /// + public int Count => _dictionary.Count; + + /// + bool ICollection>.IsReadOnly => false; + + /// + IEnumerable IReadOnlyDictionary.Keys => _dictionary.Keys; + + /// + IEnumerable IReadOnlyDictionary.Values => _dictionary.Values; + + /// + public void Add(string key, object? value) => _dictionary.Add(key, value); + + /// + void ICollection>.Add(KeyValuePair item) => ((ICollection>)_dictionary).Add(item); + + /// + public void Clear() => _dictionary.Clear(); + + /// + bool ICollection>.Contains(KeyValuePair item) => _dictionary.Contains(item); + + /// + public bool ContainsKey(string key) => _dictionary.ContainsKey(key); + + /// + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => + ((ICollection>)_dictionary).CopyTo(array, arrayIndex); + + /// + public IEnumerator> GetEnumerator() => _dictionary.GetEnumerator(); + + /// + public bool Remove(string key) => _dictionary.Remove(key); + + /// + bool ICollection>.Remove(KeyValuePair item) => ((ICollection>)_dictionary).Remove(item); + + /// + public bool TryGetValue(string key, out object? value) => _dictionary.TryGetValue(key, out value); + + /// + IEnumerator IEnumerable.GetEnumerator() => _dictionary.GetEnumerator(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/AutoChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/AutoChatToolMode.cs new file mode 100644 index 00000000000..d6307477296 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/AutoChatToolMode.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Indicates that an is free to select any of the available tools, or none at all. +/// +/// +/// Use to get an instance of . +/// +[DebuggerDisplay("Auto")] +public sealed class AutoChatToolMode : ChatToolMode +{ + /// Initializes a new instance of the class. + /// Use to get an instance of . + public AutoChatToolMode() + { + } // must exist in support of polymorphic deserialization of a ChatToolMode + + /// + public override bool Equals(object? obj) => obj is AutoChatToolMode; + + /// + public override int GetHashCode() => typeof(AutoChatToolMode).GetHashCode(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs new file mode 100644 index 00000000000..944283ccd88 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientExtensions.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides a collection of static methods for extending instances. +public static class ChatClientExtensions +{ + /// Sends a user chat text message to the model and returns the response messages. + /// The chat client. + /// The text content for the chat message to send. + /// The chat options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + public static Task CompleteAsync( + this IChatClient client, + string chatMessage, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(client); + _ = Throw.IfNull(chatMessage); + + return client.CompleteAsync([new ChatMessage(ChatRole.User, chatMessage)], options, cancellationToken); + } + + /// Sends a user chat text message to the model and streams the response messages. + /// The chat client. + /// The text content for the chat message to send. + /// The chat options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + public static IAsyncEnumerable CompleteStreamingAsync( + this IChatClient client, + string chatMessage, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(client); + _ = Throw.IfNull(chatMessage); + + return client.CompleteStreamingAsync([new ChatMessage(ChatRole.User, chatMessage)], options, cancellationToken); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs new file mode 100644 index 00000000000..b98455daf2a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatClientMetadata.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +/// Provides metadata about an . +public class ChatClientMetadata +{ + /// Initializes a new instance of the class. + /// The name of the chat completion provider, if applicable. + /// The URL for accessing the chat completion provider, if applicable. + /// The id of the chat completion model used, if applicable. + public ChatClientMetadata(string? providerName = null, Uri? providerUri = null, string? modelId = null) + { + ModelId = modelId; + ProviderName = providerName; + ProviderUri = providerUri; + } + + /// Gets the name of the chat completion provider. + public string? ProviderName { get; } + + /// Gets the URL for accessing the chat completion provider. + public Uri? ProviderUri { get; } + + /// Gets the id of the model used by this chat completion provider. + /// This may be null if either the name is unknown or there are multiple possible models associated with this instance. + public string? ModelId { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs new file mode 100644 index 00000000000..2a9237d9b5a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatCompletion.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the result of a chat completion request. +public class ChatCompletion +{ + /// The list of choices in the completion. + private IList _choices; + + /// Initializes a new instance of the class. + /// The list of choices in the completion, one message per choice. + [JsonConstructor] + public ChatCompletion(IList choices) + { + _choices = Throw.IfNull(choices); + } + + /// Initializes a new instance of the class. + /// The chat message representing the singular choice in the completion. + public ChatCompletion(ChatMessage message) + { + _ = Throw.IfNull(message); + _choices = [message]; + } + + /// Gets or sets the list of chat completion choices. + public IList Choices + { + get => _choices; + set => _choices = Throw.IfNull(value); + } + + /// Gets the chat completion message. + /// + /// If there are multiple choices, this property returns the first choice. + /// If is empty, this will throw. Use to access all choices directly."/>. + /// + public ChatMessage Message + { + get + { + var choices = Choices; + if (choices.Count == 0) + { + throw new InvalidOperationException($"The {nameof(ChatCompletion)} instance does not contain any {nameof(ChatMessage)} choices."); + } + + return choices[0]; + } + } + + /// Gets or sets the ID of the chat completion. + public string? CompletionId { get; set; } + + /// Gets or sets the model ID using in the creation of the chat completion. + public string? ModelId { get; set; } + + /// Gets or sets a timestamp for the chat completion. + public DateTimeOffset? CreatedAt { get; set; } + + /// Gets or sets the reason for the chat completion. + public ChatFinishReason? FinishReason { get; set; } + + /// Gets or sets usage details for the chat completion. + public UsageDetails? Usage { get; set; } + + /// Gets or sets the raw representation of the chat completion from an underlying implementation. + /// + /// If a is created to represent some underlying object from another object + /// model, this property can be used to store that original object. This can be useful for debugging or + /// for enabling a consumer to access the underlying object model if needed. + /// + [JsonIgnore] + public object? RawRepresentation { get; set; } + + /// Gets or sets any additional properties associated with the chat completion. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// + public override string ToString() => + Choices is { Count: > 0 } choices ? string.Join(Environment.NewLine, choices) : string.Empty; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatFinishReason.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatFinishReason.cs new file mode 100644 index 00000000000..08a5630c51b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatFinishReason.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the reason a chat response completed. +[JsonConverter(typeof(Converter))] +public readonly struct ChatFinishReason : IEquatable +{ + /// The finish reason value. If null because `default(ChatFinishReason)` was used, the instance will behave like . + private readonly string? _value; + + /// Initializes a new instance of the struct with a string that describes the reason. + /// The reason value. + /// is null. + /// is empty or composed entirely of whitespace. + [JsonConstructor] + public ChatFinishReason(string value) + { + _value = Throw.IfNullOrWhitespace(value); + } + + /// Gets the finish reason value. + public string Value => _value ?? Stop.Value; + + /// + public override bool Equals([NotNullWhen(true)] object? obj) => obj is ChatFinishReason other && Equals(other); + + /// + public bool Equals(ChatFinishReason other) => StringComparer.OrdinalIgnoreCase.Equals(Value, other.Value); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + /// Compares two instances. + /// + /// Left argument of the comparison. + /// Right argument of the comparison. + /// when equal, otherwise. + public static bool operator ==(ChatFinishReason left, ChatFinishReason right) + { + return left.Equals(right); + } + + /// + /// Compares two instances. + /// + /// Left argument of the comparison. + /// Right argument of the comparison. + /// when not equal, otherwise. + public static bool operator !=(ChatFinishReason left, ChatFinishReason right) + { + return !(left == right); + } + + /// Gets the of the finish reason. + /// The of the finish reason. + public override string ToString() => Value; + + /// Gets a representing the model encountering a natural stop point or provided stop sequence. + public static ChatFinishReason Stop { get; } = new("stop"); + + /// Gets a representing the model reaching the maximum length allowed for the request and/or response (typically in terms of tokens). + public static ChatFinishReason Length { get; } = new("length"); + + /// Gets a representing the model requesting the use of a tool that was defined in the request. + public static ChatFinishReason ToolCalls { get; } = new("tool_calls"); + + /// Gets a representing the model filtering content, whether for safety, prohibited content, sensitive content, or other such issues. + public static ChatFinishReason ContentFilter { get; } = new("content_filter"); + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override ChatFinishReason Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + new(reader.GetString()!); + + /// + public override void Write(Utf8JsonWriter writer, ChatFinishReason value, JsonSerializerOptions options) => + Throw.IfNull(writer).WriteStringValue(value.Value); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs new file mode 100644 index 00000000000..4fdb138b615 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatMessage.cs @@ -0,0 +1,99 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents a chat message used by an . +public class ChatMessage +{ + private IList? _contents; + private string? _authorName; + + /// Initializes a new instance of the class. + [JsonConstructor] + public ChatMessage() + { + } + + /// Initializes a new instance of the class. + /// Role of the author of the message. + /// Content of the message. + public ChatMessage(ChatRole role, string? content) + : this(role, content is null ? [] : [new TextContent(content)]) + { + } + + /// Initializes a new instance of the class. + /// Role of the author of the message. + /// The contents for this message. + public ChatMessage( + ChatRole role, + IList contents) + { + Role = role; + _contents = Throw.IfNull(contents); + } + + /// Gets or sets the name of the author of the message. + public string? AuthorName + { + get => _authorName; + set => _authorName = string.IsNullOrWhiteSpace(value) ? null : value; + } + + /// Gets or sets the role of the author of the message. + public ChatRole Role { get; set; } = ChatRole.User; + + /// + /// Gets or sets the text of the first instance in . + /// + /// + /// If there is no instance in , then the getter returns , + /// and the setter will add a new instance with the provided value. + /// + [JsonIgnore] + public string? Text + { + get => Contents.OfType().FirstOrDefault()?.Text; + set + { + if (Contents.OfType().FirstOrDefault() is { } textContent) + { + textContent.Text = value; + } + else if (value is not null) + { + Contents.Add(new TextContent(value)); + } + } + } + + /// Gets or sets the chat message content items. + [AllowNull] + public IList Contents + { + get => _contents ??= []; + set => _contents = value; + } + + /// Gets or sets the raw representation of the chat message from an underlying implementation. + /// + /// If a is created to represent some underlying object from another object + /// model, this property can be used to store that original object. This can be useful for debugging or + /// for enabling a consumer to access the underlying object model if needed. + /// + [JsonIgnore] + public object? RawRepresentation { get; set; } + + /// Gets or sets any additional properties associated with the message. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// + public override string ToString() => Text ?? string.Empty; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs new file mode 100644 index 00000000000..21224454000 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatOptions.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// Represents the options for a chat request. +public class ChatOptions +{ + /// Gets or sets the temperature for generating chat responses. + public float? Temperature { get; set; } + + /// Gets or sets the maximum number of tokens in the generated chat response. + public int? MaxOutputTokens { get; set; } + + /// Gets or sets the "nucleus sampling" factor (or "top p") for generating chat responses. + public float? TopP { get; set; } + + /// Gets or sets the frequency penalty for generating chat responses. + public float? FrequencyPenalty { get; set; } + + /// Gets or sets the presence penalty for generating chat responses. + public float? PresencePenalty { get; set; } + + /// + /// Gets or sets the response format for the chat request. + /// + /// + /// If null, no response format is specified and the client will use its default. + /// This may be set to to specify that the response should be unstructured text, + /// to to specify that the response should be structured JSON data, or + /// an instance of constructed with a specific JSON schema to request that the + /// response be structured JSON data according to that schema. It is up to the client implementation if or how + /// to honor the request. If the client implementation doesn't recognize the specific kind of , + /// it may be ignored. + /// + public ChatResponseFormat? ResponseFormat { get; set; } + + /// Gets or sets the model ID for the chat request. + public string? ModelId { get; set; } + + /// Gets or sets the stop sequences for generating chat responses. + public IList? StopSequences { get; set; } + + /// Gets or sets the tool mode for the chat request. + public ChatToolMode ToolMode { get; set; } = ChatToolMode.Auto; + + /// Gets or sets the list of tools to include with a chat request. + [JsonIgnore] + public IList? Tools { get; set; } + + /// Gets or sets any additional properties associated with the options. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// Produces a clone of the current instance. + /// A clone of the current instance. + /// + /// The clone will have the same values for all properties as the original instance. Any collections, like , + /// , and , are shallow-cloned, meaning a new collection instance is created, + /// but any references contained by the collections are shared with the original. + /// + public virtual ChatOptions Clone() + { + ChatOptions options = new() + { + Temperature = Temperature, + MaxOutputTokens = MaxOutputTokens, + TopP = TopP, + FrequencyPenalty = FrequencyPenalty, + PresencePenalty = PresencePenalty, + ResponseFormat = ResponseFormat, + ModelId = ModelId, + ToolMode = ToolMode, + }; + + if (StopSequences is not null) + { + options.StopSequences = new List(StopSequences); + } + + if (Tools is not null) + { + options.Tools = new List(Tools); + } + + if (AdditionalProperties is not null) + { + options.AdditionalProperties = new(AdditionalProperties); + } + + return options; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormat.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormat.cs new file mode 100644 index 00000000000..6f1574fe400 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormat.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the response format that is desired by the caller. +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +[JsonDerivedType(typeof(ChatResponseFormatText), typeDiscriminator: "text")] +[JsonDerivedType(typeof(ChatResponseFormatJson), typeDiscriminator: "json")] +#pragma warning disable CA1052 // Static holder types should be Static or NotInheritable +public class ChatResponseFormat +#pragma warning restore CA1052 +{ + /// Initializes a new instance of the class. + /// Prevents external instantiation. Close the inheritance hierarchy for now until we have good reason to open it. + private protected ChatResponseFormat() + { + } + + /// Gets a singleton instance representing unstructured textual data. + public static ChatResponseFormatText Text { get; } = new(); + + /// Gets a singleton instance representing structured JSON data but without any particular schema. + public static ChatResponseFormatJson Json { get; } = new(schema: null); + + /// Creates a representing structured JSON data with the specified schema. + /// The JSON schema. + /// An optional name of the schema, e.g. if the schema represents a particular class, this could be the name of the class. + /// An optional description of the schema. + /// The instance. + public static ChatResponseFormatJson ForJsonSchema( + [StringSyntax(StringSyntaxAttribute.Json)] string schema, string? schemaName = null, string? schemaDescription = null) => + new(Throw.IfNull(schema), + schemaName, + schemaDescription); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatJson.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatJson.cs new file mode 100644 index 00000000000..e26c769ca62 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatJson.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents a response format for structured JSON data. +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class ChatResponseFormatJson : ChatResponseFormat +{ + /// Initializes a new instance of the class with the specified schema. + /// The schema to associate with the JSON response. + /// A name for the schema. + /// A description of the schema. + [JsonConstructor] + public ChatResponseFormatJson( + [StringSyntax(StringSyntaxAttribute.Json)] string? schema, string? schemaName = null, string? schemaDescription = null) + { + if (schema is null && (schemaName is not null || schemaDescription is not null)) + { + Throw.ArgumentException( + schemaName is not null ? nameof(schemaName) : nameof(schemaDescription), + "Schema name and description can only be specified if a schema is provided."); + } + + Schema = schema; + SchemaName = schemaName; + SchemaDescription = schemaDescription; + } + + /// Gets the JSON schema associated with the response, or null if there is none. + public string? Schema { get; } + + /// Gets a name for the schema. + public string? SchemaName { get; } + + /// Gets a description of the schema. + public string? SchemaDescription { get; } + + /// + public override bool Equals(object? obj) => + obj is ChatResponseFormatJson other && + Schema == other.Schema && + SchemaName == other.SchemaName && + SchemaDescription == other.SchemaDescription; + + /// + public override int GetHashCode() => + Schema?.GetHashCode(StringComparison.Ordinal) ?? + typeof(ChatResponseFormatJson).GetHashCode(); + + /// Gets a string representing this instance to display in the debugger. + private string DebuggerDisplay => Schema ?? "JSON"; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatText.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatText.cs new file mode 100644 index 00000000000..71cd8b2877d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatResponseFormatText.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents a response format with no constraints around the format. +/// +/// Use to get an instance of . +/// +[DebuggerDisplay("Text")] +public sealed class ChatResponseFormatText : ChatResponseFormat +{ + /// Initializes a new instance of the class. + /// Use to get an instance of . + public ChatResponseFormatText() + { + // must exist in support of polymorphic deserialization of a ChatResponseFormat + } + + /// + public override bool Equals(object? obj) => obj is ChatResponseFormatText; + + /// + public override int GetHashCode() => typeof(ChatResponseFormatText).GetHashCode(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatRole.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatRole.cs new file mode 100644 index 00000000000..f898bb58892 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatRole.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Describes the intended purpose of a message within a chat completion interaction. +/// +[JsonConverter(typeof(Converter))] +public readonly struct ChatRole : IEquatable +{ + /// Gets the role that instructs or sets the behavior of the assistant. + public static ChatRole System { get; } = new("system"); + + /// Gets the role that provides responses to system-instructed, user-prompted input. + public static ChatRole Assistant { get; } = new("assistant"); + + /// Gets the role that provides input for chat completions. + public static ChatRole User { get; } = new("user"); + + /// Gets the role that provides additional information and references for chat completions. + public static ChatRole Tool { get; } = new("tool"); + + /// + /// Gets the value associated with this . + /// + /// + /// The value is what will be serialized into the "role" message field of the Chat Message format. + /// + public string Value { get; } + + /// + /// Initializes a new instance of the struct with the provided value. + /// + /// The value to associate with this . + [JsonConstructor] + public ChatRole(string value) + { + Value = Throw.IfNullOrWhitespace(value); + } + + /// + /// Returns a value indicating whether two instances are equivalent, as determined by a + /// case-insensitive comparison of their values. + /// + /// the first instance to compare. + /// the second instance to compare. + /// true if left and right are both null or have equivalent values; false otherwise. + public static bool operator ==(ChatRole left, ChatRole right) + { + return left.Equals(right); + } + + /// + /// Returns a value indicating whether two instances are not equivalent, as determined by a + /// case-insensitive comparison of their values. + /// + /// the first instance to compare. + /// the second instance to compare. + /// false if left and right are both null or have equivalent values; true otherwise. + public static bool operator !=(ChatRole left, ChatRole right) + { + return !(left == right); + } + + /// + public override bool Equals([NotNullWhen(true)] object? obj) + => obj is ChatRole otherRole && Equals(otherRole); + + /// + public bool Equals(ChatRole other) + => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() + => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override ChatRole Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + new(reader.GetString()!); + + /// + public override void Write(Utf8JsonWriter writer, ChatRole value, JsonSerializerOptions options) => + Throw.IfNull(writer).WriteStringValue(value.Value); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs new file mode 100644 index 00000000000..27b8c70e804 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// +/// Describes how tools should be selected by a . +/// +/// +/// The predefined values and are provided. +/// To nominate a specific function, use . +/// +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +[JsonDerivedType(typeof(AutoChatToolMode), typeDiscriminator: "auto")] +[JsonDerivedType(typeof(RequiredChatToolMode), typeDiscriminator: "required")] +#pragma warning disable CA1052 // Static holder types should be Static or NotInheritable +public class ChatToolMode +#pragma warning restore CA1052 +{ + /// Initializes a new instance of the class. + /// Prevents external instantiation. Close the inheritance hierarchy for now until we have good reason to open it. + private protected ChatToolMode() + { + } + + /// + /// Gets a predefined indicating that tool usage is optional. + /// + /// + /// may contain zero or more + /// instances, and the is free to invoke zero or more of them. + /// + public static AutoChatToolMode Auto { get; } = new AutoChatToolMode(); + + /// + /// Gets a predefined indicating that tool usage is required, + /// but that any tool may be selected. At least one tool must be provided in . + /// + public static RequiredChatToolMode RequireAny { get; } = new(requiredFunctionName: null); + + /// + /// Instantiates a indicating that tool usage is required, + /// and that the specified must be selected. The function name + /// must match an entry in . + /// + /// The name of the required function. + /// An instance of for the specified function name. + public static RequiredChatToolMode RequireSpecific(string functionName) => new(functionName); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs new file mode 100644 index 00000000000..a6fb40b3555 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/DelegatingChatClient.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides an optional base class for an that passes through calls to another instance. +/// +/// +/// This is recommended as a base type when building clients that can be chained in any order around an underlying . +/// The default implementation simply passes each call to the inner client instance. +/// +public class DelegatingChatClient : IChatClient +{ + /// + /// Initializes a new instance of the class. + /// + /// The wrapped client instance. + protected DelegatingChatClient(IChatClient innerClient) + { + InnerClient = Throw.IfNull(innerClient); + } + + /// + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// Gets the inner . + protected IChatClient InnerClient { get; } + + /// Provides a mechanism for releasing unmanaged resources. + /// true if being called from ; otherwise, false. + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + InnerClient.Dispose(); + } + } + + /// + public virtual ChatClientMetadata Metadata => InnerClient.Metadata; + + /// + public virtual Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + return InnerClient.CompleteAsync(chatMessages, options, cancellationToken); + } + + /// + public virtual IAsyncEnumerable CompleteStreamingAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + return InnerClient.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } + + /// + public virtual TService? GetService(object? key = null) + where TService : class + { +#pragma warning disable S3060 // "is" should not be used with "this" + // If the key is non-null, we don't know what it means so pass through to the inner service + return key is null && this is TService service ? service : InnerClient.GetService(key); +#pragma warning restore S3060 + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs new file mode 100644 index 00000000000..e9839cab2ae --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/IChatClient.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// Represents a chat completion client. +public interface IChatClient : IDisposable +{ + /// Sends chat messages to the model and returns the response messages. + /// The chat content to send. + /// The chat options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default); + + /// Sends chat messages to the model and streams the response messages. + /// The chat content to send. + /// The chat options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default); + + /// Gets metadata that describes the . + ChatClientMetadata Metadata { get; } + + /// Asks the for an object of type . + /// The type of the object to be retrieved. + /// An optional key that may be used to help identify the target service. + /// The found object, otherwise . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + TService? GetService(object? key = null) + where TService : class; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs new file mode 100644 index 00000000000..a920afaef17 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Indicates that a chat tool must be called. It may optionally nominate a specific function, +/// or if not, indicates that any of them may be selected. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class RequiredChatToolMode : ChatToolMode +{ + /// + /// Gets the name of a specific that must be called. + /// + /// + /// If the value is , any available function may be selected (but at least one must be). + /// + public string? RequiredFunctionName { get; } + + /// + /// Initializes a new instance of the class that requires a specific function to be called. + /// + /// The name of the function that must be called. + /// + /// may be . However, it is preferable to use + /// when any function may be selected. + /// + public RequiredChatToolMode(string? requiredFunctionName) + { + if (requiredFunctionName is not null) + { + _ = Throw.IfNullOrWhitespace(requiredFunctionName); + } + + RequiredFunctionName = requiredFunctionName; + } + + // The reason for not overriding Equals/GetHashCode (e.g., so two instances are equal if they + // have the same RequiredFunctionName) is to leave open the option to unseal the type in the + // future. If we did define equality based on RequiredFunctionName but a subclass added further + // fields, this would lead to wrong behavior unless the subclass author remembers to re-override + // Equals/GetHashCode as well, which they likely won't. + + /// Gets a string representing this instance to display in the debugger. + private string DebuggerDisplay => $"Required: {RequiredFunctionName ?? "Any"}"; + + /// + public override bool Equals(object? obj) => + obj is RequiredChatToolMode other && + RequiredFunctionName == other.RequiredFunctionName; + + /// + public override int GetHashCode() => + RequiredFunctionName?.GetHashCode(StringComparison.Ordinal) ?? + typeof(RequiredChatToolMode).GetHashCode(); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs new file mode 100644 index 00000000000..8192e017f7e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/StreamingChatCompletionUpdate.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +// Conceptually this combines the roles of ChatCompletion and ChatMessage in streaming output. +// For ease of consumption, it also flattens the nested structure you see on streaming chunks in +// the OpenAI/Gemini APIs, so instead of a dictionary of choices, each update represents a single +// choice (and hence has its own role, choice ID, etc.). + +/// +/// Represents a single response chunk from an . +/// +public class StreamingChatCompletionUpdate +{ + /// The completion update content items. + private IList? _contents; + + /// The name of the author of the update. + private string? _authorName; + + /// Gets or sets the name of the author of the completion update. + public string? AuthorName + { + get => _authorName; + set => _authorName = string.IsNullOrWhiteSpace(value) ? null : value; + } + + /// Gets or sets the role of the author of the completion update. + public ChatRole? Role { get; set; } + + /// + /// Gets or sets the text of the first instance in . + /// + /// + /// If there is no instance in , then the getter returns , + /// and the setter will add new instance with the provided value. + /// + [JsonIgnore] + public string? Text + { + get => Contents.OfType().FirstOrDefault()?.Text; + set + { + if (Contents.OfType().FirstOrDefault() is { } textContent) + { + textContent.Text = value; + } + else if (value is not null) + { + Contents.Add(new TextContent(value)); + } + } + } + + /// Gets or sets the chat completion update content items. + [AllowNull] + public IList Contents + { + get => _contents ??= []; + set => _contents = value; + } + + /// Gets or sets the raw representation of the completion update from an underlying implementation. + /// + /// If a is created to represent some underlying object from another object + /// model, this property can be used to store that original object. This can be useful for debugging or + /// for enabling a consumer to access the underlying object model if needed. + /// + [JsonIgnore] + public object? RawRepresentation { get; set; } + + /// Gets or sets additional properties for the update. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// Gets or sets the ID of the completion of which this update is a part. + public string? CompletionId { get; set; } + + /// Gets or sets a timestamp for the completion update. + public DateTimeOffset? CreatedAt { get; set; } + + /// Gets or sets the zero-based index of the choice with which this update is associated in the streaming sequence. + public int ChoiceIndex { get; set; } + + /// Gets or sets the finish reason for the operation. + public ChatFinishReason? FinishReason { get; set; } + + /// + public override string ToString() => Text ?? string.Empty; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs new file mode 100644 index 00000000000..456ee4940c2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AIContent.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// Provides a base class for all content used with AI services. +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +[JsonDerivedType(typeof(AudioContent), typeDiscriminator: "audio")] +[JsonDerivedType(typeof(DataContent), typeDiscriminator: "data")] +[JsonDerivedType(typeof(FunctionCallContent), typeDiscriminator: "functionCall")] +[JsonDerivedType(typeof(FunctionResultContent), typeDiscriminator: "functionResult")] +[JsonDerivedType(typeof(ImageContent), typeDiscriminator: "image")] +[JsonDerivedType(typeof(TextContent), typeDiscriminator: "text")] +[JsonDerivedType(typeof(UsageContent), typeDiscriminator: "usage")] +public class AIContent +{ + /// + /// Initializes a new instance of the class. + /// + protected AIContent() + { + } + + /// Gets or sets the raw representation of the content from an underlying implementation. + /// + /// If an is created to represent some underlying object from another object + /// model, this property can be used to store that original object. This can be useful for debugging or + /// for enabling a consumer to access the underlying object model if needed. + /// + [JsonIgnore] + public object? RawRepresentation { get; set; } + + /// + /// Gets or sets the model ID used to generate the content. + /// + public string? ModelId { get; set; } + + /// Gets or sets additional properties for the content. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AudioContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AudioContent.cs new file mode 100644 index 00000000000..84354a95b1d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/AudioContent.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents audio content. +/// +public class AudioContent : DataContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + public AudioContent(Uri uri, string? mediaType = null) + : base(uri, mediaType) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + [JsonConstructor] + public AudioContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null) + : base(uri, mediaType) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The byte contents. + /// The media type (also known as MIME type) represented by the content. + public AudioContent(ReadOnlyMemory data, string? mediaType = null) + : base(data, mediaType) + { + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs new file mode 100644 index 00000000000..5ed17aae1b5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataContent.cs @@ -0,0 +1,196 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S3996 // URI properties should not be strings +#pragma warning disable CA1056 // URI-like properties should not be strings + +namespace Microsoft.Extensions.AI; + +/// +/// Represents data content, such as an image or audio. +/// +/// +/// +/// The represented content may either be the actual bytes stored in this instance, or it may +/// be a URI that references the location of the content. +/// +/// +/// always returns a valid URI string, even if the instance was constructed from +/// a . In that case, a data URI will be constructed and returned. +/// +/// +public class DataContent : AIContent +{ + // Design note: + // Ideally DataContent would be based in terms of Uri. However, Uri has a length limitation that makes it prohibitive + // for the kinds of data URIs necessary to support here. As such, this type is based in strings. + + /// The string-based representation of the URI, including any data in the instance. + private string? _uri; + + /// The data, lazily-initialized if the data is provided in a data URI. + private ReadOnlyMemory? _data; + + /// Parsed data URI information. + private DataUriParser.DataUri? _dataUri; + + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + public DataContent(Uri uri, string? mediaType = null) + : this(Throw.IfNull(uri).ToString(), mediaType) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + [JsonConstructor] + public DataContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null) + { + _uri = Throw.IfNullOrWhitespace(uri); + + ValidateMediaType(ref mediaType); + MediaType = mediaType; + + if (uri.StartsWith(DataUriParser.Scheme, StringComparison.OrdinalIgnoreCase)) + { + _dataUri = DataUriParser.Parse(uri.AsMemory()); + + // If the data URI contains a media type that's different from a non-null media type + // explicitly provided, prefer the one explicitly provided as an override. + if (MediaType is not null) + { + if (MediaType != _dataUri.MediaType) + { + // Extract the bytes from the data URI and null out the uri. + // Then we'll lazily recreate it later if needed based on the updated media type. + _data = _dataUri.ToByteArray(); + _dataUri = null; + _uri = null; + } + } + else + { + MediaType = _dataUri.MediaType; + } + } + else if (!System.Uri.TryCreate(uri, UriKind.Absolute, out _)) + { + throw new UriFormatException("The URI is not well-formed."); + } + } + + /// + /// Initializes a new instance of the class. + /// + /// The byte contents. + /// The media type (also known as MIME type) represented by the content. + public DataContent(ReadOnlyMemory data, string? mediaType = null) + { + ValidateMediaType(ref mediaType); + MediaType = mediaType; + + _data = data; + } + + /// Sets to null if it's empty or composed entirely of whitespace. + private static void ValidateMediaType(ref string? mediaType) + { + if (!DataUriParser.IsValidMediaType(mediaType.AsSpan(), ref mediaType)) + { + Throw.ArgumentException(nameof(mediaType), "Invalid media type."); + } + } + + /// Gets the URI for this . + /// + /// The returned URI is always a valid URI string, even if the instance was constructed from a + /// or from a . In the case of a , this will return a data URI containing + /// that data. + /// + [StringSyntax(StringSyntaxAttribute.Uri)] + public string Uri + { + get + { + if (_uri is null) + { + if (_dataUri is null) + { + Debug.Assert(Data is not null, "Expected Data to be initialized."); + _uri = string.Concat("data:", MediaType, ";base64,", Convert.ToBase64String(Data.GetValueOrDefault() +#if NET + .Span)); +#else + .Span.ToArray())); +#endif + } + else + { + _uri = _dataUri.IsBase64 ? +#if NET + $"data:{MediaType};base64,{_dataUri.Data.Span}" : + $"data:{MediaType};,{_dataUri.Data.Span}"; +#else + $"data:{MediaType};base64,{_dataUri.Data}" : + $"data:{MediaType};,{_dataUri.Data}"; +#endif + } + } + + return _uri; + } + } + + /// Gets the media type (also known as MIME type) of the content. + /// + /// If the media type was explicitly specified, this property will return that value. + /// If the media type was not explicitly specified, but a data URI was supplied and that data URI contained a non-default + /// media type, that media type will be returned. + /// Otherwise, this will return null. + /// + [JsonPropertyOrder(1)] + public string? MediaType { get; private set; } + + /// + /// Gets a value indicating whether the content contains data rather than only being a reference to data. + /// + /// + /// If the instance is constructed from a or from a data URI, this property will return , + /// as the instance actually contains all of the data it represents. If, however, the instance was constructed from another form of URI, one + /// that simply references where the data can be found but doesn't actually contain the data, this property will return . + /// + [JsonIgnore] + public bool ContainsData => _dataUri is not null || _data is not null; + + /// Gets the data represented by this instance. + /// + /// If is , this property will return the represented data. + /// If is , this property will return . + /// + [MemberNotNullWhen(true, nameof(ContainsData))] + [JsonIgnore] + public ReadOnlyMemory? Data + { + get + { + if (_dataUri is not null) + { + _data ??= _dataUri.ToByteArray(); + } + + return _data; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs new file mode 100644 index 00000000000..5cb33d1a55c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/DataUriParser.cs @@ -0,0 +1,182 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +#if NET8_0_OR_GREATER +using System.Buffers.Text; +#endif +using System.Diagnostics; +using System.Net; +using System.Net.Http.Headers; +using System.Text; + +namespace Microsoft.Extensions.AI; + +/// +/// Minimal data URI parser based on RFC 2397: https://datatracker.ietf.org/doc/html/rfc2397. +/// +internal static class DataUriParser +{ + public static string Scheme => "data:"; + + public static DataUri Parse(ReadOnlyMemory dataUri) + { + // Validate, then trim off the "data:" scheme. + if (!dataUri.Span.StartsWith(Scheme.AsSpan(), StringComparison.OrdinalIgnoreCase)) + { + throw new UriFormatException("Invalid data URI format: the data URI must start with 'data:'."); + } + + dataUri = dataUri.Slice(Scheme.Length); + + // Find the comma separating the metadata from the data. + int commaPos = dataUri.Span.IndexOf(','); + if (commaPos < 0) + { + throw new UriFormatException("Invalid data URI format: the data URI must contain a comma separating the metadata and the data."); + } + + ReadOnlyMemory metadata = dataUri.Slice(0, commaPos); + + ReadOnlyMemory data = dataUri.Slice(commaPos + 1); + bool isBase64 = false; + + // Determine whether the data is Base64-encoded or percent-encoded (Uri-encoded). + // If it's base64-encoded, validate it. If it's Uri-encoded, there's nothing to validate, + // as WebUtility.UrlDecode will successfully decode any input with no sequence considered invalid. + if (metadata.Span.EndsWith(";base64".AsSpan(), StringComparison.OrdinalIgnoreCase)) + { + metadata = metadata.Slice(0, metadata.Length - ";base64".Length); + isBase64 = true; + if (!IsValidBase64Data(data.Span)) + { + throw new UriFormatException("Invalid data URI format: the data URI is base64-encoded, but the data is not a valid base64 string."); + } + } + + // Validate the media type, if present. + string? mediaType = null; + if (!IsValidMediaType(metadata.Span.Trim(), ref mediaType)) + { + throw new UriFormatException("Invalid data URI format: the media type is not a valid."); + } + + return new DataUri(data, isBase64, mediaType); + } + + /// Validates that a media type is valid, and if successful, ensures we have it as a string. + public static bool IsValidMediaType(ReadOnlySpan mediaTypeSpan, ref string? mediaType) + { + Debug.Assert( + mediaType is null || mediaTypeSpan.Equals(mediaType.AsSpan(), StringComparison.Ordinal), + "mediaType string should either be null or the same as the span"); + + // If the media type is empty or all whitespace, normalize it to null. + if (mediaTypeSpan.IsWhiteSpace()) + { + mediaType = null; + return true; + } + + // For common media types, we can avoid both allocating a string for the span and avoid parsing overheads. + string? knownType = mediaTypeSpan switch + { + "application/json" => "application/json", + "application/octet-stream" => "application/octet-stream", + "application/pdf" => "application/pdf", + "application/xml" => "application/xml", + "audio/mpeg" => "audio/mpeg", + "audio/ogg" => "audio/ogg", + "audio/wav" => "audio/wav", + "image/apng" => "image/apng", + "image/avif" => "image/avif", + "image/bmp" => "image/bmp", + "image/gif" => "image/gif", + "image/jpeg" => "image/jpeg", + "image/png" => "image/png", + "image/svg+xml" => "image/svg+xml", + "image/tiff" => "image/tiff", + "image/webp" => "image/webp", + "text/css" => "text/css", + "text/csv" => "text/csv", + "text/html" => "text/html", + "text/javascript" => "text/javascript", + "text/plain" => "text/plain", + "text/plain;charset=UTF-8" => "text/plain;charset=UTF-8", + "text/xml" => "text/xml", + _ => null, + }; + if (knownType is not null) + { + mediaType ??= knownType; + return true; + } + + // Otherwise, do the full validation using the same logic as HttpClient. + mediaType ??= mediaTypeSpan.ToString(); + return MediaTypeHeaderValue.TryParse(mediaType, out _); + } + + /// Test whether the value is a base64 string without whitespace. + private static bool IsValidBase64Data(ReadOnlySpan value) + { + if (value.IsEmpty) + { + return true; + } + +#if NET8_0_OR_GREATER + return Base64.IsValid(value) && !value.ContainsAny(" \t\r\n"); +#else +#pragma warning disable S109 // Magic numbers should not be used + if (value!.Length % 4 != 0) +#pragma warning restore S109 + { + return false; + } + + var index = value.Length - 1; + + // Step back over one or two padding chars + if (value[index] == '=') + { + index--; + } + + if (value[index] == '=') + { + index--; + } + + // Now traverse over characters + for (var i = 0; i <= index; i++) + { +#pragma warning disable S1067 // Expressions should not be too complex + bool validChar = value[i] is (>= 'A' and <= 'Z') or (>= 'a' and <= 'z') or (>= '0' and <= '9') or '+' or '/'; +#pragma warning restore S1067 + if (!validChar) + { + return false; + } + } + + return true; +#endif + } + + /// Provides the parts of a parsed data URI. + public sealed class DataUri(ReadOnlyMemory data, bool isBase64, string? mediaType) + { +#pragma warning disable S3604 // False positive: Member initializer values should not be redundant + public string? MediaType { get; } = mediaType; + + public ReadOnlyMemory Data { get; } = data; + + public bool IsBase64 { get; } = isBase64; +#pragma warning restore S3604 + + public byte[] ToByteArray() => IsBase64 ? + Convert.FromBase64String(Data.ToString()) : + Encoding.UTF8.GetBytes(WebUtility.UrlDecode(Data.ToString())); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs new file mode 100644 index 00000000000..7eefdd90a09 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallContent.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents a function call request. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class FunctionCallContent : AIContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The function call ID. + /// The function name. + /// The function original arguments. + [JsonConstructor] + public FunctionCallContent(string callId, string name, IDictionary? arguments = null) + { + Name = Throw.IfNull(name); + CallId = callId; + Arguments = arguments; + } + + /// + /// Gets or sets the function call ID. + /// + public string CallId { get; set; } + + /// + /// Gets or sets the name of the function requested. + /// + public string Name { get; set; } + + /// + /// Gets or sets the arguments requested to be provided to the function. + /// + public IDictionary? Arguments { get; set; } + + /// + /// Gets or sets any exception that occurred while mapping the original function call data to this class. + /// + /// + /// When an instance of is serialized using , any exception + /// stored in this property will be serialized as a string. When deserialized, the string will be converted back to an instance + /// of the base type. As such, consumers shouldn't rely on the exact type of the exception stored in this property. + /// + [JsonConverter(typeof(FunctionCallExceptionConverter))] + public Exception? Exception { get; set; } + + /// Gets a string representing this instance to display in the debugger. + private string DebuggerDisplay + { + get + { + string display = CallId is not null ? + $"CallId = {CallId}, " : + string.Empty; + + display += Arguments is not null ? + $"Call = {Name}({string.Join(", ", Arguments)})" : + $"Call = {Name}()"; + + return display; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallExceptionConverter.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallExceptionConverter.cs new file mode 100644 index 00000000000..0c36f11ca40 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallExceptionConverter.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ComponentModel; +#if NET +using System.Runtime.ExceptionServices; +#endif +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Serializes an exception as a string and deserializes it back as a base containing that contents as a message. +[EditorBrowsable(EditorBrowsableState.Never)] +public sealed class FunctionCallExceptionConverter : JsonConverter +{ + private const string ClassNamePropertyName = "className"; + private const string MessagePropertyName = "message"; + private const string InnerExceptionPropertyName = "innerException"; + private const string StackTracePropertyName = "stackTraceString"; + + /// + public override void Write(Utf8JsonWriter writer, Exception value, JsonSerializerOptions options) + { + _ = Throw.IfNull(writer); + _ = Throw.IfNull(value); + + // Schema and property order taken from Exception.GetObjectData() implementation. + + writer.WriteStartObject(); + writer.WriteString(ClassNamePropertyName, value.GetType().ToString()); + writer.WriteString(MessagePropertyName, value.Message); + writer.WritePropertyName(InnerExceptionPropertyName); + if (value.InnerException is Exception innerEx) + { + Write(writer, innerEx, options); + } + else + { + writer.WriteNullValue(); + } + + writer.WriteString(StackTracePropertyName, value.StackTrace); + writer.WriteEndObject(); + } + + /// + public override Exception? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (reader.TokenType != JsonTokenType.StartObject) + { + throw new JsonException(); + } + + using var doc = JsonDocument.ParseValue(ref reader); + return ParseExceptionCore(doc.RootElement); + + static Exception ParseExceptionCore(JsonElement element) + { + string? message = null; + string? stackTrace = null; + Exception? innerEx = null; + + foreach (JsonProperty property in element.EnumerateObject()) + { + switch (property.Name) + { + case MessagePropertyName: + message = property.Value.GetString(); + break; + + case StackTracePropertyName: + stackTrace = property.Value.GetString(); + break; + + case InnerExceptionPropertyName when property.Value.ValueKind is not JsonValueKind.Null: + innerEx = ParseExceptionCore(property.Value); + break; + } + } + +#pragma warning disable CA2201 // Do not raise reserved exception types + Exception result = new(message, innerEx); +#pragma warning restore CA2201 +#if NET + if (stackTrace != null) + { + ExceptionDispatchInfo.SetRemoteStackTrace(result, stackTrace); + } +#endif + return result; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs new file mode 100644 index 00000000000..42eb486f4c1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionCallHelpers.cs @@ -0,0 +1,378 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +using FunctionParameterKey = (System.Type? Type, string ParameterName, string? Description, bool HasDefaultValue, object? DefaultValue); + +namespace Microsoft.Extensions.AI; + +/// Provides a collection of static utility methods for marshalling JSON data in function calling. +internal static partial class FunctionCallHelpers +{ + /// Soft limit for how many items should be stored in the dictionaries in . + private const int CacheSoftLimit = 4096; + + /// Caches of generated schemas for each that's employed. + private static readonly ConditionalWeakTable> _schemaCaches = new(); + + /// Gets a JSON schema accepting all values. + private static JsonElement TrueJsonSchema { get; } = ParseJsonElement("true"u8); + + /// Gets a JSON schema only accepting null values. + private static JsonElement NullJsonSchema { get; } = ParseJsonElement("""{"type":"null"}"""u8); + + /// Parses a JSON object into a dictionary of objects encoded as . + /// A JSON object containing the parameters. + /// If the parsing fails, the resulting exception. + /// The parsed dictionary of objects encoded as . + public static Dictionary? ParseFunctionCallArguments(string json, out Exception? parsingException) + { + _ = Throw.IfNull(json); + + parsingException = null; + try + { + return JsonSerializer.Deserialize(json, FunctionCallHelperContext.Default.DictionaryStringObject); + } + catch (JsonException ex) + { + parsingException = new InvalidOperationException($"Function call arguments contained invalid JSON: {json}", ex); + return null; + } + } + + /// Parses a JSON object into a dictionary of objects encoded as . + /// A UTF-8 encoded JSON object containing the parameters. + /// If the parsing fails, the resulting exception. + /// The parsed dictionary of objects encoded as . + public static Dictionary? ParseFunctionCallArguments(ReadOnlySpan utf8Json, out Exception? parsingException) + { + parsingException = null; + try + { + return JsonSerializer.Deserialize(utf8Json, FunctionCallHelperContext.Default.DictionaryStringObject); + } + catch (JsonException ex) + { + parsingException = new InvalidOperationException($"Function call arguments contained invalid JSON: {Encoding.UTF8.GetString(utf8Json.ToArray())}", ex); + return null; + } + } + + /// + /// Serializes a dictionary of function parameters into a JSON string. + /// + /// The dictionary of parameters. + /// A governing serialization. + /// A JSON encoding of the parameters. + public static string FormatFunctionParametersAsJson(IDictionary? parameters, JsonSerializerOptions? options = null) + { + // Fall back to the built-in context since in most cases the return value is JsonElement or JsonNode. + options ??= FunctionCallHelperContext.Default.Options; + options.MakeReadOnly(); + return JsonSerializer.Serialize(parameters, options.GetTypeInfo(typeof(IDictionary))); + } + + /// + /// Serializes a dictionary of function parameters into a . + /// + /// The dictionary of parameters. + /// A governing serialization. + /// A JSON encoding of the parameters. + public static JsonElement FormatFunctionParametersAsJsonElement(IDictionary? parameters, JsonSerializerOptions? options = null) + { + // Fall back to the built-in context since in most cases the return value is JsonElement or JsonNode. + options ??= FunctionCallHelperContext.Default.Options; + options.MakeReadOnly(); + return JsonSerializer.SerializeToElement(parameters, options.GetTypeInfo(typeof(IDictionary))); + } + + /// + /// Serializes a .NET function return parameter to a JSON string. + /// + /// The result value to be serialized. + /// A governing serialization. + /// A JSON encoding of the parameter. + public static string FormatFunctionResultAsJson(object? result, JsonSerializerOptions? options = null) + { + // Fall back to the built-in context since in most cases the return value is JsonElement or JsonNode. + options ??= FunctionCallHelperContext.Default.Options; + options.MakeReadOnly(); + return JsonSerializer.Serialize(result, options.GetTypeInfo(typeof(object))); + } + + /// + /// Serializes a .NET function return parameter to a JSON element. + /// + /// The result value to be serialized. + /// A governing serialization. + /// A JSON encoding of the parameter. + public static JsonElement FormatFunctionResultAsJsonElement(object? result, JsonSerializerOptions? options = null) + { + // Fall back to the built-in context since in most cases the return value is JsonElement or JsonNode. + options ??= FunctionCallHelperContext.Default.Options; + options.MakeReadOnly(); + return JsonSerializer.SerializeToElement(result, options.GetTypeInfo(typeof(object))); + } + + /// + /// Determines a JSON schema for the provided parameter metadata. + /// + /// The parameter metadata from which to infer the schema. + /// The containing function metadata. + /// The global governing serialization. + /// A JSON schema document encoded as a . + public static JsonElement InferParameterJsonSchema( + AIFunctionParameterMetadata parameterMetadata, + AIFunctionMetadata functionMetadata, + JsonSerializerOptions? options) + { + options ??= functionMetadata.JsonSerializerOptions; + + if (ReferenceEquals(options, functionMetadata.JsonSerializerOptions) && + parameterMetadata.Schema is JsonElement schema) + { + // If the resolved options matches that of the function metadata, + // we can just return the precomputed JSON schema value. + return schema; + } + + if (options is null) + { + return TrueJsonSchema; + } + + return InferParameterJsonSchema( + parameterMetadata.ParameterType, + parameterMetadata.Name, + parameterMetadata.Description, + parameterMetadata.HasDefaultValue, + parameterMetadata.DefaultValue, + options); + } + + /// + /// Determines a JSON schema for the provided parameter metadata. + /// + /// The type of the parameter. + /// The name of the parameter. + /// The description of the parameter. + /// Whether the parameter is optional. + /// The default value of the optional parameter, if applicable. + /// The options used to extract the schema from the specified type. + /// A JSON schema document encoded as a . + public static JsonElement InferParameterJsonSchema( + Type? type, + string name, + string? description, + bool hasDefaultValue, + object? defaultValue, + JsonSerializerOptions options) + { + _ = Throw.IfNull(name); + _ = Throw.IfNull(options); + + options.MakeReadOnly(); + + try + { + ConcurrentDictionary cache = _schemaCaches.GetOrCreateValue(options); + FunctionParameterKey key = new(type, name, description, hasDefaultValue, defaultValue); + + if (cache.Count > CacheSoftLimit) + { + return GetJsonSchemaCore(options, key); + } + + return cache.GetOrAdd( + key: key, +#if NET + valueFactory: static (key, options) => GetJsonSchemaCore(options, key), + factoryArgument: options); +#else + valueFactory: key => GetJsonSchemaCore(options, key)); +#endif + } + catch (ArgumentException) + { + // Invalid type; ignore, and leave schema as null. + // This should be exceedingly rare, as we checked for all known category of + // problematic types above. If it becomes more common that schema creation + // could fail expensively, we'll want to track whether inference was already + // attempted and avoid doing so on subsequent accesses if it was. + return TrueJsonSchema; + } + } + + /// Infers a JSON schema from the return parameter. + /// The type of the return parameter. + /// The options used to extract the schema from the specified type. + /// A representing the schema. + public static JsonElement InferReturnParameterJsonSchema(Type? type, JsonSerializerOptions options) + { + _ = Throw.IfNull(options); + + options.MakeReadOnly(); + + // If there's no type, just return a schema that allows anything. + if (type is null) + { + return TrueJsonSchema; + } + + if (type == typeof(void)) + { + return NullJsonSchema; + } + + JsonNode node = options.GetJsonSchemaAsNode(type); + return JsonSerializer.SerializeToElement(node, FunctionCallHelperContext.Default.JsonNode); + } + + private static JsonElement GetJsonSchemaCore(JsonSerializerOptions options, FunctionParameterKey key) + { + _ = Throw.IfNull(options); + + if (options.ReferenceHandler == ReferenceHandler.Preserve) + { + throw new NotSupportedException("Schema generation not supported with ReferenceHandler.Preserve enabled."); + } + + if (key.Type is null) + { + // For parameters without a type generate a rudimentary schema with available metadata. + + JsonObject schemaObj = []; + if (key.Description is not null) + { + schemaObj["description"] = key.Description; + } + + if (key.HasDefaultValue) + { + JsonNode? defaultValueNode = key.DefaultValue is { } defaultValue + ? JsonSerializer.Serialize(defaultValue, options.GetTypeInfo(defaultValue.GetType())) + : null; + + schemaObj["default"] = defaultValueNode; + } + + return JsonSerializer.SerializeToElement(schemaObj, FunctionCallHelperContext.Default.JsonNode); + } + + options.MakeReadOnly(); + + JsonSchemaExporterOptions exporterOptions = new() + { + TreatNullObliviousAsNonNullable = true, + TransformSchemaNode = TransformSchemaNode, + }; + + JsonNode node = options.GetJsonSchemaAsNode(key.Type, exporterOptions); + return JsonSerializer.SerializeToElement(node, FunctionCallHelperContext.Default.JsonNode); + + JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) + { + const string DescriptionPropertyName = "description"; + const string NotPropertyName = "not"; + const string PropertiesPropertyName = "properties"; + const string DefaultPropertyName = "default"; + const string RefPropertyName = "$ref"; + + // Find the first DescriptionAttribute, starting first from the property, then the parameter, and finally the type itself. + Type descAttrType = typeof(DescriptionAttribute); + var descriptionAttribute = + GetAttrs(descAttrType, ctx.PropertyInfo?.AttributeProvider)?.FirstOrDefault() ?? + GetAttrs(descAttrType, ctx.PropertyInfo?.AssociatedParameter?.AttributeProvider)?.FirstOrDefault() ?? + GetAttrs(descAttrType, ctx.TypeInfo.Type)?.FirstOrDefault(); + + if (descriptionAttribute is DescriptionAttribute attr) + { + ConvertSchemaToObject(ref schema).Insert(0, DescriptionPropertyName, (JsonNode)attr.Description); + } + + // If the type is recursive, the resulting schema will contain a $ref to the type itself. + // As JSON pointer doesn't support relative paths, we need to fix up such paths to accommodate + // the fact that they're being nested inside of a higher-level schema. + if (schema is JsonObject refObj && refObj.TryGetPropertyValue(RefPropertyName, out JsonNode? paramName)) + { + // Fix up any $ref URIs to match the path from the root document. + string refUri = paramName!.GetValue(); + Debug.Assert(refUri is "#" || refUri.StartsWith("#/", StringComparison.Ordinal), $"Expected {nameof(refUri)} to be either # or start with #/, got {refUri}"); + refUri = refUri == "#" + ? $"#/{PropertiesPropertyName}/{key.ParameterName}" + : $"#/{PropertiesPropertyName}/{key.ParameterName}/{refUri.AsMemory("#/".Length)}"; + + refObj[RefPropertyName] = (JsonNode)refUri; + } + + if (ctx.Path.IsEmpty) + { + // We are at the root-level schema node, append parameter-specific metadata + + if (!string.IsNullOrWhiteSpace(key.Description)) + { + ConvertSchemaToObject(ref schema).Insert(0, DescriptionPropertyName, (JsonNode)key.Description!); + } + + if (key.HasDefaultValue) + { + JsonNode? defaultValue = JsonSerializer.Serialize(key.DefaultValue, options.GetTypeInfo(typeof(object))); + ConvertSchemaToObject(ref schema)[DefaultPropertyName] = defaultValue; + } + } + + return schema; + + static object[]? GetAttrs(Type attrType, ICustomAttributeProvider? provider) => + provider?.GetCustomAttributes(attrType, inherit: false); + + static JsonObject ConvertSchemaToObject(ref JsonNode schema) + { + JsonObject obj; + JsonValueKind kind = schema.GetValueKind(); + switch (kind) + { + case JsonValueKind.Object: + return (JsonObject)schema; + + case JsonValueKind.False: + schema = obj = new() { [NotPropertyName] = true }; + return obj; + + default: + Debug.Assert(kind is JsonValueKind.True, $"Invalid schema type: {kind}"); + schema = obj = []; + return obj; + } + } + } + } + + private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) + { + Utf8JsonReader reader = new(utf8Json); + return JsonElement.ParseValue(ref reader); + } + + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(JsonNode))] + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(JsonDocument))] + private sealed partial class FunctionCallHelperContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs new file mode 100644 index 00000000000..0a416d64f5f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/FunctionResultContent.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents the result of a function call. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class FunctionResultContent : AIContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The function call ID for which this is the result. + /// The function name that produced the result. + /// The function call result. + /// Any exception that occurred when invoking the function. + [JsonConstructor] + public FunctionResultContent(string callId, string name, object? result = null, Exception? exception = null) + { + CallId = Throw.IfNull(callId); + Name = Throw.IfNull(name); + Result = result; + Exception = exception; + } + + /// + /// Initializes a new instance of the class. + /// + /// The function call for which this is the result. + /// The function call result. + /// Any exception that occurred when invoking the function. + public FunctionResultContent(FunctionCallContent functionCall, object? result = null, Exception? exception = null) + : this(Throw.IfNull(functionCall).CallId, functionCall.Name, result, exception) + { + } + + /// + /// Gets or sets the ID of the function call for which this is the result. + /// + /// + /// If this is the result for a , this should contain the same + /// value. + /// + public string CallId { get; set; } + + /// + /// Gets or sets the name of the function that was called. + /// + public string Name { get; set; } + + /// + /// Gets or sets the result of the function call, or a generic error message if the function call failed. + /// + public object? Result { get; set; } + + /// + /// Gets or sets an exception that occurred if the function call failed. + /// + /// + /// When an instance of is serialized using , any exception + /// stored in this property will be serialized as a string. When deserialized, the string will be converted back to an instance + /// of the base type. As such, consumers shouldn't rely on the exact type of the exception stored in this property. + /// + [JsonConverter(typeof(FunctionCallExceptionConverter))] + public Exception? Exception { get; set; } + + /// Gets a string representing this instance to display in the debugger. + private string DebuggerDisplay + { + get + { + string display = CallId is not null ? + $"CallId = {CallId}, " : + string.Empty; + + display += Exception is not null ? + $"Error = {Exception.Message}" : + $"Result = {Result?.ToString() ?? string.Empty}"; + + return display; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/ImageContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/ImageContent.cs new file mode 100644 index 00000000000..d376586c993 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/ImageContent.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents image content. +/// +public class ImageContent : DataContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + public ImageContent(Uri uri, string? mediaType = null) + : base(uri, mediaType) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The URI of the content. This may be a data URI. + /// The media type (also known as MIME type) represented by the content. + [JsonConstructor] + public ImageContent([StringSyntax(StringSyntaxAttribute.Uri)] string uri, string? mediaType = null) + : base(uri, mediaType) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The byte contents. + /// The media type (also known as MIME type) represented by the content. + public ImageContent(ReadOnlyMemory data, string? mediaType = null) + : base(data, mediaType) + { + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/TextContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/TextContent.cs new file mode 100644 index 00000000000..d81e969e1c4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/TextContent.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +/// +/// Represents text content in a chat. +/// +public sealed class TextContent : AIContent +{ + /// + /// Initializes a new instance of the class. + /// + /// The text content. + public TextContent(string? text) + { + Text = text; + } + + /// + /// Gets or sets the text content. + /// + public string? Text { get; set; } + + /// + public override string ToString() => Text ?? string.Empty; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UsageContent.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UsageContent.cs new file mode 100644 index 00000000000..22d86bd97cb --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Contents/UsageContent.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Represents usage information associated with a chat response. +/// +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public class UsageContent : AIContent +{ + /// Usage information. + private UsageDetails _details; + + /// Initializes a new instance of the class with an empty . + public UsageContent() + { + _details = new(); + } + + /// Initializes a new instance of the class with the specified instance. + /// The usage details to store in this content. + [JsonConstructor] + public UsageContent(UsageDetails details) + { + _details = Throw.IfNull(details); + } + + /// Gets or sets the usage information. + public UsageDetails Details + { + get => _details; + set => _details = Throw.IfNull(value); + } + + /// Gets a string representing this instance to display in the debugger. + private string DebuggerDisplay => _details.DebuggerDisplay; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs new file mode 100644 index 00000000000..6b06d32d6d7 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/DelegatingEmbeddingGenerator.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides an optional base class for an that passes through calls to another instance. +/// +/// Specifies the type of the input passed to the generator. +/// Specifies the type of the embedding instance produced by the generator. +/// +/// This is recommended as a base type when building generators that can be chained in any order around an underlying . +/// The default implementation simply passes each call to the inner generator instance. +/// +public class DelegatingEmbeddingGenerator : IEmbeddingGenerator + where TEmbedding : Embedding +{ + /// + /// Initializes a new instance of the class. + /// + /// The wrapped generator instance. + protected DelegatingEmbeddingGenerator(IEmbeddingGenerator innerGenerator) + { + InnerGenerator = Throw.IfNull(innerGenerator); + } + + /// Gets the inner . + protected IEmbeddingGenerator InnerGenerator { get; } + + /// + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// Provides a mechanism for releasing unmanaged resources. + /// true if being called from ; otherwise, false. + protected virtual void Dispose(bool disposing) + { + if (disposing) + { + InnerGenerator.Dispose(); + } + } + + /// + public virtual EmbeddingGeneratorMetadata Metadata => + InnerGenerator.Metadata; + + /// + public virtual Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) => + InnerGenerator.GenerateAsync(values, options, cancellationToken); + + /// + public virtual TService? GetService(object? key = null) + where TService : class + { +#pragma warning disable S3060 // "is" should not be used with "this" + // If the key is non-null, we don't know what it means so pass through to the inner service + return key is null && this is TService service ? service : InnerGenerator.GetService(key); +#pragma warning restore S3060 + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs new file mode 100644 index 00000000000..e70469eaed3 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +/// Represents an embedding generated by a . +/// This base class provides metadata about the embedding. Derived types provide the concrete data contained in the embedding. +[JsonPolymorphic(TypeDiscriminatorPropertyName = "$type")] +#if NET +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "halves")] +#endif +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "floats")] +[JsonDerivedType(typeof(Embedding), typeDiscriminator: "doubles")] +public class Embedding +{ + /// Initializes a new instance of the class. + protected Embedding() + { + } + + /// Gets or sets a timestamp at which the embedding was created. + public DateTimeOffset? CreatedAt { get; set; } + + /// Gets or sets the model ID using in the creation of the embedding. + public string? ModelId { get; set; } + + /// Gets or sets any additional properties associated with the embedding. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs new file mode 100644 index 00000000000..bd010d5f447 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGenerationOptions.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +/// Represents the options for an embedding generation request. +public class EmbeddingGenerationOptions +{ + /// Gets or sets the model ID for the embedding generation request. + public string? ModelId { get; set; } + + /// Gets or sets additional properties for the embedding generation request. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// Produces a clone of the current instance. + /// A clone of the current instance. + /// + /// The clone will have the same values for all properties as the original instance. Any collections, like + /// are shallow-cloned, meaning a new collection instance is created, but any references contained by the collections are shared with the original. + /// + public virtual EmbeddingGenerationOptions Clone() + { + EmbeddingGenerationOptions options = new() + { + ModelId = ModelId, + }; + + if (AdditionalProperties is not null) + { + options.AdditionalProperties = new(AdditionalProperties); + } + + return options; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs new file mode 100644 index 00000000000..fa2a1df4fbe --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorExtensions.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides a collection of static methods for extending instances. +public static class EmbeddingGeneratorExtensions +{ + /// Generates an embedding from the specified . + /// The type from which embeddings will be generated. + /// The numeric type of the embedding data. + /// The embedding generator. + /// A value from which an embedding will be generated. + /// The embedding generation options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The generated embedding for the specified . + public static Task> GenerateAsync( + this IEmbeddingGenerator generator, + TValue value, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + where TEmbedding : Embedding + { + _ = Throw.IfNull(generator); + _ = Throw.IfNull(value); + + return generator.GenerateAsync([value], options, cancellationToken); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs new file mode 100644 index 00000000000..39bdd61d3ae --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/EmbeddingGeneratorMetadata.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +/// Provides metadata about an . +public class EmbeddingGeneratorMetadata +{ + /// Initializes a new instance of the class. + /// The name of the embedding generation provider, if applicable. + /// The URL for accessing the embedding generation provider, if applicable. + /// The id of the embedding generation model used, if applicable. + /// The number of dimensions in vectors produced by this generator, if applicable. + public EmbeddingGeneratorMetadata(string? providerName = null, Uri? providerUri = null, string? modelId = null, int? dimensions = null) + { + ModelId = modelId; + ProviderName = providerName; + ProviderUri = providerUri; + Dimensions = dimensions; + } + + /// Gets the name of the embedding generation provider. + public string? ProviderName { get; } + + /// Gets the URL for accessing the embedding generation provider. + public Uri? ProviderUri { get; } + + /// Gets the id of the model used by this embedding generation provider. + /// This may be null if either the name is unknown or there are multiple possible models associated with this instance. + public string? ModelId { get; } + + /// Gets the number of dimensions in the embeddings produced by this instance. + public int? Dimensions { get; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs new file mode 100644 index 00000000000..c80e20dfda4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/Embedding{T}.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +/// Represents an embedding composed of a vector of values. +/// The type of the values in the embedding vector. +/// Typical values of are , , or Half. +public sealed class Embedding : Embedding +{ + /// Initializes a new instance of the class with the embedding vector. + /// The embedding vector this embedding represents. + public Embedding(ReadOnlyMemory vector) + { + Vector = vector; + } + + /// Gets or sets the embedding vector this embedding represents. + public ReadOnlyMemory Vector { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/GeneratedEmbeddings.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/GeneratedEmbeddings.cs new file mode 100644 index 00000000000..e983dd3b64b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/GeneratedEmbeddings.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the result of an operation to generate embeddings. +/// Specifies the type of the generated embeddings. +public sealed class GeneratedEmbeddings : IList, IReadOnlyList + where TEmbedding : Embedding +{ + /// The underlying list of embeddings. + private List _embeddings; + + /// Initializes a new instance of the class. + public GeneratedEmbeddings() + { + _embeddings = []; + } + + /// Initializes a new instance of the class with the specified capacity. + /// The number of embeddings that the new list can initially store. + public GeneratedEmbeddings(int capacity) + { + _embeddings = new List(Throw.IfLessThan(capacity, 0)); + } + + /// + /// Initializes a new instance of the class that contains all of the embeddings from the specified collection. + /// + /// The collection whose embeddings are copied to the new list. + public GeneratedEmbeddings(IEnumerable embeddings) + { + _embeddings = new List(Throw.IfNull(embeddings)); + } + + /// Gets or sets usage details for the embeddings' generation. + public UsageDetails? Usage { get; set; } + + /// Gets or sets any additional properties associated with the embeddings. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// + public TEmbedding this[int index] + { + get => _embeddings[index]; + set => _embeddings[index] = value; + } + + /// + public int Count => _embeddings.Count; + + /// + bool ICollection.IsReadOnly => false; + + /// + public void Add(TEmbedding item) => _embeddings.Add(item); + + /// Adds the embeddings from the specified collection to the end of this list. + /// The collection whose elements should be added to this list. + public void AddRange(IEnumerable items) => _embeddings.AddRange(items); + + /// + public void Clear() => _embeddings.Clear(); + + /// + public bool Contains(TEmbedding item) => _embeddings.Contains(item); + + /// + public void CopyTo(TEmbedding[] array, int arrayIndex) => _embeddings.CopyTo(array, arrayIndex); + + /// + public IEnumerator GetEnumerator() => _embeddings.GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + public int IndexOf(TEmbedding item) => _embeddings.IndexOf(item); + + /// + public void Insert(int index, TEmbedding item) => _embeddings.Insert(index, item); + + /// + public bool Remove(TEmbedding item) => _embeddings.Remove(item); + + /// + public void RemoveAt(int index) => _embeddings.RemoveAt(index); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs new file mode 100644 index 00000000000..6c791ee2bf4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Embeddings/IEmbeddingGenerator.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +/// Represents a generator of embeddings. +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public interface IEmbeddingGenerator : IDisposable + where TEmbedding : Embedding +{ + /// Generates embeddings for each of the supplied . + /// The collection of values for which to generate embeddings. + /// The embedding generation options to configure the request. + /// The to monitor for cancellation requests. The default is . + /// The generated embeddings. + Task> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default); + + /// Gets metadata that describes the . + EmbeddingGeneratorMetadata Metadata { get; } + + /// Asks the for an object of type . + /// The type of the object to be retrieved. + /// An optional key that may be used to help identify the target service. + /// The found object, otherwise . + /// + /// The purpose of this method is to allow for the retrieval of strongly-typed services that may be provided by the , + /// including itself or any services it might be wrapping. + /// + TService? GetService(object? key = null) + where TService : class; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs new file mode 100644 index 00000000000..a4b5ecb5378 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; + +namespace Microsoft.Extensions.AI; + +/// Represents a function that can be described to an AI service and invoked. +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public abstract class AIFunction : AITool +{ + /// Gets metadata describing the function. + public abstract AIFunctionMetadata Metadata { get; } + + /// Invokes the and returns its result. + /// The arguments to pass to the function's invocation. + /// The to monitor for cancellation requests. The default is . + /// The result of the function's execution. + public Task InvokeAsync( + IEnumerable>? arguments = null, + CancellationToken cancellationToken = default) + { + arguments ??= EmptyReadOnlyDictionary.Instance; + + return InvokeCoreAsync(arguments, cancellationToken); + } + + /// + public override string ToString() => Metadata.Name; + + /// Invokes the and returns its result. + /// The arguments to pass to the function's invocation. + /// The to monitor for cancellation requests. + /// The result of the function's execution. + protected abstract Task InvokeCoreAsync( + IEnumerable> arguments, + CancellationToken cancellationToken); + + /// Gets the string to display in the debugger for this instance. + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + string.IsNullOrWhiteSpace(Metadata.Description) ? + Metadata.Name : + $"{Metadata.Name} ({Metadata.Description})"; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionMetadata.cs new file mode 100644 index 00000000000..03dac25d15f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionMetadata.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text.Json; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides read-only metadata for an . +/// +public sealed class AIFunctionMetadata +{ + /// The name of the function. + private string _name = string.Empty; + + /// The description of the function. + private string _description = string.Empty; + + /// The function's parameters. + private IReadOnlyList _parameters = []; + + /// The function's return parameter. + private AIFunctionReturnParameterMetadata _returnParameter = AIFunctionReturnParameterMetadata.Empty; + + /// Optional additional properties in addition to the named properties already available on this class. + private IReadOnlyDictionary _additionalProperties = EmptyReadOnlyDictionary.Instance; + + /// indexed by name, lazily initialized. + private Dictionary? _parametersByName; + + /// Initializes a new instance of the class for a function with the specified name. + /// The name of the function. + /// The was null. + public AIFunctionMetadata(string name) + { + _name = Throw.IfNullOrWhitespace(name); + } + + /// Initializes a new instance of the class as a copy of another . + /// The was null. + /// + /// This creates a shallow clone of . The new instance's and + /// properties will return the same objects as in the original instance. + /// + public AIFunctionMetadata(AIFunctionMetadata metadata) + { + Name = Throw.IfNull(metadata).Name; + Description = metadata.Description; + Parameters = metadata.Parameters; + ReturnParameter = metadata.ReturnParameter; + AdditionalProperties = metadata.AdditionalProperties; + } + + /// Gets the name of the function. + public string Name + { + get => _name; + init => _name = Throw.IfNullOrWhitespace(value); + } + + /// Gets a description of the function, suitable for use in describing the purpose to a model. + [AllowNull] + public string Description + { + get => _description; + init => _description = value ?? string.Empty; + } + + /// Gets the metadata for the parameters to the function. + /// If the function has no parameters, the returned list will be empty. + public IReadOnlyList Parameters + { + get => _parameters; + init => _parameters = Throw.IfNull(value); + } + + /// Gets the for a parameter by its name. + /// The name of the parameter. + /// The corresponding , if found; otherwise, null. + public AIFunctionParameterMetadata? GetParameter(string name) + { + Dictionary? parametersByName = _parametersByName ??= _parameters.ToDictionary(p => p.Name); + + return parametersByName.TryGetValue(name, out AIFunctionParameterMetadata? parameter) ? + parameter : + null; + } + + /// Gets parameter metadata for the return parameter. + /// If the function has no return parameter, the returned value will be a default instance of a . + public AIFunctionReturnParameterMetadata ReturnParameter + { + get => _returnParameter; + init => _returnParameter = Throw.IfNull(value); + } + + /// Gets any additional properties associated with the function. + public IReadOnlyDictionary AdditionalProperties + { + get => _additionalProperties; + init => _additionalProperties = Throw.IfNull(value); + } + + /// Gets a that may be used to marshal function parameters. + public JsonSerializerOptions? JsonSerializerOptions { get; init; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionParameterMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionParameterMetadata.cs new file mode 100644 index 00000000000..b9bd4d83841 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionParameterMetadata.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides read-only metadata for a parameter. +/// +public sealed class AIFunctionParameterMetadata +{ + private string _name; + + /// Initializes a new instance of the class for a parameter with the specified name. + /// The name of the parameter. + /// The was null. + /// The was empty or composed entirely of whitespace. + public AIFunctionParameterMetadata(string name) + { + _name = Throw.IfNullOrWhitespace(name); + } + + /// Initializes a new instance of the class as a copy of another . + /// The was null. + /// This creates a shallow clone of . + public AIFunctionParameterMetadata(AIFunctionParameterMetadata metadata) + { + _ = Throw.IfNull(metadata); + _ = Throw.IfNullOrWhitespace(metadata.Name); + + _name = metadata.Name; + + Description = metadata.Description; + HasDefaultValue = metadata.HasDefaultValue; + DefaultValue = metadata.DefaultValue; + IsRequired = metadata.IsRequired; + ParameterType = metadata.ParameterType; + Schema = metadata.Schema; + } + + /// Gets the name of the parameter. + public string Name + { + get => _name; + init => _name = Throw.IfNullOrWhitespace(value); + } + + /// Gets a description of the parameter, suitable for use in describing the purpose to a model. + public string? Description { get; init; } + + /// Gets a value indicating whether the parameter has a default value. + public bool HasDefaultValue { get; init; } + + /// Gets the default value of the parameter. + public object? DefaultValue { get; init; } + + /// Gets a value indicating whether the parameter is required. + public bool IsRequired { get; init; } + + /// Gets the .NET type of the parameter. + public Type? ParameterType { get; init; } + + /// Gets a JSON Schema describing the parameter's type. + public object? Schema { get; init; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionReturnParameterMetadata.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionReturnParameterMetadata.cs new file mode 100644 index 00000000000..17aec4d2fdb --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionReturnParameterMetadata.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides read-only metadata for a 's return parameter. +/// +public sealed class AIFunctionReturnParameterMetadata +{ + /// Gets an empty return parameter metadata instance. + public static AIFunctionReturnParameterMetadata Empty { get; } = new(); + + /// Initializes a new instance of the class. + public AIFunctionReturnParameterMetadata() + { + } + + /// Initializes a new instance of the class as a copy of another . + public AIFunctionReturnParameterMetadata(AIFunctionReturnParameterMetadata metadata) + { + Description = Throw.IfNull(metadata).Description; + ParameterType = metadata.ParameterType; + Schema = metadata.Schema; + } + + /// Gets a description of the return parameter, suitable for use in describing the purpose to a model. + public string? Description { get; init; } + + /// Gets the .NET type of the return parameter. + public Type? ParameterType { get; init; } + + /// Gets a JSON Schema describing the type of the return parameter. + public object? Schema { get; init; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj new file mode 100644 index 00000000000..4aa2ab89d73 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.csproj @@ -0,0 +1,36 @@ + + + + Microsoft.Extensions.AI + Abstractions for generative AI. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA2227;CA1034;SA1316;S3253 + true + + + + true + true + true + true + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md new file mode 100644 index 00000000000..eb9d3a28c6f --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/README.md @@ -0,0 +1,481 @@ +# Microsoft.Extensions.AI.Abstractions + +Provides abstractions representing generative AI components. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI.Abstractions +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +### `IChatClient` + +The `IChatClient` interface defines a client abstraction responsible for interacting with AI services that provide chat capabilities. It defines methods for sending and receiving messages comprised of multi-modal content (text, images, audio, etc.), either as a complete set or streamed incrementally. Additionally, it provides metadata information about the client and allows for retrieving strongly-typed services that may be provided by the client or its underlying services. + +#### Sample Implementation + +.NET libraries that provide clients for language models and services may provide an implementation of the `IChatClient` interface. Any consumers of the interface are then able to interoperate seamlessly with these models and services via the abstractions. + +Here is a sample implementation of an `IChatClient` to show the general structure. You can find other concrete implementations in the following packages: + +- [Microsoft.Extensions.AI.AzureAIInference](https://aka.ms/meai-azaiinference-nuget) +- [Microsoft.Extensions.AI.OpenAI](https://aka.ms/meai-openai-nuget) +- [Microsoft.Extensions.AI.Ollama](https://aka.ms/meai-ollama-nuget) + +```csharp +using System.Runtime.CompilerServices; +using Microsoft.Extensions.AI; + +public class SampleChatClient : IChatClient +{ + public ChatClientMetadata Metadata { get; } + + public SampleChatClient(Uri endpoint, string modelId) => + Metadata = new("SampleChatClient", endpoint, modelId); + + public async Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + // Simulate some operation. + await Task.Delay(300, cancellationToken); + + // Return a sample chat completion response randomly. + string[] responses = + [ + "This is the first sample response.", + "Here is another example of a response message.", + "This is yet another response message." + ]; + + return new([new ChatMessage() + { + Role = ChatRole.Assistant, + Text = responses[Random.Shared.Next(responses.Length)], + }]); + } + + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Simulate streaming by yielding messages one by one. + string[] words = ["This ", "is ", "the ", "response ", "for ", "the ", "request."]; + foreach (string word in words) + { + // Simulate some operation. + await Task.Delay(100, cancellationToken); + + // Yield the next message in the response. + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = word, + }; + } + } + + public TService? GetService(object? key = null) where TService : class => + this as TService; + + void IDisposable.Dispose() { } +} +``` + +#### Requesting a Chat Completion: `CompleteAsync` + +With an instance of `IChatClient`, the `CompleteAsync` method may be used to send a request. The request is composed of one or more messages, each of which is composed of one or more pieces of content. Accelerator methods exist to simplify common cases, such as constructing a request for a single piece of text content. + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); + +var response = await client.CompleteAsync("What is AI?"); + +Console.WriteLine(response.Message); +``` + +The core `CompleteAsync` method on the `IChatClient` interface accepts a list of messages. This list represents the history of all messages that are part of the conversation. + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); + +Console.WriteLine(await client.CompleteAsync( +[ + new(ChatRole.System, "You are a helpful AI assistant"), + new(ChatRole.User, "What is AI?"), +])); +``` + +#### Requesting a Streaming Chat Completion: `CompleteStreamingAsync` + +The inputs to `CompleteStreamingAsync` are identical to those of `CompleteAsync`. However, rather than returning the complete response as part of a `ChatCompletion` object, the method returns an `IAsyncEnumerable`, providing a stream of updates that together form the single response. + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"); + +await foreach (var update in client.CompleteStreamingAsync("What is AI?")) +{ + Console.Write(update); +} +``` + +#### Tool calling + +Some models and services support the notion of tool calling, where requests may include information about tools that the model may request be invoked in order to gather additional information, in particular functions. Rather than sending back a response message that represents the final response to the input, the model sends back a request to invoke a given function with a given set of arguments; the client may then find and invoke the relevant function and send back the results to the model (along with all the rest of the history). The abstractions in Microsoft.Extensions.AI include representations for various forms of content that may be included in messages, and this includes representations for these function call requests and results. While it's possible for the consumer of the `IChatClient` to interact with this content directly, `Microsoft.Extensions.AI` supports automating these interactions. It provides an `AIFunction` that represents an invocable function along with metadata for describing the function to the AI model, along with an `AIFunctionFactory` for creating `AIFunction`s to represent .NET methods. It also provides a `FunctionInvokingChatClient` that both is an `IChatClient` and also wraps an `IChatClient`, enabling layering automatic function invocation capabilities around an arbitrary `IChatClient` implementation. + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; + +[Description("Gets the current weather")] +string GetCurrentWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; + +IChatClient client = new ChatClientBuilder() + .UseFunctionInvocation() + .Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")); + +var response = client.CompleteStreamingAsync( + "Should I wear a rain coat?", + new() { Tools = [AIFunctionFactory.Create(GetCurrentWeather)] }); + +await foreach (var update in response) +{ + Console.Write(update); +} +``` + +#### Caching + +`Microsoft.Extensions.AI` provides other such delegating `IChatClient` implementations. The `DistributedCachingChatClient` is an `IChatClient` that layers caching around another arbitrary `IChatClient` instance. When a unique chat history that's not been seen before is submitted to the `DistributedCachingChatClient`, it forwards it along to the underlying client, and then caches the response prior to it being forwarded back to the consumer. The next time the same history is submitted, such that a cached response can be found in the cache, the `DistributedCachingChatClient` can return back the cached response rather than needing to forward the request along the pipeline. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))) + .Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")); + +string[] prompts = ["What is AI?", "What is .NET?", "What is AI?"]; + +foreach (var prompt in prompts) +{ + await foreach (var update in client.CompleteStreamingAsync(prompt)) + { + Console.Write(update); + } + Console.WriteLine(); +} +``` + +#### Telemetry + +Other such delegating chat clients are provided as well. The `OpenTelemetryChatClient`, for example, provides an implementation of the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/). As with the aforementioned `IChatClient` delegators, this implementation layers metrics and spans around other arbitrary `IChatClient` implementations. + +```csharp +using Microsoft.Extensions.AI; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +IChatClient client = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model")); + +Console.WriteLine((await client.CompleteAsync("What is AI?")).Message); +``` + +#### Pipelines of Functionality + +All of these `IChatClient`s may be layered, creating a pipeline of any number of components that all add additional functionality. Such components may come from `Microsoft.Extensions.AI`, may come from other NuGet packages, or may be your own custom implementations that augment the behavior in whatever ways you need. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Explore changing the order of the intermediate "Use" calls to see that impact +// that has on what gets cached, traced, etc. +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))) + .UseFunctionInvocation() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(new OllamaChatClient(new Uri("http://localhost:11434"), "llama3.1")); + +ChatOptions options = new() +{ + Tools = [AIFunctionFactory.Create( + () => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining", + name: "GetCurrentWeather", + description: "Gets the current weather")] +}; + +for (int i = 0; i < 3; i++) +{ + List history = + [ + new ChatMessage(ChatRole.System, "You are a helpful AI assistant"), + new ChatMessage(ChatRole.User, "Do I need an umbrella?") + ]; + + Console.WriteLine(await client.CompleteAsync(history, options)); +} +``` + +#### Custom `IChatClient` Middleware + +Anyone can layer in such additional functionality. While it's possible to implement `IChatClient` directly, the `DelegatingChatClient` class is an implementation of the `IChatClient` interface that serves as a base class for creating chat clients that delegate their operations to another `IChatClient` instance. It is designed to facilitate the chaining of multiple clients, allowing calls to be passed through to an underlying client. The class provides default implementations for methods such as `CompleteAsync`, `CompleteStreamingAsync`, and `Dispose`, simply forwarding the calls to the inner client instance. A derived type may then override just the methods it needs to in order to augment the behavior, delegating to the base implementation in order to forward the call along to the wrapped client. This setup is useful for creating flexible and modular chat clients that can be easily extended and composed. + +Here is an example class derived from `DelegatingChatClient` to provide logging functionality: +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using System.Runtime.CompilerServices; +using System.Text.Json; + +public sealed class LoggingChatClient(IChatClient innerClient, ILogger? logger = null) : + DelegatingChatClient(innerClient) +{ + public override async Task CompleteAsync( + IList chatMessages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + logger?.LogTrace("Request: {Messages}", chatMessages); + var chatCompletion = await base.CompleteAsync(chatMessages, options, cancellationToken); + logger?.LogTrace("Response: {Completion}", JsonSerializer.Serialize(chatCompletion)); + return chatCompletion; + } + + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, + ChatOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + logger?.LogTrace("Request: {Messages}", chatMessages); + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken)) + { + logger?.LogTrace("Response Update: {Update}", JsonSerializer.Serialize(update)); + yield return update; + } + } +} +``` + +This can then be composed as with other `IChatClient` implementations. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; + +var client = new LoggingChatClient( + new SampleChatClient(new Uri("http://localhost"), "test"), + LoggerFactory.Create(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)).CreateLogger("AI")); + +await client.CompleteAsync("Hello, world!"); +``` + +#### Dependency Injection + +`IChatClient` implementations will typically be provided to an application via dependency injection (DI). In this example, an `IDistributedCache` is added into the DI container, as is an `IChatClient`. The registration for the `IChatClient` employs a builder that creates a pipeline containing a caching client (which will then use an `IDistributedCache` retrieved from DI) and the sample client. Elsewhere in the app, the injected `IChatClient` may be retrieved and used. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Options; +using System.Runtime.CompilerServices; + +// App Setup +var builder = Host.CreateApplicationBuilder(); +builder.Services.AddSingleton( + new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); +builder.Services.AddChatClient(b => b + .UseDistributedCache() + .Use(new SampleChatClient(new Uri("http://coolsite.ai"), "my-custom-model"))); +var host = builder.Build(); + +// Elsewhere in the app +var chatClient = host.Services.GetRequiredService(); +Console.WriteLine(await chatClient.CompleteAsync("What is AI?")); +``` + +What instance and configuration is injected may differ based on the current needs of the application, and multiple pipelines may be injected with different keys. + +### IEmbeddingGenerator + +The `IEmbeddingGenerator` interface represents a generic generator of embeddings, where `TInput` is the type of input values being embedded and `TEmbedding` is the type of generated embedding, inheriting from `Embedding`. + +The `Embedding` class provides a base class for embeddings generated by an `IEmbeddingGenerator`. This class is designed to store and manage the metadata and data associated with embeddings. Types derived from `Embedding`, like `Embedding`, then provide the concrete embedding vector data. For example, an `Embedding` exposes a `ReadOnlyMemory Vector { get; }` property for access to its embedding data. + +`IEmbeddingGenerator` defines a method to asynchronously generate embeddings for a collection of input values with optional configuration and cancellation support. Additionally, it provides metadata describing the generator and allows for the retrieval of strongly-typed services that may be provided by the generator or its underlying services. + +#### Sample Implementation + +Here is a sample implementation of an `IEmbeddingGenerator` to show the general structure but that just generates random embedding vectors. You can find actual concrete implementations in the following packages: + +- [Microsoft.Extensions.AI.OpenAI](https://aka.ms/meai-openai-nuget) +- [Microsoft.Extensions.AI.Ollama](https://aka.ms/meai-ollama-nuget) + +```csharp +using Microsoft.Extensions.AI; + +public class SampleEmbeddingGenerator(Uri endpoint, string modelId) : IEmbeddingGenerator> +{ + public EmbeddingGeneratorMetadata Metadata { get; } = new("SampleEmbeddingGenerator", endpoint, modelId); + + public async Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + // Simulate some async operation + await Task.Delay(100, cancellationToken); + + // Create random embeddings + return new GeneratedEmbeddings>( + from value in values + select new Embedding( + Enumerable.Range(0, 384).Select(_ => Random.Shared.NextSingle()).ToArray())); + } + + public TService? GetService(object? key = null) where TService : class => + this as TService; + + void IDisposable.Dispose() { } +} +``` + +#### Creating an embedding: `GenerateAsync` + +The primary operation performed with an `IEmbeddingGenerator` is generating embeddings, which is accomplished with its `GenerateAsync` method. + +```csharp +using Microsoft.Extensions.AI; + +IEmbeddingGenerator> generator = + new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model"); + +foreach (var embedding in await generator.GenerateAsync(["What is AI?", "What is .NET?"])) +{ + Console.WriteLine(string.Join(", ", embedding.Vector.ToArray())); +} +``` + +#### Middleware + +As with `IChatClient`, `IEmbeddingGenerator` implementations may be layered. Just as `Microsoft.Extensions.AI` provides delegating implementations of `IChatClient` for caching and telemetry, it does so for `IEmbeddingGenerator` as well. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Explore changing the order of the intermediate "Use" calls to see that impact +// that has on what gets cached, traced, etc. +IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))) + .UseOpenTelemetry(sourceName) + .Use(new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model")); + +var embeddings = await generator.GenerateAsync( +[ + "What is AI?", + "What is .NET?", + "What is AI?" +]); + +foreach (var embedding in embeddings) +{ + Console.WriteLine(string.Join(", ", embedding.Vector.ToArray())); +} +``` + +Also as with `IChatClient`, `IEmbeddingGenerator` enables building custom middleware that extends the functionality of an `IEmbeddingGenerator`. The `DelegatingEmbeddingGenerator` class is an implementation of the `IEmbeddingGenerator` interface that serves as a base class for creating embedding generators which delegate their operations to another `IEmbeddingGenerator` instance. It allows for chaining multiple generators in any order, passing calls through to an underlying generator. The class provides default implementations for methods such as `GenerateAsync` and `Dispose`, which simply forward the calls to the inner generator instance, enabling flexible and modular embedding generation. + +Here is an example implementation of such a delegating embedding generator that logs embedding generation requests: +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; + +public class LoggingEmbeddingGenerator(IEmbeddingGenerator> innerGenerator, ILogger? logger = null) : + DelegatingEmbeddingGenerator>(innerGenerator) +{ + public override Task>> GenerateAsync( + IEnumerable values, + EmbeddingGenerationOptions? options = null, + CancellationToken cancellationToken = default) + { + logger?.LogInformation("Generating embeddings for {Count} values", values.Count()); + return base.GenerateAsync(values, options, cancellationToken); + } +} +``` + +This can then be layered around an arbitrary `IEmbeddingGenerator>` to log all embedding generation operations performed. + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; + +IEmbeddingGenerator> generator = + new LoggingEmbeddingGenerator( + new SampleEmbeddingGenerator(new Uri("http://coolsite.ai"), "my-custom-model"), + LoggerFactory.Create(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)).CreateLogger("AI")); + +foreach (var embedding in await generator.GenerateAsync(["What is AI?", "What is .NET?"])) +{ + Console.WriteLine(string.Join(", ", embedding.Vector.ToArray())); +} +``` + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs new file mode 100644 index 00000000000..f12ed819a6e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/UsageDetails.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides usage details about a request/response. +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public class UsageDetails +{ + /// Gets or sets the number of tokens in the input. + public int? InputTokenCount { get; set; } + + /// Gets or sets the number of tokens in the output. + public int? OutputTokenCount { get; set; } + + /// Gets or sets the total number of tokens used to produce the response. + public int? TotalTokenCount { get; set; } + + /// Gets or sets additional properties for the usage details. + public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// Gets a string representing this instance to display in the debugger. + internal string DebuggerDisplay + { + get + { + List parts = []; + + if (InputTokenCount is int input) + { + parts.Add($"{nameof(InputTokenCount)} = {input}"); + } + + if (OutputTokenCount is int output) + { + parts.Add($"{nameof(OutputTokenCount)} = {output}"); + } + + if (TotalTokenCount is int total) + { + parts.Add($"{nameof(TotalTokenCount)} = {total}"); + } + + if (AdditionalProperties is { } additionalProperties) + { + foreach (var entry in additionalProperties) + { + parts.Add($"{entry.Key} = {entry.Value}"); + } + } + + return string.Join(", ", parts); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs new file mode 100644 index 00000000000..cccd9f04caf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceChatClient.cs @@ -0,0 +1,495 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Azure.AI.Inference; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable S1135 // Track uses of "TODO" tags +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace Microsoft.Extensions.AI; + +/// An for an Azure AI Inference . +public sealed partial class AzureAIInferenceChatClient : IChatClient +{ + /// The underlying . + private readonly ChatCompletionsClient _chatCompletionsClient; + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + /// The id of the model to use. If null, it may be provided per request via . + public AzureAIInferenceChatClient(ChatCompletionsClient chatCompletionsClient, string? modelId = null) + { + _ = Throw.IfNull(chatCompletionsClient); + if (modelId is not null) + { + _ = Throw.IfNullOrWhitespace(modelId); + } + + _chatCompletionsClient = chatCompletionsClient; + + // https://github.com/Azure/azure-sdk-for-net/issues/46278 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + var providerUrl = typeof(ChatCompletionsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(chatCompletionsClient) as Uri; + + Metadata = new("AzureAIInference", providerUrl, modelId); + } + + /// Gets or sets to use for any serialization activities related to tool call arguments and results. + public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; } + + /// + public ChatClientMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class => + typeof(TService) == typeof(ChatCompletionsClient) ? (TService?)(object?)_chatCompletionsClient : + this as TService; + + /// + public async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + // Make the call. + ChatCompletions response = (await _chatCompletionsClient.CompleteAsync( + ToAzureAIOptions(chatMessages, options), + cancellationToken: cancellationToken).ConfigureAwait(false)).Value; + + // Create the return message. + List returnMessages = []; + + // Populate its content from those in the response content. + ChatFinishReason? finishReason = null; + foreach (var choice in response.Choices) + { + ChatMessage returnMessage = new() + { + RawRepresentation = choice, + Role = ToChatRole(choice.Message.Role), + AdditionalProperties = new() { [nameof(choice.Index)] = choice.Index }, + }; + + finishReason ??= ToFinishReason(choice.FinishReason); + + if (choice.Message.ToolCalls is { Count: > 0 } toolCalls) + { + foreach (var toolCall in toolCalls) + { + if (toolCall is ChatCompletionsFunctionToolCall ftc && !string.IsNullOrWhiteSpace(ftc.Name)) + { + Dictionary? arguments = FunctionCallHelpers.ParseFunctionCallArguments(ftc.Arguments, out Exception? parsingException); + + returnMessage.Contents.Add(new FunctionCallContent(toolCall.Id, ftc.Name, arguments) + { + ModelId = response.Model, + Exception = parsingException, + RawRepresentation = toolCall + }); + } + } + } + + if (!string.IsNullOrEmpty(choice.Message.Content)) + { + returnMessage.Contents.Add(new TextContent(choice.Message.Content) + { + ModelId = response.Model, + RawRepresentation = choice.Message + }); + } + + returnMessages.Add(returnMessage); + } + + UsageDetails? usage = null; + if (response.Usage is CompletionsUsage completionsUsage) + { + usage = new() + { + InputTokenCount = completionsUsage.PromptTokens, + OutputTokenCount = completionsUsage.CompletionTokens, + TotalTokenCount = completionsUsage.TotalTokens, + }; + } + + // Wrap the content in a ChatCompletion to return. + return new ChatCompletion(returnMessages) + { + RawRepresentation = response, + CompletionId = response.Id, + CreatedAt = response.Created, + ModelId = response.Model, + FinishReason = finishReason, + Usage = usage, + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + Dictionary? functionCallInfos = null; + ChatRole? streamedRole = default; + ChatFinishReason? finishReason = default; + string? completionId = null; + DateTimeOffset? createdAt = null; + string? modelId = null; + string? authorName = null; + + // Process each update as it arrives + var updates = await _chatCompletionsClient.CompleteStreamingAsync(ToAzureAIOptions(chatMessages, options), cancellationToken).ConfigureAwait(false); + await foreach (StreamingChatCompletionsUpdate chatCompletionUpdate in updates.ConfigureAwait(false)) + { + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is global::Azure.AI.Inference.ChatRole role ? ToChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is CompletionsFinishReason reason ? ToFinishReason(reason) : null; + completionId ??= chatCompletionUpdate.Id; + createdAt ??= chatCompletionUpdate.Created; + modelId ??= chatCompletionUpdate.Model; + authorName ??= chatCompletionUpdate.AuthorName; + + // Create the response content object. + StreamingChatCompletionUpdate completionUpdate = new() + { + AuthorName = authorName, + CompletionId = chatCompletionUpdate.Id, + CreatedAt = chatCompletionUpdate.Created, + FinishReason = finishReason, + RawRepresentation = chatCompletionUpdate, + Role = streamedRole, + }; + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is string update) + { + completionUpdate.Contents.Add(new TextContent(update) + { + ModelId = modelId, + }); + } + + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdate is StreamingFunctionToolCallUpdate toolCallUpdate) + { + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.ToolCallIndex, out FunctionCallInfo? existing)) + { + functionCallInfos[toolCallUpdate.ToolCallIndex] = existing = new(); + } + + existing.CallId ??= toolCallUpdate.Id; + existing.Name ??= toolCallUpdate.Name; + if (toolCallUpdate.ArgumentsUpdate is not null) + { + _ = (existing.Arguments ??= new()).Append(toolCallUpdate.ArgumentsUpdate); + } + } + + // Now yield the item. + yield return completionUpdate; + } + + // TODO: Add usage as content when it's exposed by Azure.AI.Inference. + + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) + { + var completionUpdate = new StreamingChatCompletionUpdate + { + AuthorName = authorName, + CompletionId = completionId, + CreatedAt = createdAt, + FinishReason = finishReason, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + var arguments = FunctionCallHelpers.ParseFunctionCallArguments( + fci.Arguments?.ToString() ?? string.Empty, + out Exception? parsingException); + + completionUpdate.Contents.Add(new FunctionCallContent(fci.CallId!, fci.Name!, arguments) + { + ModelId = modelId, + Exception = parsingException + }); + } + } + + yield return completionUpdate; + } + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IChatClient interface. + } + + /// POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates. + private sealed class FunctionCallInfo + { + public string? CallId; + public string? Name; + public StringBuilder? Arguments; + } + + /// Converts an AzureAI role to an Extensions role. + private static ChatRole ToChatRole(global::Azure.AI.Inference.ChatRole role) => + role.Equals(global::Azure.AI.Inference.ChatRole.System) ? ChatRole.System : + role.Equals(global::Azure.AI.Inference.ChatRole.User) ? ChatRole.User : + role.Equals(global::Azure.AI.Inference.ChatRole.Assistant) ? ChatRole.Assistant : + role.Equals(global::Azure.AI.Inference.ChatRole.Tool) ? ChatRole.Tool : + new ChatRole(role.ToString()); + + /// Converts an AzureAI finish reason to an Extensions finish reason. + private static ChatFinishReason? ToFinishReason(CompletionsFinishReason? finishReason) => + finishReason?.ToString() is not string s ? null : + finishReason == CompletionsFinishReason.Stopped ? ChatFinishReason.Stop : + finishReason == CompletionsFinishReason.TokenLimitReached ? ChatFinishReason.Length : + finishReason == CompletionsFinishReason.ContentFiltered ? ChatFinishReason.ContentFilter : + finishReason == CompletionsFinishReason.ToolCalls ? ChatFinishReason.ToolCalls : + new(s); + + /// Converts an extensions options instance to an AzureAI options instance. + private ChatCompletionsOptions ToAzureAIOptions(IList chatContents, ChatOptions? options) + { + ChatCompletionsOptions result = new(ToAzureAIInferenceChatMessages(chatContents)) + { + Model = options?.ModelId ?? Metadata.ModelId ?? throw new InvalidOperationException("No model id was provided when either constructing the client or in the chat options.") + }; + + if (options is not null) + { + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxTokens = options.MaxOutputTokens; + result.NucleusSamplingFactor = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + foreach (string stopSequence in stopSequences) + { + result.StopSequences.Add(stopSequence); + } + } + + if (options.AdditionalProperties is { } props) + { + foreach (var prop in props) + { + switch (prop.Key) + { + // These properties are strongly-typed on the ChatCompletionsOptions class. + case nameof(result.Seed) when prop.Value is long seed: + result.Seed = seed; + break; + + // Propagate everything else to the ChatCompletionOptions' AdditionalProperties. + default: + if (prop.Value is not null) + { + result.AdditionalProperties[prop.Key] = BinaryData.FromObjectAsJson(prop.Value, ToolCallJsonSerializerOptions); + } + + break; + } + } + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (AITool tool in tools) + { + if (tool is AIFunction af) + { + result.Tools.Add(ToAzureAIChatTool(af)); + } + } + + switch (options.ToolMode) + { + case AutoChatToolMode: + result.ToolChoice = ChatCompletionsToolChoice.Auto; + break; + + case RequiredChatToolMode required: + result.ToolChoice = required.RequiredFunctionName is null ? + ChatCompletionsToolChoice.Required : + new ChatCompletionsToolChoice(new FunctionDefinition(required.RequiredFunctionName)); + break; + } + } + + if (options.ResponseFormat is ChatResponseFormatText) + { + result.ResponseFormat = new ChatCompletionsResponseFormatText(); + } + else if (options.ResponseFormat is ChatResponseFormatJson) + { + result.ResponseFormat = new ChatCompletionsResponseFormatJSON(); + } + } + + return result; + } + + /// Converts an Extensions function to an AzureAI chat tool. + private ChatCompletionsFunctionToolDefinition ToAzureAIChatTool(AIFunction aiFunction) + { + BinaryData resultParameters = AzureAIChatToolJson.ZeroFunctionParametersSchema; + + var parameters = aiFunction.Metadata.Parameters; + if (parameters is { Count: > 0 }) + { + AzureAIChatToolJson tool = new(); + + foreach (AIFunctionParameterMetadata parameter in parameters) + { + tool.Properties.Add( + parameter.Name, + FunctionCallHelpers.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions)); + + if (parameter.IsRequired) + { + tool.Required.Add(parameter.Name); + } + } + + resultParameters = BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.AzureAIChatToolJson)); + } + + return new() + { + Name = aiFunction.Metadata.Name, + Description = aiFunction.Metadata.Description, + Parameters = resultParameters, + }; + } + + /// Used to create the JSON payload for an AzureAI chat tool description. + private sealed class AzureAIChatToolJson + { + /// Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function. + public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray()); + + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + [JsonPropertyName("required")] + public List Required { get; set; } = []; + + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + } + + /// Converts an Extensions chat message enumerable to an AzureAI chat message enumerable. + private IEnumerable ToAzureAIInferenceChatMessages(IEnumerable inputs) + { + // Maps all of the M.E.AI types to the corresponding AzureAI types. + // Unrecognized content is ignored. + + foreach (ChatMessage input in inputs) + { + if (input.Role == ChatRole.System) + { + yield return new ChatRequestSystemMessage(input.Text); + } + else if (input.Role == ChatRole.Tool) + { + foreach (AIContent item in input.Contents) + { + if (item is FunctionResultContent resultContent) + { + string? result = resultContent.Result as string; + if (result is null && resultContent.Result is not null) + { + try + { + result = FunctionCallHelpers.FormatFunctionResultAsJson(resultContent.Result, ToolCallJsonSerializerOptions); + } + catch (NotSupportedException) + { + // If the type can't be serialized, skip it. + } + } + + yield return new ChatRequestToolMessage(result ?? string.Empty, resultContent.CallId); + } + } + } + else if (input.Role == ChatRole.User) + { + yield return new ChatRequestUserMessage(input.Contents.Select(static (AIContent item) => item switch + { + TextContent textContent => new ChatMessageTextContentItem(textContent.Text), + ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? new ChatMessageImageContentItem(BinaryData.FromBytes(data), imageContent.MediaType) : + imageContent.Uri is string uri ? new ChatMessageImageContentItem(new Uri(uri)) : + (ChatMessageContentItem?)null, + _ => null, + }).Where(c => c is not null)); + } + else if (input.Role == ChatRole.Assistant) + { + Dictionary? toolCalls = null; + + foreach (var content in input.Contents) + { + if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true) + { + string jsonArguments = FunctionCallHelpers.FormatFunctionParametersAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions); + (toolCalls ??= []).Add( + callRequest.CallId, + new ChatCompletionsFunctionToolCall( + callRequest.CallId, + callRequest.Name, + jsonArguments)); + } + } + + ChatRequestAssistantMessage message = new(); + if (toolCalls is not null) + { + foreach (var entry in toolCalls) + { + message.ToolCalls.Add(entry.Value); + } + } + else + { + message.Content = input.Text; + } + + yield return message; + } + } + } + + /// Source-generated JSON type information. + [JsonSerializable(typeof(AzureAIChatToolJson))] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceExtensions.cs new file mode 100644 index 00000000000..d8ba7616316 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceExtensions.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Azure.AI.Inference; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for working with Azure AI Inference. +public static class AzureAIInferenceExtensions +{ + /// Gets an for use with this . + /// The client. + /// The id of the model to use. If null, it may be provided per request via . + /// An that may be used to converse via the . + public static IChatClient AsChatClient(this ChatCompletionsClient chatCompletionsClient, string? modelId = null) => + new AzureAIInferenceChatClient(chatCompletionsClient, modelId); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj new file mode 100644 index 00000000000..d1f802ace8a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.csproj @@ -0,0 +1,43 @@ + + + + Microsoft.Extensions.AI + Implementation of generative AI abstractions for Azure.AI.Inference. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA1063;CA2227;SA1316;S1067;S1121;S3358 + true + + + + true + true + true + true + true + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.json b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/Microsoft.Extensions.AI.AzureAIInference.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/README.md b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/README.md new file mode 100644 index 00000000000..3fd34c7897b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.AzureAIInference/README.md @@ -0,0 +1,283 @@ +# Microsoft.Extensions.AI.AzureAIInference + +Provides an implementation of the `IChatClient` interface for the `Azure.AI.Inference` package. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI.AzureAIInference +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +### Chat + +```csharp +using Azure; +using Microsoft.Extensions.AI; + +IChatClient client = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Chat + Conversation History + +```csharp +using Azure; +using Microsoft.Extensions.AI; + +IChatClient client = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync( +[ + new ChatMessage(ChatRole.System, "You are a helpful AI assistant"), + new ChatMessage(ChatRole.User, "What is AI?"), +])); +``` + +### Chat streaming + +```csharp +using Azure; +using Microsoft.Extensions.AI; + +IChatClient client = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +await foreach (var update in client.CompleteStreamingAsync("What is AI?")) +{ + Console.Write(update); +} +``` + +### Tool calling + +```csharp +using System.ComponentModel; +using Azure; +using Microsoft.Extensions.AI; + +IChatClient azureClient = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseFunctionInvocation() + .Use(azureClient); + +ChatOptions chatOptions = new() +{ + Tools = [AIFunctionFactory.Create(GetWeather)] +}; + +await foreach (var message in client.CompleteStreamingAsync("Do I need an umbrella?", chatOptions)) +{ + Console.Write(message); +} + +[Description("Gets the weather")] +static string GetWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; +``` + +### Caching + +```csharp +using Azure; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IChatClient azureClient = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .Use(azureClient); + +for (int i = 0; i < 3; i++) +{ + await foreach (var message in client.CompleteStreamingAsync("In less than 100 words, what is AI?")) + { + Console.Write(message); + } + + Console.WriteLine(); + Console.WriteLine(); +} +``` + +### Telemetry + +```csharp +using Azure; +using Microsoft.Extensions.AI; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +IChatClient azureClient = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(azureClient); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Telemetry, Caching, and Tool Calling + +```csharp +using System.ComponentModel; +using Azure; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenTelemetry.Trace; + +// Configure telemetry +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Configure caching +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +// Configure tool calling +var chatOptions = new ChatOptions +{ + Tools = [AIFunctionFactory.Create(GetPersonAge)] +}; + +IChatClient azureClient = + new Azure.AI.Inference.ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!)) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .UseFunctionInvocation() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(azureClient); + +for (int i = 0; i < 3; i++) +{ + Console.WriteLine(await client.CompleteAsync("How much older is Alice than Bob?", chatOptions)); +} + +[Description("Gets the age of a person specified by name.")] +static int GetPersonAge(string personName) => + personName switch + { + "Alice" => 42, + "Bob" => 35, + _ => 26, + }; +``` + +### Dependency Injection + +```csharp +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +// App Setup +var builder = Host.CreateApplicationBuilder(); +builder.Services.AddSingleton( + new ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(Environment.GetEnvironmentVariable("GH_TOKEN")!))); +builder.Services.AddSingleton(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); +builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)); + +builder.Services.AddChatClient(b => b + .UseDistributedCache() + .UseLogging() + .Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +var app = builder.Build(); + +// Elsewhere in the app +var chatClient = app.Services.GetRequiredService(); +Console.WriteLine(await chatClient.CompleteAsync("What is AI?")); +``` + +### Minimal Web API + +```csharp +using Azure; +using Azure.AI.Inference; +using Microsoft.Extensions.AI; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddSingleton(new ChatCompletionsClient( + new("https://models.inference.ai.azure.com"), + new AzureKeyCredential(builder.Configuration["GH_TOKEN"]!))); + +builder.Services.AddChatClient(b => + b.Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +var app = builder.Build(); + +app.MapPost("/chat", async (IChatClient client, string message) => +{ + var response = await client.CompleteAsync(message); + return response.Message; +}); + +app.Run(); +``` + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/JsonContext.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/JsonContext.cs new file mode 100644 index 00000000000..6de0144c7cf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/JsonContext.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.SnakeCaseLower, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] +[JsonSerializable(typeof(OllamaChatRequest))] +[JsonSerializable(typeof(OllamaChatRequestMessage))] +[JsonSerializable(typeof(OllamaChatResponse))] +[JsonSerializable(typeof(OllamaChatResponseMessage))] +[JsonSerializable(typeof(OllamaFunctionCallContent))] +[JsonSerializable(typeof(OllamaFunctionResultContent))] +[JsonSerializable(typeof(OllamaFunctionTool))] +[JsonSerializable(typeof(OllamaFunctionToolCall))] +[JsonSerializable(typeof(OllamaFunctionToolParameter))] +[JsonSerializable(typeof(OllamaFunctionToolParameters))] +[JsonSerializable(typeof(OllamaRequestOptions))] +[JsonSerializable(typeof(OllamaTool))] +[JsonSerializable(typeof(OllamaToolCall))] +[JsonSerializable(typeof(OllamaEmbeddingRequest))] +[JsonSerializable(typeof(OllamaEmbeddingResponse))] +internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj new file mode 100644 index 00000000000..ac0abe33c10 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.csproj @@ -0,0 +1,47 @@ + + + + Microsoft.Extensions.AI + Implementation of generative AI abstractions for Ollama. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA2227;SA1316;S1121;EA0002 + true + + + + true + true + true + true + true + true + + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.json b/src/Libraries/Microsoft.Extensions.AI.Ollama/Microsoft.Extensions.AI.Ollama.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs new file mode 100644 index 00000000000..61827d45cc9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs @@ -0,0 +1,408 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Net.Http; +using System.Net.Http.Json; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable EA0011 // Consider removing unnecessary conditional access operator (?) + +namespace Microsoft.Extensions.AI; + +/// An for Ollama. +public sealed class OllamaChatClient : IChatClient +{ + /// The api/chat endpoint URI. + private readonly Uri _apiChatEndpoint; + + /// The to use for sending requests. + private readonly HttpClient _httpClient; + + /// Initializes a new instance of the class. + /// The endpoint URI where Ollama is hosted. + /// + /// The id of the model to use. This may also be overridden per request via . + /// Either this parameter or must provide a valid model id. + /// + /// An instance to use for HTTP operations. + public OllamaChatClient(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) + { + _ = Throw.IfNull(endpoint); + if (modelId is not null) + { + _ = Throw.IfNullOrWhitespace(modelId); + } + + _apiChatEndpoint = new Uri(endpoint, "api/chat"); + _httpClient = httpClient ?? OllamaUtilities.SharedClient; + Metadata = new("ollama", endpoint, modelId); + } + + /// + public ChatClientMetadata Metadata { get; } + + /// Gets or sets to use for any serialization activities related to tool call arguments and results. + public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; } + + /// + public async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + using var httpResponse = await _httpClient.PostAsJsonAsync( + _apiChatEndpoint, + ToOllamaChatRequest(chatMessages, options, stream: false), + JsonContext.Default.OllamaChatRequest, + cancellationToken).ConfigureAwait(false); + + var response = (await httpResponse.Content.ReadFromJsonAsync( + JsonContext.Default.OllamaChatResponse, + cancellationToken).ConfigureAwait(false))!; + + if (!string.IsNullOrEmpty(response.Error)) + { + throw new InvalidOperationException($"Ollama error: {response.Error}"); + } + + return new([FromOllamaMessage(response.Message!)]) + { + CompletionId = response.CreatedAt, + ModelId = response.Model ?? options?.ModelId ?? Metadata.ModelId, + CreatedAt = DateTimeOffset.TryParse(response.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, + AdditionalProperties = ParseOllamaChatResponseProps(response), + FinishReason = ToFinishReason(response), + Usage = ParseOllamaChatResponseUsage(response), + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + if (options?.Tools is { Count: > 0 }) + { + // We can actually make it work by using the /generate endpoint like the eShopSupport sample does, + // but it's complicated. Really it should be Ollama's job to support this. + throw new NotSupportedException( + "Currently, Ollama does not support function calls in streaming mode. " + + "See Ollama docs at https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1 to see whether support has since been added."); + } + + using HttpRequestMessage request = new(HttpMethod.Post, _apiChatEndpoint) + { + Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest) + }; + using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + using var httpResponseStream = await httpResponse.Content +#if NET + .ReadAsStreamAsync(cancellationToken) +#else + .ReadAsStreamAsync() +#endif + .ConfigureAwait(false); + + await foreach (OllamaChatResponse? chunk in JsonSerializer.DeserializeAsyncEnumerable( + httpResponseStream, + JsonContext.Default.OllamaChatResponse, + topLevelValues: true, + cancellationToken).ConfigureAwait(false)) + { + if (chunk is null) + { + continue; + } + + StreamingChatCompletionUpdate update = new() + { + Role = chunk.Message?.Role is not null ? new ChatRole(chunk.Message.Role) : null, + CreatedAt = DateTimeOffset.TryParse(chunk.CreatedAt, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateTimeOffset createdAt) ? createdAt : null, + AdditionalProperties = ParseOllamaChatResponseProps(chunk), + FinishReason = ToFinishReason(chunk), + }; + + string? modelId = chunk.Model ?? Metadata.ModelId; + + if (chunk.Message is { } message) + { + update.Contents.Add(new TextContent(message.Content) { ModelId = modelId }); + } + + if (ParseOllamaChatResponseUsage(chunk) is { } usage) + { + update.Contents.Add(new UsageContent(usage) { ModelId = modelId }); + } + + yield return update; + } + } + + /// + public TService? GetService(object? key = null) + where TService : class + => key is null ? this as TService : null; + + /// + public void Dispose() + { + if (_httpClient != OllamaUtilities.SharedClient) + { + _httpClient.Dispose(); + } + } + + private static UsageDetails? ParseOllamaChatResponseUsage(OllamaChatResponse response) + { + if (response.PromptEvalCount is not null || response.EvalCount is not null) + { + return new() + { + InputTokenCount = response.PromptEvalCount, + OutputTokenCount = response.EvalCount, + TotalTokenCount = response.PromptEvalCount.GetValueOrDefault() + response.EvalCount.GetValueOrDefault(), + }; + } + + return null; + } + + private static AdditionalPropertiesDictionary? ParseOllamaChatResponseProps(OllamaChatResponse response) + { + AdditionalPropertiesDictionary? metadata = null; + + OllamaUtilities.TransferNanosecondsTime(response, static r => r.LoadDuration, "load_duration", ref metadata); + OllamaUtilities.TransferNanosecondsTime(response, static r => r.TotalDuration, "total_duration", ref metadata); + OllamaUtilities.TransferNanosecondsTime(response, static r => r.PromptEvalDuration, "prompt_eval_duration", ref metadata); + OllamaUtilities.TransferNanosecondsTime(response, static r => r.EvalDuration, "eval_duration", ref metadata); + + return metadata; + } + + private static ChatFinishReason? ToFinishReason(OllamaChatResponse response) => + response.DoneReason switch + { + null => null, + "length" => ChatFinishReason.Length, + "stop" => ChatFinishReason.Stop, + _ => new ChatFinishReason(response.DoneReason), + }; + + private static ChatMessage FromOllamaMessage(OllamaChatResponseMessage message) + { + List contents = []; + + // Add any tool calls. + if (message.ToolCalls is { Length: > 0 }) + { + foreach (var toolCall in message.ToolCalls) + { + if (toolCall.Function is { } function) + { + var id = Guid.NewGuid().ToString().Substring(0, 8); + contents.Add(new FunctionCallContent(id, function.Name, function.Arguments)); + } + } + } + + // Ollama frequently sends back empty content with tool calls. Rather than always adding an empty + // content, we only add the content if either it's not empty or there weren't any tool calls. + if (message.Content?.Length > 0 || contents.Count == 0) + { + contents.Insert(0, new TextContent(message.Content)); + } + + return new ChatMessage(new(message.Role), contents); + } + + private OllamaChatRequest ToOllamaChatRequest(IList chatMessages, ChatOptions? options, bool stream) + { + OllamaChatRequest request = new() + { + Format = options?.ResponseFormat is ChatResponseFormatJson ? "json" : null, + Messages = chatMessages.SelectMany(ToOllamaChatRequestMessages).ToArray(), + Model = options?.ModelId ?? Metadata.ModelId ?? string.Empty, + Stream = stream, + Tools = options?.Tools is { Count: > 0 } tools ? tools.OfType().Select(ToOllamaTool) : null, + }; + + if (options is not null) + { + TransferMetadataValue(nameof(OllamaRequestOptions.embedding_only), (options, value) => options.embedding_only = value); + TransferMetadataValue(nameof(OllamaRequestOptions.f16_kv), (options, value) => options.f16_kv = value); + TransferMetadataValue(nameof(OllamaRequestOptions.logits_all), (options, value) => options.logits_all = value); + TransferMetadataValue(nameof(OllamaRequestOptions.low_vram), (options, value) => options.low_vram = value); + TransferMetadataValue(nameof(OllamaRequestOptions.main_gpu), (options, value) => options.main_gpu = value); + TransferMetadataValue(nameof(OllamaRequestOptions.min_p), (options, value) => options.min_p = value); + TransferMetadataValue(nameof(OllamaRequestOptions.mirostat), (options, value) => options.mirostat = value); + TransferMetadataValue(nameof(OllamaRequestOptions.mirostat_eta), (options, value) => options.mirostat_eta = value); + TransferMetadataValue(nameof(OllamaRequestOptions.mirostat_tau), (options, value) => options.mirostat_tau = value); + TransferMetadataValue(nameof(OllamaRequestOptions.num_batch), (options, value) => options.num_batch = value); + TransferMetadataValue(nameof(OllamaRequestOptions.num_ctx), (options, value) => options.num_ctx = value); + TransferMetadataValue(nameof(OllamaRequestOptions.num_gpu), (options, value) => options.num_gpu = value); + TransferMetadataValue(nameof(OllamaRequestOptions.num_keep), (options, value) => options.num_keep = value); + TransferMetadataValue(nameof(OllamaRequestOptions.num_thread), (options, value) => options.num_thread = value); + TransferMetadataValue(nameof(OllamaRequestOptions.numa), (options, value) => options.numa = value); + TransferMetadataValue(nameof(OllamaRequestOptions.penalize_newline), (options, value) => options.penalize_newline = value); + TransferMetadataValue(nameof(OllamaRequestOptions.repeat_last_n), (options, value) => options.repeat_last_n = value); + TransferMetadataValue(nameof(OllamaRequestOptions.repeat_penalty), (options, value) => options.repeat_penalty = value); + TransferMetadataValue(nameof(OllamaRequestOptions.seed), (options, value) => options.seed = value); + TransferMetadataValue(nameof(OllamaRequestOptions.tfs_z), (options, value) => options.tfs_z = value); + TransferMetadataValue(nameof(OllamaRequestOptions.top_k), (options, value) => options.top_k = value); + TransferMetadataValue(nameof(OllamaRequestOptions.typical_p), (options, value) => options.typical_p = value); + TransferMetadataValue(nameof(OllamaRequestOptions.use_mmap), (options, value) => options.use_mmap = value); + TransferMetadataValue(nameof(OllamaRequestOptions.use_mlock), (options, value) => options.use_mlock = value); + TransferMetadataValue(nameof(OllamaRequestOptions.vocab_only), (options, value) => options.vocab_only = value); + + if (options.FrequencyPenalty is float frequencyPenalty) + { + (request.Options ??= new()).frequency_penalty = frequencyPenalty; + } + + if (options.MaxOutputTokens is int maxOutputTokens) + { + (request.Options ??= new()).num_predict = maxOutputTokens; + } + + if (options.PresencePenalty is float presencePenalty) + { + (request.Options ??= new()).presence_penalty = presencePenalty; + } + + if (options.StopSequences is { Count: > 0 }) + { + (request.Options ??= new()).stop = [.. options.StopSequences]; + } + + if (options.Temperature is float temperature) + { + (request.Options ??= new()).temperature = temperature; + } + + if (options.TopP is float topP) + { + (request.Options ??= new()).top_p = topP; + } + } + + return request; + + void TransferMetadataValue(string propertyName, Action setOption) + { + if (options.AdditionalProperties?.TryGetConvertedValue(propertyName, out T? t) is true) + { + request.Options ??= new(); + setOption(request.Options, t); + } + } + } + + private IEnumerable ToOllamaChatRequestMessages(ChatMessage content) + { + // In general, we return a single request message for each understood content item. + // However, various image models expect both text and images in the same request message. + // To handle that, attach images to a previous text message if one exists. + + OllamaChatRequestMessage? currentTextMessage = null; + foreach (var item in content.Contents) + { + if (currentTextMessage is not null && item is not ImageContent) + { + yield return currentTextMessage; + currentTextMessage = null; + } + + switch (item) + { + case TextContent textContent: + currentTextMessage = new OllamaChatRequestMessage + { + Role = content.Role.Value, + Content = textContent.Text ?? string.Empty, + }; + break; + + case ImageContent imageContent when imageContent.Data is not null: + IList images = currentTextMessage?.Images ?? []; + images.Add(Convert.ToBase64String(imageContent.Data.Value +#if NET + .Span)); +#else + .ToArray())); +#endif + + if (currentTextMessage is not null) + { + currentTextMessage.Images = images; + } + else + { + yield return new OllamaChatRequestMessage + { + Role = content.Role.Value, + Images = images, + }; + } + + break; + + case FunctionCallContent fcc: + yield return new OllamaChatRequestMessage + { + Role = "assistant", + Content = JsonSerializer.Serialize(new OllamaFunctionCallContent + { + CallId = fcc.CallId, + Name = fcc.Name, + Arguments = FunctionCallHelpers.FormatFunctionParametersAsJsonElement(fcc.Arguments, ToolCallJsonSerializerOptions), + }, JsonContext.Default.OllamaFunctionCallContent) + }; + break; + + case FunctionResultContent frc: + JsonElement jsonResult = FunctionCallHelpers.FormatFunctionResultAsJsonElement(frc.Result, ToolCallJsonSerializerOptions); + yield return new OllamaChatRequestMessage + { + Role = "tool", + Content = JsonSerializer.Serialize(new OllamaFunctionResultContent + { + CallId = frc.CallId, + Result = jsonResult, + }, JsonContext.Default.OllamaFunctionResultContent) + }; + break; + } + } + + if (currentTextMessage is not null) + { + yield return currentTextMessage; + } + } + + private OllamaTool ToOllamaTool(AIFunction function) => new() + { + Type = "function", + Function = new OllamaFunctionTool + { + Name = function.Metadata.Name, + Description = function.Metadata.Description, + Parameters = new OllamaFunctionToolParameters + { + Properties = function.Metadata.Parameters.ToDictionary( + p => p.Name, + p => FunctionCallHelpers.InferParameterJsonSchema(p, function.Metadata, ToolCallJsonSerializerOptions)), + Required = function.Metadata.Parameters.Where(p => p.IsRequired).Select(p => p.Name).ToList(), + }, + } + }; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequest.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequest.cs new file mode 100644 index 00000000000..5d2f63ddfe5 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequest.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaChatRequest +{ + public required string Model { get; set; } + public required OllamaChatRequestMessage[] Messages { get; set; } + public string? Format { get; set; } + public bool Stream { get; set; } + public IEnumerable? Tools { get; set; } + public OllamaRequestOptions? Options { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequestMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequestMessage.cs new file mode 100644 index 00000000000..5a377b1eb34 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatRequestMessage.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaChatRequestMessage +{ + public required string Role { get; set; } + public string? Content { get; set; } + public IList? Images { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponse.cs new file mode 100644 index 00000000000..8c39f9ab598 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponse.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaChatResponse +{ + public string? Model { get; set; } + public string? CreatedAt { get; set; } + public long? TotalDuration { get; set; } + public long? LoadDuration { get; set; } + public string? DoneReason { get; set; } + public int? PromptEvalCount { get; set; } + public long? PromptEvalDuration { get; set; } + public int? EvalCount { get; set; } + public long? EvalDuration { get; set; } + public OllamaChatResponseMessage? Message { get; set; } + public bool Done { get; set; } + public string? Error { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponseMessage.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponseMessage.cs new file mode 100644 index 00000000000..bf73c08d793 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatResponseMessage.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaChatResponseMessage +{ + public required string Role { get; set; } + public required string Content { get; set; } + public OllamaToolCall[]? ToolCalls { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs new file mode 100644 index 00000000000..b0ecf08895c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingGenerator.cs @@ -0,0 +1,137 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Net.Http.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// An for Ollama. +public sealed class OllamaEmbeddingGenerator : IEmbeddingGenerator> +{ + /// The api/embeddings endpoint URI. + private readonly Uri _apiEmbeddingsEndpoint; + + /// The to use for sending requests. + private readonly HttpClient _httpClient; + + /// Initializes a new instance of the class. + /// The endpoint URI where Ollama is hosted. + /// + /// The id of the model to use. This may also be overridden per request via . + /// Either this parameter or must provide a valid model id. + /// + /// An instance to use for HTTP operations. + public OllamaEmbeddingGenerator(Uri endpoint, string? modelId = null, HttpClient? httpClient = null) + { + _ = Throw.IfNull(endpoint); + if (modelId is not null) + { + _ = Throw.IfNullOrWhitespace(modelId); + } + + _apiEmbeddingsEndpoint = new Uri(endpoint, "api/embed"); + _httpClient = httpClient ?? OllamaUtilities.SharedClient; + Metadata = new("ollama", endpoint, modelId); + } + + /// + public EmbeddingGeneratorMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class + => key is null ? this as TService : null; + + /// + public void Dispose() + { + if (_httpClient != OllamaUtilities.SharedClient) + { + _httpClient.Dispose(); + } + } + + /// + public async Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(values); + + // Create request. + string[] inputs = values.ToArray(); + string? requestModel = options?.ModelId ?? Metadata.ModelId; + var request = new OllamaEmbeddingRequest + { + Model = requestModel ?? string.Empty, + Input = inputs, + }; + + if (options?.AdditionalProperties is { } requestProps) + { + if (requestProps.TryGetConvertedValue("keep_alive", out long keepAlive)) + { + request.KeepAlive = keepAlive; + } + + if (requestProps.TryGetConvertedValue("truncate", out bool truncate)) + { + request.Truncate = truncate; + } + } + + // Send request and get response. + var httpResponse = await _httpClient.PostAsJsonAsync( + _apiEmbeddingsEndpoint, + request, + JsonContext.Default.OllamaEmbeddingRequest, + cancellationToken).ConfigureAwait(false); + + var response = (await httpResponse.Content.ReadFromJsonAsync( + JsonContext.Default.OllamaEmbeddingResponse, + cancellationToken).ConfigureAwait(false))!; + + // Validate response. + if (!string.IsNullOrEmpty(response.Error)) + { + throw new InvalidOperationException($"Ollama error: {response.Error}"); + } + + if (response.Embeddings is null || response.Embeddings.Length != inputs.Length) + { + throw new InvalidOperationException($"Ollama generated {response.Embeddings?.Length ?? 0} embeddings but {inputs.Length} were expected."); + } + + // Convert response into result objects. + AdditionalPropertiesDictionary? responseProps = null; + OllamaUtilities.TransferNanosecondsTime(response, r => r.TotalDuration, "total_duration", ref responseProps); + OllamaUtilities.TransferNanosecondsTime(response, r => r.LoadDuration, "load_duration", ref responseProps); + + UsageDetails? usage = null; + if (response.PromptEvalCount is int tokens) + { + usage = new() + { + InputTokenCount = tokens, + TotalTokenCount = tokens, + }; + } + + return new(response.Embeddings.Select(e => + new Embedding(e) + { + CreatedAt = DateTimeOffset.UtcNow, + ModelId = response.Model ?? requestModel, + })) + { + Usage = usage, + AdditionalProperties = responseProps, + }; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingRequest.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingRequest.cs new file mode 100644 index 00000000000..07e3530b8ed --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingRequest.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaEmbeddingRequest +{ + public required string Model { get; set; } + public required string[] Input { get; set; } + public OllamaRequestOptions? Options { get; set; } + public bool? Truncate { get; set; } + public long? KeepAlive { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingResponse.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingResponse.cs new file mode 100644 index 00000000000..c4fd2cde87c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaEmbeddingResponse.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaEmbeddingResponse +{ + [JsonPropertyName("model")] + public string? Model { get; set; } + [JsonPropertyName("embeddings")] + public float[][]? Embeddings { get; set; } + [JsonPropertyName("total_duration")] + public long? TotalDuration { get; set; } + [JsonPropertyName("load_duration")] + public long? LoadDuration { get; set; } + [JsonPropertyName("prompt_eval_count")] + public int? PromptEvalCount { get; set; } + public string? Error { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionCallContent.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionCallContent.cs new file mode 100644 index 00000000000..f518413586a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionCallContent.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionCallContent +{ + public string? CallId { get; set; } + public string? Name { get; set; } + public JsonElement Arguments { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionResultContent.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionResultContent.cs new file mode 100644 index 00000000000..ba3eab607b8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionResultContent.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionResultContent +{ + public string? CallId { get; set; } + public JsonElement Result { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionTool.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionTool.cs new file mode 100644 index 00000000000..880e37bec2a --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionTool.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionTool +{ + public required string Name { get; set; } + public required string Description { get; set; } + public required OllamaFunctionToolParameters Parameters { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolCall.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolCall.cs new file mode 100644 index 00000000000..c94d41bd3f3 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolCall.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionToolCall +{ + public required string Name { get; set; } + public IDictionary? Arguments { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameter.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameter.cs new file mode 100644 index 00000000000..77ba2a5561c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameter.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionToolParameter +{ + public string? Type { get; set; } + public string? Description { get; set; } + public IEnumerable? Enum { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameters.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameters.cs new file mode 100644 index 00000000000..1e01d4d5d62 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaFunctionToolParameters.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json; + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaFunctionToolParameters +{ + public string Type { get; set; } = "object"; + public required IDictionary Properties { get; set; } + public required IList Required { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaRequestOptions.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaRequestOptions.cs new file mode 100644 index 00000000000..cc8b548c1a1 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaRequestOptions.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +#pragma warning disable IDE1006 // Naming Styles + +internal sealed class OllamaRequestOptions +{ + public bool? embedding_only { get; set; } + public bool? f16_kv { get; set; } + public float? frequency_penalty { get; set; } + public bool? logits_all { get; set; } + public bool? low_vram { get; set; } + public int? main_gpu { get; set; } + public float? min_p { get; set; } + public int? mirostat { get; set; } + public float? mirostat_eta { get; set; } + public float? mirostat_tau { get; set; } + public int? num_batch { get; set; } + public int? num_ctx { get; set; } + public int? num_gpu { get; set; } + public int? num_keep { get; set; } + public int? num_predict { get; set; } + public int? num_thread { get; set; } + public bool? numa { get; set; } + public bool? penalize_newline { get; set; } + public float? presence_penalty { get; set; } + public int? repeat_last_n { get; set; } + public float? repeat_penalty { get; set; } + public long? seed { get; set; } + public string[]? stop { get; set; } + public float? temperature { get; set; } + public float? tfs_z { get; set; } + public int? top_k { get; set; } + public float? top_p { get; set; } + public float? typical_p { get; set; } + public bool? use_mlock { get; set; } + public bool? use_mmap { get; set; } + public bool? vocab_only { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaTool.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaTool.cs new file mode 100644 index 00000000000..457793dc476 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaTool.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaTool +{ + public required string Type { get; set; } + public required OllamaFunctionTool Function { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaToolCall.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaToolCall.cs new file mode 100644 index 00000000000..a00d0e0e290 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaToolCall.cs @@ -0,0 +1,9 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +internal sealed class OllamaToolCall +{ + public OllamaFunctionToolCall? Function { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaUtilities.cs b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaUtilities.cs new file mode 100644 index 00000000000..ba823cde7f8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaUtilities.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using System.Threading; + +namespace Microsoft.Extensions.AI; + +internal static class OllamaUtilities +{ + /// Gets a singleton used when no other instance is supplied. + public static HttpClient SharedClient { get; } = new() + { + // Expected use is localhost access for non-production use. Typical production use should supply + // an HttpClient configured with whatever more robust resilience policy / handlers are appropriate. + Timeout = Timeout.InfiniteTimeSpan, + }; + + public static void TransferNanosecondsTime(TResponse response, Func getNanoseconds, string key, ref AdditionalPropertiesDictionary? metadata) + { + if (getNanoseconds(response) is long duration) + { + try + { + const double NanosecondsPerMillisecond = 1_000_000; + (metadata ??= [])[key] = TimeSpan.FromMilliseconds(duration / NanosecondsPerMillisecond); + } + catch (OverflowException) + { + // Ignore options that don't convert + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.Ollama/README.md b/src/Libraries/Microsoft.Extensions.AI.Ollama/README.md new file mode 100644 index 00000000000..ef8c60ff7b2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.Ollama/README.md @@ -0,0 +1,285 @@ +# Microsoft.Extensions.AI.Ollama + +Provides an implementation of the `IChatClient` interface for Ollama. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI.Ollama +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +### Chat + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Chat + Conversation History + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +Console.WriteLine(await client.CompleteAsync( +[ + new ChatMessage(ChatRole.System, "You are a helpful AI assistant"), + new ChatMessage(ChatRole.User, "What is AI?"), +])); +``` + +### Chat Streaming + +```csharp +using Microsoft.Extensions.AI; + +IChatClient client = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +await foreach (var update in client.CompleteStreamingAsync("What is AI?")) +{ + Console.Write(update); +} +``` + +### Tool Calling + +Known limitations: + +- Only a subset of models provided by Ollama support tool calling. +- Tool calling is currently not supported with streaming requests. + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; + +IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +IChatClient client = new ChatClientBuilder() + .UseFunctionInvocation() + .Use(ollamaClient); + +ChatOptions chatOptions = new() +{ + Tools = [AIFunctionFactory.Create(GetWeather)] +}; + +Console.WriteLine(await client.CompleteAsync("Do I need an umbrella?", chatOptions)); + +[Description("Gets the weather")] +static string GetWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; +``` + +### Caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .Use(ollamaClient); + +for (int i = 0; i < 3; i++) +{ + await foreach (var message in client.CompleteStreamingAsync("In less than 100 words, what is AI?")) + { + Console.Write(message); + } + + Console.WriteLine(); + Console.WriteLine(); +} +``` + +### Telemetry + +```csharp +using Microsoft.Extensions.AI; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +IChatClient client = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(ollamaClient); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Telemetry, Caching, and Tool Calling + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenTelemetry.Trace; + +// Configure telemetry +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Configure caching +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +// Configure tool calling +var chatOptions = new ChatOptions +{ + Tools = [AIFunctionFactory.Create(GetPersonAge)] +}; + +IChatClient ollamaClient = new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .UseFunctionInvocation() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(ollamaClient); + +for (int i = 0; i < 3; i++) +{ + Console.WriteLine(await client.CompleteAsync("How much older is Alice than Bob?", chatOptions)); +} + +[Description("Gets the age of a person specified by name.")] +static int GetPersonAge(string personName) => + personName switch + { + "Alice" => 42, + "Bob" => 35, + _ => 26, + }; +``` + +### Text embedding generation + +```csharp +using Microsoft.Extensions.AI; + +IEmbeddingGenerator> generator = + new OllamaEmbeddingGenerator(new Uri("http://localhost:11434/"), "all-minilm"); + +var embeddings = await generator.GenerateAsync("What is AI?"); + +Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +``` + +### Text embedding generation with caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IEmbeddingGenerator> ollamaGenerator = + new OllamaEmbeddingGenerator(new Uri("http://localhost:11434/"), "all-minilm"); + +IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(cache) + .Use(ollamaGenerator); + +foreach (var prompt in new[] { "What is AI?", "What is .NET?", "What is AI?" }) +{ + var embeddings = await generator.GenerateAsync(prompt); + + Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +} +``` + +### Dependency Injection + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +// App Setup +var builder = Host.CreateApplicationBuilder(); +builder.Services.AddSingleton(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); +builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)); + +builder.Services.AddChatClient(b => b + .UseDistributedCache() + .UseLogging() + .Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"))); + +var app = builder.Build(); + +// Elsewhere in the app +var chatClient = app.Services.GetRequiredService(); +Console.WriteLine(await chatClient.CompleteAsync("What is AI?")); +``` + +### Minimal Web API + +```csharp +using Microsoft.Extensions.AI; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddChatClient(c => + c.Use(new OllamaChatClient(new Uri("http://localhost:11434/"), "llama3.1"))); + +builder.Services.AddEmbeddingGenerator>(g => + g.Use(new OllamaEmbeddingGenerator(endpoint, "all-minilm"))); + +var app = builder.Build(); + +app.MapPost("/chat", async (IChatClient client, string message) => +{ + var response = await client.CompleteAsync(message, cancellationToken: default); + return response.Message; +}); + +app.MapPost("/embedding", async (IEmbeddingGenerator> client, string message) => +{ + var response = await client.GenerateAsync(message); + return response[0].Vector; +}); + +app.Run(); +``` + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj new file mode 100644 index 00000000000..1efedb13f11 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.csproj @@ -0,0 +1,43 @@ + + + + Microsoft.Extensions.AI + Implementation of generative AI abstractions for OpenAI-compatible endpoints. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002 + true + + + + true + true + true + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.json b/src/Libraries/Microsoft.Extensions.AI.OpenAI/Microsoft.Extensions.AI.OpenAI.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs new file mode 100644 index 00000000000..f92fcfa3bc9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs @@ -0,0 +1,659 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; +using OpenAI; +using OpenAI.Chat; + +#pragma warning disable S1135 // Track uses of "TODO" tags +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace Microsoft.Extensions.AI; + +/// An for an OpenAI or . +public sealed partial class OpenAIChatClient : IChatClient +{ + /// Default OpenAI endpoint. + private static readonly Uri _defaultOpenAIEndpoint = new("https://api.openai.com/v1"); + + /// The underlying . + private readonly OpenAIClient? _openAIClient; + + /// The underlying . + private readonly ChatClient _chatClient; + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + /// The model to use. + public OpenAIChatClient(OpenAIClient openAIClient, string modelId) + { + _ = Throw.IfNull(openAIClient); + _ = Throw.IfNullOrWhitespace(modelId); + + _openAIClient = openAIClient; + _chatClient = openAIClient.GetChatClient(modelId); + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + Uri providerUrl = typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(openAIClient) as Uri ?? _defaultOpenAIEndpoint; + + Metadata = new(providerName, providerUrl, modelId); + } + + /// Initializes a new instance of the class for the specified . + /// The underlying client. + public OpenAIChatClient(ChatClient chatClient) + { + _ = Throw.IfNull(chatClient); + + _chatClient = chatClient; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = chatClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + Uri providerUrl = typeof(ChatClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(chatClient) as Uri ?? _defaultOpenAIEndpoint; + string? model = typeof(ChatClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(chatClient) as string; + + Metadata = new(providerName, providerUrl, model); + } + + /// Gets or sets to use for any serialization activities related to tool call arguments and results. + public JsonSerializerOptions? ToolCallJsonSerializerOptions { get; set; } + + /// + public ChatClientMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class => + typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : + typeof(TService) == typeof(ChatClient) ? (TService)(object)_chatClient : + this as TService; + + /// + public async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + // Make the call to OpenAI. + OpenAI.Chat.ChatCompletion response = (await _chatClient.CompleteChatAsync( + ToOpenAIChatMessages(chatMessages), + ToOpenAIOptions(options), + cancellationToken).ConfigureAwait(false)).Value; + + // Create the return message. + ChatMessage returnMessage = new() + { + RawRepresentation = response, + Role = ToChatRole(response.Role), + }; + + // Populate its content from those in the OpenAI response content. + foreach (ChatMessageContentPart contentPart in response.Content) + { + if (ToAIContent(contentPart, response.Model) is AIContent aiContent) + { + returnMessage.Contents.Add(aiContent); + } + } + + // Also manufacture function calling content items from any tool calls in the response. + if (options?.Tools is { Count: > 0 }) + { + foreach (ChatToolCall toolCall in response.ToolCalls) + { + if (!string.IsNullOrWhiteSpace(toolCall.FunctionName)) + { + Dictionary? arguments = FunctionCallHelpers.ParseFunctionCallArguments(toolCall.FunctionArguments, out Exception? parsingException); + + returnMessage.Contents.Add(new FunctionCallContent(toolCall.Id, toolCall.FunctionName, arguments) + { + ModelId = response.Model, + Exception = parsingException, + RawRepresentation = toolCall + }); + } + } + } + + // Wrap the content in a ChatCompletion to return. + var completion = new ChatCompletion([returnMessage]) + { + RawRepresentation = response, + CompletionId = response.Id, + CreatedAt = response.CreatedAt, + ModelId = response.Model, + FinishReason = ToFinishReason(response.FinishReason), + }; + + if (response.Usage is ChatTokenUsage tokenUsage) + { + completion.Usage = new() + { + InputTokenCount = tokenUsage.InputTokenCount, + OutputTokenCount = tokenUsage.OutputTokenCount, + TotalTokenCount = tokenUsage.TotalTokenCount, + }; + + if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails details) + { + completion.Usage.AdditionalProperties = new() { [nameof(details.ReasoningTokenCount)] = details.ReasoningTokenCount }; + } + } + + if (response.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(response.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (response.Refusal is string refusal) + { + (completion.AdditionalProperties ??= [])[nameof(response.Refusal)] = refusal; + } + + if (response.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completion.AdditionalProperties ??= [])[nameof(response.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (response.SystemFingerprint is string systemFingerprint) + { + (completion.AdditionalProperties ??= [])[nameof(response.SystemFingerprint)] = systemFingerprint; + } + + return completion; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + Dictionary? functionCallInfos = null; + ChatRole? streamedRole = null; + ChatFinishReason? finishReason = null; + StringBuilder? refusal = null; + string? completionId = null; + DateTimeOffset? createdAt = null; + string? modelId = null; + string? fingerprint = null; + + // Process each update as it arrives + await foreach (OpenAI.Chat.StreamingChatCompletionUpdate chatCompletionUpdate in _chatClient.CompleteChatStreamingAsync( + ToOpenAIChatMessages(chatMessages), ToOpenAIOptions(options), cancellationToken).ConfigureAwait(false)) + { + // The role and finish reason may arrive during any update, but once they've arrived, the same value should be the same for all subsequent updates. + streamedRole ??= chatCompletionUpdate.Role is ChatMessageRole role ? ToChatRole(role) : null; + finishReason ??= chatCompletionUpdate.FinishReason is OpenAI.Chat.ChatFinishReason reason ? ToFinishReason(reason) : null; + completionId ??= chatCompletionUpdate.CompletionId; + createdAt ??= chatCompletionUpdate.CreatedAt; + modelId ??= chatCompletionUpdate.Model; + fingerprint ??= chatCompletionUpdate.SystemFingerprint; + + // Create the response content object. + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = chatCompletionUpdate.CompletionId, + CreatedAt = chatCompletionUpdate.CreatedAt, + FinishReason = finishReason, + RawRepresentation = chatCompletionUpdate, + Role = streamedRole, + }; + + // Populate it with any additional metadata from the OpenAI object. + if (chatCompletionUpdate.ContentTokenLogProbabilities is { Count: > 0 } contentTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.ContentTokenLogProbabilities)] = contentTokenLogProbs; + } + + if (chatCompletionUpdate.RefusalTokenLogProbabilities is { Count: > 0 } refusalTokenLogProbs) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.RefusalTokenLogProbabilities)] = refusalTokenLogProbs; + } + + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(chatCompletionUpdate.SystemFingerprint)] = fingerprint; + } + + // Transfer over content update items. + if (chatCompletionUpdate.ContentUpdate is { Count: > 0 }) + { + foreach (ChatMessageContentPart contentPart in chatCompletionUpdate.ContentUpdate) + { + if (ToAIContent(contentPart, modelId) is AIContent aiContent) + { + completionUpdate.Contents.Add(aiContent); + } + } + } + + // Transfer over refusal updates. + if (chatCompletionUpdate.RefusalUpdate is not null) + { + _ = (refusal ??= new()).Append(chatCompletionUpdate.RefusalUpdate); + } + + // Transfer over tool call updates. + if (chatCompletionUpdate.ToolCallUpdates is { Count: > 0 } toolCallUpdates) + { + foreach (StreamingChatToolCallUpdate toolCallUpdate in toolCallUpdates) + { + functionCallInfos ??= []; + if (!functionCallInfos.TryGetValue(toolCallUpdate.Index, out FunctionCallInfo? existing)) + { + functionCallInfos[toolCallUpdate.Index] = existing = new(); + } + + existing.CallId ??= toolCallUpdate.ToolCallId; + existing.Name ??= toolCallUpdate.FunctionName; + if (toolCallUpdate.FunctionArgumentsUpdate is not null) + { + _ = (existing.Arguments ??= new()).Append(toolCallUpdate.FunctionArgumentsUpdate); + } + } + } + + // Transfer over usage updates. + if (chatCompletionUpdate.Usage is ChatTokenUsage tokenUsage) + { + UsageDetails usageDetails = new() + { + InputTokenCount = tokenUsage.InputTokenCount, + OutputTokenCount = tokenUsage.OutputTokenCount, + TotalTokenCount = tokenUsage.TotalTokenCount, + }; + + if (tokenUsage.OutputTokenDetails is ChatOutputTokenUsageDetails details) + { + (usageDetails.AdditionalProperties = [])[nameof(tokenUsage.OutputTokenDetails)] = new Dictionary + { + [nameof(details.ReasoningTokenCount)] = details.ReasoningTokenCount, + }; + } + + // TODO: Add support for prompt token details (e.g. cached tokens) once it's exposed in OpenAI library. + + completionUpdate.Contents.Add(new UsageContent(usageDetails) + { + ModelId = modelId + }); + } + + // Now yield the item. + yield return completionUpdate; + } + + // Now that we've received all updates, combine any for function calls into a single item to yield. + if (functionCallInfos is not null) + { + StreamingChatCompletionUpdate completionUpdate = new() + { + CompletionId = completionId, + CreatedAt = createdAt, + FinishReason = finishReason, + Role = streamedRole, + }; + + foreach (var entry in functionCallInfos) + { + FunctionCallInfo fci = entry.Value; + if (!string.IsNullOrWhiteSpace(fci.Name)) + { + var arguments = FunctionCallHelpers.ParseFunctionCallArguments( + fci.Arguments?.ToString() ?? string.Empty, + out Exception? parsingException); + + completionUpdate.Contents.Add(new FunctionCallContent(fci.CallId!, fci.Name!, arguments) + { + ModelId = modelId, + Exception = parsingException + }); + } + } + + // Refusals are about the model not following the schema for tool calls. As such, if we have any refusal, + // add it to this function calling item. + if (refusal is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(ChatMessageContentPart.Refusal)] = refusal.ToString(); + } + + // Propagate additional relevant metadata. + if (fingerprint is not null) + { + (completionUpdate.AdditionalProperties ??= [])[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)] = fingerprint; + } + + yield return completionUpdate; + } + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IChatClient interface. + } + + /// POCO representing function calling info. Used to concatenation information for a single function call from across multiple streaming updates. + private sealed class FunctionCallInfo + { + public string? CallId; + public string? Name; + public StringBuilder? Arguments; + } + + /// Converts an OpenAI role to an Extensions role. + private static ChatRole ToChatRole(ChatMessageRole role) => + role switch + { + ChatMessageRole.System => ChatRole.System, + ChatMessageRole.User => ChatRole.User, + ChatMessageRole.Assistant => ChatRole.Assistant, + ChatMessageRole.Tool => ChatRole.Tool, + _ => new ChatRole(role.ToString()), + }; + + /// Converts an OpenAI finish reason to an Extensions finish reason. + private static ChatFinishReason? ToFinishReason(OpenAI.Chat.ChatFinishReason? finishReason) => + finishReason?.ToString() is not string s ? null : + finishReason switch + { + OpenAI.Chat.ChatFinishReason.Stop => ChatFinishReason.Stop, + OpenAI.Chat.ChatFinishReason.Length => ChatFinishReason.Length, + OpenAI.Chat.ChatFinishReason.ContentFilter => ChatFinishReason.ContentFilter, + OpenAI.Chat.ChatFinishReason.ToolCalls or OpenAI.Chat.ChatFinishReason.FunctionCall => ChatFinishReason.ToolCalls, + _ => new ChatFinishReason(s), + }; + + /// Converts an extensions options instance to an OpenAI options instance. + private ChatCompletionOptions ToOpenAIOptions(ChatOptions? options) + { + ChatCompletionOptions result = new(); + + if (options is not null) + { + result.FrequencyPenalty = options.FrequencyPenalty; + result.MaxOutputTokenCount = options.MaxOutputTokens; + result.TopP = options.TopP; + result.PresencePenalty = options.PresencePenalty; + result.Temperature = options.Temperature; + + if (options.StopSequences is { Count: > 0 } stopSequences) + { + foreach (string stopSequence in stopSequences) + { + result.StopSequences.Add(stopSequence); + } + } + + if (options.AdditionalProperties is { Count: > 0 } additionalProperties) + { + if (additionalProperties.TryGetConvertedValue(nameof(result.EndUserId), out string? endUserId)) + { + result.EndUserId = endUserId; + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.IncludeLogProbabilities), out bool includeLogProbabilities)) + { + result.IncludeLogProbabilities = includeLogProbabilities; + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.LogitBiases), out IDictionary? logitBiases)) + { + foreach (KeyValuePair kvp in logitBiases!) + { + result.LogitBiases[kvp.Key] = kvp.Value; + } + } + + if (additionalProperties.TryGetConvertedValue(nameof(result.AllowParallelToolCalls), out bool allowParallelToolCalls)) + { + result.AllowParallelToolCalls = allowParallelToolCalls; + } + +#pragma warning disable OPENAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + if (additionalProperties.TryGetConvertedValue(nameof(result.Seed), out long seed)) + { + result.Seed = seed; + } +#pragma warning restore OPENAI001 + + if (additionalProperties.TryGetConvertedValue(nameof(result.TopLogProbabilityCount), out int topLogProbabilityCountInt)) + { + result.TopLogProbabilityCount = topLogProbabilityCountInt; + } + } + + if (options.Tools is { Count: > 0 } tools) + { + foreach (AITool tool in tools) + { + if (tool is AIFunction af) + { + result.Tools.Add(ToOpenAIChatTool(af)); + } + } + + switch (options.ToolMode) + { + case AutoChatToolMode: + result.ToolChoice = ChatToolChoice.CreateAutoChoice(); + break; + + case RequiredChatToolMode required: + result.ToolChoice = required.RequiredFunctionName is null ? + ChatToolChoice.CreateRequiredChoice() : + ChatToolChoice.CreateFunctionChoice(required.RequiredFunctionName); + break; + } + } + + if (options.ResponseFormat is ChatResponseFormatText) + { + result.ResponseFormat = OpenAI.Chat.ChatResponseFormat.CreateTextFormat(); + } + else if (options.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + result.ResponseFormat = jsonFormat.Schema is string jsonSchema ? + OpenAI.Chat.ChatResponseFormat.CreateJsonSchemaFormat(jsonFormat.SchemaName ?? "json_schema", BinaryData.FromString(jsonSchema), jsonFormat.SchemaDescription) : + OpenAI.Chat.ChatResponseFormat.CreateJsonObjectFormat(); + } + } + + return result; + } + + /// Converts an Extensions function to an OpenAI chat tool. + private ChatTool ToOpenAIChatTool(AIFunction aiFunction) + { + _ = aiFunction.Metadata.AdditionalProperties.TryGetConvertedValue("Strict", out bool strict); + + BinaryData resultParameters = OpenAIChatToolJson.ZeroFunctionParametersSchema; + + var parameters = aiFunction.Metadata.Parameters; + if (parameters is { Count: > 0 }) + { + OpenAIChatToolJson tool = new(); + + foreach (AIFunctionParameterMetadata parameter in parameters) + { + tool.Properties.Add( + parameter.Name, + FunctionCallHelpers.InferParameterJsonSchema(parameter, aiFunction.Metadata, ToolCallJsonSerializerOptions)); + + if (parameter.IsRequired) + { + tool.Required.Add(parameter.Name); + } + } + + resultParameters = BinaryData.FromBytes( + JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.OpenAIChatToolJson)); + } + + return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict); + } + + /// Used to create the JSON payload for an OpenAI chat tool description. + private sealed class OpenAIChatToolJson + { + /// Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function. + public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray()); + + [JsonPropertyName("type")] + public string Type { get; set; } = "object"; + + [JsonPropertyName("required")] + public List Required { get; set; } = []; + + [JsonPropertyName("properties")] + public Dictionary Properties { get; set; } = []; + } + + /// Creates an from a . + /// The content part to convert into a content. + /// The model ID. + /// The constructed , or null if the content part could not be converted. + private static AIContent? ToAIContent(ChatMessageContentPart contentPart, string? modelId) + { + AIContent? aiContent = null; + + AdditionalPropertiesDictionary? additionalProperties = null; + + if (contentPart.Kind == ChatMessageContentPartKind.Text) + { + aiContent = new TextContent(contentPart.Text); + } + else if (contentPart.Kind == ChatMessageContentPartKind.Image) + { + ImageContent? imageContent; + aiContent = imageContent = + contentPart.ImageUri is not null ? new ImageContent(contentPart.ImageUri, contentPart.ImageBytesMediaType) : + contentPart.ImageBytes is not null ? new ImageContent(contentPart.ImageBytes.ToMemory(), contentPart.ImageBytesMediaType) : + null; + + if (imageContent is not null && contentPart.ImageDetailLevel?.ToString() is string detail) + { + (additionalProperties ??= [])[nameof(contentPart.ImageDetailLevel)] = detail; + } + } + + if (aiContent is not null) + { + if (contentPart.Refusal is string refusal) + { + (additionalProperties ??= [])[nameof(contentPart.Refusal)] = refusal; + } + + aiContent.ModelId = modelId; + aiContent.AdditionalProperties = additionalProperties; + aiContent.RawRepresentation = contentPart; + } + + return aiContent; + } + + /// Converts an Extensions chat message enumerable to an OpenAI chat message enumerable. + private IEnumerable ToOpenAIChatMessages(IEnumerable inputs) + { + // Maps all of the M.E.AI types to the corresponding OpenAI types. + // Unrecognized content is ignored. + + foreach (ChatMessage input in inputs) + { + if (input.Role == ChatRole.System) + { + yield return new SystemChatMessage(input.Text) { ParticipantName = input.AuthorName }; + } + else if (input.Role == ChatRole.Tool) + { + foreach (AIContent item in input.Contents) + { + if (item is FunctionResultContent resultContent) + { + string? result = resultContent.Result as string; + if (result is null && resultContent.Result is not null) + { + try + { + result = FunctionCallHelpers.FormatFunctionResultAsJson(resultContent.Result, ToolCallJsonSerializerOptions); + } + catch (NotSupportedException) + { + // If the type can't be serialized, skip it. + } + } + + yield return new ToolChatMessage(resultContent.CallId, result ?? string.Empty); + } + } + } + else if (input.Role == ChatRole.User) + { + yield return new UserChatMessage(input.Contents.Select(static (AIContent item) => item switch + { + TextContent textContent => ChatMessageContentPart.CreateTextPart(textContent.Text), + ImageContent imageContent => imageContent.Data is { IsEmpty: false } data ? ChatMessageContentPart.CreateImagePart(BinaryData.FromBytes(data), imageContent.MediaType) : + imageContent.Uri is string uri ? ChatMessageContentPart.CreateImagePart(new Uri(uri)) : + null, + _ => null, + }).Where(c => c is not null)) + { ParticipantName = input.AuthorName }; + } + else if (input.Role == ChatRole.Assistant) + { + Dictionary? toolCalls = null; + + foreach (var content in input.Contents) + { + if (content is FunctionCallContent callRequest && callRequest.CallId is not null && toolCalls?.ContainsKey(callRequest.CallId) is not true) + { + (toolCalls ??= []).Add( + callRequest.CallId, + ChatToolCall.CreateFunctionToolCall( + callRequest.CallId, + callRequest.Name, + BinaryData.FromObjectAsJson(callRequest.Arguments, ToolCallJsonSerializerOptions))); + } + } + + AssistantChatMessage message = toolCalls is not null ? + new(toolCalls.Values) { ParticipantName = input.AuthorName } : + new(input.Text) { ParticipantName = input.AuthorName }; + + if (input.AdditionalProperties?.TryGetConvertedValue(nameof(message.Refusal), out string? refusal) is true) + { + message.Refusal = refusal; + } + + yield return message; + } + } + } + + /// Source-generated JSON type information. + [JsonSerializable(typeof(OpenAIChatToolJson))] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs new file mode 100644 index 00000000000..a33fd34e1ea --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIClientExtensions.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for working with s. +public static class OpenAIClientExtensions +{ + /// Gets an for use with this . + /// The client. + /// The model. + /// An that may be used to converse via the . + public static IChatClient AsChatClient(this OpenAIClient openAIClient, string modelId) => + new OpenAIChatClient(openAIClient, modelId); + + /// Gets an for use with this . + /// The client. + /// An that may be used to converse via the . + public static IChatClient AsChatClient(this ChatClient chatClient) => + new OpenAIChatClient(chatClient); + + /// Gets an for use with this . + /// The client. + /// The model to use. + /// The number of dimensions to generate in each embedding. + /// An that may be used to generate embeddings via the . + public static IEmbeddingGenerator> AsEmbeddingGenerator(this OpenAIClient openAIClient, string modelId, int? dimensions = null) => + new OpenAIEmbeddingGenerator(openAIClient, modelId, dimensions); + + /// Gets an for use with this . + /// The client. + /// The number of dimensions to generate in each embedding. + /// An that may be used to generate embeddings via the . + public static IEmbeddingGenerator> AsEmbeddingGenerator(this EmbeddingClient embeddingClient, int? dimensions = null) => + new OpenAIEmbeddingGenerator(embeddingClient, dimensions); +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs new file mode 100644 index 00000000000..e91394befdd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIEmbeddingGenerator.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; +using OpenAI; +using OpenAI.Embeddings; + +#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields + +namespace Microsoft.Extensions.AI; + +/// An for an OpenAI . +public sealed class OpenAIEmbeddingGenerator : IEmbeddingGenerator> +{ + /// Default OpenAI endpoint. + private const string DefaultOpenAIEndpoint = "https://api.openai.com/v1"; + + /// The underlying . + private readonly OpenAIClient? _openAIClient; + + /// The underlying . + private readonly EmbeddingClient _embeddingClient; + + /// The number of dimensions produced by the generator. + private readonly int? _dimensions; + + /// Initializes a new instance of the class. + /// The underlying client. + /// The model to use. + /// The number of dimensions to generate in each embedding. + public OpenAIEmbeddingGenerator( + OpenAIClient openAIClient, string modelId, int? dimensions = null) + { + _ = Throw.IfNull(openAIClient); + _ = Throw.IfNullOrWhitespace(modelId); + if (dimensions is < 1) + { + Throw.ArgumentOutOfRangeException(nameof(dimensions), "Value must be greater than 0."); + } + + _openAIClient = openAIClient; + _embeddingClient = openAIClient.GetEmbeddingClient(modelId); + _dimensions = dimensions; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = openAIClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + string providerUrl = (typeof(OpenAIClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(openAIClient) as Uri)?.ToString() ?? + DefaultOpenAIEndpoint; + + Metadata = CreateMetadata(dimensions, providerName, providerUrl, modelId); + } + + /// Initializes a new instance of the class. + /// The underlying client. + /// The number of dimensions to generate in each embedding. + public OpenAIEmbeddingGenerator(EmbeddingClient embeddingClient, int? dimensions = null) + { + _ = Throw.IfNull(embeddingClient); + if (dimensions < 1) + { + Throw.ArgumentOutOfRangeException(nameof(dimensions), "Value must be greater than 0."); + } + + _embeddingClient = embeddingClient; + _dimensions = dimensions; + + // https://github.com/openai/openai-dotnet/issues/215 + // The endpoint and model aren't currently exposed, so use reflection to get at them, temporarily. Once packages + // implement the abstractions directly rather than providing adapters on top of the public APIs, + // the package can provide such implementations separate from what's exposed in the public API. + string providerName = embeddingClient.GetType().Name.StartsWith("Azure", StringComparison.Ordinal) ? "azureopenai" : "openai"; + string providerUrl = (typeof(EmbeddingClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance) + ?.GetValue(embeddingClient) as Uri)?.ToString() ?? + DefaultOpenAIEndpoint; + + FieldInfo? modelField = typeof(EmbeddingClient).GetField("_model", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + string? model = modelField?.GetValue(embeddingClient) as string; + + Metadata = CreateMetadata(dimensions, providerName, providerUrl, model); + } + + /// Creates the for this instance. + private static EmbeddingGeneratorMetadata CreateMetadata(int? dimensions, string providerName, string providerUrl, string? model) => + new(providerName, Uri.TryCreate(providerUrl, UriKind.Absolute, out Uri? providerUri) ? providerUri : null, model, dimensions); + + /// + public EmbeddingGeneratorMetadata Metadata { get; } + + /// + public TService? GetService(object? key = null) + where TService : class + => + typeof(TService) == typeof(OpenAIClient) ? (TService?)(object?)_openAIClient : + typeof(TService) == typeof(EmbeddingClient) ? (TService)(object)_embeddingClient : + this as TService; + + /// + public async Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + OpenAI.Embeddings.EmbeddingGenerationOptions? openAIOptions = ToOpenAIOptions(options); + + var embeddings = (await _embeddingClient.GenerateEmbeddingsAsync(values, openAIOptions, cancellationToken).ConfigureAwait(false)).Value; + + return new(embeddings.Select(e => + new Embedding(e.ToFloats()) + { + CreatedAt = DateTimeOffset.UtcNow, + ModelId = embeddings.Model, + })) + { + Usage = new() + { + InputTokenCount = embeddings.Usage.InputTokenCount, + TotalTokenCount = embeddings.Usage.TotalTokenCount + }, + }; + } + + /// + void IDisposable.Dispose() + { + // Nothing to dispose. Implementation required for the IEmbeddingGenerator interface. + } + + /// Converts an extensions options instance to an OpenAI options instance. + private OpenAI.Embeddings.EmbeddingGenerationOptions? ToOpenAIOptions(EmbeddingGenerationOptions? options) + { + OpenAI.Embeddings.EmbeddingGenerationOptions openAIOptions = new() + { + Dimensions = _dimensions, + }; + + if (options?.AdditionalProperties is { Count: > 0 } additionalProperties) + { + // Allow per-instance dimensions to be overridden by a per-call property + if (additionalProperties.TryGetConvertedValue(nameof(openAIOptions.Dimensions), out int? dimensions)) + { + openAIOptions.Dimensions = dimensions; + } + + if (additionalProperties.TryGetConvertedValue(nameof(openAIOptions.EndUserId), out string? endUserId)) + { + openAIOptions.EndUserId = endUserId; + } + } + + return openAIOptions; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md b/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md new file mode 100644 index 00000000000..f7af212f4d7 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/README.md @@ -0,0 +1,313 @@ +# Microsoft.Extensions.AI.OpenAI + +Provides an implementation of the `IChatClient` interface for the `OpenAI` package and OpenAI-compatible endpoints. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI.OpenAI +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +### Chat + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Chat + Conversation History + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +Console.WriteLine(await client.CompleteAsync( +[ + new ChatMessage(ChatRole.System, "You are a helpful AI assistant"), + new ChatMessage(ChatRole.User, "What is AI?"), +])); +``` + +### Chat streaming + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient client = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +await foreach (var update in client.CompleteStreamingAsync("What is AI?")) +{ + Console.Write(update); +} +``` + +### Tool calling + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; +using OpenAI; + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseFunctionInvocation() + .Use(openaiClient); + +ChatOptions chatOptions = new() +{ + Tools = [AIFunctionFactory.Create(GetWeather)] +}; + +await foreach (var message in client.CompleteStreamingAsync("Do I need an umbrella?", chatOptions)) +{ + Console.Write(message); +} + +[Description("Gets the weather")] +static string GetWeather() => Random.Shared.NextDouble() > 0.5 ? "It's sunny" : "It's raining"; +``` + +### Caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .Use(openaiClient); + +for (int i = 0; i < 3; i++) +{ + await foreach (var message in client.CompleteStreamingAsync("In less than 100 words, what is AI?")) + { + Console.Write(message); + } + + Console.WriteLine(); + Console.WriteLine(); +} +``` + +### Telemetry + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; +using OpenTelemetry.Trace; + +// Configure OpenTelemetry exporter +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(openaiClient); + +Console.WriteLine(await client.CompleteAsync("What is AI?")); +``` + +### Telemetry, Caching, and Tool Calling + +```csharp +using System.ComponentModel; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; +using OpenTelemetry.Trace; + +// Configure telemetry +var sourceName = Guid.NewGuid().ToString(); +var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddConsoleExporter() + .Build(); + +// Configure caching +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +// Configure tool calling +var chatOptions = new ChatOptions +{ + Tools = [AIFunctionFactory.Create(GetPersonAge)] +}; + +IChatClient openaiClient = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsChatClient("gpt-4o-mini"); + +IChatClient client = new ChatClientBuilder() + .UseDistributedCache(cache) + .UseFunctionInvocation() + .UseOpenTelemetry(sourceName, c => c.EnableSensitiveData = true) + .Use(openaiClient); + +for (int i = 0; i < 3; i++) +{ + Console.WriteLine(await client.CompleteAsync("How much older is Alice than Bob?", chatOptions)); +} + +[Description("Gets the age of a person specified by name.")] +static int GetPersonAge(string personName) => + personName switch + { + "Alice" => 42, + "Bob" => 35, + _ => 26, + }; +``` + +### Text embedding generation + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +IEmbeddingGenerator> generator = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsEmbeddingGenerator("text-embedding-3-small"); + +var embeddings = await generator.GenerateAsync("What is AI?"); + +Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +``` + +### Text embedding generation with caching + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using OpenAI; + +IDistributedCache cache = new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions())); + +IEmbeddingGenerator> openAIGenerator = + new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY")) + .AsEmbeddingGenerator("text-embedding-3-small"); + +IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(cache) + .Use(openAIGenerator); + +foreach (var prompt in new[] { "What is AI?", "What is .NET?", "What is AI?" }) +{ + var embeddings = await generator.GenerateAsync(prompt); + + Console.WriteLine(string.Join(", ", embeddings[0].Vector.ToArray())); +} +``` + +### Dependency Injection + +```csharp +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using OpenAI; + +// App Setup +var builder = Host.CreateApplicationBuilder(); +builder.Services.AddSingleton(new OpenAIClient(Environment.GetEnvironmentVariable("OPENAI_API_KEY"))); +builder.Services.AddSingleton(new MemoryDistributedCache(Options.Create(new MemoryDistributedCacheOptions()))); +builder.Services.AddLogging(b => b.AddConsole().SetMinimumLevel(LogLevel.Trace)); + +builder.Services.AddChatClient(b => b + .UseDistributedCache() + .UseLogging() + .Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +var app = builder.Build(); + +// Elsewhere in the app +var chatClient = app.Services.GetRequiredService(); +Console.WriteLine(await chatClient.CompleteAsync("What is AI?")); +``` + +### Minimal Web API + +```csharp +using Microsoft.Extensions.AI; +using OpenAI; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddSingleton(new OpenAIClient(builder.Configuration["OPENAI_API_KEY"])); + +builder.Services.AddChatClient(b => + b.Use(b.Services.GetRequiredService().AsChatClient("gpt-4o-mini"))); + +builder.Services.AddEmbeddingGenerator>(g => + g.Use(g.Services.GetRequiredService().AsEmbeddingGenerator("text-embedding-3-small"))); + +var app = builder.Build(); + +app.MapPost("/chat", async (IChatClient client, string message) => +{ + var response = await client.CompleteAsync(message); + return response.Message; +}); + +app.MapPost("/embedding", async (IEmbeddingGenerator> client, string message) => +{ + var response = await client.GenerateAsync(message); + return response[0].Vector; +}); + +app.Run(); +``` + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs new file mode 100644 index 00000000000..8128926f942 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/CachingHelpers.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Security.Cryptography; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides internal helpers for implementing caching services. +internal static class CachingHelpers +{ + /// Computes a default cache key for the specified parameters. + /// Specifies the type of the data being used to compute the key. + /// The data with which to compute the key. + /// The . + /// A string that will be used as a cache key. + public static string GetCacheKey(TValue value, JsonSerializerOptions serializerOptions) + => GetCacheKey(value, false, serializerOptions); + + /// Computes a default cache key for the specified parameters. + /// Specifies the type of the data being used to compute the key. + /// The data with which to compute the key. + /// Another data item that causes the key to vary. + /// The . + /// A string that will be used as a cache key. + public static string GetCacheKey(TValue value, bool flag, JsonSerializerOptions serializerOptions) + { + _ = Throw.IfNull(value); + _ = Throw.IfNull(serializerOptions); + serializerOptions.MakeReadOnly(); + + var jsonKeyBytes = JsonSerializer.SerializeToUtf8Bytes(value, serializerOptions.GetTypeInfo(typeof(TValue))); + + if (flag && jsonKeyBytes.Length > 0) + { + // Make an arbitrary change to the hash input based on the flag + // The alternative would be including the flag in "value" in the + // first place, but that's likely to require an extra allocation + // or the inclusion of another type in the JsonSerializerContext. + // This is a micro-optimization we can change at any time. + jsonKeyBytes[0] = (byte)(byte.MaxValue - jsonKeyBytes[0]); + } + + // The complete JSON representation is excessively long for a cache key, duplicating much of the content + // from the value. So we use a hash of it as the default key. +#if NET8_0_OR_GREATER + Span hashData = stackalloc byte[SHA256.HashSizeInBytes]; + SHA256.HashData(jsonKeyBytes, hashData); + return Convert.ToHexString(hashData); +#else + using var sha256 = SHA256.Create(); + var hashData = sha256.ComputeHash(jsonKeyBytes); + return BitConverter.ToString(hashData).Replace("-", string.Empty); +#endif + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs new file mode 100644 index 00000000000..89a778cdd1b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/CachingChatClient.cs @@ -0,0 +1,155 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that caches the results of chat calls. +/// +public abstract class CachingChatClient : DelegatingChatClient +{ + /// Initializes a new instance of the class. + /// The underlying . + protected CachingChatClient(IChatClient innerClient) + : base(innerClient) + { + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + // We're only storing the final result, not the in-flight task, so that we can avoid caching failures + // or having problems when one of the callers cancels but others don't. This has the drawback that + // concurrent callers might trigger duplicate requests, but that's acceptable. + var cacheKey = GetCacheKey(false, chatMessages, options); + + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing) + { + return existing; + } + + var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false); + return result; + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + var cacheKey = GetCacheKey(true, chatMessages, options); + if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks) + { + foreach (var chunk in existingChunks) + { + yield return chunk; + } + } + else + { + var capturedItems = new List(); + StreamingChatCompletionUpdate? previousCoalescedCopy = null; + await foreach (var item in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + yield return item; + + // If this item is compatible with the previous one, we will coalesce them in the cache + var previous = capturedItems.Count > 0 ? capturedItems[capturedItems.Count - 1] : null; + if (item.ChoiceIndex == 0 + && item.Contents.Count == 1 + && item.Contents[0] is TextContent currentTextContent + && previous is { ChoiceIndex: 0 } + && previous.Role == item.Role + && previous.Contents is { Count: 1 } + && previous.Contents[0] is TextContent previousTextContent) + { + if (!ReferenceEquals(previous, previousCoalescedCopy)) + { + // We don't want to mutate any object that we also yield, since the recipient might + // not expect that. Instead make a copy we can safely mutate. + previousCoalescedCopy = new() + { + Role = previous.Role, + AuthorName = previous.AuthorName, + AdditionalProperties = previous.AdditionalProperties, + ChoiceIndex = previous.ChoiceIndex, + RawRepresentation = previous.RawRepresentation, + Contents = [new TextContent(previousTextContent.Text)] + }; + + // The last item we captured was before we knew it could be coalesced + // with this one, so replace it with the coalesced copy + capturedItems[capturedItems.Count - 1] = previousCoalescedCopy; + } + +#pragma warning disable S1643 // Strings should not be concatenated using '+' in a loop + ((TextContent)previousCoalescedCopy.Contents[0]).Text += currentTextContent.Text; +#pragma warning restore S1643 + } + else + { + capturedItems.Add(item); + } + } + + await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Computes a cache key for the specified call parameters. + /// + /// A flag to indicate if this is a streaming call. + /// The chat content. + /// The chat options to configure the request. + /// A string that will be used as a cache key. + protected abstract string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options); + + /// + /// Returns a previously cached , if available. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); + + /// + /// Returns a previously cached list of values, if available. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken); + + /// + /// Stores a in the underlying cache. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheAsync(string key, ChatCompletion value, CancellationToken cancellationToken); + + /// + /// Stores a list of values in the underlying cache. + /// This is used when there is a call to . + /// + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs new file mode 100644 index 00000000000..d7934ba7809 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilder.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A builder for creating pipelines of . +public sealed class ChatClientBuilder +{ + /// The registered client factory instances. + private List>? _clientFactories; + + /// Initializes a new instance of the class. + /// The service provider to use for dependency injection. + public ChatClientBuilder(IServiceProvider? services = null) + { + Services = services ?? EmptyServiceProvider.Instance; + } + + /// Gets the associated with the builder instance. + public IServiceProvider Services { get; } + + /// Completes the pipeline by adding a final that represents the underlying backend. This is typically a client for an LLM service. + /// The inner client to use. + /// An instance of that represents the entire pipeline. Calls to this instance will pass through each of the pipeline stages in turn. + public IChatClient Use(IChatClient innerClient) + { + var chatClient = Throw.IfNull(innerClient); + + // To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost. + if (_clientFactories is not null) + { + for (var i = _clientFactories.Count - 1; i >= 0; i--) + { + chatClient = _clientFactories[i](Services, chatClient) ?? + throw new InvalidOperationException( + $"The {nameof(ChatClientBuilder)} entry at index {i} returned null. " + + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IChatClient)} instances."); + } + } + + return chatClient; + } + + /// Adds a factory for an intermediate chat client to the chat client pipeline. + /// The client factory function. + /// The updated instance. + public ChatClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + return Use((_, innerClient) => clientFactory(innerClient)); + } + + /// Adds a factory for an intermediate chat client to the chat client pipeline. + /// The client factory function. + /// The updated instance. + public ChatClientBuilder Use(Func clientFactory) + { + _ = Throw.IfNull(clientFactory); + + (_clientFactories ??= []).Add(clientFactory); + return this; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs new file mode 100644 index 00000000000..246ac7f3689 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientBuilderServiceCollectionExtensions.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for registering with a . +public static class ChatClientBuilderServiceCollectionExtensions +{ + /// Adds a chat client to the . + /// The to which the client should be added. + /// The factory to use to construct the instance. + /// The collection. + /// The client is registered as a scoped service. + public static IServiceCollection AddChatClient( + this IServiceCollection services, + Func clientFactory) + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(clientFactory); + + return services.AddScoped(services => + clientFactory(new ChatClientBuilder(services))); + } + + /// Adds a chat client to the . + /// The to which the client should be added. + /// The key with which to associate the client. + /// The factory to use to construct the instance. + /// The collection. + /// The client is registered as a scoped service. + public static IServiceCollection AddKeyedChatClient( + this IServiceCollection services, + object serviceKey, + Func clientFactory) + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(serviceKey); + _ = Throw.IfNull(clientFactory); + + return services.AddKeyedScoped(serviceKey, (services, _) => + clientFactory(new ChatClientBuilder(services))); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs new file mode 100644 index 00000000000..2a8b794c50e --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -0,0 +1,225 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Schema; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods on that simplify working with structured output. +/// +public static partial class ChatClientStructuredOutputExtensions +{ + private const string UsesReflectionJsonSerializerMessage = + "This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications."; + + private static JsonSerializerOptions? _defaultJsonSerializerOptions; + + /// Sends chat messages to the model, requesting a response matching the type . + /// The . + /// The chat content to send. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + /// The type of structured output to request. + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + public static Task> CompleteAsync( + this IChatClient chatClient, + IList chatMessages, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, chatMessages, DefaultJsonSerializerOptions, options, useNativeJsonSchema, cancellationToken); + + /// Sends a user chat text message to the model, requesting a response matching the type . + /// The . + /// The text content for the chat message to send. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// The type of structured output to request. + [RequiresDynamicCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. " + + "Use System.Text.Json source generation for native AOT applications.")] + public static Task> CompleteAsync( + this IChatClient chatClient, + string chatMessage, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken); + + /// Sends a user chat text message to the model, requesting a response matching the type . + /// The . + /// The text content for the chat message to send. + /// The JSON serialization options to use. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// The type of structured output to request. + public static Task> CompleteAsync( + this IChatClient chatClient, + string chatMessage, + JsonSerializerOptions serializerOptions, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class => + CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken); + + /// Sends chat messages to the model, requesting a response matching the type . + /// The . + /// The chat content to send. + /// The JSON serialization options to use. + /// The chat options to configure the request. + /// + /// Optionally specifies whether to set a JSON schema on the . + /// This improves reliability if the underlying model supports native structured output with a schema, but may cause an error if the model does not support it. + /// If not specified, the underlying provider's default will be used. + /// + /// The to monitor for cancellation requests. The default is . + /// The response messages generated by the client. + /// + /// The returned messages will not have been added to . However, any intermediate messages generated implicitly + /// by the client, including any messages for roundtrips to the model as part of the implementation of this request, will be included. + /// + /// The type of structured output to request. + public static async Task> CompleteAsync( + this IChatClient chatClient, + IList chatMessages, + JsonSerializerOptions serializerOptions, + ChatOptions? options = null, + bool? useNativeJsonSchema = null, + CancellationToken cancellationToken = default) + where T : class + { + _ = Throw.IfNull(chatClient); + _ = Throw.IfNull(chatMessages); + _ = Throw.IfNull(serializerOptions); + + serializerOptions.MakeReadOnly(); + + var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), new() + { + TreatNullObliviousAsNonNullable = true, + TransformSchemaNode = static (context, node) => + { + if (node is JsonObject obj) + { + if (obj.TryGetPropertyValue("enum", out _) + && !obj.TryGetPropertyValue("type", out _)) + { + obj.Insert(0, "type", "string"); + } + } + + return node; + }, + }); + schemaNode.Insert(0, "$schema", "https://json-schema.org/draft/2020-12/schema"); + schemaNode.Add("additionalProperties", false); + var schema = JsonSerializer.Serialize(schemaNode, JsonNodeContext.Default.JsonNode); + + ChatMessage? promptAugmentation = null; + options = (options ?? new()).Clone(); + + // Currently there's no way for the inner IChatClient to specify whether structured output + // is supported, so we always default to false. In the future, some mechanism of declaring + // capabilities may be added (e.g., on ChatClientMetadata). + if (useNativeJsonSchema.GetValueOrDefault(false)) + { + // When using native structured output, we don't add any additional prompt, because + // the LLM backend is meant to do whatever's needed to explain the schema to the LLM. + options.ResponseFormat = ChatResponseFormat.ForJsonSchema( + schema, + schemaName: typeof(T).Name, + schemaDescription: typeof(T).GetCustomAttribute()?.Description); + } + else + { + options.ResponseFormat = ChatResponseFormat.Json; + + // When not using native structured output, augment the chat messages with a schema prompt +#pragma warning disable SA1118 // Parameter should not span multiple lines + promptAugmentation = new ChatMessage(ChatRole.System, $$""" + Respond with a JSON value conforming to the following schema: + ``` + {{schema}} + ``` + """); +#pragma warning restore SA1118 // Parameter should not span multiple lines + + chatMessages.Add(promptAugmentation); + } + + try + { + var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + return new ChatCompletion(result, serializerOptions); + } + finally + { + if (promptAugmentation is not null) + { + _ = chatMessages.Remove(promptAugmentation); + } + } + } + + private static JsonSerializerOptions DefaultJsonSerializerOptions + { + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + get => _defaultJsonSerializerOptions ?? GetOrCreateDefaultJsonSerializerOptions(); + } + + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + private static JsonSerializerOptions GetOrCreateDefaultJsonSerializerOptions() + { + var options = new JsonSerializerOptions(JsonSerializerDefaults.General) + { + Converters = { new JsonStringEnumConverter() }, + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), + WriteIndented = true, + }; + return Interlocked.CompareExchange(ref _defaultJsonSerializerOptions, options, null) ?? options; + } + + [JsonSerializable(typeof(JsonNode))] + [JsonSourceGenerationOptions(WriteIndented = true)] + private sealed partial class JsonNodeContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs new file mode 100644 index 00000000000..344a01d2c22 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs @@ -0,0 +1,147 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Represents the result of a chat completion request with structured output. +/// The type of value expected from the chat completion. +/// +/// Language models are not guaranteed to honor the requested schema. If the model's output is not +/// parseable as the expected type, then will return . +/// You can access the underlying JSON response on the property. +/// +public class ChatCompletion : ChatCompletion +{ + private static readonly JsonReaderOptions _allowMultipleValuesJsonReaderOptions = new JsonReaderOptions { AllowMultipleValues = true }; + private readonly JsonSerializerOptions _serializerOptions; + + private T? _deserializedResult; + private bool _hasDeserializedResult; + + /// Initializes a new instance of the class. + /// The unstructured that is being wrapped. + /// The to use when deserializing the result. + public ChatCompletion(ChatCompletion completion, JsonSerializerOptions serializerOptions) + : base(Throw.IfNull(completion).Choices) + { + _serializerOptions = Throw.IfNull(serializerOptions); + CompletionId = completion.CompletionId; + ModelId = completion.ModelId; + CreatedAt = completion.CreatedAt; + FinishReason = completion.FinishReason; + Usage = completion.Usage; + RawRepresentation = completion.RawRepresentation; + AdditionalProperties = completion.AdditionalProperties; + } + + /// + /// Gets the result of the chat completion as an instance of . + /// If the response did not contain JSON, or if deserialization fails, this property will throw. + /// To avoid exceptions, use instead. + /// + public T Result + { + get + { + var result = GetResultCore(out var failureReason); + return failureReason switch + { + FailureReason.ResultDidNotContainJson => throw new InvalidOperationException("The response did not contain text to be deserialized"), + FailureReason.DeserializationProducedNull => throw new InvalidOperationException("The deserialized response is null"), + _ => result!, + }; + } + } + + /// + /// Attempts to deserialize the result to produce an instance of . + /// + /// The result. + /// if the result was produced, otherwise . + public bool TryGetResult([NotNullWhen(true)] out T? result) + { + try + { + result = GetResultCore(out var failureReason); + return failureReason is null; + } +#pragma warning disable CA1031 // Do not catch general exception types + catch + { + result = default; + return false; + } +#pragma warning restore CA1031 // Do not catch general exception types + } + + private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo typeInfo) + { + // We need to deserialize only the first top-level object as a workaround for a common LLM backend + // issue. GPT 3.5 Turbo commonly returns multiple top-level objects after doing a function call. + // See https://community.openai.com/t/2-json-objects-returned-when-using-function-calling-and-json-mode/574348 + var utf8ByteLength = Encoding.UTF8.GetByteCount(json); + var buffer = ArrayPool.Shared.Rent(utf8ByteLength); + try + { + var utf8SpanLength = Encoding.UTF8.GetBytes(json, 0, json.Length, buffer, 0); + var utf8Span = new ReadOnlySpan(buffer, 0, utf8SpanLength); + var reader = new Utf8JsonReader(utf8Span, _allowMultipleValuesJsonReaderOptions); + return JsonSerializer.Deserialize(ref reader, typeInfo); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + private string? GetResultAsJson() + { + var choice = Choices.Count == 1 ? Choices[0] : null; + var content = choice?.Contents.Count == 1 ? choice.Contents[0] : null; + return (content as TextContent)?.Text; + } + + private T? GetResultCore(out FailureReason? failureReason) + { + if (_hasDeserializedResult) + { + failureReason = default; + return _deserializedResult; + } + + var json = GetResultAsJson(); + if (string.IsNullOrEmpty(json)) + { + failureReason = FailureReason.ResultDidNotContainJson; + return default; + } + + // If there's an exception here, we want it to propagate, since the Result property is meant to throw directly + var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo)_serializerOptions.GetTypeInfo(typeof(T))); + if (deserialized is null) + { + failureReason = FailureReason.DeserializationProducedNull; + return default; + } + + _deserializedResult = deserialized; + _hasDeserializedResult = true; + failureReason = default; + return deserialized; + } + + private enum FailureReason + { + ResultDidNotContainJson, + DeserializationProducedNull, + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs new file mode 100644 index 00000000000..a8a4b9269e2 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClient.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that updates or replaces the used by the remainder of the pipeline. +/// +/// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options +/// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide +/// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example +/// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the +/// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance +/// and mutating the clone, for example: +/// +/// options => +/// { +/// var newOptions = options?.Clone() ?? new(); +/// newOptions.MaxTokens = 1000; +/// return newOptions; +/// } +/// +/// +public sealed class ConfigureOptionsChatClient : DelegatingChatClient +{ + /// The callback delegate used to configure options. + private readonly Func _configureOptions; + + /// Initializes a new instance of the class with the specified callback. + /// The inner client. + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + public ConfigureOptionsChatClient(IChatClient innerClient, Func configureOptions) + : base(innerClient) + { + _configureOptions = Throw.IfNull(configureOptions); + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + return await base.CompleteAsync(chatMessages, _configureOptions(options), cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await foreach (var update in base.CompleteStreamingAsync(chatMessages, _configureOptions(options), cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..12b903c0dac --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ConfigureOptionsChatClientBuilderExtensions.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable SA1629 // Documentation text should end with a period + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class ConfigureOptionsChatClientBuilderExtensions +{ + /// + /// Adds a callback that updates or replaces . This can be used to set default options. + /// + /// The . + /// + /// The delegate to invoke to configure the instance. It is passed the caller-supplied + /// instance and should return the configured instance to use. + /// + /// The . + /// + /// The configuration callback is invoked with the caller-supplied instance. To override the caller-supplied options + /// with a new instance, the callback may simply return that new instance, for example _ => new ChatOptions() { MaxTokens = 1000 }. To provide + /// a new instance only if the caller-supplied instance is `null`, the callback may conditionally return a new instance, for example + /// options => options ?? new ChatOptions() { MaxTokens = 1000 }. Any changes to the caller-provided options instance will persist on the + /// original instance, so the callback must take care to only do so when such mutations are acceptable, such as by cloning the original instance + /// and mutating the clone, for example: + /// + /// options => + /// { + /// var newOptions = options?.Clone() ?? new(); + /// newOptions.MaxTokens = 1000; + /// return newOptions; + /// } + /// + /// + public static ChatClientBuilder UseChatOptions( + this ChatClientBuilder builder, Func configureOptions) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(configureOptions); + + return builder.Use(innerClient => new ConfigureOptionsChatClient(innerClient, configureOptions)); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs new file mode 100644 index 00000000000..65c50c090bd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClient.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that caches the results of completion calls, storing them as JSON in an . +/// +public class DistributedCachingChatClient : CachingChatClient +{ + private readonly IDistributedCache _storage; + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used as the backing store for the cache. + public DistributedCachingChatClient(IChatClient innerClient, IDistributedCache storage) + : base(innerClient) + { + _storage = Throw.IfNull(storage); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing cache data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return (ChatCompletion?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion))); + } + + return null; + } + + /// + protected override async Task?> ReadCacheStreamingAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return (IReadOnlyList?)JsonSerializer.Deserialize(existingJson, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); + } + + return null; + } + + /// + protected override async Task WriteCacheAsync(string key, ChatCompletion value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override async Task WriteCacheStreamingAsync(string key, IReadOnlyList value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, _jsonSerializerOptions.GetTypeInfo(typeof(IReadOnlyList))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) + { + // While it might be desirable to include ChatOptions in the cache key, it's not always possible, + // since ChatOptions can contain types that are not guaranteed to be serializable or have a stable + // hashcode across multiple calls. So the default cache key is simply the JSON representation of + // the chat contents. Developers may subclass and override this to provide custom rules. + _jsonSerializerOptions.MakeReadOnly(); + return CachingHelpers.GetCacheKey(chatMessages, streaming, _jsonSerializerOptions); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..d465161e1e4 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/DistributedCachingChatClientBuilderExtensions.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Extension methods for adding a to an pipeline. +/// +public static class DistributedCachingChatClientBuilderExtensions +{ + /// + /// Adds a as the next stage in the pipeline. + /// + /// The . + /// + /// An optional instance that will be used as the backing store for the cache. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The provided as . + public static ChatClientBuilder UseDistributedCache(this ChatClientBuilder builder, IDistributedCache? storage = null, Action? configure = null) + { + _ = Throw.IfNull(builder); + return builder.Use((services, innerClient) => + { + storage ??= services.GetRequiredService(); + var chatClient = new DistributedCachingChatClient(innerClient, storage); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs new file mode 100644 index 00000000000..c46d7f43156 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -0,0 +1,639 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating chat client that invokes functions defined on . +/// Include this in a chat pipeline to resolve function calls automatically. +/// +/// +/// When this client receives a in a chat completion, it responds +/// by calling the corresponding defined in , +/// producing a . +/// +public class FunctionInvokingChatClient : DelegatingChatClient +{ + /// Maximum number of roundtrips allowed to the inner client. + private int? _maximumIterationsPerRequest; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , or the next instance in a chain of clients. + public FunctionInvokingChatClient(IChatClient innerClient) + : base(innerClient) + { + } + + /// + /// Gets or sets a value indicating whether to handle exceptions that occur during function calls. + /// + /// + /// + /// If the value is , then if a function call fails with an exception, the + /// underlying will be instructed to give a response without invoking + /// any further functions. + /// + /// + /// If the value is , the underlying will be allowed + /// to continue attempting function calls until is reached. + /// + /// + /// The default value is . + /// + /// + public bool RetryOnError { get; set; } + + /// + /// Gets or sets a value indicating whether detailed exception information should be included + /// in the chat history when calling the underlying . + /// + /// + /// + /// The default value is , meaning that only a generic error message will + /// be included in the chat history. This prevents the underlying language model from disclosing + /// raw exception details to the end user, since it does not receive that information. Even in this + /// case, the raw object is available to application code by inspecting + /// the property. + /// + /// + /// If set to , the full exception message will be added to the chat history + /// when calling the underlying . This can help it to bypass problems on + /// its own, for example by retrying the function call with different arguments. However it may + /// result in disclosing the raw exception information to external users, which may be a security + /// concern depending on the application scenario. + /// + /// + public bool DetailedErrors { get; set; } + + /// + /// Gets or sets a value indicating whether to allow concurrent invocation of functions. + /// + /// + /// + /// An individual response from the inner client may contain multiple function call requests. + /// By default, such function calls may be issued to execute concurrently with each other. Set + /// to false to disable such concurrent invocation and force + /// the functions to be invoked serially. + /// + /// + /// The default value is . + /// + /// + public bool ConcurrentInvocation { get; set; } = true; + + /// + /// Gets or sets a value indicating whether to keep intermediate messages in the chat history. + /// + /// + /// When the inner returns to the + /// , the adds + /// those messages to the list of messages, along with instances + /// it creates with the results of invoking the requested functions. The resulting augmented + /// list of messages is then passed to the inner client in order to send the results back. + /// By default, is , and those + /// messages will persist in the list provided to + /// and by the caller. Set + /// to to remove those messages prior to completing the operation. + /// + public bool KeepFunctionCallingMessages { get; set; } = true; + + /// + /// Gets or sets the maximum number of iterations per request. + /// + /// + /// + /// Each request to this may end up making + /// multiple requests to the inner client. Each time the inner client responds with + /// a function call request, this client may perform that invocation and send the results + /// back to the inner client in a new request. This property limits the number of times + /// such a roundtrip is performed. If null, there is no limit applied. If set, the value + /// must be at least one, as it includes the initial request. + /// + /// + /// The default value is . + /// + /// + public int? MaximumIterationsPerRequest + { + get => _maximumIterationsPerRequest; + set + { + if (value < 1) + { + Throw.ArgumentOutOfRangeException(nameof(value)); + } + + _maximumIterationsPerRequest = value; + } + } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + ChatCompletion? response; + + HashSet? messagesToRemove = null; + HashSet? contentsToRemove = null; + try + { + for (int iteration = 0; ; iteration++) + { + // Make the call to the handler. + response = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (options is null + || options.Tools is not { Count: > 0 } + || response.Choices.Count == 0 + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations)) + { + break; + } + + // If there's more than one choice, we don't know which one to add to chat history, or which + // of their function calls to process. This should not happen except if the developer has + // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer + // doesn't realize this and is wasting their budget requesting extra choices we'd never use. + if (response.Choices.Count > 1) + { + throw new InvalidOperationException($"Automatic function call invocation only accepts a single choice, but {response.Choices.Count} choices were received."); + } + + // Extract any function call contents on the first choice. If there are none, we're done. + // We don't have any way to express a preference to use a different choice, since this + // is a niche case especially with function calling. + FunctionCallContent[] functionCallContents = response.Message.Contents.OfType().ToArray(); + if (functionCallContents.Length == 0) + { + break; + } + + // Track all added messages in order to remove them, if requested. + if (!KeepFunctionCallingMessages) + { + messagesToRemove ??= []; + } + + // Add the original response message into the history and track the message for removal. + chatMessages.Add(response.Message); + if (messagesToRemove is not null) + { + if (functionCallContents.Length == response.Message.Contents.Count) + { + // The most common case is that the response message contains only function calling content. + // In that case, we can just track the whole message for removal. + _ = messagesToRemove.Add(response.Message); + } + else + { + // In the less likely case where some content is function calling and some isn't, we don't want to remove + // the non-function calling content by removing the whole message. So we track the content directly. + (contentsToRemove ??= []).UnionWith(functionCallContents); + } + } + + // Add the responses from the function calls into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (modeAndMessages.MessagesAdded is not null) + { + messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + } + + switch (modeAndMessages.Mode) + { + case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: + // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = ChatToolMode.Auto; + break; + + case ContinueMode.AllowOneMoreRoundtrip: + // The LLM gets one further chance to answer, but cannot use tools. + options = options.Clone(); + options.Tools = null; + break; + + case ContinueMode.Terminate: + // Bail immediately. + return response; + } + } + + return response!; + } + finally + { + RemoveMessagesAndContentFromList(messagesToRemove, contentsToRemove, chatMessages); + } + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(chatMessages); + + HashSet? messagesToRemove = null; + try + { + for (int iteration = 0; ; iteration++) + { + List? functionCallContents = null; + await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + // We're going to emit all StreamingChatMessage items upstream, even ones that represent + // function calls, because a given StreamingChatMessage can contain other content too. + yield return chunk; + + foreach (var item in chunk.Contents.OfType()) + { + functionCallContents ??= []; + functionCallContents.Add(item); + } + } + + // If there are no tools to call, or for any other reason we should stop, return the response. + if (options is null + || options.Tools is not { Count: > 0 } + || (MaximumIterationsPerRequest is { } maxIterations && iteration >= maxIterations) + || functionCallContents is not { Count: > 0 }) + { + break; + } + + // Track all added messages in order to remove them, if requested. + if (!KeepFunctionCallingMessages) + { + messagesToRemove ??= []; + } + + // Add a manufactured response message containing the function call contents to the chat history. + ChatMessage functionCallMessage = new(ChatRole.Assistant, [.. functionCallContents]); + chatMessages.Add(functionCallMessage); + _ = messagesToRemove?.Add(functionCallMessage); + + // Process all of the functions, adding their results into the history. + var modeAndMessages = await ProcessFunctionCallsAsync(chatMessages, options, functionCallContents, iteration, cancellationToken).ConfigureAwait(false); + if (modeAndMessages.MessagesAdded is not null) + { + messagesToRemove?.UnionWith(modeAndMessages.MessagesAdded); + } + + // Decide how to proceed based on the result of the function calls. + switch (modeAndMessages.Mode) + { + case ContinueMode.Continue when options.ToolMode is RequiredChatToolMode: + // We have to reset this after the first iteration, otherwise we'll be in an infinite loop. + options = options.Clone(); + options.ToolMode = ChatToolMode.Auto; + break; + + case ContinueMode.AllowOneMoreRoundtrip: + // The LLM gets one further chance to answer, but cannot use tools. + options = options.Clone(); + options.Tools = null; + break; + + case ContinueMode.Terminate: + // Bail immediately. + yield break; + } + } + } + finally + { + RemoveMessagesAndContentFromList(messagesToRemove, contentToRemove: null, chatMessages); + } + } + + /// + /// Removes all of the messages in from + /// and all of the content in from the messages in . + /// + private static void RemoveMessagesAndContentFromList( + HashSet? messagesToRemove, + HashSet? contentToRemove, + IList messages) + { + Debug.Assert( + contentToRemove is null || messagesToRemove is not null, + "We should only be tracking content to remove if we're also tracking messages to remove."); + + if (messagesToRemove is not null) + { + for (int m = messages.Count - 1; m >= 0; m--) + { + ChatMessage message = messages[m]; + + if (contentToRemove is not null) + { + for (int c = message.Contents.Count - 1; c >= 0; c--) + { + if (contentToRemove.Contains(message.Contents[c])) + { + message.Contents.RemoveAt(c); + } + } + } + + if (messages.Count == 0 || messagesToRemove.Contains(messages[m])) + { + messages.RemoveAt(m); + } + } + } + } + + /// + /// Processes the function calls in the list. + /// + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call contents representing the functions to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task<(ContinueMode Mode, IList MessagesAdded)> ProcessFunctionCallsAsync( + IList chatMessages, ChatOptions options, IReadOnlyList functionCallContents, int iteration, CancellationToken cancellationToken) + { + // We must add a response for every tool call, regardless of whether we successfully executed it or not. + // If we successfully execute it, we'll add the result. If we don't, we'll add an error. + + int functionCount = functionCallContents.Count; + Debug.Assert(functionCount > 0, $"Expecteded {nameof(functionCount)} to be > 0, got {functionCount}."); + + // Process all functions. If there's more than one and concurrent invocation is enabled, do so in parallel. + if (functionCount == 1) + { + FunctionInvocationResult result = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[0], iteration, 0, 1, cancellationToken).ConfigureAwait(false); + IList added = AddResponseMessages(chatMessages, [result]); + return (result.ContinueMode, added); + } + else + { + FunctionInvocationResult[] results; + + if (ConcurrentInvocation) + { + // Schedule the invocation of every function. + results = await Task.WhenAll( + from i in Enumerable.Range(0, functionCount) + select Task.Run(() => ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken))).ConfigureAwait(false); + } + else + { + // Invoke each function serially. + results = new FunctionInvocationResult[functionCount]; + for (int i = 0; i < functionCount; i++) + { + results[i] = await ProcessFunctionCallAsync(chatMessages, options, functionCallContents[i], iteration, i, functionCount, cancellationToken).ConfigureAwait(false); + } + } + + ContinueMode continueMode = ContinueMode.Continue; + IList added = AddResponseMessages(chatMessages, results); + foreach (FunctionInvocationResult fir in results) + { + if (fir.ContinueMode > continueMode) + { + continueMode = fir.ContinueMode; + } + } + + return (continueMode, added); + } + } + + /// Processes the function call described in . + /// The current chat contents, inclusive of the function call contents being processed. + /// The options used for the response being processed. + /// The function call content representing the function to be invoked. + /// The iteration number of how many roundtrips have been made to the inner client. + /// The 0-based index of the function being called out of total functions. + /// The number of function call requests made, of which this is one. + /// The to monitor for cancellation requests. + /// A value indicating how the caller should proceed. + private async Task ProcessFunctionCallAsync( + IList chatMessages, ChatOptions options, FunctionCallContent functionCallContent, + int iteration, int functionCallIndex, int totalFunctionCount, CancellationToken cancellationToken) + { + // Look up the AIFunction for the function call. If the requested function isn't available, send back an error. + AIFunction? function = options.Tools!.OfType().FirstOrDefault(t => t.Metadata.Name == functionCallContent.Name); + if (function is null) + { + return new(ContinueMode.Continue, FunctionStatus.NotFound, functionCallContent, result: null, exception: null); + } + + FunctionInvocationContext context = new(chatMessages, functionCallContent, function) + { + Iteration = iteration, + FunctionCallIndex = functionCallIndex, + FunctionCount = totalFunctionCount, + }; + + try + { + object? result = await InvokeFunctionAsync(context, cancellationToken).ConfigureAwait(false); + return new( + context.Terminate ? ContinueMode.Terminate : ContinueMode.Continue, + FunctionStatus.CompletedSuccessfully, + functionCallContent, + result, + exception: null); + } + catch (Exception e) when (!cancellationToken.IsCancellationRequested) + { + return new( + RetryOnError ? ContinueMode.Continue : ContinueMode.AllowOneMoreRoundtrip, // We won't allow further function calls, hence the LLM will just get one more chance to give a final answer. + FunctionStatus.Failed, + functionCallContent, + result: null, + exception: e); + } + } + + /// Represents the return value of , dictating how the loop should behave. + /// These values are ordered from least severe to most severe, and code explicitly depends on the ordering. + internal enum ContinueMode + { + /// Send back the responses and continue processing. + Continue = 0, + + /// Send back the response but without any tools. + AllowOneMoreRoundtrip = 1, + + /// Immediately exit the function calling loop. + Terminate = 2, + } + + /// Adds one or more response messages for function invocation results. + /// The chat to which to add the one or more response messages. + /// Information about the function call invocations and results. + /// A list of all chat messages added to . + protected virtual IList AddResponseMessages(IList chat, ReadOnlySpan results) + { + _ = Throw.IfNull(chat); + + var contents = new AIContent[results.Length]; + for (int i = 0; i < results.Length; i++) + { + contents[i] = CreateFunctionResultContent(results[i]); + } + + ChatMessage message = new(ChatRole.Tool, contents); + chat.Add(message); + return [message]; + + FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult result) + { + _ = Throw.IfNull(result); + + object? functionResult; + if (result.Status == FunctionStatus.CompletedSuccessfully) + { + functionResult = result.Result ?? "Success: Function completed."; + } + else + { + string message = result.Status switch + { + FunctionStatus.NotFound => "Error: Requested function not found.", + FunctionStatus.Failed => "Error: Function failed.", + _ => "Error: Unknown error.", + }; + + if (DetailedErrors && result.Exception is not null) + { + message = $"{message} Exception: {result.Exception.Message}"; + } + + functionResult = message; + } + + return new FunctionResultContent(result.CallContent.CallId, result.CallContent.Name, functionResult, result.Exception); + } + } + + /// Invokes the function asynchronously. + /// + /// The function invocation context detailing the function to be invoked and its arguments along with additional request information. + /// + /// The to monitor for cancellation requests. The default is . + /// The result of the function invocation. This may be null if the function invocation returned null. + protected virtual Task InvokeFunctionAsync(FunctionInvocationContext context, CancellationToken cancellationToken) + { + _ = Throw.IfNull(context); + + return context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken); + } + + /// Provides context for a function invocation. + public sealed class FunctionInvocationContext + { + /// Initializes a new instance of the class. + /// The chat contents associated with the operation that initiated this function call request. + /// The AI function to be invoked. + /// The function call content information associated with this invocation. + internal FunctionInvocationContext( + IList chatMessages, + FunctionCallContent functionCallContent, + AIFunction function) + { + Function = function; + CallContent = functionCallContent; + ChatMessages = chatMessages; + } + + /// Gets or sets the AI function to be invoked. + public AIFunction Function { get; set; } + + /// Gets or sets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; set; } + + /// Gets or sets the chat contents associated with the operation that initiated this function call request. + public IList ChatMessages { get; set; } + + /// Gets or sets the number of this iteration with the underlying client. + /// + /// The initial request to the client that passes along the chat contents provided to the + /// is iteration 1. If the client responds with a function call request, the next request to the client is iteration 2, and so on. + /// + public int Iteration { get; set; } + + /// Gets or sets the index of the function call within the iteration. + /// + /// The response from the underlying client may include multiple function call requests. + /// This index indicates the position of the function call within the iteration. + /// + public int FunctionCallIndex { get; set; } + + /// Gets or sets the total number of function call requests within the iteration. + /// + /// The response from the underlying client may include multiple function call requests. + /// This count indicates how many there were. + /// + public int FunctionCount { get; set; } + + /// Gets or sets a value indicating whether to terminate the request. + /// + /// In response to a function call request, the function may be invoked, its result added to the chat contents, + /// and a new request issued to the wrapped client. If this property is set to true, that subsequent request + /// will not be issued and instead the loop immediately terminated rather than continuing until there are no + /// more function call requests in responses. + /// + public bool Terminate { get; set; } + } + + /// Provides information about the invocation of a function call. + public sealed class FunctionInvocationResult + { + internal FunctionInvocationResult(ContinueMode continueMode, FunctionStatus status, FunctionCallContent callContent, object? result, Exception? exception) + { + ContinueMode = continueMode; + Status = status; + CallContent = callContent; + Result = result; + Exception = exception; + } + + /// Gets status about how the function invocation completed. + public FunctionStatus Status { get; } + + /// Gets the function call content information associated with this invocation. + public FunctionCallContent CallContent { get; } + + /// Gets the result of the function call. + public object? Result { get; } + + /// Gets any exception the function call threw. + public Exception? Exception { get; } + + /// Gets an indication for how the caller should continue the processing loop. + internal ContinueMode ContinueMode { get; } + } + + /// Provides error codes for when errors occur as part of the function calling loop. + public enum FunctionStatus + { + /// The operation completed successfully. + CompletedSuccessfully, + + /// The requested function could not be found. + NotFound, + + /// The function call failed with an exception. + Failed, + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..15010b42068 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClientBuilderExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Provides extension methods for attaching a to a chat pipeline. +/// +public static class FunctionInvokingChatClientBuilderExtensions +{ + /// + /// Enables automatic function call invocation on the chat pipeline. + /// + /// This works by adding an instance of with default options. + /// The being used to build the chat pipeline. + /// An optional callback that can be used to configure the instance. + /// The supplied . + public static ChatClientBuilder UseFunctionInvocation(this ChatClientBuilder builder, Action? configure = null) + { + _ = Throw.IfNull(builder); + + return builder.Use(innerClient => + { + var chatClient = new FunctionInvokingChatClient(innerClient); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs new file mode 100644 index 00000000000..f0a9e8a0d75 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClient.cs @@ -0,0 +1,154 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable EA0000 // Use source generated logging methods for improved performance +#pragma warning disable CA2254 // Template should be a static expression + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that logs chat operations to an . +public class LoggingChatClient : DelegatingChatClient +{ + /// An instance used for all logging. + private readonly ILogger _logger; + + /// The to use for serialization of state written to the logger. + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used for all logging. + public LoggingChatClient(IChatClient innerClient, ILogger logger) + : base(innerClient) + { + _logger = Throw.IfNull(logger); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing logging data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + public override async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + LogStart(chatMessages, options); + try + { + var completion = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (completion, _jsonSerializerOptions), null, static (state, _) => + $"CompleteAsync completed: {JsonSerializer.Serialize(state.completion, state._jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion)))}"); + } + else + { + _logger.LogDebug("CompleteAsync completed."); + } + } + + return completion; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteAsync failed."); + throw; + } + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + LogStart(chatMessages, options); + + IAsyncEnumerator e; + try + { + e = base.CompleteStreamingAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteStreamingAsync failed."); + throw; + } + + try + { + StreamingChatCompletionUpdate? update = null; + while (true) + { + try + { + if (!await e.MoveNextAsync().ConfigureAwait(false)) + { + break; + } + + update = e.Current; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "CompleteStreamingAsync failed."); + throw; + } + + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (update, _jsonSerializerOptions), null, static (state, _) => + $"CompleteStreamingAsync received update: {JsonSerializer.Serialize(state.update, state._jsonSerializerOptions.GetTypeInfo(typeof(StreamingChatCompletionUpdate)))}"); + } + else + { + _logger.LogDebug("CompleteStreamingAsync received update."); + } + } + + yield return update; + } + + _logger.LogDebug("CompleteStreamingAsync completed."); + } + finally + { + await e.DisposeAsync().ConfigureAwait(false); + } + } + + private void LogStart(IList chatMessages, ChatOptions? options, [CallerMemberName] string? methodName = null) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (methodName, chatMessages, options, this), null, static (state, _) => + $"{state.methodName} invoked: " + + $"Messages: {JsonSerializer.Serialize(state.chatMessages, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(IList)))}. " + + $"Options: {JsonSerializer.Serialize(state.options, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatOptions)))}. " + + $"Metadata: {JsonSerializer.Serialize(state.Item4.Metadata, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatClientMetadata)))}."); + } + else + { + _logger.LogDebug($"{methodName} invoked."); + } + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..056ba5401fc --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/LoggingChatClientBuilderExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class LoggingChatClientBuilderExtensions +{ + /// Adds logging to the chat client pipeline. + /// The . + /// + /// An optional with which logging should be performed. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The . + public static ChatClientBuilder UseLogging( + this ChatClientBuilder builder, ILogger? logger = null, Action? configure = null) + { + _ = Throw.IfNull(builder); + + return builder.Use((services, innerClient) => + { + logger ??= services.GetRequiredService().CreateLogger(nameof(LoggingChatClient)); + var chatClient = new LoggingChatClient(innerClient, logger); + configure?.Invoke(chatClient); + return chatClient; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs new file mode 100644 index 00000000000..13e2d1229dd --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClient.cs @@ -0,0 +1,509 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating chat client that implements the OpenTelemetry Semantic Conventions for Generative AI systems. +/// +/// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. +/// The specification is still experimental and subject to change; as such, the telemetry output by this client is also subject to change. +/// +public sealed class OpenTelemetryChatClient : DelegatingChatClient +{ + private readonly ActivitySource _activitySource; + private readonly Meter _meter; + + private readonly Histogram _tokenUsageHistogram; + private readonly Histogram _operationDurationHistogram; + + private readonly string? _modelId; + private readonly string? _modelProvider; + private readonly string? _endpointAddress; + private readonly int _endpointPort; + + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An optional source name that will be used on the telemetry data. + public OpenTelemetryChatClient(IChatClient innerClient, string? sourceName = null) + : base(innerClient) + { + Debug.Assert(innerClient is not null, "Should have been validated by the base ctor"); + + ChatClientMetadata metadata = innerClient!.Metadata; + _modelId = metadata.ModelId; + _modelProvider = metadata.ProviderName; + _endpointAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); + _endpointPort = metadata.ProviderUri?.Port ?? 0; + + string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; + _activitySource = new(name); + _meter = new(name); + + _tokenUsageHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.TokenUsage.Name, + OpenTelemetryConsts.TokensUnit, + OpenTelemetryConsts.GenAI.Client.TokenUsage.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.TokenUsage.ExplicitBucketBoundaries }); + + _operationDurationHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.OperationDuration.Name, + OpenTelemetryConsts.SecondsUnit, + OpenTelemetryConsts.GenAI.Client.OperationDuration.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries }); + + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when formatting chat data into telemetry strings. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + _activitySource.Dispose(); + _meter.Dispose(); + } + + base.Dispose(disposing); + } + + /// + /// Gets or sets a value indicating whether potentially sensitive information (e.g. prompts) should be included in telemetry. + /// + /// + /// The value is by default, meaning that telemetry will include metadata such as token counts but not the raw text of prompts or completions. + /// + public bool EnableSensitiveData { get; set; } + + /// + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + _jsonSerializerOptions.MakeReadOnly(); + + using Activity? activity = StartActivity(chatMessages, options); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + string? requestModelId = options?.ModelId ?? _modelId; + + ChatCompletion? response = null; + Exception? error = null; + try + { + response = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + SetCompletionResponse(activity, requestModelId, response, error, stopwatch); + } + + return response; + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _jsonSerializerOptions.MakeReadOnly(); + + using Activity? activity = StartActivity(chatMessages, options); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + string? requestModelId = options?.ModelId ?? _modelId; + + IAsyncEnumerable response; + try + { + response = base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } + catch (Exception ex) + { + SetCompletionResponse(activity, requestModelId, null, ex, stopwatch); + throw; + } + + var responseEnumerator = response.ConfigureAwait(false).GetAsyncEnumerator(); + List? streamedContents = activity is not null ? [] : null; + try + { + while (true) + { + StreamingChatCompletionUpdate update; + try + { + if (!await responseEnumerator.MoveNextAsync()) + { + break; + } + + update = responseEnumerator.Current; + } + catch (Exception ex) + { + SetCompletionResponse(activity, requestModelId, null, ex, stopwatch); + throw; + } + + streamedContents?.Add(update); + yield return update; + } + } + finally + { + if (activity is not null) + { + UsageContent? usageContent = streamedContents?.SelectMany(c => c.Contents).OfType().LastOrDefault(); + SetCompletionResponse( + activity, + stopwatch, + requestModelId, + OrganizeStreamingContent(streamedContents), + streamedContents?.SelectMany(c => c.Contents).OfType(), + usage: usageContent?.Details); + } + + await responseEnumerator.DisposeAsync(); + } + } + + /// Gets a value indicating whether diagnostics are enabled. + private bool Enabled => _activitySource.HasListeners(); + + /// Convert chat history to a string aligned with the OpenAI format. + private static string ToOpenAIFormat(IEnumerable messages, JsonSerializerOptions serializerOptions) + { + var sb = new StringBuilder().Append('['); + + string messageSeparator = string.Empty; + foreach (var message in messages) + { + _ = sb.Append(messageSeparator); + messageSeparator = ", \n"; + + string text = string.Concat(message.Contents.OfType().Select(c => c.Text)); + _ = sb.Append("{\"role\": \"").Append(message.Role).Append("\", \"content\": ").Append(JsonSerializer.Serialize(text, serializerOptions.GetTypeInfo(typeof(string)))); + + if (message.Contents.OfType().Any()) + { + _ = sb.Append(", \"tool_calls\": ").Append('['); + + string messageItemSeparator = string.Empty; + foreach (var functionCall in message.Contents.OfType()) + { + _ = sb.Append(messageItemSeparator); + messageItemSeparator = ", \n"; + + _ = sb.Append("{\"id\": \"").Append(functionCall.CallId) + .Append("\", \"function\": {\"arguments\": ").Append(JsonSerializer.Serialize(functionCall.Arguments, serializerOptions.GetTypeInfo(typeof(IDictionary)))) + .Append(", \"name\": \"").Append(functionCall.Name) + .Append("\"}, \"type\": \"function\"}"); + } + + _ = sb.Append(']'); + } + + _ = sb.Append('}'); + } + + _ = sb.Append(']'); + return sb.ToString(); + } + + /// Organize streaming content by choice index. + private static Dictionary> OrganizeStreamingContent(IEnumerable? contents) + { + Dictionary> choices = []; + if (contents is null) + { + return choices; + } + + foreach (var content in contents) + { + if (!choices.TryGetValue(content.ChoiceIndex, out var choiceContents)) + { + choices[content.ChoiceIndex] = choiceContents = []; + } + + choiceContents.Add(content); + } + + return choices; + } + + /// Creates an activity for a chat completion request, or returns null if not enabled. + private Activity? StartActivity(IList chatMessages, ChatOptions? options) + { + Activity? activity = null; + if (Enabled) + { + string? modelId = options?.ModelId ?? _modelId; + + activity = _activitySource.StartActivity( + $"chat.completions {modelId}", + ActivityKind.Client, + default(ActivityContext), + [ + new(OpenTelemetryConsts.GenAI.Operation.Name, "chat"), + new(OpenTelemetryConsts.GenAI.Request.Model, modelId), + new(OpenTelemetryConsts.GenAI.System, _modelProvider), + ]); + + if (activity is not null) + { + if (_endpointAddress is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Server.Address, _endpointAddress) + .SetTag(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (options is not null) + { + if (options.FrequencyPenalty is float frequencyPenalty) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.FrequencyPenalty, frequencyPenalty); + } + + if (options.MaxOutputTokens is int maxTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.MaxTokens, maxTokens); + } + + if (options.PresencePenalty is float presencePenalty) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.PresencePenalty, presencePenalty); + } + + if (options.StopSequences is IList stopSequences) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.StopSequences, $"[{string.Join(", ", stopSequences.Select(s => $"\"{s}\""))}]"); + } + + if (options.Temperature is float temperature) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.Temperature, temperature); + } + + if (options.AdditionalProperties?.TryGetConvertedValue("top_k", out double topK) is true) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.TopK, topK); + } + + if (options.TopP is float top_p) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.TopP, top_p); + } + } + + if (EnableSensitiveData) + { + _ = activity.AddEvent(new ActivityEvent( + OpenTelemetryConsts.GenAI.Content.Prompt, + tags: new ActivityTagsCollection([new(OpenTelemetryConsts.GenAI.Prompt, ToOpenAIFormat(chatMessages, _jsonSerializerOptions))]))); + } + } + } + + return activity; + } + + /// Adds chat completion information to the activity. + private void SetCompletionResponse( + Activity? activity, + string? requestModelId, + ChatCompletion? completions, + Exception? error, + Stopwatch? stopwatch) + { + if (!Enabled) + { + return; + } + + if (_operationDurationHistogram.Enabled && stopwatch is not null) + { + TagList tags = default; + + AddMetricTags(ref tags, requestModelId, completions); + if (error is not null) + { + tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); + } + + _operationDurationHistogram.Record(stopwatch.Elapsed.TotalSeconds, tags); + } + + if (_tokenUsageHistogram.Enabled && completions?.Usage is { } usage) + { + if (usage.InputTokenCount is int inputTokens) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input"); + AddMetricTags(ref tags, requestModelId, completions); + _tokenUsageHistogram.Record(inputTokens); + } + + if (usage.OutputTokenCount is int outputTokens) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "output"); + AddMetricTags(ref tags, requestModelId, completions); + _tokenUsageHistogram.Record(outputTokens); + } + } + + if (activity is null) + { + return; + } + + if (error is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Error.Type, error.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, error.Message); + return; + } + + if (completions is not null) + { + if (completions.FinishReason is ChatFinishReason finishReason) + { +#pragma warning disable CA1308 // Normalize strings to uppercase + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.FinishReasons, $"[\"{finishReason.Value.ToLowerInvariant()}\"]"); +#pragma warning restore CA1308 + } + + if (!string.IsNullOrWhiteSpace(completions.CompletionId)) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Id, completions.CompletionId); + } + + if (completions.ModelId is not null) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Model, completions.ModelId); + } + + if (completions.Usage?.InputTokenCount is int inputTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.InputTokens, inputTokens); + } + + if (completions.Usage?.OutputTokenCount is int outputTokens) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.OutputTokens, outputTokens); + } + + if (EnableSensitiveData) + { + _ = activity.AddEvent(new ActivityEvent( + OpenTelemetryConsts.GenAI.Content.Completion, + tags: new ActivityTagsCollection([new(OpenTelemetryConsts.GenAI.Completion, ToOpenAIFormat(completions.Choices, _jsonSerializerOptions))]))); + } + } + } + + /// Adds streaming chat completion information to the activity. + private void SetCompletionResponse( + Activity? activity, + Stopwatch? stopwatch, + string? requestModelId, + Dictionary> choices, + IEnumerable? toolCalls, + UsageDetails? usage) + { + if (activity is null || !Enabled || choices.Count == 0) + { + return; + } + + string? id = null; + ChatFinishReason? finishReason = null; + string? modelId = null; + List messages = new(choices.Count); + + foreach (var choice in choices) + { + ChatRole? role = null; + List items = []; + foreach (var update in choice.Value) + { + id ??= update.CompletionId; + role ??= update.Role; + finishReason ??= update.FinishReason; + foreach (AIContent content in update.Contents) + { + items.Add(content); + modelId ??= content.ModelId; + } + } + + messages.Add(new ChatMessage(role ?? ChatRole.Assistant, items)); + } + + if (toolCalls is not null && messages.FirstOrDefault()?.Contents is { } c) + { + foreach (var functionCall in toolCalls) + { + c.Add(functionCall); + } + } + + ChatCompletion completion = new(messages) + { + CompletionId = id, + FinishReason = finishReason, + ModelId = modelId, + Usage = usage, + }; + + SetCompletionResponse(activity, requestModelId, completion, error: null, stopwatch); + } + + private void AddMetricTags(ref TagList tags, string? requestModelId, ChatCompletion? completions) + { + tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, "chat"); + + if (requestModelId is not null) + { + tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModelId); + } + + tags.Add(OpenTelemetryConsts.GenAI.System, _modelProvider); + + if (_endpointAddress is string endpointAddress) + { + tags.Add(OpenTelemetryConsts.Server.Address, endpointAddress); + tags.Add(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (completions?.ModelId is string responseModel) + { + tags.Add(OpenTelemetryConsts.GenAI.Response.Model, responseModel); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs new file mode 100644 index 00000000000..bf1ff4e9f0d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/OpenTelemetryChatClientBuilderExtensions.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class OpenTelemetryChatClientBuilderExtensions +{ + /// + /// Adds OpenTelemetry support to the chat client pipeline, following the OpenTelemetry Semantic Conventions for Generative AI systems. + /// + /// + /// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. + /// The specification is still experimental and subject to change; as such, the telemetry output by this client is also subject to change. + /// + /// The . + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The . + public static ChatClientBuilder UseOpenTelemetry( + this ChatClientBuilder builder, string? sourceName = null, Action? configure = null) => + Throw.IfNull(builder).Use(innerClient => + { + var chatClient = new OpenTelemetryChatClient(innerClient, sourceName); + configure?.Invoke(chatClient); + return chatClient; + }); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs new file mode 100644 index 00000000000..8438d467eb6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/CachingEmbeddingGenerator.cs @@ -0,0 +1,129 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that caches the results of embedding generation calls. +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public abstract class CachingEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + /// Initializes a new instance of the class. + /// The underlying . + protected CachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator) + : base(innerGenerator) + { + } + + /// + public override async Task> GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(values); + + // Optimize for the common-case of a single value in a list/array. + if (values is IList valuesList) + { + switch (valuesList.Count) + { + case 0: + return []; + + case 1: + // In the expected common case where we can cheaply tell there's only a single value and access it, + // we can avoid all the overhead of splitting the list and reassembling it. + var cacheKey = GetCacheKey(valuesList[0], options); + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding e) + { + return [e]; + } + else + { + var generated = await base.GenerateAsync(valuesList, options, cancellationToken).ConfigureAwait(false); + if (generated.Count != 1) + { + throw new InvalidOperationException($"Expected exactly one embedding to be generated, but received {generated.Count}."); + } + + await WriteCacheAsync(cacheKey, generated[0], cancellationToken).ConfigureAwait(false); + return generated; + } + } + } + + // Some of the inputs may already be cached. Go through each, checking to see whether each individually is cached. + // Split those that are cached into one list and those that aren't into another. We retain their original positions + // so that we can reassemble the results in the correct order. + GeneratedEmbeddings results = []; + List<(int Index, string CacheKey, TInput Input)>? uncached = null; + foreach (TInput input in values) + { + // We're only storing the final result, not the in-flight task, so that we can avoid caching failures + // or having problems when one of the callers cancels but others don't. This has the drawback that + // concurrent callers might trigger duplicate requests, but that's acceptable. + var cacheKey = GetCacheKey(input, options); + + if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is TEmbedding existing) + { + results.Add(existing); + } + else + { + (uncached ??= []).Add((results.Count, cacheKey, input)); + results.Add(null!); // temporary placeholder + } + } + + // If anything wasn't cached, we need to generate embeddings for those. + if (uncached is not null) + { + // Now make a single call to the wrapped generator to generate embeddings for all of the uncached inputs. + var uncachedResults = await base.GenerateAsync(uncached.Select(e => e.Input), options, cancellationToken).ConfigureAwait(false); + + // Store the resulting embeddings into the cache individually. + for (int i = 0; i < uncachedResults.Count; i++) + { + await WriteCacheAsync(uncached[i].CacheKey, uncachedResults[i], cancellationToken).ConfigureAwait(false); + } + + // Fill in the gaps with the newly generated results. + for (int i = 0; i < uncachedResults.Count; i++) + { + results[uncached[i].Index] = uncachedResults[i]; + } + } + + Debug.Assert(results.All(e => e is not null), "Expected all values to be non-null"); + return results; + } + + /// + /// Computes a cache key for the specified call parameters. + /// + /// The for which an embedding is being requested. + /// The options to configure the request. + /// A string that will be used as a cache key. + protected abstract string GetCacheKey(TInput value, EmbeddingGenerationOptions? options); + + /// Returns a previously cached , if available. + /// The cache key. + /// The to monitor for cancellation requests. + /// The previously cached data, if available, otherwise . + protected abstract Task ReadCacheAsync(string key, CancellationToken cancellationToken); + + /// Stores a in the underlying cache. + /// The cache key. + /// The to be stored. + /// The to monitor for cancellation requests. + /// A representing the completion of the operation. + protected abstract Task WriteCacheAsync(string key, TEmbedding value, CancellationToken cancellationToken); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs new file mode 100644 index 00000000000..932bb2f91b8 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGenerator.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// A delegating embedding generator that caches the results of embedding generation calls, +/// storing them as JSON in an . +/// +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public class DistributedCachingEmbeddingGenerator : CachingEmbeddingGenerator + where TEmbedding : Embedding +{ + private readonly IDistributedCache _storage; + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// A instance that will be used as the backing store for the cache. + public DistributedCachingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, IDistributedCache storage) + : base(innerGenerator) + { + _ = Throw.IfNull(storage); + _storage = storage; + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing cache data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set + { + _ = Throw.IfNull(value); + _jsonSerializerOptions = value; + } + } + + /// + protected override async Task ReadCacheAsync(string key, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _jsonSerializerOptions.MakeReadOnly(); + + if (await _storage.GetAsync(key, cancellationToken).ConfigureAwait(false) is byte[] existingJson) + { + return JsonSerializer.Deserialize(existingJson, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); + } + + return null; + } + + /// + protected override async Task WriteCacheAsync(string key, TEmbedding value, CancellationToken cancellationToken) + { + _ = Throw.IfNull(key); + _ = Throw.IfNull(value); + _jsonSerializerOptions.MakeReadOnly(); + + var newJson = JsonSerializer.SerializeToUtf8Bytes(value, (JsonTypeInfo)_jsonSerializerOptions.GetTypeInfo(typeof(TEmbedding))); + await _storage.SetAsync(key, newJson, cancellationToken).ConfigureAwait(false); + } + + /// + protected override string GetCacheKey(TInput value, EmbeddingGenerationOptions? options) + { + // While it might be desirable to include options in the cache key, it's not always possible, + // since options can contain types that are not guaranteed to be serializable or have a stable + // hashcode across multiple calls. So the default cache key is simply the JSON representation of + // the value. Developers may subclass and override this to provide custom rules. + return CachingHelpers.GetCacheKey(value, _jsonSerializerOptions); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..77aaa30e05d --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/DistributedCachingEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Extension methods for adding a to an +/// pipeline. +/// +public static class DistributedCachingEmbeddingGeneratorBuilderExtensions +{ + /// + /// Adds a as the next stage in the pipeline. + /// + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The . + /// + /// An optional instance that will be used as the backing store for the cache. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The provided as . + public static EmbeddingGeneratorBuilder UseDistributedCache( + this EmbeddingGeneratorBuilder builder, + IDistributedCache? storage = null, + Action>? configure = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(builder); + return builder.Use((services, innerGenerator) => + { + storage ??= services.GetRequiredService(); + var result = new DistributedCachingEmbeddingGenerator(innerGenerator, storage); + configure?.Invoke(result); + return result; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs new file mode 100644 index 00000000000..96c4c92d4a9 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilder.cs @@ -0,0 +1,79 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A builder for creating pipelines of . +/// The type from which embeddings will be generated. +/// The type of embeddings to generate. +public sealed class EmbeddingGeneratorBuilder + where TEmbedding : Embedding +{ + /// The registered client factory instances. + private List, IEmbeddingGenerator>>? _generatorFactories; + + /// Initializes a new instance of the class. + /// The service provider to use for dependency injection. + public EmbeddingGeneratorBuilder(IServiceProvider? services = null) + { + Services = services ?? EmptyServiceProvider.Instance; + } + + /// Gets the associated with the builder instance. + public IServiceProvider Services { get; } + + /// + /// Builds an instance of using the specified inner generator. + /// + /// The inner generator to use. + /// An instance of . + /// + /// If there are any factories registered with this builder, is used as a seed to + /// the last factory, and the result of each factory delegate is passed to the previously registered factory. + /// The final result is then returned from this call. + /// + public IEmbeddingGenerator Use(IEmbeddingGenerator innerGenerator) + { + var embeddingGenerator = Throw.IfNull(innerGenerator); + + // To match intuitive expectations, apply the factories in reverse order, so that the first factory added is the outermost. + if (_generatorFactories is not null) + { + for (var i = _generatorFactories.Count - 1; i >= 0; i--) + { + embeddingGenerator = _generatorFactories[i](Services, embeddingGenerator) ?? + throw new InvalidOperationException( + $"The {nameof(IEmbeddingGenerator)} entry at index {i} returned null. " + + $"Ensure that the callbacks passed to {nameof(Use)} return non-null {nameof(IEmbeddingGenerator)} instances."); + } + } + + return embeddingGenerator; + } + + /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. + /// The generator factory function. + /// The updated instance. + public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) + { + _ = Throw.IfNull(generatorFactory); + + return Use((_, innerGenerator) => generatorFactory(innerGenerator)); + } + + /// Adds a factory for an intermediate embedding generator to the embedding generator pipeline. + /// The generator factory function. + /// The updated instance. + public EmbeddingGeneratorBuilder Use(Func, IEmbeddingGenerator> generatorFactory) + { + _ = Throw.IfNull(generatorFactory); + + _generatorFactories ??= []; + _generatorFactories.Add(generatorFactory); + return this; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs new file mode 100644 index 00000000000..369de130e72 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/EmbeddingGeneratorBuilderServiceCollectionExtensions.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extension methods for registering with a . +public static class EmbeddingGeneratorBuilderServiceCollectionExtensions +{ + /// Adds a embedding generator to the . + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The to which the generator should be added. + /// The factory to use to construct the instance. + /// The collection. + /// The generator is registered as a scoped service. + public static IServiceCollection AddEmbeddingGenerator( + this IServiceCollection services, + Func, IEmbeddingGenerator> generatorFactory) + where TEmbedding : Embedding + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(generatorFactory); + + return services.AddScoped(services => + generatorFactory(new EmbeddingGeneratorBuilder(services))); + } + + /// Adds an embedding generator to the . + /// The type from which embeddings will be generated. + /// The type of embeddings to generate. + /// The to which the service should be added. + /// The key with which to associated the generator. + /// The factory to use to construct the instance. + /// The collection. + /// The generator is registered as a scoped service. + public static IServiceCollection AddKeyedEmbeddingGenerator( + this IServiceCollection services, + object serviceKey, + Func, IEmbeddingGenerator> generatorFactory) + where TEmbedding : Embedding + { + _ = Throw.IfNull(services); + _ = Throw.IfNull(serviceKey); + _ = Throw.IfNull(generatorFactory); + + return services.AddKeyedScoped(serviceKey, (services, _) => + generatorFactory(new EmbeddingGeneratorBuilder(services))); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs new file mode 100644 index 00000000000..b7981de8129 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGenerator.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable EA0000 // Use source generated logging methods for improved performance + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that logs embedding generation operations to an . +/// Specifies the type of the input passed to the generator. +/// Specifies the type of the embedding instance produced by the generator. +public class LoggingEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + /// An instance used for all logging. + private readonly ILogger _logger; + + /// The to use for serialization of state written to the logger. + private JsonSerializerOptions _jsonSerializerOptions; + + /// Initializes a new instance of the class. + /// The underlying . + /// An instance that will be used for all logging. + public LoggingEmbeddingGenerator(IEmbeddingGenerator innerGenerator, ILogger logger) + : base(innerGenerator) + { + _logger = Throw.IfNull(logger); + _jsonSerializerOptions = JsonDefaults.Options; + } + + /// Gets or sets JSON serialization options to use when serializing logging data. + public JsonSerializerOptions JsonSerializerOptions + { + get => _jsonSerializerOptions; + set => _jsonSerializerOptions = Throw.IfNull(value); + } + + /// + public override async Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + if (_logger.IsEnabled(LogLevel.Debug)) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + _logger.Log(LogLevel.Trace, 0, (values, options, this), null, static (state, _) => + "GenerateAsync invoked: " + + $"Values: {JsonSerializer.Serialize(state.values, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(IEnumerable)))}. " + + $"Options: {JsonSerializer.Serialize(state.options, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGenerationOptions)))}. " + + $"Metadata: {JsonSerializer.Serialize(state.Item3.Metadata, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGeneratorMetadata)))}."); + } + else + { + _logger.LogDebug("GenerateAsync invoked."); + } + } + + try + { + var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + + if (_logger.IsEnabled(LogLevel.Debug)) + { + _logger.LogDebug("GenerateAsync generated {Count} embedding(s).", embeddings.Count); + } + + return embeddings; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, "GenerateAsync failed."); + throw; + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..1335a3fd8d3 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/LoggingEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class LoggingEmbeddingGeneratorBuilderExtensions +{ + /// Adds logging to the embedding generator pipeline. + /// Specifies the type of the input passed to the generator. + /// Specifies the type of the embedding instance produced by the generator. + /// The . + /// + /// An optional with which logging should be performed. If not supplied, an instance will be resolved from the service provider. + /// + /// An optional callback that can be used to configure the instance. + /// The . + public static EmbeddingGeneratorBuilder UseLogging( + this EmbeddingGeneratorBuilder builder, ILogger? logger = null, Action>? configure = null) + where TEmbedding : Embedding + { + _ = Throw.IfNull(builder); + + return builder.Use((services, innerGenerator) => + { + logger ??= services.GetRequiredService().CreateLogger(nameof(LoggingEmbeddingGenerator)); + var generator = new LoggingEmbeddingGenerator(innerGenerator, logger); + configure?.Invoke(generator); + return generator; + }); + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs new file mode 100644 index 00000000000..8105cc64bdf --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGenerator.cs @@ -0,0 +1,239 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// A delegating embedding generator that implements the OpenTelemetry Semantic Conventions for Generative AI systems. +/// +/// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. +/// The specification is still experimental and subject to change; as such, the telemetry output by this generator is also subject to change. +/// +/// The type of input used to produce embeddings. +/// The type of embedding generated. +public sealed class OpenTelemetryEmbeddingGenerator : DelegatingEmbeddingGenerator + where TEmbedding : Embedding +{ + private readonly ActivitySource _activitySource; + private readonly Meter _meter; + + private readonly Histogram _tokenUsageHistogram; + private readonly Histogram _operationDurationHistogram; + + private readonly string? _modelId; + private readonly string? _modelProvider; + private readonly string? _endpointAddress; + private readonly int _endpointPort; + private readonly int? _dimensions; + + /// + /// Initializes a new instance of the class. + /// + /// The underlying , which is the next stage of the pipeline. + /// An optional source name that will be used on the telemetry data. + public OpenTelemetryEmbeddingGenerator(IEmbeddingGenerator innerGenerator, string? sourceName = null) + : base(innerGenerator) + { + Debug.Assert(innerGenerator is not null, "Should have been validated by the base ctor."); + + EmbeddingGeneratorMetadata metadata = innerGenerator!.Metadata; + _modelId = metadata.ModelId; + _modelProvider = metadata.ProviderName; + _endpointAddress = metadata.ProviderUri?.GetLeftPart(UriPartial.Path); + _endpointPort = metadata.ProviderUri?.Port ?? 0; + _dimensions = metadata.Dimensions; + + string name = string.IsNullOrEmpty(sourceName) ? OpenTelemetryConsts.DefaultSourceName : sourceName!; + _activitySource = new(name); + _meter = new(name); + + _tokenUsageHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.TokenUsage.Name, + OpenTelemetryConsts.TokensUnit, + OpenTelemetryConsts.GenAI.Client.TokenUsage.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.TokenUsage.ExplicitBucketBoundaries }); + + _operationDurationHistogram = _meter.CreateHistogram( + OpenTelemetryConsts.GenAI.Client.OperationDuration.Name, + OpenTelemetryConsts.SecondsUnit, + OpenTelemetryConsts.GenAI.Client.OperationDuration.Description, + advice: new() { HistogramBucketBoundaries = OpenTelemetryConsts.GenAI.Client.OperationDuration.ExplicitBucketBoundaries }); + } + + /// + protected override void Dispose(bool disposing) + { + if (disposing) + { + _activitySource.Dispose(); + _meter.Dispose(); + } + + base.Dispose(disposing); + } + + /// Gets a value indicating whether diagnostics are enabled. + private bool Enabled => _activitySource.HasListeners(); + + /// + public override async Task> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + _ = Throw.IfNull(values); + + using Activity? activity = StartActivity(); + Stopwatch? stopwatch = _operationDurationHistogram.Enabled ? Stopwatch.StartNew() : null; + + GeneratedEmbeddings? response = null; + Exception? error = null; + try + { + response = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + SetCompletionResponse(activity, response, error, stopwatch); + } + + return response; + } + + /// Creates an activity for an embedding generation request, or returns null if not enabled. + private Activity? StartActivity() + { + Activity? activity = null; + if (Enabled) + { + activity = _activitySource.StartActivity( + $"embedding {_modelId}", + ActivityKind.Client, + default(ActivityContext), + [ + new(OpenTelemetryConsts.GenAI.Operation.Name, "embedding"), + new(OpenTelemetryConsts.GenAI.Request.Model, _modelId), + new(OpenTelemetryConsts.GenAI.System, _modelProvider), + ]); + + if (activity is not null) + { + if (_endpointAddress is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Server.Address, _endpointAddress) + .SetTag(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + if (_dimensions is int dimensions) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Request.EmbeddingDimensions, dimensions); + } + } + } + + return activity; + } + + /// Adds embedding generation response information to the activity. + private void SetCompletionResponse( + Activity? activity, + GeneratedEmbeddings? embeddings, + Exception? error, + Stopwatch? stopwatch) + { + if (!Enabled) + { + return; + } + + int? inputTokens = null; + string? responseModelId = null; + if (embeddings is not null) + { + responseModelId = embeddings.FirstOrDefault()?.ModelId; + if (embeddings.Usage?.InputTokenCount is int i) + { + inputTokens = inputTokens.GetValueOrDefault() + i; + } + } + + if (_operationDurationHistogram.Enabled && stopwatch is not null) + { + TagList tags = default; + AddMetricTags(ref tags, responseModelId); + if (error is not null) + { + tags.Add(OpenTelemetryConsts.Error.Type, error.GetType().FullName); + } + + _operationDurationHistogram.Record(stopwatch.Elapsed.TotalSeconds, tags); + } + + if (_tokenUsageHistogram.Enabled && inputTokens.HasValue) + { + TagList tags = default; + tags.Add(OpenTelemetryConsts.GenAI.Token.Type, "input"); + AddMetricTags(ref tags, responseModelId); + + _tokenUsageHistogram.Record(inputTokens.Value); + } + + if (activity is null) + { + return; + } + + if (error is not null) + { + _ = activity + .SetTag(OpenTelemetryConsts.Error.Type, error.GetType().FullName) + .SetStatus(ActivityStatusCode.Error, error.Message); + return; + } + + if (inputTokens.HasValue) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.InputTokens, inputTokens); + } + + if (responseModelId is not null) + { + _ = activity.SetTag(OpenTelemetryConsts.GenAI.Response.Model, responseModelId); + } + } + + private void AddMetricTags(ref TagList tags, string? responseModelId) + { + tags.Add(OpenTelemetryConsts.GenAI.Operation.Name, "embedding"); + + if (_modelId is string requestModel) + { + tags.Add(OpenTelemetryConsts.GenAI.Request.Model, requestModel); + } + + tags.Add(OpenTelemetryConsts.GenAI.System, _modelProvider); + + if (_endpointAddress is string endpointAddress) + { + tags.Add(OpenTelemetryConsts.Server.Address, endpointAddress); + tags.Add(OpenTelemetryConsts.Server.Port, _endpointPort); + } + + // Assume all of the embeddings in the same batch used the same model + if (responseModelId is not null) + { + tags.Add(OpenTelemetryConsts.GenAI.Response.Model, responseModelId); + } + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs new file mode 100644 index 00000000000..ba60847ef93 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Embeddings/OpenTelemetryEmbeddingGeneratorBuilderExtensions.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides extensions for configuring instances. +public static class OpenTelemetryEmbeddingGeneratorBuilderExtensions +{ + /// + /// Adds OpenTelemetry support to the embedding generator pipeline, following the OpenTelemetry Semantic Conventions for Generative AI systems. + /// + /// + /// The draft specification this follows is available at https://opentelemetry.io/docs/specs/semconv/gen-ai/. + /// The specification is still experimental and subject to change; as such, the telemetry output by this generator is also subject to change. + /// + /// The type of input used to produce embeddings. + /// The type of embedding generated. + /// The . + /// An optional source name that will be used on the telemetry data. + /// An optional callback that can be used to configure the instance. + /// The . + public static EmbeddingGeneratorBuilder UseOpenTelemetry( + this EmbeddingGeneratorBuilder builder, string? sourceName = null, Action>? configure = null) + where TEmbedding : Embedding => + Throw.IfNull(builder).Use(innerGenerator => + { + var generator = new OpenTelemetryEmbeddingGenerator(innerGenerator, sourceName); + configure?.Invoke(generator); + return generator; + }); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs b/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs new file mode 100644 index 00000000000..5e3abc9fc0c --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/EmptyServiceProvider.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.Extensions.AI; + +/// Provides an implementation of that contains no services. +internal sealed class EmptyServiceProvider : IKeyedServiceProvider +{ + /// Gets a singleton instance of . + public static EmptyServiceProvider Instance { get; } = new(); + + /// + public object? GetService(Type serviceType) => null; + + /// + public object? GetKeyedService(Type serviceType, object? serviceKey) => null; + + /// + public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => + GetKeyedService(serviceType, serviceKey) ?? + throw new InvalidOperationException($"No service for type '{serviceType}' and key '{serviceKey}' has been registered."); +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs new file mode 100644 index 00000000000..25f239f8883 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionContext.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using System.Threading; + +namespace Microsoft.Extensions.AI; + +/// Provides additional context to the invocation of an created by . +/// +/// A delegate or passed to methods may represent a method that has a parameter +/// of type . Whereas all other parameters are passed by name from the supplied collection of arguments, +/// a parameter is passed specially by the implementation, in order to pass relevant +/// context into the method's invocation. For example, any passed to the +/// method is available from the property. +/// +public class AIFunctionContext +{ + /// Initializes a new instance of the class. + public AIFunctionContext() + { + } + + /// Gets or sets a related to the operation. + public CancellationToken CancellationToken { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs new file mode 100644 index 00000000000..0fff0cd64fa --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -0,0 +1,480 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Shared.Collections; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// Provides factory methods for creating commonly-used implementations of . +public static +#if NET + partial +#endif + class AIFunctionFactory +{ + internal const string UsesReflectionJsonSerializerMessage = + "This method uses the reflection-based JsonSerializer which can break in trimmed or AOT applications."; + + /// Lazily-initialized default options instance. + private static AIFunctionFactoryCreateOptions? _defaultOptions; + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The created for invoking . + [RequiresUnreferencedCode(UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(UsesReflectionJsonSerializerMessage)] + public static AIFunction Create(Delegate method) => Create(method, _defaultOptions ??= new()); + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method.Method, method.Target, options); + } + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The name to use for the . + /// The description to use for the . + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied Delegate.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied Delegate.")] + public static AIFunction Create(Delegate method, string? name, string? description = null) + => Create(method, (_defaultOptions ??= new()).SerializerOptions, name, description); + + /// Creates an instance for a method, specified via a delegate. + /// The method to be represented via the created . + /// The used to marshal function parameters. + /// The name to use for the . + /// The description to use for the . + /// The created for invoking . + public static AIFunction Create(Delegate method, JsonSerializerOptions options, string? name = null, string? description = null) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method.Method, method.Target, new(options) { Name = name, Description = description }); + } + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] + public static AIFunction Create(MethodInfo method, object? target = null) + => Create(method, target, _defaultOptions ??= new()); + + /// + /// Creates an instance for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// Metadata to use to override defaults inferred from . + /// The created for invoking . + public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + return new ReflectionAIFunction(method, target, options); + } + + private sealed +#if NET + partial +#endif + class ReflectionAIFunction : AIFunction + { + private readonly MethodInfo _method; + private readonly object? _target; + private readonly Func, AIFunctionContext?, object?>[] _parameterMarshalers; + private readonly Func> _returnMarshaler; + private readonly JsonTypeInfo? _returnTypeInfo; + private readonly bool _needsAIFunctionContext; + + /// + /// Initializes a new instance of the class for a method, specified via an instance + /// and an optional target object if the method is an instance method. + /// + /// The method to be represented via the created . + /// + /// The target object for the if it represents an instance method. + /// This should be if and only if is a static method. + /// + /// Function creation options. + public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) + { + _ = Throw.IfNull(method); + _ = Throw.IfNull(options); + + options.SerializerOptions.MakeReadOnly(); + + if (method.ContainsGenericParameters) + { + Throw.ArgumentException(nameof(method), "Open generic methods are not supported"); + } + + if (!method.IsStatic && target is null) + { + Throw.ArgumentNullException(nameof(target), "Target must not be null for an instance method."); + } + + _method = method; + _target = target; + + // Get the function name to use. + string? functionName = options.Name; + if (functionName is null) + { + functionName = SanitizeMetadataName(method.Name!); + + const string AsyncSuffix = "Async"; + if (IsAsyncMethod(method) && + functionName.EndsWith(AsyncSuffix, StringComparison.Ordinal) && + functionName.Length > AsyncSuffix.Length) + { + functionName = functionName.Substring(0, functionName.Length - AsyncSuffix.Length); + } + + static bool IsAsyncMethod(MethodInfo method) + { + Type t = method.ReturnType; + + if (t == typeof(Task) || t == typeof(ValueTask)) + { + return true; + } + + if (t.IsGenericType) + { + t = t.GetGenericTypeDefinition(); + if (t == typeof(Task<>) || t == typeof(ValueTask<>) || t == typeof(IAsyncEnumerable<>)) + { + return true; + } + } + + return false; + } + } + + // Build up a list of AIParameterMetadata for the parameters we expect to be populated + // from arguments. Some arguments are populated specially, not from arguments, and thus + // we don't want to advertise their metadata. + List? parameterMetadata = options.Parameters is not null ? null : []; + + // Get marshaling delegates for parameters and build up the parameter metadata. + var parameters = method.GetParameters(); + _parameterMarshalers = new Func, AIFunctionContext?, object?>[parameters.Length]; + bool sawAIContextParameter = false; + for (int i = 0; i < parameters.Length; i++) + { + if (GetParameterMarshaler(options.SerializerOptions, parameters[i], ref sawAIContextParameter, out _parameterMarshalers[i]) is AIFunctionParameterMetadata parameterView) + { + parameterMetadata?.Add(parameterView); + } + } + + _needsAIFunctionContext = sawAIContextParameter; + + // Get the return type and a marshaling func for the return value. + Type returnType = GetReturnMarshaler(method, out _returnMarshaler); + _returnTypeInfo = returnType != typeof(void) ? options.SerializerOptions.GetTypeInfo(returnType) : null; + + Metadata = new AIFunctionMetadata(functionName) + { + Description = options.Description ?? method.GetCustomAttribute(inherit: true)?.Description ?? string.Empty, + Parameters = options.Parameters ?? parameterMetadata!, + ReturnParameter = options.ReturnParameter ?? new() + { + ParameterType = returnType, + Description = method.ReturnParameter.GetCustomAttribute(inherit: true)?.Description, + Schema = FunctionCallHelpers.InferReturnParameterJsonSchema(returnType, options.SerializerOptions), + }, + AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary.Instance, + JsonSerializerOptions = options.SerializerOptions, + }; + } + + /// + public override AIFunctionMetadata Metadata { get; } + + /// + protected override async Task InvokeCoreAsync( + IEnumerable>? arguments, + CancellationToken cancellationToken) + { + var paramMarshalers = _parameterMarshalers; + object?[] args = paramMarshalers.Length != 0 ? new object?[paramMarshalers.Length] : []; + + IReadOnlyDictionary argDict = + arguments is null || args.Length == 0 ? EmptyReadOnlyDictionary.Instance : + arguments as IReadOnlyDictionary ?? + arguments. +#if NET8_0_OR_GREATER + ToDictionary(); +#else + ToDictionary(kvp => kvp.Key, kvp => kvp.Value); +#endif + AIFunctionContext? context = _needsAIFunctionContext ? + new() { CancellationToken = cancellationToken } : + null; + + for (int i = 0; i < args.Length; i++) + { + args[i] = paramMarshalers[i](argDict, context); + } + + object? result = await _returnMarshaler(ReflectionInvoke(_method, _target, args)).ConfigureAwait(false); + + switch (_returnTypeInfo) + { + case null: + Debug.Assert(Metadata.ReturnParameter.ParameterType == typeof(void), "The return parameter is not void."); + return null; + + case { Kind: JsonTypeInfoKind.None }: + // Special-case trivial contracts to avoid the more expensive general-purpose serialization path. + return JsonSerializer.SerializeToElement(result, _returnTypeInfo); + + default: + { + // Serialize asynchronously to support potential IAsyncEnumerable responses. + using MemoryStream stream = new(); + await JsonSerializer.SerializeAsync(stream, result, _returnTypeInfo, cancellationToken).ConfigureAwait(false); + Utf8JsonReader reader = new(stream.GetBuffer().AsSpan(0, (int)stream.Length)); + return JsonElement.ParseValue(ref reader); + } + } + } + + /// + /// Gets a delegate for handling the marshaling of a parameter. + /// + private static AIFunctionParameterMetadata? GetParameterMarshaler( + JsonSerializerOptions options, + ParameterInfo parameter, + ref bool sawAIFunctionContext, + out Func, AIFunctionContext?, object?> marshaler) + { + if (string.IsNullOrWhiteSpace(parameter.Name)) + { + Throw.ArgumentException(nameof(parameter), "Parameter is missing a name."); + } + + // Special-case an AIFunctionContext parameter. + if (parameter.ParameterType == typeof(AIFunctionContext)) + { + if (sawAIFunctionContext) + { + Throw.ArgumentException(nameof(parameter), $"Only one {nameof(AIFunctionContext)} parameter is permitted."); + } + + sawAIFunctionContext = true; + + marshaler = static (_, ctx) => + { + Debug.Assert(ctx is not null, "Expected a non-null context object."); + return ctx; + }; + return null; + } + + // Resolve the contract used to marshall the value from JSON -- can throw if not supported or not found. + Type parameterType = parameter.ParameterType; + JsonTypeInfo typeInfo = options.GetTypeInfo(parameterType); + + // Create a marshaler that simply looks up the parameter by name in the arguments dictionary. + marshaler = (IReadOnlyDictionary arguments, AIFunctionContext? _) => + { + // If the parameter has an argument specified in the dictionary, return that argument. + if (arguments.TryGetValue(parameter.Name, out object? value)) + { + return value switch + { + null => null, // Return as-is if null -- if the parameter is a struct this will be handled by MethodInfo.Invoke + _ when parameterType.IsInstanceOfType(value) => value, // Do nothing if value is assignable to parameter type + JsonElement element => JsonSerializer.Deserialize(element, typeInfo), + JsonDocument doc => JsonSerializer.Deserialize(doc, typeInfo), + JsonNode node => JsonSerializer.Deserialize(node, typeInfo), + _ => MarshallViaJsonRoundtrip(value), + }; + + object? MarshallViaJsonRoundtrip(object value) + { +#pragma warning disable CA1031 // Do not catch general exception types + try + { + string json = JsonSerializer.Serialize(value, options.GetTypeInfo(value.GetType())); + return JsonSerializer.Deserialize(json, typeInfo); + } + catch + { + // Eat any exceptions and fall back to the original value to force a cast exception later on. + return value; + } +#pragma warning restore CA1031 // Do not catch general exception types + } + } + + // There was no argument for the parameter. Try to use a default value. + if (parameter.HasDefaultValue) + { + return parameter.DefaultValue; + } + + // No default either. Leave it empty. + return null; + }; + + string? description = parameter.GetCustomAttribute(inherit: true)?.Description; + return new AIFunctionParameterMetadata(parameter.Name) + { + Description = description, + HasDefaultValue = parameter.HasDefaultValue, + DefaultValue = parameter.HasDefaultValue ? parameter.DefaultValue : null, + IsRequired = !parameter.IsOptional, + ParameterType = parameter.ParameterType, + Schema = FunctionCallHelpers.InferParameterJsonSchema( + parameter.ParameterType, + parameter.Name, + description, + parameter.HasDefaultValue, + parameter.DefaultValue, + options) + }; + } + + /// + /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. + /// + private static Type GetReturnMarshaler(MethodInfo method, out Func> marshaler) + { + // Handle each known return type for the method + Type returnType = method.ReturnType; + + // Task + if (returnType == typeof(Task)) + { + marshaler = async static result => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + return typeof(void); + } + + // ValueTask + if (returnType == typeof(ValueTask)) + { + marshaler = async static result => + { + await ((ValueTask)ThrowIfNullResult(result)).ConfigureAwait(false); + return null; + }; + return typeof(void); + } + + if (returnType.IsGenericType) + { + // Task + if (returnType.GetGenericTypeDefinition() == typeof(Task<>) && + returnType.GetProperty(nameof(Task.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo taskResultGetter) + { + marshaler = async result => + { + await ((Task)ThrowIfNullResult(result)).ConfigureAwait(false); + return ReflectionInvoke(taskResultGetter, result, null); + }; + return taskResultGetter.ReturnType; + } + + // ValueTask + if (returnType.GetGenericTypeDefinition() == typeof(ValueTask<>) && + returnType.GetMethod(nameof(ValueTask.AsTask), BindingFlags.Public | BindingFlags.Instance) is MethodInfo valueTaskAsTask && + valueTaskAsTask.ReturnType.GetProperty(nameof(ValueTask.Result), BindingFlags.Public | BindingFlags.Instance)?.GetGetMethod() is MethodInfo asTaskResultGetter) + { + marshaler = async result => + { + var task = (Task)ReflectionInvoke(valueTaskAsTask, ThrowIfNullResult(result), null)!; + await task.ConfigureAwait(false); + return ReflectionInvoke(asTaskResultGetter, task, null); + }; + return asTaskResultGetter.ReturnType; + } + } + + // For everything else, just use the result as-is. + marshaler = result => new ValueTask(result); + return returnType; + + // Throws an exception if a result is found to be null unexpectedly + static object ThrowIfNullResult(object? result) => result ?? throw new InvalidOperationException("Function returned null unexpectedly."); + } + + /// Invokes the MethodInfo with the specified target object and arguments. + private static object? ReflectionInvoke(MethodInfo method, object? target, object?[]? arguments) + { +#if NET + return method.Invoke(target, BindingFlags.DoNotWrapExceptions, binder: null, arguments, culture: null); +#else + try + { + return method.Invoke(target, BindingFlags.Default, binder: null, arguments, culture: null); + } + catch (TargetInvocationException e) when (e.InnerException is not null) + { + // If we're targeting .NET Framework, such that BindingFlags.DoNotWrapExceptions + // is ignored, the original exception will be wrapped in a TargetInvocationException. + // Unwrap it and throw that original exception, maintaining its stack information. + System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture(e.InnerException).Throw(); + return null; + } +#endif + } + + /// + /// Remove characters from method name that are valid in metadata but shouldn't be used in a method name. + /// This is primarily intended to remove characters emitted by for compiler-generated method name mangling. + /// + private static string SanitizeMetadataName(string methodName) => + InvalidNameCharsRegex().Replace(methodName, "_"); + + /// Regex that flags any character other than ASCII digits or letters or the underscore. +#if NET + [GeneratedRegex("[^0-9A-Za-z_]")] + private static partial Regex InvalidNameCharsRegex(); +#else + private static Regex InvalidNameCharsRegex() => _invalidNameCharsRegex; + private static readonly Regex _invalidNameCharsRegex = new("[^0-9A-Za-z_]", RegexOptions.Compiled); +#endif + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs new file mode 100644 index 00000000000..8e0db9b4813 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryCreateOptions.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Extensions.AI; + +/// +/// Options that can be provided when creating an from a method. +/// +public sealed class AIFunctionFactoryCreateOptions +{ + /// + /// Initializes a new instance of the class with default serializer options. + /// + [RequiresUnreferencedCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)] + [RequiresDynamicCode(AIFunctionFactory.UsesReflectionJsonSerializerMessage)] + public AIFunctionFactoryCreateOptions() + : this(JsonSerializerOptions.Default) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The JSON serialization options used to marshal .NET types. + public AIFunctionFactoryCreateOptions(JsonSerializerOptions serializerOptions) + { + SerializerOptions = Throw.IfNull(serializerOptions); + } + + /// Gets the used to marshal .NET values being passed to the underlying delegate. + public JsonSerializerOptions SerializerOptions { get; } + + /// Gets or sets the name to use for the function. + /// + /// If , it will default to one derived from the method represented by the passed or . + /// + public string? Name { get; set; } + + /// Gets or sets the description to use for the function. + /// + /// If , it will default to one derived from the passed or , if possible + /// (e.g. via a on the method). + /// + public string? Description { get; set; } + + /// Gets or sets metadata for the parameters of the function. + /// + /// If , it will default to metadata derived from the passed or . + /// + public IReadOnlyList? Parameters { get; set; } + + /// Gets or sets metadata for function's return parameter. + /// + /// If , it will default to one derived from the passed or . + /// + public AIFunctionReturnParameterMetadata? ReturnParameter { get; set; } + + /// + /// Gets or sets additional values that will be stored on the resulting property. + /// + /// + /// This can be used to provide arbitrary information about the function. + /// + public IReadOnlyDictionary? AdditionalProperties { get; set; } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs new file mode 100644 index 00000000000..71edc9404b6 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Text.Json.Serialization.Metadata; + +namespace Microsoft.Extensions.AI; + +/// Provides cached options around JSON serialization to be used by the project. +internal static partial class JsonDefaults +{ + /// Gets the singleton to use for serialization-related operations. + public static JsonSerializerOptions Options { get; } = CreateDefaultOptions(); + + /// Creates the default to use for serialization-related operations. + private static JsonSerializerOptions CreateDefaultOptions() + { + // If reflection-based serialization is enabled by default, use it, as it's the most permissive in terms of what it can serialize, + // and we want to be flexible in terms of what can be put into the various collections in the object model. + // Otherwise, use the source-generated options to enable Native AOT. + + if (JsonSerializer.IsReflectionEnabledByDefault) + { + // Keep in sync with the JsonSourceGenerationOptions on JsonContext below. + var options = new JsonSerializerOptions(JsonSerializerDefaults.Web) + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, +#pragma warning disable IL3050 + TypeInfoResolver = new DefaultJsonTypeInfoResolver(), +#pragma warning restore IL3050 + }; + + options.MakeReadOnly(); + return options; + } + else + { + return JsonContext.Default.Options; + } + } + + // Keep in sync with CreateDefaultOptions above. + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] + [JsonSerializable(typeof(IList))] + [JsonSerializable(typeof(ChatOptions))] + [JsonSerializable(typeof(EmbeddingGenerationOptions))] + [JsonSerializable(typeof(ChatClientMetadata))] + [JsonSerializable(typeof(EmbeddingGeneratorMetadata))] + [JsonSerializable(typeof(ChatCompletion))] + [JsonSerializable(typeof(StreamingChatCompletionUpdate))] + [JsonSerializable(typeof(IReadOnlyList))] + [JsonSerializable(typeof(Dictionary))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(IDictionary))] + [JsonSerializable(typeof(JsonElement))] + [JsonSerializable(typeof(IEnumerable))] + [JsonSerializable(typeof(string))] + [JsonSerializable(typeof(int))] + [JsonSerializable(typeof(long))] + [JsonSerializable(typeof(float))] + [JsonSerializable(typeof(double))] + [JsonSerializable(typeof(bool))] + [JsonSerializable(typeof(TimeSpan))] + [JsonSerializable(typeof(DateTimeOffset))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] +#if NET + [JsonSerializable(typeof(Embedding))] +#endif + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(Embedding))] + [JsonSerializable(typeof(AIContent))] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj new file mode 100644 index 00000000000..8e389b61652 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -0,0 +1,42 @@ + + + + Microsoft.Extensions.AI + Utilities for working with generative AI components. + AI + + + + preview + 0 + 0 + + + + $(TargetFrameworks);netstandard2.0 + $(NoWarn);CA2227;CA1034;SA1316;S1067;S1121;S1994;S3253 + true + + + + true + true + + + + + + + + + + + + + + + + + + + diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.json new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs b/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs new file mode 100644 index 00000000000..31e61101a13 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/OpenTelemetryConsts.cs @@ -0,0 +1,90 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +#pragma warning disable S3218 // Inner class members should not shadow outer class "static" or type members +#pragma warning disable CA1716 // Identifiers should not match keywords +#pragma warning disable S4041 // Type names should not match namespaces + +/// Provides constants used by various telemetry services. +internal static class OpenTelemetryConsts +{ + public const string DefaultSourceName = "Experimental.Microsoft.Extensions.AI"; + + public const string SecondsUnit = "s"; + public const string TokensUnit = "token"; + + public static class Error + { + public const string Type = "error.type"; + } + + public static class GenAI + { + public const string Completion = "gen_ai.completion"; + public const string Prompt = "gen_ai.prompt"; + public const string System = "gen_ai.system"; + + public static class Client + { + public static class OperationDuration + { + public const string Description = "Measures the duration of a GenAI operation"; + public const string Name = "gen_ai.client.operation.duration"; + public static readonly double[] ExplicitBucketBoundaries = [0.01, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64, 1.28, 2.56, 5.12, 10.24, 20.48, 40.96, 81.92]; + } + + public static class TokenUsage + { + public const string Description = "Measures number of input and output tokens used"; + public const string Name = "gen_ai.client.token.usage"; + public static readonly int[] ExplicitBucketBoundaries = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384, 65_536, 262_144, 1_048_576, 4_194_304, 16_777_216, 67_108_864]; + } + } + + public static class Content + { + public const string Completion = "gen_ai.content.completion"; + public const string Prompt = "gen_ai.content.prompt"; + } + + public static class Operation + { + public const string Name = "gen_ai.operation.name"; + } + + public static class Request + { + public const string EmbeddingDimensions = "gen_ai.request.embedding.dimensions"; + public const string FrequencyPenalty = "gen_ai.request.frequency_penalty"; + public const string Model = "gen_ai.request.model"; + public const string MaxTokens = "gen_ai.request.max_tokens"; + public const string PresencePenalty = "gen_ai.request.presence_penalty"; + public const string StopSequences = "gen_ai.request.stop_sequences"; + public const string Temperature = "gen_ai.request.temperature"; + public const string TopK = "gen_ai.request.top_k"; + public const string TopP = "gen_ai.request.top_p"; + } + + public static class Response + { + public const string FinishReasons = "gen_ai.response.finish_reasons"; + public const string Id = "gen_ai.response.id"; + public const string InputTokens = "gen_ai.response.input_tokens"; + public const string Model = "gen_ai.response.model"; + public const string OutputTokens = "gen_ai.response.output_tokens"; + } + + public static class Token + { + public const string Type = "gen_ai.token.type"; + } + } + + public static class Server + { + public const string Address = "server.address"; + public const string Port = "server.port"; + } +} diff --git a/src/Libraries/Microsoft.Extensions.AI/README.md b/src/Libraries/Microsoft.Extensions.AI/README.md new file mode 100644 index 00000000000..ef092749200 --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.AI/README.md @@ -0,0 +1,27 @@ +# Microsoft.Extensions.AI + +Provides utilities for working with generative AI components. + +## Install the package + +From the command-line: + +```console +dotnet add package Microsoft.Extensions.AI +``` + +Or directly in the C# project file: + +```xml + + + +``` + +## Usage Examples + +Please refer to the [README](https://www.nuget.org/packages/Microsoft.Extensions.AI.Abstractions/#readme-body-tab) for the [Microsoft.Extensions.AI.Abstractions](https://www.nuget.org/packages/Microsoft.Extensions.AI.Abstractions) package. + +## Feedback & Contributing + +We welcome feedback and contributions in [our GitHub repo](https://github.com/dotnet/extensions). diff --git a/src/Shared/CollectionExtensions/CollectionExtensions.cs b/src/Shared/CollectionExtensions/CollectionExtensions.cs new file mode 100644 index 00000000000..33196e6e771 --- /dev/null +++ b/src/Shared/CollectionExtensions/CollectionExtensions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; + +#pragma warning disable S108 // Nested blocks of code should not be left empty +#pragma warning disable S1067 // Expressions should not be too complex +#pragma warning disable SA1501 // Statement should not be on a single line + +#pragma warning disable CA1716 +namespace Microsoft.Shared.Collections; +#pragma warning restore CA1716 + +/// +/// Utilities to augment the basic collection types. +/// +#if !SHARED_PROJECT +[ExcludeFromCodeCoverage] +#endif + +internal static class CollectionExtensions +{ + /// Attempts to extract a typed value from the dictionary. + /// The dictionary to query. + /// The key to locate. + /// The value retrieved from the dictionary, if found; otherwise, default. + /// True if the value was found and converted to the requested type; otherwise, false. + /// + /// If a value is found for the key in the dictionary, but the value is not of the requested type but is + /// an object, the method will attempt to convert the object to the requested type. + /// is employed because these methods are primarily intended for use with primitives. + /// + public static bool TryGetConvertedValue(this IReadOnlyDictionary? input, string key, [NotNullWhen(true)] out T? value) + { + object? valueObject = null; + _ = input?.TryGetValue(key, out valueObject); + return TryConvertValue(valueObject, out value); + } + + private static bool TryConvertValue(object? obj, [NotNullWhen(true)] out T? value) + { + switch (obj) + { + case T t: + // The object is already of the requested type. Return it. + value = t; + return true; + + case IConvertible: + // The object is convertible; try to convert it to the requested type. Unfortunately, there's no + // convenient way to do this that avoids exceptions and that doesn't involve a ton of boilerplate, + // so we only try when the source object is at least an IConvertible, which is what ChangeType uses. + try + { + value = (T)Convert.ChangeType(obj, typeof(T), CultureInfo.InvariantCulture); + return true; + } + catch (ArgumentException) { } + catch (InvalidCastException) { } + catch (FormatException) { } + catch (OverflowException) { } + break; + } + + // Unable to convert the object to the requested type. Fail. + value = default; + return false; + } +} diff --git a/src/Shared/CollectionExtensions/README.md b/src/Shared/CollectionExtensions/README.md new file mode 100644 index 00000000000..a732b7c36d4 --- /dev/null +++ b/src/Shared/CollectionExtensions/README.md @@ -0,0 +1,11 @@ +# Collection Extensions + +`TryGetTypedValue` performs a ``TryGetValue` on a dictionary and then attempts to cast the value to the specified type. If the value is not of the specified type, false is returned. + +To use this in your project, add the following to your `.csproj` file: + +```xml + + true + +``` diff --git a/src/Shared/NumericExtensions/README.md b/src/Shared/NumericExtensions/README.md index bcb2d9a7cba..c93835acd3b 100644 --- a/src/Shared/NumericExtensions/README.md +++ b/src/Shared/NumericExtensions/README.md @@ -6,6 +6,6 @@ To use this in your project, add the following to your `.csproj` file: ```xml - true + true ``` diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs new file mode 100644 index 00000000000..e71b2f431e8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AdditionalPropertiesDictionaryTests.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AdditionalPropertiesDictionaryTests +{ + [Fact] + public void Constructor_Roundtrips() + { + AdditionalPropertiesDictionary d = new(); + Assert.Empty(d); + + d = new(new Dictionary { ["key1"] = "value1" }); + Assert.Single(d); + + d = new((IEnumerable>)new Dictionary { ["key1"] = "value1", ["key2"] = "value2" }); + Assert.Equal(2, d.Count); + } + + [Fact] + public void Comparer_OrdinalIgnoreCase() + { + AdditionalPropertiesDictionary d = new() + { + ["key1"] = "value1", + ["KEY1"] = "value2", + ["key2"] = "value3", + ["key3"] = "value4", + ["KeY3"] = "value5", + }; + + Assert.Equal(3, d.Count); + + Assert.Equal("value2", d["key1"]); + Assert.Equal("value2", d["kEY1"]); + + Assert.Equal("value3", d["key2"]); + Assert.Equal("value3", d["KEY2"]); + + Assert.Equal("value5", d["Key3"]); + Assert.Equal("value5", d["KEy3"]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs new file mode 100644 index 00000000000..2c54a6f0865 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AssertExtensions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using Xunit; +using Xunit.Sdk; + +namespace Microsoft.Extensions.AI; + +internal static class AssertExtensions +{ + /// + /// Asserts that the two function call parameters are equal, up to JSON equivalence. + /// + public static void EqualFunctionCallParameters( + IDictionary? expected, + IDictionary? actual, + JsonSerializerOptions? options = null) + { + if (expected is null || actual is null) + { + Assert.Equal(expected, actual); + return; + } + + foreach (var expectedEntry in expected) + { + if (!actual.TryGetValue(expectedEntry.Key, out object? actualValue)) + { + throw new XunitException($"Expected parameter '{expectedEntry.Key}' not found in actual value."); + } + + AreJsonEquivalentValues(expectedEntry.Value, actualValue, options, propertyName: expectedEntry.Key); + } + + if (expected.Count != actual.Count) + { + var extraParameters = actual + .Where(e => !expected.ContainsKey(e.Key)) + .Select(e => $"'{e.Key}'") + .First(); + + throw new XunitException($"Actual value contains additional parameters {string.Join(", ", extraParameters)} not found in expected value."); + } + } + + /// + /// Asserts that the two function call results are equal, up to JSON equivalence. + /// + public static void EqualFunctionCallResults(object? expected, object? actual, JsonSerializerOptions? options = null) + => AreJsonEquivalentValues(expected, actual, options); + + private static void AreJsonEquivalentValues(object? expected, object? actual, JsonSerializerOptions? options, string? propertyName = null) + { + options ??= JsonSerializerOptions.Default; + JsonElement expectedElement = NormalizeToElement(expected, options); + JsonElement actualElement = NormalizeToElement(actual, options); + if (!JsonElement.DeepEquals(expectedElement, actualElement)) + { + string message = propertyName is null + ? $"Function result does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}" + : $"Parameter '{propertyName}' does not match expected JSON.\r\nExpected: {expectedElement.GetRawText()}\r\nActual: {actualElement.GetRawText()}"; + + throw new XunitException(message); + } + + static JsonElement NormalizeToElement(object? value, JsonSerializerOptions options) + => value is JsonElement e ? e : JsonSerializer.SerializeToElement(value, options); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs new file mode 100644 index 00000000000..274021988e1 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/CapturingLogger.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; + +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CapturingLogger : ILogger +{ + private readonly Stack _scopes = new(); + private readonly List _entries = []; + private readonly LogLevel _enabledLevel; + + public CapturingLogger(LogLevel enabledLevel = LogLevel.Trace) + { + _enabledLevel = enabledLevel; + } + + public IReadOnlyList Entries => _entries; + + public IDisposable? BeginScope(TState state) + where TState : notnull + { + var scope = new LoggerScope(this); + _scopes.Push(scope); + return scope; + } + + public bool IsEnabled(LogLevel logLevel) => logLevel >= _enabledLevel; + + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + var message = formatter(state, exception); + lock (_entries) + { + _entries.Add(new LogEntry(logLevel, eventId, state, exception, message)); + } + } + + private sealed class LoggerScope(CapturingLogger owner) : IDisposable + { + public void Dispose() => owner.EndScope(this); + } + + private void EndScope(LoggerScope loggerScope) + { + if (_scopes.Peek() != loggerScope) + { + throw new InvalidOperationException("Logger scopes out of order"); + } + + _scopes.Pop(); + } + + public record LogEntry(LogLevel Level, EventId EventId, object? State, Exception? Exception, string Message); +} + +internal sealed class CapturingLoggerProvider : ILoggerProvider +{ + public CapturingLogger Logger { get; } = new(); + + public ILogger CreateLogger(string categoryName) => Logger; + + void IDisposable.Dispose() + { + // nop + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs new file mode 100644 index 00000000000..68f5ad12245 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientExtensionsTests.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientExtensionsTests +{ + [Fact] + public void CompleteAsync_InvalidArgs_Throws() + { + Assert.Throws("client", () => + { + _ = ChatClientExtensions.CompleteAsync(null!, "hello"); + }); + + Assert.Throws("chatMessage", () => + { + _ = ChatClientExtensions.CompleteAsync(new TestChatClient(), null!); + }); + } + + [Fact] + public void CompleteStreamingAsync_InvalidArgs_Throws() + { + Assert.Throws("client", () => + { + _ = ChatClientExtensions.CompleteStreamingAsync(null!, "hello"); + }); + + Assert.Throws("chatMessage", () => + { + _ = ChatClientExtensions.CompleteStreamingAsync(new TestChatClient(), null!); + }); + } + + [Fact] + public async Task CompleteAsync_CreatesTextMessageAsync() + { + var expectedResponse = new ChatCompletion([new ChatMessage()]); + var expectedOptions = new ChatOptions(); + using var cts = new CancellationTokenSource(); + + using TestChatClient client = new() + { + CompleteAsyncCallback = (chatMessages, options, cancellationToken) => + { + ChatMessage m = Assert.Single(chatMessages); + Assert.Equal(ChatRole.User, m.Role); + Assert.Equal("hello", m.Text); + + Assert.Same(expectedOptions, options); + + Assert.Equal(cts.Token, cancellationToken); + + return Task.FromResult(expectedResponse); + }, + }; + + ChatCompletion response = await client.CompleteAsync("hello", expectedOptions, cts.Token); + + Assert.Same(expectedResponse, response); + } + + [Fact] + public async Task CompleteStreamingAsync_CreatesTextMessageAsync() + { + var expectedOptions = new ChatOptions(); + using var cts = new CancellationTokenSource(); + + using TestChatClient client = new() + { + CompleteStreamingAsyncCallback = (chatMessages, options, cancellationToken) => + { + ChatMessage m = Assert.Single(chatMessages); + Assert.Equal(ChatRole.User, m.Role); + Assert.Equal("hello", m.Text); + + Assert.Same(expectedOptions, options); + + Assert.Equal(cts.Token, cancellationToken); + + return YieldAsync([new StreamingChatCompletionUpdate { Text = "world" }]); + }, + }; + + int count = 0; + await foreach (var update in client.CompleteStreamingAsync("hello", expectedOptions, cts.Token)) + { + Assert.Equal(0, count); + Assert.Equal("world", update.Text); + count++; + } + + Assert.Equal(1, count); + } + + private static async IAsyncEnumerable YieldAsync(params StreamingChatCompletionUpdate[] updates) + { + await Task.Yield(); + foreach (var update in updates) + { + yield return update; + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs new file mode 100644 index 00000000000..43e24e61f8e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatClientMetadataTests.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientMetadataTests +{ + [Fact] + public void Constructor_NullValues_AllowedAndRoundtrip() + { + ChatClientMetadata metadata = new(null, null, null); + Assert.Null(metadata.ProviderName); + Assert.Null(metadata.ProviderUri); + Assert.Null(metadata.ModelId); + } + + [Fact] + public void Constructor_Value_Roundtrips() + { + var uri = new Uri("https://example.com"); + ChatClientMetadata metadata = new("providerName", uri, "theModel"); + Assert.Equal("providerName", metadata.ProviderName); + Assert.Same(uri, metadata.ProviderUri); + Assert.Equal("theModel", metadata.ModelId); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs new file mode 100644 index 00000000000..a695e686f6e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatCompletionTests.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatCompletionTests +{ + [Fact] + public void Constructor_InvalidArgs_Throws() + { + Assert.Throws("message", () => new ChatCompletion((ChatMessage)null!)); + Assert.Throws("choices", () => new ChatCompletion((IList)null!)); + } + + [Fact] + public void Constructor_Message_Roundtrips() + { + ChatMessage message = new(); + + ChatCompletion completion = new(message); + Assert.Same(message, completion.Message); + Assert.Same(message, Assert.Single(completion.Choices)); + } + + [Fact] + public void Constructor_Choices_Roundtrips() + { + List messages = + [ + new ChatMessage(), + new ChatMessage(), + new ChatMessage(), + ]; + + ChatCompletion completion = new(messages); + Assert.Same(messages, completion.Choices); + Assert.Equal(3, messages.Count); + } + + [Fact] + public void Message_EmptyChoices_Throws() + { + ChatCompletion completion = new([]); + + Assert.Empty(completion.Choices); + Assert.Throws(() => completion.Message); + } + + [Fact] + public void Message_SingleChoice_Returned() + { + ChatMessage message = new(); + ChatCompletion completion = new([message]); + + Assert.Same(message, completion.Message); + Assert.Same(message, completion.Choices[0]); + } + + [Fact] + public void Message_MultipleChoices_ReturnsFirst() + { + ChatMessage first = new(); + ChatCompletion completion = new([ + first, + new ChatMessage(), + ]); + + Assert.Same(first, completion.Message); + Assert.Same(first, completion.Choices[0]); + } + + [Fact] + public void Choices_SetNull_Throws() + { + ChatCompletion completion = new([]); + Assert.Throws("value", () => completion.Choices = null!); + } + + [Fact] + public void Properties_Roundtrip() + { + ChatCompletion completion = new([]); + + Assert.Null(completion.CompletionId); + completion.CompletionId = "id"; + Assert.Equal("id", completion.CompletionId); + + Assert.Null(completion.ModelId); + completion.ModelId = "modelId"; + Assert.Equal("modelId", completion.ModelId); + + Assert.Null(completion.CreatedAt); + completion.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), completion.CreatedAt); + + Assert.Null(completion.FinishReason); + completion.FinishReason = ChatFinishReason.ContentFilter; + Assert.Equal(ChatFinishReason.ContentFilter, completion.FinishReason); + + Assert.Null(completion.Usage); + UsageDetails usage = new(); + completion.Usage = usage; + Assert.Same(usage, completion.Usage); + + Assert.Null(completion.RawRepresentation); + object raw = new(); + completion.RawRepresentation = raw; + Assert.Same(raw, completion.RawRepresentation); + + Assert.Null(completion.AdditionalProperties); + AdditionalPropertiesDictionary additionalProps = []; + completion.AdditionalProperties = additionalProps; + Assert.Same(additionalProps, completion.AdditionalProperties); + + List newChoices = [new ChatMessage(), new ChatMessage()]; + completion.Choices = newChoices; + Assert.Same(newChoices, completion.Choices); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatCompletion original = new( + [ + new ChatMessage(ChatRole.Assistant, "Choice1"), + new ChatMessage(ChatRole.Assistant, "Choice2"), + new ChatMessage(ChatRole.Assistant, "Choice3"), + new ChatMessage(ChatRole.Assistant, "Choice4"), + ]) + { + CompletionId = "id", + ModelId = "modelId", + CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + FinishReason = ChatFinishReason.ContentFilter, + Usage = new UsageDetails(), + RawRepresentation = new(), + AdditionalProperties = new() { ["key"] = "value" }, + }; + + string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.ChatCompletion); + + ChatCompletion? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatCompletion); + + Assert.NotNull(result); + Assert.Equal(4, result.Choices.Count); + + for (int i = 0; i < original.Choices.Count; i++) + { + Assert.Equal(ChatRole.Assistant, result.Choices[i].Role); + Assert.Equal($"Choice{i + 1}", result.Choices[i].Text); + } + + Assert.Equal("id", result.CompletionId); + Assert.Equal("modelId", result.ModelId); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); + Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); + Assert.NotNull(result.Usage); + + Assert.NotNull(result.AdditionalProperties); + Assert.Single(result.AdditionalProperties); + Assert.True(result.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs new file mode 100644 index 00000000000..0318a77b47b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatFinishReasonTests.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatFinishReasonTests +{ + [Fact] + public void Constructor_Value_Roundtrips() + { + Assert.Equal("abc", new ChatFinishReason("abc").Value); + } + + [Fact] + public void Constructor_NullOrWhiteSpace_Throws() + { + Assert.Throws(() => new ChatFinishReason(null!)); + Assert.Throws(() => new ChatFinishReason(" ")); + } + + [Fact] + public void Equality_UsesOrdinalIgnoreCaseComparison() + { + Assert.True(new ChatFinishReason("abc").Equals(new ChatFinishReason("ABC"))); + Assert.True(new ChatFinishReason("abc").Equals((object)new ChatFinishReason("ABC"))); + Assert.True(new ChatFinishReason("abc") == new ChatFinishReason("ABC")); + Assert.Equal(new ChatFinishReason("abc").GetHashCode(), new ChatFinishReason("ABC").GetHashCode()); + Assert.False(new ChatFinishReason("abc") != new ChatFinishReason("ABC")); + + Assert.False(new ChatFinishReason("abc").Equals(new ChatFinishReason("def"))); + Assert.False(new ChatFinishReason("abc").Equals((object)new ChatFinishReason("def"))); + Assert.False(new ChatFinishReason("abc").Equals(null)); + Assert.False(new ChatFinishReason("abc").Equals("abc")); + Assert.False(new ChatFinishReason("abc") == new ChatFinishReason("def")); + Assert.True(new ChatFinishReason("abc") != new ChatFinishReason("def")); + Assert.NotEqual(new ChatFinishReason("abc").GetHashCode(), new ChatFinishReason("def").GetHashCode()); // not guaranteed due to possible hash code collisions + } + + [Fact] + public void Singletons_UseKnownValues() + { + Assert.Equal("stop", ChatFinishReason.Stop.Value); + Assert.Equal("length", ChatFinishReason.Length.Value); + Assert.Equal("tool_calls", ChatFinishReason.ToolCalls.Value); + Assert.Equal("content_filter", ChatFinishReason.ContentFilter.Value); + } + + [Fact] + public void Value_NormalizesToStopped() + { + Assert.Equal("test", new ChatFinishReason("test").Value); + Assert.Equal("test", new ChatFinishReason("test").ToString()); + + Assert.Equal("TEST", new ChatFinishReason("TEST").Value); + Assert.Equal("TEST", new ChatFinishReason("TEST").ToString()); + + Assert.Equal("stop", default(ChatFinishReason).Value); + Assert.Equal("stop", default(ChatFinishReason).ToString()); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatFinishReason role = new("abc"); + string? json = JsonSerializer.Serialize(role, TestJsonSerializerContext.Default.ChatFinishReason); + Assert.Equal("\"abc\"", json); + + ChatFinishReason? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatFinishReason); + Assert.Equal(role, result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs new file mode 100644 index 00000000000..dbef5f4088b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatMessageTests.cs @@ -0,0 +1,382 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatMessageTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + ChatMessage message = new(); + Assert.Null(message.AuthorName); + Assert.Empty(message.Contents); + Assert.Equal(ChatRole.User, message.Role); + Assert.Null(message.Text); + Assert.NotNull(message.Contents); + Assert.Same(message.Contents, message.Contents); + Assert.Empty(message.Contents); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + Assert.Equal(string.Empty, message.ToString()); + } + + [Theory] + [InlineData(null)] + [InlineData("text")] + public void Constructor_RoleString_PropsRoundtrip(string? text) + { + ChatMessage message = new(ChatRole.Assistant, text); + + Assert.Equal(ChatRole.Assistant, message.Role); + + Assert.Same(message.Contents, message.Contents); + if (text is null) + { + Assert.Empty(message.Contents); + } + else + { + Assert.Single(message.Contents); + TextContent tc = Assert.IsType(message.Contents[0]); + Assert.Equal(text, tc.Text); + } + + Assert.Null(message.AuthorName); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + Assert.Equal(text ?? string.Empty, message.ToString()); + } + + [Fact] + public void Constructor_RoleList_InvalidArgs_Throws() + { + Assert.Throws("contents", () => new ChatMessage(ChatRole.User, (IList)null!)); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + public void Constructor_RoleList_PropsRoundtrip(int messageCount) + { + List content = []; + for (int i = 0; i < messageCount; i++) + { + content.Add(new TextContent($"text-{i}")); + } + + ChatMessage message = new(ChatRole.System, content); + + Assert.Equal(ChatRole.System, message.Role); + + Assert.Same(message.Contents, message.Contents); + if (messageCount == 0) + { + Assert.Empty(message.Contents); + Assert.Null(message.Text); + } + else + { + Assert.Equal(messageCount, message.Contents.Count); + for (int i = 0; i < messageCount; i++) + { + TextContent tc = Assert.IsType(message.Contents[i]); + Assert.Equal($"text-{i}", tc.Text); + } + + Assert.Equal("text-0", message.Text); + Assert.Equal("text-0", message.ToString()); + } + + Assert.Null(message.AuthorName); + Assert.Null(message.RawRepresentation); + Assert.Null(message.AdditionalProperties); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" \r\n\t\v ")] + public void AuthorName_InvalidArg_UsesNull(string? authorName) + { + ChatMessage message = new() + { + AuthorName = authorName + }; + Assert.Null(message.AuthorName); + + message.AuthorName = "author"; + Assert.Equal("author", message.AuthorName); + + message.AuthorName = authorName; + Assert.Null(message.AuthorName); + } + + [Fact] + public void Text_GetSet_UsesFirstTextContent() + { + ChatMessage message = new(ChatRole.User, + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new TextContent("text-1"), + new TextContent("text-2"), + new FunctionResultContent(new FunctionCallContent("callId1", "fc2"), "result"), + ]); + + TextContent textContent = Assert.IsType(message.Contents[3]); + Assert.Equal("text-1", textContent.Text); + Assert.Equal("text-1", message.Text); + Assert.Equal("text-1", message.ToString()); + + message.Text = "text-3"; + Assert.Equal("text-3", message.Text); + Assert.Equal("text-3", message.Text); + Assert.Same(textContent, message.Contents[3]); + Assert.Equal("text-3", message.ToString()); + } + + [Fact] + public void Text_Set_AddsTextMessageToEmptyList() + { + ChatMessage message = new(ChatRole.User, []); + Assert.Empty(message.Contents); + + message.Text = "text-1"; + Assert.Equal("text-1", message.Text); + + Assert.Single(message.Contents); + TextContent textContent = Assert.IsType(message.Contents[0]); + Assert.Equal("text-1", textContent.Text); + } + + [Fact] + public void Text_Set_AddsTextMessageToListWithNoText() + { + ChatMessage message = new(ChatRole.User, + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + ]); + Assert.Equal(3, message.Contents.Count); + + message.Text = "text-1"; + Assert.Equal("text-1", message.Text); + Assert.Equal(4, message.Contents.Count); + + message.Text = "text-2"; + Assert.Equal("text-2", message.Text); + Assert.Equal(4, message.Contents.Count); + + message.Contents.RemoveAt(3); + Assert.Equal(3, message.Contents.Count); + + message.Text = "text-3"; + Assert.Equal("text-3", message.Text); + Assert.Equal(4, message.Contents.Count); + } + + [Fact] + public void Contents_InitializesToList() + { + // This is an implementation detail, but if this test starts failing, we need to ensure + // tests are in place for whatever possibly-custom implementation of IList is being used. + Assert.IsType>(new ChatMessage().Contents); + } + + [Fact] + public void Contents_Roundtrips() + { + ChatMessage message = new(); + Assert.Empty(message.Contents); + + List contents = []; + message.Contents = contents; + + Assert.Same(contents, message.Contents); + + message.Contents = contents; + Assert.Same(contents, message.Contents); + + message.Contents = null; + Assert.NotNull(message.Contents); + Assert.NotSame(contents, message.Contents); + Assert.Empty(message.Contents); + } + + [Fact] + public void RawRepresentation_Roundtrips() + { + ChatMessage message = new(); + Assert.Null(message.RawRepresentation); + + object raw = new(); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + + message.RawRepresentation = null; + Assert.Null(message.RawRepresentation); + + message.RawRepresentation = raw; + Assert.Same(raw, message.RawRepresentation); + } + + [Fact] + public void AdditionalProperties_Roundtrips() + { + ChatMessage message = new(); + Assert.Null(message.RawRepresentation); + + AdditionalPropertiesDictionary props = []; + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + + message.AdditionalProperties = null; + Assert.Null(message.AdditionalProperties); + + message.AdditionalProperties = props; + Assert.Same(props, message.AdditionalProperties); + } + + [Fact] + public void ItCanBeSerializeAndDeserialized() + { + // Arrange + IList items = + [ + new TextContent("content-1") + { + ModelId = "model-1", + AdditionalProperties = new() { ["metadata-key-1"] = "metadata-value-1" } + }, + new ImageContent(new Uri("https://fake-random-test-host:123"), "mime-type/2") + { + ModelId = "model-2", + AdditionalProperties = new() { ["metadata-key-2"] = "metadata-value-2" } + }, + new DataContent(new BinaryData(new[] { 1, 2, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/3") + { + ModelId = "model-3", + AdditionalProperties = new() { ["metadata-key-3"] = "metadata-value-3" } + }, + new AudioContent(new BinaryData(new[] { 3, 2, 1 }, options: TestJsonSerializerContext.Default.Options), "mime-type/4") + { + ModelId = "model-4", + AdditionalProperties = new() { ["metadata-key-4"] = "metadata-value-4" } + }, + new ImageContent(new BinaryData(new[] { 2, 1, 3 }, options: TestJsonSerializerContext.Default.Options), "mime-type/5") + { + ModelId = "model-5", + AdditionalProperties = new() { ["metadata-key-5"] = "metadata-value-5" } + }, + new TextContent("content-6") + { + ModelId = "model-6", + AdditionalProperties = new() { ["metadata-key-6"] = "metadata-value-6" } + }, + new FunctionCallContent("function-id", "plugin-name-function-name", new Dictionary { ["parameter"] = "argument" }), + new FunctionResultContent(new FunctionCallContent("function-id", "plugin-name-function-name"), "function-result"), + ]; + + // Act + var chatMessageJson = JsonSerializer.Serialize(new ChatMessage(ChatRole.User, contents: items) + { + Text = "content-1-override", // Override the content of the first text content item that has the "content-1" content + AuthorName = "Fred", + AdditionalProperties = new() { ["message-metadata-key-1"] = "message-metadata-value-1" }, + }, TestJsonSerializerContext.Default.Options); + + var deserializedMessage = JsonSerializer.Deserialize(chatMessageJson, TestJsonSerializerContext.Default.Options)!; + + // Assert + Assert.Equal("Fred", deserializedMessage.AuthorName); + Assert.Equal("user", deserializedMessage.Role.Value); + Assert.NotNull(deserializedMessage.AdditionalProperties); + Assert.Single(deserializedMessage.AdditionalProperties); + Assert.Equal("message-metadata-value-1", deserializedMessage.AdditionalProperties["message-metadata-key-1"]?.ToString()); + + Assert.NotNull(deserializedMessage.Contents); + Assert.Equal(items.Count, deserializedMessage.Contents.Count); + + var textContent = deserializedMessage.Contents[0] as TextContent; + Assert.NotNull(textContent); + Assert.Equal("content-1-override", textContent.Text); + Assert.Equal("model-1", textContent.ModelId); + Assert.NotNull(textContent.AdditionalProperties); + Assert.Single(textContent.AdditionalProperties); + Assert.Equal("metadata-value-1", textContent.AdditionalProperties["metadata-key-1"]?.ToString()); + + var imageContent = deserializedMessage.Contents[1] as ImageContent; + Assert.NotNull(imageContent); + Assert.Equal("https://fake-random-test-host:123/", imageContent.Uri); + Assert.Equal("model-2", imageContent.ModelId); + Assert.Equal("mime-type/2", imageContent.MediaType); + Assert.NotNull(imageContent.AdditionalProperties); + Assert.Single(imageContent.AdditionalProperties); + Assert.Equal("metadata-value-2", imageContent.AdditionalProperties["metadata-key-2"]?.ToString()); + + var dataContent = deserializedMessage.Contents[2] as DataContent; + Assert.NotNull(dataContent); + Assert.True(dataContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 1, 2, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-3", dataContent.ModelId); + Assert.Equal("mime-type/3", dataContent.MediaType); + Assert.NotNull(dataContent.AdditionalProperties); + Assert.Single(dataContent.AdditionalProperties); + Assert.Equal("metadata-value-3", dataContent.AdditionalProperties["metadata-key-3"]?.ToString()); + + var audioContent = deserializedMessage.Contents[3] as AudioContent; + Assert.NotNull(audioContent); + Assert.True(audioContent.Data!.Value.Span.SequenceEqual(new BinaryData(new[] { 3, 2, 1 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-4", audioContent.ModelId); + Assert.Equal("mime-type/4", audioContent.MediaType); + Assert.NotNull(audioContent.AdditionalProperties); + Assert.Single(audioContent.AdditionalProperties); + Assert.Equal("metadata-value-4", audioContent.AdditionalProperties["metadata-key-4"]?.ToString()); + + imageContent = deserializedMessage.Contents[4] as ImageContent; + Assert.NotNull(imageContent); + Assert.True(imageContent.Data?.Span.SequenceEqual(new BinaryData(new[] { 2, 1, 3 }, TestJsonSerializerContext.Default.Options))); + Assert.Equal("model-5", imageContent.ModelId); + Assert.Equal("mime-type/5", imageContent.MediaType); + Assert.NotNull(imageContent.AdditionalProperties); + Assert.Single(imageContent.AdditionalProperties); + Assert.Equal("metadata-value-5", imageContent.AdditionalProperties["metadata-key-5"]?.ToString()); + + textContent = deserializedMessage.Contents[5] as TextContent; + Assert.NotNull(textContent); + Assert.Equal("content-6", textContent.Text); + Assert.Equal("model-6", textContent.ModelId); + Assert.NotNull(textContent.AdditionalProperties); + Assert.Single(textContent.AdditionalProperties); + Assert.Equal("metadata-value-6", textContent.AdditionalProperties["metadata-key-6"]?.ToString()); + + var functionCallContent = deserializedMessage.Contents[6] as FunctionCallContent; + Assert.NotNull(functionCallContent); + Assert.Equal("plugin-name-function-name", functionCallContent.Name); + Assert.Equal("function-id", functionCallContent.CallId); + Assert.NotNull(functionCallContent.Arguments); + Assert.Single(functionCallContent.Arguments); + Assert.Equal("argument", functionCallContent.Arguments["parameter"]?.ToString()); + + var functionResultContent = deserializedMessage.Contents[7] as FunctionResultContent; + Assert.NotNull(functionResultContent); + Assert.Equal("function-result", functionResultContent.Result?.ToString()); + Assert.Equal("function-id", functionResultContent.CallId); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs new file mode 100644 index 00000000000..2e769ff6d7e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatOptionsTests.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatOptionsTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + ChatOptions options = new(); + Assert.Null(options.Temperature); + Assert.Null(options.MaxOutputTokens); + Assert.Null(options.TopP); + Assert.Null(options.FrequencyPenalty); + Assert.Null(options.PresencePenalty); + Assert.Null(options.ResponseFormat); + Assert.Null(options.ModelId); + Assert.Null(options.StopSequences); + Assert.Same(ChatToolMode.Auto, options.ToolMode); + Assert.Null(options.Tools); + Assert.Null(options.AdditionalProperties); + + ChatOptions clone = options.Clone(); + Assert.Null(clone.Temperature); + Assert.Null(clone.MaxOutputTokens); + Assert.Null(clone.TopP); + Assert.Null(clone.FrequencyPenalty); + Assert.Null(clone.PresencePenalty); + Assert.Null(clone.ResponseFormat); + Assert.Null(clone.ModelId); + Assert.Null(clone.StopSequences); + Assert.Same(ChatToolMode.Auto, clone.ToolMode); + Assert.Null(clone.Tools); + Assert.Null(clone.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrip() + { + ChatOptions options = new(); + + List stopSequences = + [ + "stop1", + "stop2", + ]; + + List tools = + [ + AIFunctionFactory.Create(() => 42), + AIFunctionFactory.Create(() => 43), + ]; + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.Temperature = 0.1f; + options.MaxOutputTokens = 2; + options.TopP = 0.3f; + options.FrequencyPenalty = 0.4f; + options.PresencePenalty = 0.5f; + options.ResponseFormat = ChatResponseFormat.Json; + options.ModelId = "modelId"; + options.StopSequences = stopSequences; + options.ToolMode = ChatToolMode.RequireAny; + options.Tools = tools; + options.AdditionalProperties = additionalProps; + + Assert.Equal(0.1f, options.Temperature); + Assert.Equal(2, options.MaxOutputTokens); + Assert.Equal(0.3f, options.TopP); + Assert.Equal(0.4f, options.FrequencyPenalty); + Assert.Equal(0.5f, options.PresencePenalty); + Assert.Same(ChatResponseFormat.Json, options.ResponseFormat); + Assert.Equal("modelId", options.ModelId); + Assert.Same(stopSequences, options.StopSequences); + Assert.Same(ChatToolMode.RequireAny, options.ToolMode); + Assert.Same(tools, options.Tools); + Assert.Same(additionalProps, options.AdditionalProperties); + + ChatOptions clone = options.Clone(); + Assert.Equal(0.1f, clone.Temperature); + Assert.Equal(2, clone.MaxOutputTokens); + Assert.Equal(0.3f, clone.TopP); + Assert.Equal(0.4f, clone.FrequencyPenalty); + Assert.Equal(0.5f, clone.PresencePenalty); + Assert.Same(ChatResponseFormat.Json, clone.ResponseFormat); + Assert.Equal("modelId", clone.ModelId); + Assert.Equal(stopSequences, clone.StopSequences); + Assert.Same(ChatToolMode.RequireAny, clone.ToolMode); + Assert.Equal(tools, clone.Tools); + Assert.Equal(additionalProps, clone.AdditionalProperties); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatOptions options = new(); + + List stopSequences = + [ + "stop1", + "stop2", + ]; + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.Temperature = 0.1f; + options.MaxOutputTokens = 2; + options.TopP = 0.3f; + options.FrequencyPenalty = 0.4f; + options.PresencePenalty = 0.5f; + options.ResponseFormat = ChatResponseFormat.Json; + options.ModelId = "modelId"; + options.StopSequences = stopSequences; + options.ToolMode = ChatToolMode.RequireAny; + options.Tools = + [ + AIFunctionFactory.Create(() => 42), + AIFunctionFactory.Create(() => 43), + ]; + options.AdditionalProperties = additionalProps; + + string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.ChatOptions); + + ChatOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatOptions); + Assert.NotNull(deserialized); + + Assert.Equal(0.1f, deserialized.Temperature); + Assert.Equal(2, deserialized.MaxOutputTokens); + Assert.Equal(0.3f, deserialized.TopP); + Assert.Equal(0.4f, deserialized.FrequencyPenalty); + Assert.Equal(0.5f, deserialized.PresencePenalty); + Assert.Equal(ChatResponseFormat.Json, deserialized.ResponseFormat); + Assert.NotSame(ChatResponseFormat.Json, deserialized.ResponseFormat); + Assert.Equal("modelId", deserialized.ModelId); + Assert.NotSame(stopSequences, deserialized.StopSequences); + Assert.Equal(stopSequences, deserialized.StopSequences); + Assert.Equal(ChatToolMode.RequireAny, deserialized.ToolMode); + Assert.Null(deserialized.Tools); + + Assert.NotNull(deserialized.AdditionalProperties); + Assert.Single(deserialized.AdditionalProperties); + Assert.True(deserialized.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs new file mode 100644 index 00000000000..f4a63f34e05 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatResponseFormatTests.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatResponseFormatTests +{ + [Fact] + public void Singletons_Idempotent() + { + Assert.Same(ChatResponseFormat.Text, ChatResponseFormat.Text); + Assert.Same(ChatResponseFormat.Json, ChatResponseFormat.Json); + } + + [Fact] + public void Constructor_InvalidArgs_Throws() + { + Assert.Throws(() => new ChatResponseFormatJson(null, "name")); + Assert.Throws(() => new ChatResponseFormatJson(null, null, "description")); + Assert.Throws(() => new ChatResponseFormatJson(null, "name", "description")); + } + + [Fact] + public void Constructor_PropsDefaulted() + { + ChatResponseFormatJson f = new(null); + Assert.Null(f.Schema); + Assert.Null(f.SchemaName); + Assert.Null(f.SchemaDescription); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + ChatResponseFormatJson f = new("{}", "name", "description"); + Assert.Equal("{}", f.Schema); + Assert.Equal("name", f.SchemaName); + Assert.Equal("description", f.SchemaDescription); + } + + [Fact] + public void Equality_ComparersProduceExpectedResults() + { + Assert.True(ChatResponseFormat.Text == ChatResponseFormat.Text); + Assert.True(ChatResponseFormat.Text.Equals(ChatResponseFormat.Text)); + Assert.Equal(ChatResponseFormat.Text.GetHashCode(), ChatResponseFormat.Text.GetHashCode()); + Assert.False(ChatResponseFormat.Text.Equals(ChatResponseFormat.Json)); + Assert.False(ChatResponseFormat.Text.Equals(new ChatResponseFormatJson(null))); + Assert.False(ChatResponseFormat.Text.Equals(new ChatResponseFormatJson("{}"))); + + Assert.True(ChatResponseFormat.Json == ChatResponseFormat.Json); + Assert.True(ChatResponseFormat.Json.Equals(ChatResponseFormat.Json)); + Assert.False(ChatResponseFormat.Json.Equals(ChatResponseFormat.Text)); + Assert.False(ChatResponseFormat.Json.Equals(new ChatResponseFormatJson("{}"))); + + Assert.True(ChatResponseFormat.Json.Equals(new ChatResponseFormatJson(null))); + Assert.Equal(ChatResponseFormat.Json.GetHashCode(), new ChatResponseFormatJson(null).GetHashCode()); + + Assert.True(new ChatResponseFormatJson("{}").Equals(new ChatResponseFormatJson("{}"))); + Assert.Equal(new ChatResponseFormatJson("{}").GetHashCode(), new ChatResponseFormatJson("{}").GetHashCode()); + + Assert.False(new ChatResponseFormatJson("""{ "prop": 42 }""").Equals(new ChatResponseFormatJson("""{ "prop": 43 }"""))); + Assert.NotEqual(new ChatResponseFormatJson("""{ "prop": 42 }""").GetHashCode(), new ChatResponseFormatJson("""{ "prop": 43 }""").GetHashCode()); // technically not guaranteed + + Assert.False(new ChatResponseFormatJson("""{ "prop": 42 }""").Equals(new ChatResponseFormatJson("""{ "PROP": 42 }"""))); + Assert.NotEqual(new ChatResponseFormatJson("""{ "prop": 42 }""").GetHashCode(), new ChatResponseFormatJson("""{ "PROP": 42 }""").GetHashCode()); // technically not guaranteed + + Assert.True(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name", "description"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name", "description2"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name2", "description"))); + Assert.False(new ChatResponseFormatJson("{}", "name", "description").Equals(new ChatResponseFormatJson("{}", "name2", "description2"))); + + Assert.Equal(new ChatResponseFormatJson("{}", "name", "description").GetHashCode(), new ChatResponseFormatJson("{}", "name", "description").GetHashCode()); + } + + [Fact] + public void Serialization_TextRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.Text, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"text"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.Text, result); + } + + [Fact] + public void Serialization_JsonRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.Json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"json"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.Json, result); + } + + [Fact] + public void Serialization_ForJsonSchemaRoundtrips() + { + string json = JsonSerializer.Serialize(ChatResponseFormat.ForJsonSchema("[1,2,3]", "name", "description"), TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal("""{"$type":"json","schema":"[1,2,3]","schemaName":"name","schemaDescription":"description"}""", json); + + ChatResponseFormat? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatResponseFormat); + Assert.Equal(ChatResponseFormat.ForJsonSchema("[1,2,3]", "name", "description"), result); + Assert.Equal("[1,2,3]", (result as ChatResponseFormatJson)?.Schema); + Assert.Equal("name", (result as ChatResponseFormatJson)?.SchemaName); + Assert.Equal("description", (result as ChatResponseFormatJson)?.SchemaDescription); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs new file mode 100644 index 00000000000..7761aa2fdc3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatRoleTests.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatRoleTests +{ + [Fact] + public void Constructor_Value_Roundtrips() + { + Assert.Equal("abc", new ChatRole("abc").Value); + } + + [Fact] + public void Constructor_NullOrWhiteSpace_Throws() + { + Assert.Throws(() => new ChatRole(null!)); + Assert.Throws(() => new ChatRole(" ")); + } + + [Fact] + public void Equality_UsesOrdinalIgnoreCaseComparison() + { + Assert.True(new ChatRole("abc").Equals(new ChatRole("ABC"))); + Assert.True(new ChatRole("abc").Equals((object)new ChatRole("ABC"))); + Assert.True(new ChatRole("abc") == new ChatRole("ABC")); + Assert.False(new ChatRole("abc") != new ChatRole("ABC")); + + Assert.False(new ChatRole("abc").Equals(new ChatRole("def"))); + Assert.False(new ChatRole("abc").Equals((object)new ChatRole("def"))); + Assert.False(new ChatRole("abc").Equals(null)); + Assert.False(new ChatRole("abc").Equals("abc")); + Assert.False(new ChatRole("abc") == new ChatRole("def")); + Assert.True(new ChatRole("abc") != new ChatRole("def")); + + Assert.Equal(new ChatRole("abc").GetHashCode(), new ChatRole("abc").GetHashCode()); + Assert.Equal(new ChatRole("abc").GetHashCode(), new ChatRole("ABC").GetHashCode()); + Assert.NotEqual(new ChatRole("abc").GetHashCode(), new ChatRole("def").GetHashCode()); // not guaranteed + } + + [Fact] + public void Singletons_UseKnownValues() + { + Assert.Equal("assistant", ChatRole.Assistant.Value); + Assert.Equal("system", ChatRole.System.Value); + Assert.Equal("tool", ChatRole.Tool.Value); + Assert.Equal("user", ChatRole.User.Value); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + ChatRole role = new("abc"); + string? json = JsonSerializer.Serialize(role, TestJsonSerializerContext.Default.ChatRole); + Assert.Equal("\"abc\"", json); + + ChatRole? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatRole); + Assert.Equal(role, result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs new file mode 100644 index 00000000000..7cdda8ef975 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatToolModeTests +{ + [Fact] + public void Singletons_Idempotent() + { + Assert.Same(ChatToolMode.Auto, ChatToolMode.Auto); + Assert.Same(ChatToolMode.RequireAny, ChatToolMode.RequireAny); + } + + [Fact] + public void Equality_ComparersProduceExpectedResults() + { + Assert.True(ChatToolMode.Auto == ChatToolMode.Auto); + Assert.True(ChatToolMode.Auto.Equals(ChatToolMode.Auto)); + Assert.False(ChatToolMode.Auto.Equals(ChatToolMode.RequireAny)); + Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode(null))); + Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode("func"))); + Assert.Equal(ChatToolMode.Auto.GetHashCode(), ChatToolMode.Auto.GetHashCode()); + + Assert.True(ChatToolMode.RequireAny == ChatToolMode.RequireAny); + Assert.True(ChatToolMode.RequireAny.Equals(ChatToolMode.RequireAny)); + Assert.False(ChatToolMode.RequireAny.Equals(ChatToolMode.Auto)); + Assert.False(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode("func"))); + + Assert.True(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode(null))); + Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), new RequiredChatToolMode(null).GetHashCode()); + Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), ChatToolMode.RequireAny.GetHashCode()); + + Assert.True(new RequiredChatToolMode("func").Equals(new RequiredChatToolMode("func"))); + Assert.Equal(new RequiredChatToolMode("func").GetHashCode(), new RequiredChatToolMode("func").GetHashCode()); + + Assert.False(new RequiredChatToolMode("func1").Equals(new RequiredChatToolMode("func2"))); + Assert.NotEqual(new RequiredChatToolMode("func1").GetHashCode(), new RequiredChatToolMode("func2").GetHashCode()); // technically not guaranteed + + Assert.False(new RequiredChatToolMode("func1").Equals(new RequiredChatToolMode("FUNC1"))); + Assert.NotEqual(new RequiredChatToolMode("func1").GetHashCode(), new RequiredChatToolMode("FUNC1").GetHashCode()); // technically not guaranteed + } + + [Fact] + public void Serialization_AutoRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.Auto, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"auto"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.Auto, result); + } + + [Fact] + public void Serialization_RequireAnyRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.RequireAny, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"required"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.RequireAny, result); + } + + [Fact] + public void Serialization_RequireSpecificRoundtrips() + { + string json = JsonSerializer.Serialize(ChatToolMode.RequireSpecific("myFunc"), TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal("""{"$type":"required","requiredFunctionName":"myFunc"}""", json); + + ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); + Assert.Equal(ChatToolMode.RequireSpecific("myFunc"), result); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs new file mode 100644 index 00000000000..51c82c7dcb7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/DelegatingChatClientTests.cs @@ -0,0 +1,166 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingChatClientTests +{ + [Fact] + public void RequiresInnerChatClient() + { + Assert.Throws(() => new NoOpDelegatingChatClient(null!)); + } + + [Fact] + public void MetadataDefaultsToInnerClient() + { + using var inner = new TestChatClient(); + using var delegating = new NoOpDelegatingChatClient(inner); + + Assert.Same(inner.Metadata, delegating.Metadata); + } + + [Fact] + public async Task ChatAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedChatContents = new List(); + var expectedChatOptions = new ChatOptions(); + var expectedCancellationToken = CancellationToken.None; + var expectedResult = new TaskCompletionSource(); + var expectedCompletion = new ChatCompletion([]); + using var inner = new TestChatClient + { + CompleteAsyncCallback = (chatContents, options, cancellationToken) => + { + Assert.Same(expectedChatContents, chatContents); + Assert.Same(expectedChatOptions, options); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var resultTask = delegating.CompleteAsync(expectedChatContents, expectedChatOptions, expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedCompletion); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedCompletion, await resultTask); + } + + [Fact] + public async Task ChatStreamingAsyncDefaultsToInnerClientAsync() + { + // Arrange + var expectedChatContents = new List(); + var expectedChatOptions = new ChatOptions(); + var expectedCancellationToken = CancellationToken.None; + StreamingChatCompletionUpdate[] expectedResults = + [ + new() { Role = ChatRole.User, Text = "Message 1" }, + new() { Role = ChatRole.User, Text = "Message 2" } + ]; + + using var inner = new TestChatClient + { + CompleteStreamingAsyncCallback = (chatContents, options, cancellationToken) => + { + Assert.Same(expectedChatContents, chatContents); + Assert.Same(expectedChatOptions, options); + Assert.Equal(expectedCancellationToken, cancellationToken); + return YieldAsync(expectedResults); + } + }; + + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var resultAsyncEnumerable = delegating.CompleteStreamingAsync(expectedChatContents, expectedChatOptions, expectedCancellationToken); + + // Assert + var enumerator = resultAsyncEnumerable.GetAsyncEnumerator(); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedResults[0], enumerator.Current); + Assert.True(await enumerator.MoveNextAsync()); + Assert.Same(expectedResults[1], enumerator.Current); + Assert.False(await enumerator.MoveNextAsync()); + } + + [Fact] + public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() + { + // Arrange + using var inner = new TestChatClient(); + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var client = delegating.GetService(); + + // Assert + Assert.Same(delegating, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfKeyIsNotNull() + { + // Arrange + var expectedParam = new object(); + var expectedKey = new object(); + using var expectedResult = new TestChatClient(); + using var inner = new TestChatClient + { + GetServiceCallback = (_, _) => expectedResult + }; + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var client = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, client); + } + + [Fact] + public void GetServiceDelegatesToInnerIfNotCompatibleWithRequest() + { + // Arrange + var expectedParam = new object(); + var expectedResult = TimeZoneInfo.Local; + var expectedKey = new object(); + using var inner = new TestChatClient + { + GetServiceCallback = (type, key) => type == expectedResult.GetType() && key == expectedKey + ? expectedResult + : throw new InvalidOperationException("Unexpected call") + }; + using var delegating = new NoOpDelegatingChatClient(inner); + + // Act + var tzi = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, tzi); + } + + private static async IAsyncEnumerable YieldAsync(IEnumerable input) + { + await Task.Yield(); + foreach (var item in input) + { + yield return item; + } + } + + private sealed class NoOpDelegatingChatClient(IChatClient innerClient) + : DelegatingChatClient(innerClient); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs new file mode 100644 index 00000000000..988727b1159 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/StreamingChatCompletionUpdateTests.cs @@ -0,0 +1,220 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class StreamingChatCompletionUpdateTests +{ + [Fact] + public void Constructor_PropsDefaulted() + { + StreamingChatCompletionUpdate update = new(); + Assert.Null(update.AuthorName); + Assert.Null(update.Role); + Assert.Null(update.Text); + Assert.Empty(update.Contents); + Assert.Null(update.RawRepresentation); + Assert.Null(update.AdditionalProperties); + Assert.Null(update.CompletionId); + Assert.Null(update.CreatedAt); + Assert.Null(update.FinishReason); + Assert.Equal(0, update.ChoiceIndex); + Assert.Equal(string.Empty, update.ToString()); + } + + [Fact] + public void Properties_Roundtrip() + { + StreamingChatCompletionUpdate update = new(); + + Assert.Null(update.AuthorName); + update.AuthorName = "author"; + Assert.Equal("author", update.AuthorName); + + Assert.Null(update.Role); + update.Role = ChatRole.Assistant; + Assert.Equal(ChatRole.Assistant, update.Role); + + Assert.Empty(update.Contents); + update.Contents.Add(new TextContent("text")); + Assert.Single(update.Contents); + Assert.Equal("text", update.Text); + Assert.Same(update.Contents, update.Contents); + IList newList = [new TextContent("text")]; + update.Contents = newList; + Assert.Same(newList, update.Contents); + update.Contents = null; + Assert.NotNull(update.Contents); + Assert.Empty(update.Contents); + + Assert.Null(update.Text); + update.Text = "text"; + Assert.Equal("text", update.Text); + + Assert.Null(update.RawRepresentation); + object raw = new(); + update.RawRepresentation = raw; + Assert.Same(raw, update.RawRepresentation); + + Assert.Null(update.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { ["key"] = "value" }; + update.AdditionalProperties = props; + Assert.Same(props, update.AdditionalProperties); + + Assert.Null(update.CompletionId); + update.CompletionId = "id"; + Assert.Equal("id", update.CompletionId); + + Assert.Null(update.CreatedAt); + update.CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), update.CreatedAt); + + Assert.Equal(0, update.ChoiceIndex); + update.ChoiceIndex = 42; + Assert.Equal(42, update.ChoiceIndex); + + Assert.Null(update.FinishReason); + update.FinishReason = ChatFinishReason.ContentFilter; + Assert.Equal(ChatFinishReason.ContentFilter, update.FinishReason); + } + + [Fact] + public void Text_GetSet_UsesFirstTextContent() + { + StreamingChatCompletionUpdate update = new() + { + Role = ChatRole.User, + Contents = + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new TextContent("text-1"), + new TextContent("text-2"), + new FunctionResultContent(new FunctionCallContent("callId1", "fc2"), "result"), + ], + }; + + TextContent textContent = Assert.IsType(update.Contents[3]); + Assert.Equal("text-1", textContent.Text); + Assert.Equal("text-1", update.Text); + Assert.Equal("text-1", update.ToString()); + + update.Text = "text-3"; + Assert.Equal("text-3", update.Text); + Assert.Equal("text-3", update.Text); + Assert.Same(textContent, update.Contents[3]); + Assert.Equal("text-3", update.ToString()); + } + + [Fact] + public void Text_Set_AddsTextMessageToEmptyList() + { + StreamingChatCompletionUpdate update = new() + { + Role = ChatRole.User, + }; + Assert.Empty(update.Contents); + + update.Text = "text-1"; + Assert.Equal("text-1", update.Text); + + Assert.Single(update.Contents); + TextContent textContent = Assert.IsType(update.Contents[0]); + Assert.Equal("text-1", textContent.Text); + } + + [Fact] + public void Text_Set_AddsTextMessageToListWithNoText() + { + StreamingChatCompletionUpdate update = new() + { + Contents = + [ + new AudioContent("http://localhost/audio"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + ] + }; + Assert.Equal(3, update.Contents.Count); + + update.Text = "text-1"; + Assert.Equal("text-1", update.Text); + Assert.Equal(4, update.Contents.Count); + + update.Text = "text-2"; + Assert.Equal("text-2", update.Text); + Assert.Equal(4, update.Contents.Count); + + update.Contents.RemoveAt(3); + Assert.Equal(3, update.Contents.Count); + + update.Text = "text-3"; + Assert.Equal("text-3", update.Text); + Assert.Equal(4, update.Contents.Count); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + StreamingChatCompletionUpdate original = new() + { + AuthorName = "author", + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("text-1"), + new ImageContent("http://localhost/image"), + new FunctionCallContent("callId1", "fc1"), + new DataContent("data"u8.ToArray()), + new TextContent("text-2"), + ], + RawRepresentation = new object(), + CompletionId = "id", + CreatedAt = new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), + FinishReason = ChatFinishReason.ContentFilter, + AdditionalProperties = new() { ["key"] = "value" }, + ChoiceIndex = 42, + }; + + string json = JsonSerializer.Serialize(original, TestJsonSerializerContext.Default.StreamingChatCompletionUpdate); + + StreamingChatCompletionUpdate? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.StreamingChatCompletionUpdate); + + Assert.NotNull(result); + Assert.Equal(5, result.Contents.Count); + + Assert.IsType(result.Contents[0]); + Assert.Equal("text-1", ((TextContent)result.Contents[0]).Text); + + Assert.IsType(result.Contents[1]); + Assert.Equal("http://localhost/image", ((ImageContent)result.Contents[1]).Uri); + + Assert.IsType(result.Contents[2]); + Assert.Equal("fc1", ((FunctionCallContent)result.Contents[2]).Name); + + Assert.IsType(result.Contents[3]); + Assert.Equal("data"u8.ToArray(), ((DataContent)result.Contents[3]).Data?.ToArray()); + + Assert.IsType(result.Contents[4]); + Assert.Equal("text-2", ((TextContent)result.Contents[4]).Text); + + Assert.Equal("author", result.AuthorName); + Assert.Equal(ChatRole.Assistant, result.Role); + Assert.Equal("id", result.CompletionId); + Assert.Equal(new DateTimeOffset(2022, 1, 1, 0, 0, 0, TimeSpan.Zero), result.CreatedAt); + Assert.Equal(ChatFinishReason.ContentFilter, result.FinishReason); + Assert.Equal(42, result.ChoiceIndex); + + Assert.NotNull(result.AdditionalProperties); + Assert.Single(result.AdditionalProperties); + Assert.True(result.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs new file mode 100644 index 00000000000..ece02f017bb --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AIContentTests.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + DerivedAIContent c = new(); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + DerivedAIContent c = new(); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + } + + private sealed class DerivedAIContent : AIContent; +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs new file mode 100644 index 00000000000..7aff849e8a1 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/AudioContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class AudioContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs new file mode 100644 index 00000000000..18aae8c0497 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class DataContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs new file mode 100644 index 00000000000..ea3017cf7ea --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/DataContentTests{T}.cs @@ -0,0 +1,249 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Reflection; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public abstract class DataContentTests + where T : DataContent +{ + private static T Create(params object?[] args) + { + try + { + return (T)Activator.CreateInstance(typeof(T), args)!; + } + catch (TargetInvocationException e) + { + throw e.InnerException!; + } + } + + public T CreateDataContent(Uri uri, string? mediaType = null) => Create(uri, mediaType)!; + +#pragma warning disable S3997 // String URI overloads should call "System.Uri" overloads + public T CreateDataContent(string uriString, string? mediaType = null) => Create(uriString, mediaType)!; +#pragma warning restore S3997 + + public T CreateDataContent(ReadOnlyMemory data, string? mediaType = null) => Create(data, mediaType)!; + + [Theory] + + // Invalid URI + [InlineData("", typeof(ArgumentException))] + [InlineData("invalid", typeof(UriFormatException))] + + // Format errors + [InlineData("data", typeof(UriFormatException))] // data missing colon + [InlineData("data:", typeof(UriFormatException))] // data missing comma + [InlineData("data:something,", typeof(UriFormatException))] // mime type without subtype + [InlineData("data:something;else,data", typeof(UriFormatException))] // mime type without subtype + [InlineData("data:type/subtype;;parameter=value;else,", typeof(UriFormatException))] // parameter without value + [InlineData("data:type/subtype;parameter=va=lue;else,", typeof(UriFormatException))] // parameter with multiple = + [InlineData("data:type/subtype;=value;else,", typeof(UriFormatException))] // empty parameter name + [InlineData("", typeof(UriFormatException))] // multiple slashes in media type + + // Base64 Validation Errors + [InlineData("data:text;base64,something!", typeof(UriFormatException))] // Invalid base64 due to invalid character '!' + [InlineData("data:text/plain;base64,U29tZQ==\t", typeof(UriFormatException))] // Invalid base64 due to tab character + [InlineData("data:text/plain;base64,U29tZQ==\r", typeof(UriFormatException))] // Invalid base64 due to carriage return character + [InlineData("data:text/plain;base64,U29tZQ==\n", typeof(UriFormatException))] // Invalid base64 due to line feed character + [InlineData("data:text/plain;base64,U29t\r\nZQ==", typeof(UriFormatException))] // Invalid base64 due to carriage return and line feed characters + [InlineData("data:text/plain;base64,U29", typeof(UriFormatException))] // Invalid base64 due to missing padding + [InlineData("data:text/plain;base64,U29tZQ", typeof(UriFormatException))] // Invalid base64 due to missing padding + [InlineData("data:text/plain;base64,U29tZQ=", typeof(UriFormatException))] // Invalid base64 due to missing padding + public void Ctor_InvalidUri_Throws(string path, Type exception) + { + Assert.Throws(exception, () => CreateDataContent(path)); + } + + [Theory] + [InlineData("type")] + [InlineData("type//subtype")] + [InlineData("type/subtype/")] + [InlineData("type/subtype;key=")] + [InlineData("type/subtype;=value")] + [InlineData("type/subtype;key=value;another=")] + public void Ctor_InvalidMediaType_Throws(string mediaType) + { + Assert.Throws(() => CreateDataContent("http://localhost/test", mediaType)); + } + + [Theory] + [InlineData("type/subtype")] + [InlineData("type/subtype;key=value")] + [InlineData("type/subtype;key=value;another=value")] + [InlineData("type/subtype;key=value;another=value;yet_another=value")] + public void Ctor_ValidMediaType_Roundtrips(string mediaType) + { + T content = CreateDataContent("http://localhost/test", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent("data:,", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent("data:text/plain,", mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(new Uri("data:text/plain,"), mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(new byte[] { 0, 1, 2 }, mediaType); + Assert.Equal(mediaType, content.MediaType); + + content = CreateDataContent(content.Uri); + Assert.Equal(mediaType, content.MediaType); + } + + [Fact] + public void Ctor_NoMediaType_Roundtrips() + { + T content; + + foreach (string url in new[] { "http://localhost/test", "about:something", "file://c:\\path" }) + { + content = CreateDataContent(url); + Assert.Equal(url, content.Uri); + Assert.Null(content.MediaType); + Assert.Null(content.Data); + } + + content = CreateDataContent("data:,something"); + Assert.Equal("data:,something", content.Uri); + Assert.Null(content.MediaType); + Assert.Equal("something"u8.ToArray(), content.Data!.Value.ToArray()); + + content = CreateDataContent("data:,Hello+%3C%3E"); + Assert.Equal("data:,Hello+%3C%3E", content.Uri); + Assert.Null(content.MediaType); + Assert.Equal("Hello <>"u8.ToArray(), content.Data!.Value.ToArray()); + } + + [Fact] + public void Serialize_MatchesExpectedJson() + { + Assert.Equal( + """{"uri":"data:,"}""", + JsonSerializer.Serialize(CreateDataContent("data:,"), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"http://localhost/"}""", + JsonSerializer.Serialize(CreateDataContent(new Uri("http://localhost/")), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + JsonSerializer.Serialize(CreateDataContent( + uriString: "data:application/octet-stream;base64,AQIDBA=="), TestJsonSerializerContext.Default.Options)); + + Assert.Equal( + """{"uri":"data:application/octet-stream;base64,AQIDBA==","mediaType":"application/octet-stream"}""", + JsonSerializer.Serialize(CreateDataContent( + new ReadOnlyMemory([0x01, 0x02, 0x03, 0x04]), "application/octet-stream"), + TestJsonSerializerContext.Default.Options)); + } + + [Theory] + [InlineData("{}")] + [InlineData("""{ "mediaType":"text/plain" }""")] + public void Deserialize_MissingUriString_Throws(string json) + { + Assert.Throws(() => JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options)!); + } + + [Fact] + public void Deserialize_MatchesExpectedData() + { + // Data + MimeType only + var content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"data:;base64,AQIDBA=="}""", TestJsonSerializerContext.Default.Options)!; + + Assert.Equal("data:application/octet-stream;base64,AQIDBA==", content.Uri); + Assert.NotNull(content.Data); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data!.Value.ToArray()); + Assert.Equal("application/octet-stream", content.MediaType); + Assert.True(content.ContainsData); + + // Uri referenced content-only + content = JsonSerializer.Deserialize("""{"mediaType":"application/octet-stream","uri":"http://localhost/"}""", TestJsonSerializerContext.Default.Options)!; + + Assert.Null(content.Data); + Assert.Equal("http://localhost/", content.Uri); + Assert.Equal("application/octet-stream", content.MediaType); + Assert.False(content.ContainsData); + + // Using extra metadata + content = JsonSerializer.Deserialize(""" + { + "uri": "data:;base64,AQIDBA==", + "modelId": "gpt-4", + "additionalProperties": + { + "key": "value" + }, + "mediaType": "text/plain" + } + """, TestJsonSerializerContext.Default.Options)!; + + Assert.Equal("data:text/plain;base64,AQIDBA==", content.Uri); + Assert.NotNull(content.Data); + Assert.Equal([0x01, 0x02, 0x03, 0x04], content.Data!.Value.ToArray()); + Assert.Equal("text/plain", content.MediaType); + Assert.True(content.ContainsData); + Assert.Equal("gpt-4", content.ModelId); + Assert.Equal("value", content.AdditionalProperties!["key"]!.ToString()); + } + + [Theory] + [InlineData( + """{"uri": "data:;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType": "text/plain"}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( // Does not support non-readable content + """{"uri": "data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", "unexpected": true}""", + """{"uri":"data:text/plain;base64,AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=","mediaType":"text/plain"}""")] + [InlineData( // Uri comes before mimetype + """{"mediaType": "text/plain", "uri": "http://localhost/" }""", + """{"uri":"http://localhost/","mediaType":"text/plain"}""")] + public void Serialize_Deserialize_Roundtrips(string serialized, string expectedToString) + { + var content = JsonSerializer.Deserialize(serialized, TestJsonSerializerContext.Default.Options)!; + var reSerialization = JsonSerializer.Serialize(content, TestJsonSerializerContext.Default.Options); + Assert.Equal(expectedToString, reSerialization); + } + + [Theory] + [InlineData("application/json")] + [InlineData("application/octet-stream")] + [InlineData("application/pdf")] + [InlineData("application/xml")] + [InlineData("audio/mpeg")] + [InlineData("audio/ogg")] + [InlineData("audio/wav")] + [InlineData("image/apng")] + [InlineData("image/avif")] + [InlineData("image/bmp")] + [InlineData("image/gif")] + [InlineData("image/jpeg")] + [InlineData("image/png")] + [InlineData("image/svg+xml")] + [InlineData("image/tiff")] + [InlineData("image/webp")] + [InlineData("text/css")] + [InlineData("text/csv")] + [InlineData("text/html")] + [InlineData("text/javascript")] + [InlineData("text/plain")] + [InlineData("text/plain;charset=UTF-8")] + [InlineData("text/xml")] + [InlineData("custom/mediatypethatdoesntexists")] + public void MediaType_Roundtrips(string mediaType) + { + DataContent c = new("data:,", mediaType); + Assert.Equal(mediaType, c.MediaType); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs new file mode 100644 index 00000000000..791bb4cc0e7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionCallContentTests..cs @@ -0,0 +1,302 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +#if NET +using System.Runtime.ExceptionServices; +#endif +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionCallContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + FunctionCallContent c = new("callId1", "name"); + + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + Assert.Equal("name", c.Name); + + Assert.Null(c.Arguments); + Assert.Null(c.Exception); + } + + [Fact] + public void Constructor_ArgumentsRoundtrip() + { + Dictionary args = []; + + FunctionCallContent c = new("id", "name", args); + + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Equal("name", c.Name); + Assert.Equal("id", c.CallId); + Assert.Same(args, c.Arguments); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + FunctionCallContent c = new("callId1", "name"); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + c.CallId = "id"; + Assert.Equal("id", c.CallId); + + Assert.Null(c.Arguments); + AdditionalPropertiesDictionary args = new() { { "key", "value" } }; + c.Arguments = args; + Assert.Same(args, c.Arguments); + + Assert.Null(c.Exception); + Exception e = new(); + c.Exception = e; + Assert.Same(e, c.Exception); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializableWithException() + { + // Arrange + var ex = new InvalidOperationException("hello", new NullReferenceException("bye")); +#if NET + ExceptionDispatchInfo.SetRemoteStackTrace(ex, "stack trace"); +#endif + var sut = new FunctionCallContent("callId1", "functionName") { Exception = ex }; + + // Act + var json = JsonSerializer.SerializeToNode(sut, TestJsonSerializerContext.Default.Options); + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + JsonObject jsonEx = Assert.IsType(json!["exception"]); + Assert.Equal(4, jsonEx.Count); + Assert.Equal("System.InvalidOperationException", (string?)jsonEx["className"]); + Assert.Equal("hello", (string?)jsonEx["message"]); +#if NET + Assert.StartsWith("stack trace", (string?)jsonEx["stackTraceString"]); +#endif + JsonObject jsonExInner = Assert.IsType(jsonEx["innerException"]); + Assert.Equal(4, jsonExInner.Count); + Assert.Equal("System.NullReferenceException", (string?)jsonExInner["className"]); + Assert.Equal("bye", (string?)jsonExInner["message"]); + Assert.Null(jsonExInner["innerException"]); + Assert.Null(jsonExInner["stackTraceString"]); + + Assert.NotNull(deserializedSut); + Assert.IsType(deserializedSut.Exception); + Assert.Equal("hello", deserializedSut.Exception.Message); +#if NET + Assert.StartsWith("stack trace", deserializedSut.Exception.StackTrace); +#endif + + Assert.IsType(deserializedSut.Exception.InnerException); + Assert.Equal("bye", deserializedSut.Exception.InnerException.Message); + Assert.Null(deserializedSut.Exception.InnerException.StackTrace); + Assert.Null(deserializedSut.Exception.InnerException.InnerException); + } + + [Fact] + public async Task AIFunctionFactory_ObjectValues_Converted() + { + Dictionary arguments = new() + { + ["a"] = new DayOfWeek[] { DayOfWeek.Monday, DayOfWeek.Tuesday, DayOfWeek.Wednesday }, + ["b"] = 123.4M, + ["c"] = "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + ["d"] = new ReadOnlyDictionary((new Dictionary + { + ["p1"] = "42", + ["p2"] = "43", + })), + }; + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task AIFunctionFactory_JsonElementValues_ValuesDeserialized() + { + Dictionary arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!; + Assert.All(arguments.Values, v => Assert.IsType(v)); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public void AIFunctionFactory_WhenTypesUnknownByContext_Throws() + { + var ex = Assert.Throws(() => AIFunctionFactory.Create((CustomType arg) => { }, TestJsonSerializerContext.Default.Options)); + Assert.Contains("JsonTypeInfo metadata", ex.Message); + Assert.Contains(nameof(CustomType), ex.Message); + + ex = Assert.Throws(() => AIFunctionFactory.Create(() => new CustomType(), TestJsonSerializerContext.Default.Options)); + Assert.Contains("JsonTypeInfo metadata", ex.Message); + Assert.Contains(nameof(CustomType), ex.Message); + } + + [Fact] + public async Task AIFunctionFactory_JsonDocumentValues_ValuesDeserialized() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task AIFunctionFactory_JsonNodeValues_ValuesDeserialized() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": ["Monday", "Tuesday", "Wednesday"], + "b": 123.4, + "c": "072c2d93-7cf6-4d0d-aebc-acc51e6ee7ee", + "d": { + "property1": "42", + "property2": "43", + "property3": "44" + } + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + AIFunction function = AIFunctionFactory.Create((DayOfWeek[] a, double b, Guid c, Dictionary d) => b, TestJsonSerializerContext.Default.Options); + var result = await function.InvokeAsync(arguments); + AssertExtensions.EqualFunctionCallResults(123.4, result); + } + + [Fact] + public async Task TypelessAIFunction_JsonDocumentValues_AcceptsArguments() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + [Fact] + public async Task TypelessAIFunction_JsonElementValues_AcceptsArguments() + { + Dictionary arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!; + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + [Fact] + public async Task TypelessAIFunction_JsonNodeValues_AcceptsArguments() + { + var arguments = JsonSerializer.Deserialize>(""" + { + "a": "string", + "b": 123.4, + "c": true, + "d": false, + "e": ["Monday", "Tuesday", "Wednesday"], + "f": null + } + """, TestJsonSerializerContext.Default.Options)!.ToDictionary(k => k.Key, k => (object?)k.Value); + + var result = await NetTypelessAIFunction.Instance.InvokeAsync(arguments); + Assert.Same(result, arguments); + } + + private sealed class CustomType; + + private sealed class NetTypelessAIFunction : AIFunction + { + public static NetTypelessAIFunction Instance { get; } = new NetTypelessAIFunction(); + + public override AIFunctionMetadata Metadata => new("NetTypeless") + { + Description = "AIFunction with parameters that lack .NET types", + Parameters = + [ + new AIFunctionParameterMetadata("a"), + new AIFunctionParameterMetadata("b"), + new AIFunctionParameterMetadata("c"), + new AIFunctionParameterMetadata("d"), + new AIFunctionParameterMetadata("e"), + new AIFunctionParameterMetadata("f"), + ] + }; + + protected override Task InvokeCoreAsync(IEnumerable>? arguments, CancellationToken cancellationToken) => + Task.FromResult(arguments); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs new file mode 100644 index 00000000000..a24120ca9a9 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/FunctionResultContentTests.cs @@ -0,0 +1,120 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionResultContentTests +{ + [Fact] + public void Constructor_PropsDefault() + { + FunctionResultContent c = new("callId1", "functionName"); + Assert.Equal("callId1", c.CallId); + Assert.Equal("functionName", c.Name); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Null(c.Result); + Assert.Null(c.Exception); + } + + [Fact] + public void Constructor_String_PropsRoundtrip() + { + Exception e = new(); + + FunctionResultContent c = new("id", "name", "result", e); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal("name", c.Name); + Assert.Equal("id", c.CallId); + Assert.Equal("result", c.Result); + Assert.Same(e, c.Exception); + } + + [Fact] + public void Constructor_FunctionCallContent_PropsRoundtrip() + { + Exception e = new(); + + FunctionResultContent c = new(new FunctionCallContent("id", "name"), "result", e); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal("id", c.CallId); + Assert.Equal("result", c.Result); + Assert.Same(e, c.Exception); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + FunctionResultContent c = new("callId1", "functionName"); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Equal("callId1", c.CallId); + c.CallId = "id"; + Assert.Equal("id", c.CallId); + + Assert.Null(c.Result); + c.Result = "result"; + Assert.Equal("result", c.Result); + + Assert.Null(c.Exception); + Exception e = new(); + c.Exception = e; + Assert.Same(e, c.Exception); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializable() + { + // Arrange + var sut = new FunctionResultContent(new FunctionCallContent("id", "p1-f1"), "result"); + + // Act + var json = JsonSerializer.Serialize(sut, TestJsonSerializerContext.Default.Options); + + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.NotNull(deserializedSut); + Assert.Equal(sut.Name, deserializedSut.Name); + Assert.Equal(sut.CallId, deserializedSut.CallId); + Assert.Equal(sut.Result, deserializedSut.Result?.ToString()); + } + + [Fact] + public void ItShouldBeSerializableAndDeserializableWithException() + { + // Arrange + var sut = new FunctionResultContent("callId1", "functionName") { Exception = new InvalidOperationException("hello") }; + + // Act + var json = JsonSerializer.Serialize(sut, TestJsonSerializerContext.Default.Options); + var deserializedSut = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.NotNull(deserializedSut); + Assert.IsType(deserializedSut.Exception); + Assert.Contains("hello", deserializedSut.Exception.Message); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs new file mode 100644 index 00000000000..7b088e3ebf3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/ImageContentTests.cs @@ -0,0 +1,6 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public sealed class ImageContentTests : DataContentTests; diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs new file mode 100644 index 00000000000..d1ba5e83bc9 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/TextContentTests.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class TextContentTests +{ + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData("text")] + public void Constructor_String_PropsDefault(string? text) + { + TextContent c = new(text); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + Assert.Equal(text, c.Text); + } + + [Fact] + public void Constructor_PropsRoundtrip() + { + TextContent c = new(null); + + Assert.Null(c.RawRepresentation); + object raw = new(); + c.RawRepresentation = raw; + Assert.Same(raw, c.RawRepresentation); + + Assert.Null(c.ModelId); + c.ModelId = "modelId"; + Assert.Equal("modelId", c.ModelId); + + Assert.Null(c.AdditionalProperties); + AdditionalPropertiesDictionary props = new() { { "key", "value" } }; + c.AdditionalProperties = props; + Assert.Same(props, c.AdditionalProperties); + + Assert.Null(c.Text); + c.Text = "text"; + Assert.Equal("text", c.Text); + Assert.Equal("text", c.ToString()); + + c.Text = null; + Assert.Null(c.Text); + Assert.Equal(string.Empty, c.ToString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs new file mode 100644 index 00000000000..109bdc8120e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Contents/UsageContentTests.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class UsageContentTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("details", () => new UsageContent(null!)); + } + + [Fact] + public void Constructor_Parameterless_PropsDefault() + { + UsageContent c = new(); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.NotNull(c.Details); + Assert.Same(c.Details, c.Details); + Assert.Null(c.Details.InputTokenCount); + Assert.Null(c.Details.OutputTokenCount); + Assert.Null(c.Details.TotalTokenCount); + Assert.Null(c.Details.AdditionalProperties); + } + + [Fact] + public void Constructor_UsageDetails_PropsRoundtrip() + { + UsageDetails details = new(); + + UsageContent c = new(details); + Assert.Null(c.RawRepresentation); + Assert.Null(c.ModelId); + Assert.Null(c.AdditionalProperties); + + Assert.Same(details, c.Details); + + UsageDetails details2 = new(); + c.Details = details2; + Assert.Same(details2, c.Details); + } + + [Fact] + public void Details_SetNull_Throws() + { + UsageContent c = new(); + + UsageDetails d = c.Details; + Assert.NotNull(d); + + Assert.Throws("value", () => c.Details = null!); + + Assert.Same(d, c.Details); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..91640e62f4f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/DelegatingEmbeddingGeneratorTests.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DelegatingEmbeddingGeneratorTests +{ + [Fact] + public void RequiresInnerService() + { + Assert.Throws(() => new NoOpDelegatingEmbeddingGenerator(null!)); + } + + [Fact] + public void MetadataDefaultsToInnerService() + { + using var inner = new TestEmbeddingGenerator(); + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + Assert.Same(inner.Metadata, delegating.Metadata); + } + + [Fact] + public async Task GenerateEmbeddingsDefaultsToInnerServiceAsync() + { + // Arrange + var expectedInput = new List(); + using var cts = new CancellationTokenSource(); + var expectedCancellationToken = cts.Token; + var expectedResult = new TaskCompletionSource>>(); + var expectedEmbedding = new GeneratedEmbeddings>([new(new float[] { 1.0f, 2.0f, 3.0f })]); + using var inner = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (input, options, cancellationToken) => + { + Assert.Same(expectedInput, input); + Assert.Equal(expectedCancellationToken, cancellationToken); + return expectedResult.Task; + } + }; + + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var resultTask = delegating.GenerateAsync(expectedInput, options: null, expectedCancellationToken); + + // Assert + Assert.False(resultTask.IsCompleted); + expectedResult.SetResult(expectedEmbedding); + Assert.True(resultTask.IsCompleted); + Assert.Same(expectedEmbedding, await resultTask); + } + + [Fact] + public void GetServiceReturnsSelfIfCompatibleWithRequestAndKeyIsNull() + { + // Arrange + using var inner = new TestEmbeddingGenerator(); + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService>>(); + + // Assert + Assert.Same(delegating, service); + } + + [Fact] + public void GetServiceDelegatesToInnerIfKeyIsNotNull() + { + // Arrange + var expectedParam = new object(); + var expectedKey = new object(); + using var expectedResult = new TestEmbeddingGenerator(); + using var inner = new TestEmbeddingGenerator + { + GetServiceCallback = (_, _) => expectedResult + }; + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService>>(expectedKey); + + // Assert + Assert.Same(expectedResult, service); + } + + [Fact] + public void GetServiceDelegatesToInnerIfNotCompatibleWithRequest() + { + // Arrange + var expectedParam = new object(); + var expectedResult = TimeZoneInfo.Local; + var expectedKey = new object(); + using var inner = new TestEmbeddingGenerator + { + GetServiceCallback = (type, key) => type == expectedResult.GetType() && key == expectedKey + ? expectedResult + : throw new InvalidOperationException("Unexpected call") + }; + using var delegating = new NoOpDelegatingEmbeddingGenerator(inner); + + // Act + var service = delegating.GetService(expectedKey); + + // Assert + Assert.Same(expectedResult, service); + } + + private sealed class NoOpDelegatingEmbeddingGenerator(IEmbeddingGenerator> innerGenerator) : + DelegatingEmbeddingGenerator>(innerGenerator); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs new file mode 100644 index 00000000000..e9dd45959c7 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGenerationOptionsTests.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGenerationOptionsTests +{ + [Fact] + public void Constructor_Parameterless_PropsDefaulted() + { + EmbeddingGenerationOptions options = new(); + Assert.Null(options.ModelId); + Assert.Null(options.AdditionalProperties); + + EmbeddingGenerationOptions clone = options.Clone(); + Assert.Null(clone.ModelId); + Assert.Null(clone.AdditionalProperties); + } + + [Fact] + public void Properties_Roundtrip() + { + EmbeddingGenerationOptions options = new(); + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.ModelId = "modelId"; + options.AdditionalProperties = additionalProps; + + Assert.Equal("modelId", options.ModelId); + Assert.Same(additionalProps, options.AdditionalProperties); + + EmbeddingGenerationOptions clone = options.Clone(); + Assert.Equal("modelId", clone.ModelId); + Assert.Equal(additionalProps, clone.AdditionalProperties); + } + + [Fact] + public void JsonSerialization_Roundtrips() + { + EmbeddingGenerationOptions options = new(); + + AdditionalPropertiesDictionary additionalProps = new() + { + ["key"] = "value", + }; + + options.ModelId = "model"; + options.AdditionalProperties = additionalProps; + + string json = JsonSerializer.Serialize(options, TestJsonSerializerContext.Default.EmbeddingGenerationOptions); + + EmbeddingGenerationOptions? deserialized = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.EmbeddingGenerationOptions); + Assert.NotNull(deserialized); + + Assert.Equal("model", deserialized.ModelId); + + Assert.NotNull(deserialized.AdditionalProperties); + Assert.Single(deserialized.AdditionalProperties); + Assert.True(deserialized.AdditionalProperties.TryGetValue("key", out object? value)); + Assert.IsType(value); + Assert.Equal("value", ((JsonElement)value!).GetString()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs new file mode 100644 index 00000000000..827ed04c712 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorExtensionsTests.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorExtensionsTests +{ + [Fact] + public async Task GenerateAsync_InvalidArgs_ThrowsAsync() + { + await Assert.ThrowsAsync("generator", () => ((TestEmbeddingGenerator)null!).GenerateAsync("hello")); + } + + [Fact] + public async Task GenerateAsync_ReturnsSingleEmbeddingAsync() + { + Embedding result = new(new float[] { 1f, 2f, 3f }); + + using TestEmbeddingGenerator service = new() + { + GenerateAsyncCallback = (values, options, cancellationToken) => + Task.FromResult>>([result]) + }; + + Assert.Same(result, (await service.GenerateAsync("hello"))[0]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs new file mode 100644 index 00000000000..b3cd0d59abb --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingGeneratorMetadataTests.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorMetadataTests +{ + [Fact] + public void Constructor_NullValues_AllowedAndRoundtrip() + { + EmbeddingGeneratorMetadata metadata = new(null, null, null, null); + Assert.Null(metadata.ProviderName); + Assert.Null(metadata.ProviderUri); + Assert.Null(metadata.ModelId); + Assert.Null(metadata.Dimensions); + } + + [Fact] + public void Constructor_Value_Roundtrips() + { + var uri = new Uri("https://example.com"); + EmbeddingGeneratorMetadata metadata = new("providerName", uri, "theModel", 42); + Assert.Equal("providerName", metadata.ProviderName); + Assert.Same(uri, metadata.ProviderUri); + Assert.Equal("theModel", metadata.ModelId); + Assert.Equal(42, metadata.Dimensions); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs new file mode 100644 index 00000000000..45fcce8ba63 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/EmbeddingTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Runtime.InteropServices; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingTests +{ + [Fact] + public void Embedding_Ctor_Roundtrips() + { + float[] floats = [1f, 2f, 3f]; + UsageDetails usage = new(); + AdditionalPropertiesDictionary props = []; + var createdAt = DateTimeOffset.Parse("2022-01-01T00:00:00Z"); + const string Model = "text-embedding-3-small"; + + Embedding e = new(floats) + { + CreatedAt = createdAt, + ModelId = Model, + AdditionalProperties = props, + }; + + Assert.Equal(floats, e.Vector.ToArray()); + Assert.Equal(Model, e.ModelId); + Assert.Same(props, e.AdditionalProperties); + Assert.Equal(createdAt, e.CreatedAt); + + Assert.True(MemoryMarshal.TryGetArray(e.Vector, out ArraySegment array)); + Assert.Same(floats, array.Array); + } + +#if NET + [Fact] + public void Embedding_Half_SerializationRoundtrips() + { + Half[] halfs = [(Half)1f, (Half)2f, (Half)3f]; + Embedding e = new(halfs); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"halves","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } +#endif + + [Fact] + public void Embedding_Single_SerializationRoundtrips() + { + float[] floats = [1f, 2f, 3f]; + Embedding e = new(floats); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"floats","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } + + [Fact] + public void Embedding_Double_SerializationRoundtrips() + { + double[] floats = [1f, 2f, 3f]; + Embedding e = new(floats); + + string json = JsonSerializer.Serialize(e, TestJsonSerializerContext.Default.Embedding); + Assert.Equal("""{"$type":"doubles","vector":[1,2,3]}""", json); + + Embedding result = Assert.IsType>(JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.Embedding)); + Assert.Equal(e.Vector.ToArray(), result.Vector.ToArray()); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs new file mode 100644 index 00000000000..4ebd9465ca8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Embeddings/GeneratedEmbeddingsTests.cs @@ -0,0 +1,246 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +#pragma warning disable xUnit2013 // Do not use equality check to check for collection size. +#pragma warning disable xUnit2017 // Do not use Contains() to check if a value exists in a collection + +namespace Microsoft.Extensions.AI; + +public class GeneratedEmbeddingsTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("embeddings", () => new GeneratedEmbeddings>(null!)); + Assert.Throws("capacity", () => new GeneratedEmbeddings>(-1)); + } + + [Fact] + public void Ctor_ValidArgs_NoExceptions() + { + GeneratedEmbeddings>[] instances = + [ + [], + new(0), + new(42), + new([]) + ]; + + foreach (var instance in instances) + { + Assert.Empty(instance); + + Assert.False(((ICollection>)instance).IsReadOnly); + Assert.Equal(0, instance.Count); + + Assert.False(instance.Contains(new Embedding(new float[] { 1, 2, 3 }))); + Assert.False(instance.Contains(null!)); + + Assert.Equal(-1, instance.IndexOf(new Embedding(new float[] { 1, 2, 3 }))); + Assert.Equal(-1, instance.IndexOf(null!)); + + instance.CopyTo(Array.Empty>(), 0); + + Assert.Throws(() => instance[0]); + Assert.Throws(() => instance[-1]); + } + } + + [Fact] + public void Ctor_RoundtripsEnumerable() + { + List> embeddings = + [ + new(new float[] { 1, 2, 3 }), + new(new float[] { 4, 5, 6 }), + ]; + + var generatedEmbeddings = new GeneratedEmbeddings>(embeddings); + + Assert.Equal(embeddings, generatedEmbeddings); + Assert.Equal(2, generatedEmbeddings.Count); + + Assert.Same(embeddings[0], generatedEmbeddings[0]); + Assert.Same(embeddings[1], generatedEmbeddings[1]); + + Assert.Equal(0, generatedEmbeddings.IndexOf(embeddings[0])); + Assert.Equal(1, generatedEmbeddings.IndexOf(embeddings[1])); + + Assert.True(generatedEmbeddings.Contains(embeddings[0])); + Assert.True(generatedEmbeddings.Contains(embeddings[1])); + + Assert.False(generatedEmbeddings.Contains(null!)); + Assert.Equal(-1, generatedEmbeddings.IndexOf(null!)); + + Assert.Throws(() => generatedEmbeddings[-1]); + Assert.Throws(() => generatedEmbeddings[2]); + + Assert.True(embeddings.SequenceEqual(generatedEmbeddings)); + + var e = new Embedding(new float[] { 7, 8, 9 }); + generatedEmbeddings.Add(e); + Assert.Equal(3, generatedEmbeddings.Count); + Assert.Same(e, generatedEmbeddings[2]); + } + + [Fact] + public void Properties_Roundtrip() + { + GeneratedEmbeddings> embeddings = []; + + Assert.Null(embeddings.Usage); + + UsageDetails usage = new(); + embeddings.Usage = usage; + Assert.Same(usage, embeddings.Usage); + embeddings.Usage = null; + Assert.Null(embeddings.Usage); + + Assert.Null(embeddings.AdditionalProperties); + AdditionalPropertiesDictionary props = []; + embeddings.AdditionalProperties = props; + Assert.Same(props, embeddings.AdditionalProperties); + embeddings.AdditionalProperties = null; + Assert.Null(embeddings.AdditionalProperties); + } + + [Fact] + public void Add() + { + GeneratedEmbeddings> embeddings = []; + var e = new Embedding(new float[] { 1, 2, 3 }); + + embeddings.Add(e); + Assert.Equal(1, embeddings.Count); + Assert.Same(e, embeddings[0]); + } + + [Fact] + public void AddRange() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + + Assert.Equal(2, embeddings.Count); + Assert.Same(e1, embeddings[0]); + Assert.Same(e2, embeddings[1]); + } + + [Fact] + public void Clear() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + embeddings.Clear(); + Assert.Equal(0, embeddings.Count); + Assert.Empty(embeddings); + } + + [Fact] + public void Remove() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + Assert.True(embeddings.Remove(e1)); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + Assert.False(embeddings.Remove(e1)); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + Assert.True(embeddings.Remove(e2)); + Assert.Equal(0, embeddings.Count); + } + + [Fact] + public void RemoveAt() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + embeddings.RemoveAt(0); + Assert.Equal(1, embeddings.Count); + Assert.Same(e2, embeddings[0]); + + embeddings.RemoveAt(0); + Assert.Equal(0, embeddings.Count); + } + + [Fact] + public void Insert() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + var e3 = new Embedding(new float[] { 7, 8, 9 }); + embeddings.Insert(1, e3); + Assert.Equal(3, embeddings.Count); + Assert.Same(e3, embeddings[1]); + Assert.Same(e2, embeddings[2]); + } + + [Fact] + public void Indexer() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + var e3 = new Embedding(new float[] { 7, 8, 9 }); + embeddings[1] = e3; + Assert.Equal(2, embeddings.Count); + Assert.Same(e1, embeddings[0]); + Assert.Same(e3, embeddings[1]); + } + + [Fact] + public void Indexer_InvalidIndex_Throws() + { + GeneratedEmbeddings> embeddings = []; + + var e1 = new Embedding(new float[] { 1, 2, 3 }); + var e2 = new Embedding(new float[] { 4, 5, 6 }); + + embeddings.AddRange(new[] { e1, e2 }); + Assert.Equal(2, embeddings.Count); + + Assert.Throws(() => embeddings[-1]); + Assert.Throws(() => embeddings[2]); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs new file mode 100644 index 00000000000..a1aa48bd115 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionMetadataTests.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionMetadataTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("name", () => new AIFunctionMetadata((string)null!)); + Assert.Throws("name", () => new AIFunctionMetadata(" \t ")); + Assert.Throws("metadata", () => new AIFunctionMetadata((AIFunctionMetadata)null!)); + } + + [Fact] + public void Constructor_String_PropsDefaulted() + { + AIFunctionMetadata f = new("name"); + Assert.Equal("name", f.Name); + Assert.Empty(f.Description); + Assert.Empty(f.Parameters); + + Assert.NotNull(f.ReturnParameter); + Assert.Null(f.ReturnParameter.Schema); + Assert.Null(f.ReturnParameter.ParameterType); + Assert.Null(f.ReturnParameter.Description); + + Assert.NotNull(f.AdditionalProperties); + Assert.Empty(f.AdditionalProperties); + Assert.Same(f.AdditionalProperties, new AIFunctionMetadata("name2").AdditionalProperties); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionMetadata f1 = new("name") + { + Description = "description", + Parameters = [new AIFunctionParameterMetadata("param")], + ReturnParameter = new AIFunctionReturnParameterMetadata(), + AdditionalProperties = new Dictionary { { "key", "value" } }, + }; + + AIFunctionMetadata f2 = new(f1); + Assert.Equal(f1.Name, f2.Name); + Assert.Equal(f1.Description, f2.Description); + Assert.Same(f1.Parameters, f2.Parameters); + Assert.Same(f1.ReturnParameter, f2.ReturnParameter); + Assert.Same(f1.AdditionalProperties, f2.AdditionalProperties); + } + + [Fact] + public void Props_InvalidArg_Throws() + { + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { Parameters = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { ReturnParameter = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { AdditionalProperties = null! }); + } + + [Fact] + public void Description_NullNormalizedToEmpty() + { + AIFunctionMetadata f = new("name") { Description = null }; + Assert.Equal("", f.Description); + } + + [Fact] + public void GetParameter_EmptyCollection_ReturnsNull() + { + Assert.Null(new AIFunctionMetadata("name").GetParameter("test")); + } + + [Fact] + public void GetParameter_ByName_ReturnsParameter() + { + AIFunctionMetadata f = new("name") + { + Parameters = + [ + new AIFunctionParameterMetadata("param0"), + new AIFunctionParameterMetadata("param1"), + new AIFunctionParameterMetadata("param2"), + ] + }; + + Assert.Same(f.Parameters[0], f.GetParameter("param0")); + Assert.Same(f.Parameters[1], f.GetParameter("param1")); + Assert.Same(f.Parameters[2], f.GetParameter("param2")); + Assert.Null(f.GetParameter("param3")); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs new file mode 100644 index 00000000000..23c33ecf07a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionParameterMetadataTests.cs @@ -0,0 +1,91 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionParameterMetadataTests +{ + [Fact] + public void Constructor_InvalidArg_Throws() + { + Assert.Throws("name", () => new AIFunctionParameterMetadata((string)null!)); + Assert.Throws("name", () => new AIFunctionParameterMetadata(" ")); + Assert.Throws("metadata", () => new AIFunctionParameterMetadata((AIFunctionParameterMetadata)null!)); + } + + [Fact] + public void Constructor_String_PropsDefaulted() + { + AIFunctionParameterMetadata p = new("name"); + Assert.Equal("name", p.Name); + Assert.Null(p.Description); + Assert.Null(p.DefaultValue); + Assert.False(p.IsRequired); + Assert.Null(p.ParameterType); + Assert.Null(p.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionParameterMetadata p1 = new("name") + { + Description = "description", + HasDefaultValue = true, + DefaultValue = 42, + IsRequired = true, + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionParameterMetadata p2 = new(p1); + + Assert.Equal(p1.Name, p2.Name); + Assert.Equal(p1.Description, p2.Description); + Assert.Equal(p1.DefaultValue, p2.DefaultValue); + Assert.Equal(p1.IsRequired, p2.IsRequired); + Assert.Equal(p1.ParameterType, p2.ParameterType); + Assert.Equal(p1.Schema, p2.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagatedAndOverwritten() + { + AIFunctionParameterMetadata p1 = new("name") + { + Description = "description", + HasDefaultValue = true, + DefaultValue = 42, + IsRequired = true, + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionParameterMetadata p2 = new(p1) + { + Description = "description2", + HasDefaultValue = true, + DefaultValue = 43, + IsRequired = false, + ParameterType = typeof(long), + Schema = JsonDocument.Parse("""{"type":"number"}"""), + }; + + Assert.Equal("description2", p2.Description); + Assert.True(p2.HasDefaultValue); + Assert.Equal(43, p2.DefaultValue); + Assert.False(p2.IsRequired); + Assert.Equal(typeof(long), p2.ParameterType); + } + + [Fact] + public void Props_InvalidArg_Throws() + { + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = null! }); + Assert.Throws("value", () => new AIFunctionMetadata("name") { Name = "\r\n\t " }); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs new file mode 100644 index 00000000000..bb5bbeec03a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionReturnParameterMetadataTests.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionReturnParameterMetadataTests +{ + [Fact] + public void Constructor_PropsDefaulted() + { + AIFunctionReturnParameterMetadata p = new(); + Assert.Null(p.Description); + Assert.Null(p.ParameterType); + Assert.Null(p.Schema); + } + + [Fact] + public void Constructor_Copy_PropsPropagated() + { + AIFunctionReturnParameterMetadata p1 = new() + { + Description = "description", + ParameterType = typeof(int), + Schema = JsonDocument.Parse("""{"type":"integer"}"""), + }; + + AIFunctionReturnParameterMetadata p2 = new(p1); + Assert.Equal(p1.Description, p2.Description); + Assert.Equal(p1.ParameterType, p2.ParameterType); + Assert.Equal(p1.Schema, p2.Schema); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs new file mode 100644 index 00000000000..df143e8b97e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Functions/AIFunctionTests.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionTests +{ + [Fact] + public async Task InvokeAsync_UsesDefaultEmptyCollectionForNullArgsAsync() + { + DerivedAIFunction f = new(); + + using CancellationTokenSource cts = new(); + var result1 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + + Assert.NotNull(result1.Item1); + Assert.Empty(result1.Item1); + Assert.Equal(cts.Token, result1.Item2); + + var result2 = ((IEnumerable>, CancellationToken))(await f.InvokeAsync(null, cts.Token))!; + Assert.Same(result1.Item1, result2.Item1); + } + + [Fact] + public void ToString_ReturnsName() + { + DerivedAIFunction f = new(); + Assert.Equal("name", f.ToString()); + } + + private sealed class DerivedAIFunction : AIFunction + { + public override AIFunctionMetadata Metadata => new("name"); + + protected override Task InvokeCoreAsync(IEnumerable> arguments, CancellationToken cancellationToken) + { + Assert.NotNull(arguments); + return Task.FromResult((arguments, cancellationToken)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj new file mode 100644 index 00000000000..0d4d5fbfa96 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/Microsoft.Extensions.AI.Abstractions.Tests.csproj @@ -0,0 +1,24 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.Abstractions. + + + + $(NoWarn);CA1063;CA1861;CA2201;VSTHRD003 + true + + + + true + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs new file mode 100644 index 00000000000..55f4c486483 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestChatClient.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +public sealed class TestChatClient : IChatClient +{ + public IServiceProvider? Services { get; set; } + + public ChatClientMetadata Metadata { get; set; } = new(); + + public Func, ChatOptions?, CancellationToken, Task>? CompleteAsyncCallback { get; set; } + + public Func, ChatOptions?, CancellationToken, IAsyncEnumerable>? CompleteStreamingAsyncCallback { get; set; } + + public Func? GetServiceCallback { get; set; } + + public Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => CompleteAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + + public IAsyncEnumerable CompleteStreamingAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + => CompleteStreamingAsyncCallback!.Invoke(chatMessages, options, cancellationToken); + + public TService? GetService(object? key = null) + where TService : class + => (TService?)GetServiceCallback!(typeof(TService), key); + + void IDisposable.Dispose() + { + // No resources need disposing. + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs new file mode 100644 index 00000000000..83680a2be10 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestEmbeddingGenerator.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +public sealed class TestEmbeddingGenerator : IEmbeddingGenerator> +{ + public EmbeddingGeneratorMetadata Metadata { get; } = new(); + + public Func, EmbeddingGenerationOptions?, CancellationToken, Task>>>? GenerateAsyncCallback { get; set; } + + public Func? GetServiceCallback { get; set; } + + public Task>> GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + => GenerateAsyncCallback!.Invoke(values, options, cancellationToken); + + public TService? GetService(object? key = null) + where TService : class + => (TService?)GetServiceCallback!(typeof(TService), key); + + void IDisposable.Dispose() + { + // No resources to dispose + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..5a3e966c17b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + UseStringEnumConverter = true)] +[JsonSerializable(typeof(ChatCompletion))] +[JsonSerializable(typeof(StreamingChatCompletionUpdate))] +[JsonSerializable(typeof(ChatOptions))] +[JsonSerializable(typeof(EmbeddingGenerationOptions))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(int[]))] // Used in ChatMessageContentTests +[JsonSerializable(typeof(Embedding))] // Used in EmbeddingTests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(Dictionary))] // Used in Content tests +[JsonSerializable(typeof(ReadOnlyDictionary))] // Used in Content tests +[JsonSerializable(typeof(DayOfWeek[]))] // Used in Content tests +[JsonSerializable(typeof(Guid))] // Used in Content tests +[JsonSerializable(typeof(decimal))] // Used in Content tests +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs new file mode 100644 index 00000000000..29aef62fd77 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientIntegrationTests.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Microsoft.TestUtilities; + +namespace Microsoft.Extensions.AI; + +public class AzureAIInferenceChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetChatCompletionsClient() + ?.AsChatClient(Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_CHAT_MODEL") ?? "gpt-4o-mini"); + + public override Task CompleteStreamingAsync_UsageDataAvailable() => + throw new SkipTestException("Azure.AI.Inference library doesn't currently surface streaming usage data."); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs new file mode 100644 index 00000000000..fd4bd11a96f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/AzureAIInferenceChatClientTests.cs @@ -0,0 +1,536 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Azure; +using Azure.AI.Inference; +using Azure.Core.Pipeline; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class AzureAIInferenceChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("chatCompletionsClient", () => new AzureAIInferenceChatClient(null!, "model")); + + ChatCompletionsClient client = new(new("http://somewhere"), new AzureKeyCredential("key")); + Assert.Throws("modelId", () => new AzureAIInferenceChatClient(client, " ")); + } + + [Fact] + public void AsChatClient_InvalidArgs_Throws() + { + Assert.Throws("chatCompletionsClient", () => ((ChatCompletionsClient)null!).AsChatClient("model")); + + ChatCompletionsClient client = new(new("http://somewhere"), new AzureKeyCredential("key")); + Assert.Throws("modelId", () => client.AsChatClient(" ")); + } + + [Fact] + public void AsChatClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + ChatCompletionsClient client = new(endpoint, new AzureKeyCredential("key")); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("AzureAIInference", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + ChatCompletionsClient client = new(new("http://localhost"), new AzureKeyCredential("key")); + IChatClient chatClient = client.AsChatClient("model"); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(chatClient, chatClient.GetService()); + + Assert.Same(client, chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(client, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + {"messages":[{"content":[{"text":"hello","type":"text"}],"role":"user"}],"max_tokens":10,"temperature":0.5,"model":"gpt-4o-mini"} + """; + + const string Output = """ + { + "id": "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + "object": "chat.completion", + "created": 1727888631, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 9, + "total_tokens": 17, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.CompletionId); + Assert.Equal("Hello! How can I assist you today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(8, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.OutputTokenCount); + Assert.Equal(17, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + {"messages":[{"content":[{"text":"hello","type":"text"}],"role":"user"}],"max_tokens":20,"temperature":0.5,"stream":true,"model":"gpt-4o-mini"} + """; + + const string Output = """ + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":8,"completion_tokens":9,"total_tokens":17,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_889_370); + Assert.Equal(12, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.Equal(i < 10 ? 1 : 0, updates[i].Contents.Count); + Assert.Equal(i < 10 ? null : ChatFinishReason.Stop, updates[i].FinishReason); + } + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "content": "You are a really nice friend.", + "role": "system" + }, + { + "content": [ + { + "text": "hello!", + "type": "text" + } + ], + "role": "user" + }, + { + "content": "hi, how are you?", + "role": "assistant" + }, + { + "content": [ + { + "text": "i\u0027m good. how are you?", + "type": "text" + } + ], + "role": "user" + } + ], + "temperature": 0.25, + "stop": [ + "great" + ], + "presence_penalty": 0.5, + "frequency_penalty": 0.75, + "model": "gpt-4o-mini", + "seed": 42 + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", + "object": "chat.completion", + "created": 1727894187, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I’m doing well, thank you! What’s on your mind today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 42, + "completion_tokens": 15, + "total_tokens": 57, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List messages = + [ + new(ChatRole.System, "You are a really nice friend."), + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42L }, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(42, response.Usage.InputTokenCount); + Assert.Equal(15, response.Usage.OutputTokenCount); + Assert.Equal(57, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "content": [ + { + "text": "How old is Alice?", + "type": "text" + } + ], + "role": "user" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "required": ["personName"], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + } + }, + "type": "function" + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADydKhrSKEBWJ8gy0KCIU74rN3Hmk", + "object": "chat.completion", + "created": 1727894702, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_8qbINM045wlmKZt9bVJgwAym", + "type": "function", + "function": { + "name": "GetPersonAge", + "arguments": "{\"personName\":\"Alice\"}" + } + } + ], + "refusal": null + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 61, + "completion_tokens": 16, + "total_tokens": 77, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); + Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(61, response.Usage.InputTokenCount); + Assert.Equal(16, response.Usage.OutputTokenCount); + Assert.Equal(77, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + [Fact] + public async Task FunctionCallContent_Streaming() + { + const string Input = """ + { + "messages": [ + { + "content": [ + { + "text": "How old is Alice?", + "type": "text" + } + ], + "role": "user" + } + ], + "stream": true, + "model": "gpt-4o-mini", + "tools": [ + { + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "required": ["personName"], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + } + }, + "type": "function" + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_F9ZaqPWo69u0urxAhVt8meDW","type":"function","function":{"name":"GetPersonAge","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"person"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Alice"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":61,"completion_tokens":16,"total_tokens":77,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + })) + { + updates.Add(update); + } + + Assert.Equal("", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_895_263); + Assert.Equal(10, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.Equal(i < 7 ? null : ChatFinishReason.ToolCalls, updates[i].FinishReason); + } + + FunctionCallContent fcc = Assert.IsType(Assert.Single(updates[updates.Count - 1].Contents)); + Assert.Equal("call_F9ZaqPWo69u0urxAhVt8meDW", fcc.CallId); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + private static IChatClient CreateChatClient(HttpClient httpClient, string modelId) => + new ChatCompletionsClient( + new("http://somewhere"), + new AzureKeyCredential("key"), + new ChatCompletionsClientOptions { Transport = new HttpClientTransport(httpClient) }) + .AsChatClient(modelId); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..4c4086e1157 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Azure; +using Azure.AI.Inference; + +namespace Microsoft.Extensions.AI; + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets an to use for testing, or null if the associated tests should be disabled. + public static ChatCompletionsClient? GetChatCompletionsClient() + { + string? apiKey = + Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_APIKEY") ?? + Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + + if (apiKey is not null) + { + string? endpoint = + Environment.GetEnvironmentVariable("AZURE_AI_INFERENCE_ENDPOINT") ?? + "https://api.openai.com/v1"; + + return new(new Uri(endpoint), new AzureKeyCredential(apiKey)); + } + + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj new file mode 100644 index 00000000000..d992413109b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.AzureAIInference.Tests/Microsoft.Extensions.AI.AzureAIInference.Tests.csproj @@ -0,0 +1,22 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.AzureAIInference + + + + true + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs new file mode 100644 index 00000000000..f538d1476b0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/BinaryEmbedding.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +internal sealed class BinaryEmbedding : Embedding +{ + public BinaryEmbedding(ReadOnlyMemory bits) + { + Bits = bits; + } + + public ReadOnlyMemory Bits { get; } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs new file mode 100644 index 00000000000..c2aaa0d086d --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingChatClient.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CallCountingChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) +{ + private int _callCount; + + public int CallCount => _callCount; + + public override Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + + public override IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.CompleteStreamingAsync(chatMessages, options, cancellationToken); + } +} + +internal static class CallCountingChatClientBuilderExtensions +{ + public static ChatClientBuilder UseCallCounting(this ChatClientBuilder builder) => + builder.Use(innerClient => new CallCountingChatClient(innerClient)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs new file mode 100644 index 00000000000..2930f94b6db --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/CallCountingEmbeddingGenerator.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1204 // Static elements should appear before instance elements +#pragma warning disable SA1402 // File may only contain a single type + +namespace Microsoft.Extensions.AI; + +internal sealed class CallCountingEmbeddingGenerator(IEmbeddingGenerator> innerGenerator) + : DelegatingEmbeddingGenerator>(innerGenerator) +{ + private int _callCount; + + public int CallCount => _callCount; + + public override Task>> GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default) + { + Interlocked.Increment(ref _callCount); + return base.GenerateAsync(values, options, cancellationToken); + } +} + +internal static class CallCountingEmbeddingGeneratorBuilderExtensions +{ + public static EmbeddingGeneratorBuilder> UseCallCounting( + this EmbeddingGeneratorBuilder> builder) => + builder.Use(innerGenerator => new CallCountingEmbeddingGenerator(innerGenerator)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs new file mode 100644 index 00000000000..50257544430 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -0,0 +1,650 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.TestUtilities; +using OpenTelemetry.Trace; +using Xunit; + +#pragma warning disable CA2000 // Dispose objects before losing scope +#pragma warning disable CA2214 // Do not call overridable methods in constructors + +namespace Microsoft.Extensions.AI; + +public abstract class ChatClientIntegrationTests : IDisposable +{ + private readonly IChatClient? _chatClient; + + protected ChatClientIntegrationTests() + { + _chatClient = CreateChatClient(); + } + + public void Dispose() + { + _chatClient?.Dispose(); + GC.SuppressFinalize(this); + } + + protected abstract IChatClient? CreateChatClient(); + + [ConditionalFact] + public virtual async Task CompleteAsync_SingleRequestMessage() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync("What's the biggest animal?"); + + Assert.Contains("whale", response.Message.Text, StringComparison.OrdinalIgnoreCase); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_MultipleRequestMessages() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync( + [ + new(ChatRole.User, "Pick a city, any city"), + new(ChatRole.Assistant, "Seattle"), + new(ChatRole.User, "And another one"), + new(ChatRole.Assistant, "Jakarta"), + new(ChatRole.User, "What continent are they each in?"), + ]); + + Assert.Single(response.Choices); + Assert.Contains("America", response.Message.Text); + Assert.Contains("Asia", response.Message.Text); + } + + [ConditionalFact] + public virtual async Task CompleteStreamingAsync_SingleStreamingResponseChoice() + { + SkipIfNotEnabled(); + + IList chatHistory = + [ + new(ChatRole.User, "Quote, word for word, Neil Armstrong's famous words.") + ]; + + StringBuilder sb = new(); + await foreach (var chunk in _chatClient.CompleteStreamingAsync(chatHistory)) + { + sb.Append(chunk.Text); + } + + string responseText = sb.ToString(); + Assert.Contains("one small step", responseText, StringComparison.OrdinalIgnoreCase); + Assert.Contains("one giant leap", responseText, StringComparison.OrdinalIgnoreCase); + + // The input list is left unaugmented. + Assert.Single(chatHistory); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_UsageDataAvailable() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync("Explain in 10 words how AI works"); + + Assert.Single(response.Choices); + Assert.True(response.Usage?.InputTokenCount > 1); + Assert.True(response.Usage?.OutputTokenCount > 1); + Assert.Equal(response.Usage?.InputTokenCount + response.Usage?.OutputTokenCount, response.Usage?.TotalTokenCount); + } + + [ConditionalFact] + public virtual async Task CompleteStreamingAsync_UsageDataAvailable() + { + SkipIfNotEnabled(); + + var response = _chatClient.CompleteStreamingAsync("Explain in 10 words how AI works"); + + List chunks = []; + await foreach (var chunk in response) + { + chunks.Add(chunk); + } + + Assert.True(chunks.Count > 1); + + UsageContent usage = chunks.SelectMany(c => c.Contents).OfType().Single(); + Assert.True(usage.Details.InputTokenCount > 1); + Assert.True(usage.Details.OutputTokenCount > 1); + Assert.Equal(usage.Details.InputTokenCount + usage.Details.OutputTokenCount, usage.Details.TotalTokenCount); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_Parameterless() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + int secretNumber = 42; + + var response = await chatClient.CompleteAsync("What is the current secret number?", new() + { + Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] + }); + + Assert.Single(response.Choices); + Assert.Contains(secretNumber.ToString(), response.Message.Text); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_NonStreaming() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = await chatClient.CompleteAsync("What is the result of SecretComputation on 42 and 84?", new() + { + Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] + }); + + Assert.Single(response.Choices); + Assert.Contains("3528", response.Message.Text); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_Streaming() + { + SkipIfNotEnabled(); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = chatClient.CompleteStreamingAsync("What is the result of SecretComputation on 42 and 84?", new() + { + Tools = [AIFunctionFactory.Create((int a, int b) => a * b, "SecretComputation")] + }); + + StringBuilder sb = new(); + await foreach (var chunk in response) + { + sb.Append(chunk.Text); + } + + Assert.Contains("3528", sb.ToString()); + } + + protected virtual bool SupportsParallelFunctionCalling => true; + + [ConditionalFact] + public virtual async Task FunctionInvocation_SupportsMultipleParallelRequests() + { + SkipIfNotEnabled(); + if (!SupportsParallelFunctionCalling) + { + throw new SkipTestException("Parallel function calling is not supported by this chat client"); + } + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + // The service/model isn't guaranteed to request two calls to GetPersonAge in the same turn, but it's common that it will. + var response = await chatClient.CompleteAsync("How much older is Elsa than Anna? Return the age difference as a single number.", new() + { + Tools = [AIFunctionFactory.Create((string personName) => + { + return personName switch + { + "Elsa" => 21, + "Anna" => 18, + _ => 30, + }; + }, "GetPersonAge")] + }); + + Assert.True( + Regex.IsMatch(response.Message.Text ?? "", @"\b(3|three)\b", RegexOptions.IgnoreCase), + $"Doesn't contain three: {response.Message.Text}"); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_RequireAny() + { + SkipIfNotEnabled(); + + int callCount = 0; + var tool = AIFunctionFactory.Create(() => + { + callCount++; + return 123; + }, "GetSecretNumber"); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + var response = await chatClient.CompleteAsync("Are birds real?", new() + { + Tools = [tool], + ToolMode = ChatToolMode.RequireAny, + }); + + Assert.Single(response.Choices); + Assert.True(callCount >= 1); + } + + [ConditionalFact] + public virtual async Task FunctionInvocation_RequireSpecific() + { + SkipIfNotEnabled(); + + bool shieldsUp = false; + var getSecretNumberTool = AIFunctionFactory.Create(() => 123, "GetSecretNumber"); + var shieldsUpTool = AIFunctionFactory.Create(() => shieldsUp = true, "ShieldsUp"); + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + + // Even though the user doesn't ask for the shields to be activated, verify that the tool is invoked + var response = await chatClient.CompleteAsync("What's the current secret number?", new() + { + Tools = [getSecretNumberTool, shieldsUpTool], + ToolMode = ChatToolMode.RequireSpecific(shieldsUpTool.Metadata.Name), + }); + + Assert.True(shieldsUp); + } + + [ConditionalFact] + public virtual async Task Caching_OutputVariesWithoutCaching() + { + SkipIfNotEnabled(); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + var firstResponse = await _chatClient.CompleteAsync([message]); + Assert.Single(firstResponse.Choices); + + var secondResponse = await _chatClient.CompleteAsync([message]); + Assert.NotEqual(firstResponse.Message.Text, secondResponse.Message.Text); + } + + [ConditionalFact] + public virtual async Task Caching_SamePromptResultsInCacheHit_NonStreaming() + { + SkipIfNotEnabled(); + + using var chatClient = new DistributedCachingChatClient( + _chatClient, + new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + var firstResponse = await chatClient.CompleteAsync([message]); + Assert.Single(firstResponse.Choices); + + // No matter what it said before, we should see identical output due to caching + for (int i = 0; i < 3; i++) + { + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(firstResponse.Message.Text, secondResponse.Message.Text); + } + + // ... but if the conversation differs, we should see different output + message.Text += "!"; + var thirdResponse = await chatClient.CompleteAsync([message]); + Assert.NotEqual(firstResponse.Message.Text, thirdResponse.Message.Text); + } + + [ConditionalFact] + public virtual async Task Caching_SamePromptResultsInCacheHit_Streaming() + { + SkipIfNotEnabled(); + + using var chatClient = new DistributedCachingChatClient( + _chatClient, + new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))); + + var message = new ChatMessage(ChatRole.User, "Pick a random number, uniformly distributed between 1 and 1000000"); + StringBuilder orig = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + orig.Append(update.Text); + } + + // No matter what it said before, we should see identical output due to caching + for (int i = 0; i < 3; i++) + { + StringBuilder second = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + second.Append(update.Text); + } + + Assert.Equal(orig.ToString(), second.ToString()); + } + + // ... but if the conversation differs, we should see different output + message.Text += "!"; + StringBuilder third = new(); + await foreach (var update in chatClient.CompleteStreamingAsync([message])) + { + third.Append(update.Text); + } + + Assert.NotEqual(orig.ToString(), third.ToString()); + } + + [ConditionalFact] + public virtual async Task Caching_BeforeFunctionInvocation_AvoidsExtraCalls() + { + SkipIfNotEnabled(); + + int functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return $"{100 + functionCallCount} degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseFunctionInvocation() + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("101", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(2, llmCallCount!.CallCount); + + // Second call doesn't execute the function or call the LLM, but rather just returns the cached result + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputUnchangedAsync() + { + SkipIfNotEnabled(); + + // This means that if the function call produces the same result, we can avoid calling the LLM + // whereas if the function call produces a different result, we do call the LLM + + var functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return "58 degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseFunctionInvocation() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("58", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + + // Second time, the calls to the LLM don't happen, but the function is called again + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Equal(response.Message.Text, secondResponse.Message.Text); + Assert.Equal(2, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Caching_AfterFunctionInvocation_FunctionOutputChangedAsync() + { + SkipIfNotEnabled(); + + // This means that if the function call produces the same result, we can avoid calling the LLM + // whereas if the function call produces a different result, we do call the LLM + + var functionCallCount = 0; + var getTemperature = AIFunctionFactory.Create([Description("Gets the current temperature")] () => + { + functionCallCount++; + return $"{80 + functionCallCount} degrees celsius"; + }, "GetTemperature"); + + // First call executes the function and calls the LLM + using var chatClient = new ChatClientBuilder() + .UseChatOptions(_ => new() { Tools = [getTemperature] }) + .UseFunctionInvocation() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateChatClient()!); + + var llmCallCount = chatClient.GetService(); + var message = new ChatMessage(ChatRole.User, "What is the temperature?"); + var response = await chatClient.CompleteAsync([message]); + Assert.Contains("81", response.Message.Text); + + // First LLM call tells us to call the function, second deals with the result + Assert.Equal(1, functionCallCount); + Assert.Equal(2, llmCallCount!.CallCount); + + // Second time, the first call to the LLM don't happen, but the function is called again, + // and since its output now differs, we no longer hit the cache so the second LLM call does happen + var secondResponse = await chatClient.CompleteAsync([message]); + Assert.Contains("82", secondResponse.Message.Text); + Assert.Equal(2, functionCallCount); + Assert.Equal(3, llmCallCount!.CallCount); + } + + [ConditionalFact] + public virtual async Task Logging_LogsCalls_NonStreaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new LoggingChatClient(CreateChatClient()!, logger); + + await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]); + + Assert.Collection(logger.Entries, + entry => Assert.Contains("What\\u0027s the biggest animal?", entry.Message), + entry => Assert.Contains("whale", entry.Message)); + } + + [ConditionalFact] + public virtual async Task Logging_LogsCalls_Streaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new LoggingChatClient(CreateChatClient()!, logger); + + await foreach (var update in chatClient.CompleteStreamingAsync("What's the biggest animal?")) + { + // Do nothing with the updates + } + + Assert.Contains(logger.Entries, e => e.Message.Contains("What\\u0027s the biggest animal?")); + Assert.Contains(logger.Entries, e => e.Message.Contains("whale")); + } + + [ConditionalFact] + public virtual async Task Logging_LogsFunctionCalls_NonStreaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new FunctionInvokingChatClient( + new LoggingChatClient(CreateChatClient()!, logger)); + + int secretNumber = 42; + await chatClient.CompleteAsync( + "What is the current secret number?", + new ChatOptions { Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] }); + + Assert.Collection(logger.Entries, + entry => Assert.Contains("What is the current secret number?", entry.Message), + entry => Assert.Contains("\"name\":\"GetSecretNumber\"", entry.Message), + entry => Assert.Contains($"\"result\":{secretNumber}", entry.Message), + entry => Assert.Contains(secretNumber.ToString(), entry.Message)); + } + + [ConditionalFact] + public virtual async Task Logging_LogsFunctionCalls_Streaming() + { + SkipIfNotEnabled(); + + CapturingLogger logger = new(); + + using var chatClient = + new FunctionInvokingChatClient( + new LoggingChatClient(CreateChatClient()!, logger)); + + int secretNumber = 42; + await foreach (var update in chatClient.CompleteStreamingAsync( + "What is the current secret number?", + new ChatOptions { Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")] })) + { + // Do nothing with the updates + } + + Assert.Contains(logger.Entries, e => e.Message.Contains("What is the current secret number?")); + Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\":\"GetSecretNumber\"")); + Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\":{secretNumber}")); + } + + [ConditionalFact] + public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics() + { + SkipIfNotEnabled(); + + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => { instance.EnableSensitiveData = true; }) + .Use(CreateChatClient()!); + + var response = await chatClient.CompleteAsync([new(ChatRole.User, "What's the biggest animal?")]); + + var activity = Assert.Single(activities); + Assert.StartsWith("chat.completions", activity.DisplayName); + Assert.StartsWith("http", (string)activity.GetTagItem("server.address")!); + Assert.Equal(chatClient.Metadata.ProviderUri?.Port, (int)activity.GetTagItem("server.port")!); + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.input_tokens")!); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.output_tokens")!); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutput() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + Who is described in the following sentence? + Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + """); + + Assert.Equal("Jimbo Smith", response.Result.FullName); + Assert.Equal(35, response.Result.AgeInYears); + Assert.Contains("Cardiff", response.Result.HomeTown); + Assert.Equal(JobType.Programmer, response.Result.Job); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutput_WithFunctions() + { + SkipIfNotEnabled(); + + var expectedPerson = new Person + { + FullName = "Jimbo Smith", + AgeInYears = 35, + HomeTown = "Cardiff", + Job = JobType.Programmer, + }; + + using var chatClient = new FunctionInvokingChatClient(_chatClient); + var response = await chatClient.CompleteAsync( + "Who is person with ID 123?", new ChatOptions + { + Tools = [AIFunctionFactory.Create((int personId) => + { + Assert.Equal(123, personId); + return expectedPerson; + }, "GetPersonById")] + }); + + Assert.NotSame(expectedPerson, response.Result); + Assert.Equal(expectedPerson.FullName, response.Result.FullName); + Assert.Equal(expectedPerson.AgeInYears, response.Result.AgeInYears); + Assert.Equal(expectedPerson.HomeTown, response.Result.HomeTown); + Assert.Equal(expectedPerson.Job, response.Result.Job); + } + + private class Person + { +#pragma warning disable S1144, S3459 // Unassigned members should be removed + public string? FullName { get; set; } + public int AgeInYears { get; set; } + public string? HomeTown { get; set; } + public JobType Job { get; set; } +#pragma warning restore S1144, S3459 // Unused private types or members should be removed + } + + private enum JobType + { + Surgeon, + PopStar, + Programmer, + Unknown, + } + + [MemberNotNull(nameof(_chatClient))] + protected void SkipIfNotEnabled() + { + if (_chatClient is null) + { + throw new SkipTestException("Client is not enabled."); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..252427836e8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/EmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +#if NET +using System.Numerics.Tensors; +#endif +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.TestUtilities; +using OpenTelemetry.Trace; +using Xunit; + +#pragma warning disable CA2214 // Do not call overridable methods in constructors +#pragma warning disable S3967 // Multidimensional arrays should not be used + +namespace Microsoft.Extensions.AI; + +public abstract class EmbeddingGeneratorIntegrationTests : IDisposable +{ + private readonly IEmbeddingGenerator>? _embeddingGenerator; + + protected EmbeddingGeneratorIntegrationTests() + { + _embeddingGenerator = CreateEmbeddingGenerator(); + } + + public void Dispose() + { + _embeddingGenerator?.Dispose(); + GC.SuppressFinalize(this); + } + + protected abstract IEmbeddingGenerator>? CreateEmbeddingGenerator(); + + [ConditionalFact] + public virtual async Task GenerateEmbedding_CreatesEmbeddingSuccessfully() + { + SkipIfNotEnabled(); + + var embeddings = await _embeddingGenerator.GenerateAsync("Using AI with .NET"); + + Assert.NotNull(embeddings.Usage); + Assert.NotNull(embeddings.Usage.InputTokenCount); + Assert.NotNull(embeddings.Usage.TotalTokenCount); + Assert.Single(embeddings); + Assert.Equal(_embeddingGenerator.Metadata.ModelId, embeddings[0].ModelId); + Assert.NotEmpty(embeddings[0].Vector.ToArray()); + } + + [ConditionalFact] + public virtual async Task GenerateEmbeddings_CreatesEmbeddingsSuccessfully() + { + SkipIfNotEnabled(); + + var embeddings = await _embeddingGenerator.GenerateAsync([ + "Red", + "White", + "Blue", + ]); + + Assert.Equal(3, embeddings.Count); + Assert.NotNull(embeddings.Usage); + Assert.NotNull(embeddings.Usage.InputTokenCount); + Assert.NotNull(embeddings.Usage.TotalTokenCount); + Assert.All(embeddings, embedding => + { + Assert.Equal(_embeddingGenerator.Metadata.ModelId, embedding.ModelId); + Assert.NotEmpty(embedding.Vector.ToArray()); + }); + } + + [ConditionalFact] + public virtual async Task Caching_SameOutputsForSameInput() + { + SkipIfNotEnabled(); + + using var generator = new EmbeddingGeneratorBuilder>() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .UseCallCounting() + .Use(CreateEmbeddingGenerator()!); + + string input = "Red, White, and Blue"; + var embedding1 = await generator.GenerateAsync(input); + var embedding2 = await generator.GenerateAsync(input); + var embedding3 = await generator.GenerateAsync(input + "... and Green"); + var embedding4 = await generator.GenerateAsync(input); + + var callCounter = generator.GetService(); + Assert.NotNull(callCounter); + + Assert.Equal(2, callCounter.CallCount); + } + + [ConditionalFact] + public virtual async Task OpenTelemetry_CanEmitTracesAndMetrics() + { + SkipIfNotEnabled(); + + string sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + var embeddingGenerator = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry(sourceName) + .Use(CreateEmbeddingGenerator()!); + + _ = await embeddingGenerator.GenerateAsync("Hello, world!"); + + Assert.Single(activities); + var activity = activities.Single(); + Assert.StartsWith("embedding", activity.DisplayName); + Assert.StartsWith("http", (string)activity.GetTagItem("server.address")!); + Assert.Equal(embeddingGenerator.Metadata.ProviderUri?.Port, (int)activity.GetTagItem("server.port")!); + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + Assert.NotEqual(0, (int)activity.GetTagItem("gen_ai.response.input_tokens")!); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + +#if NET + [ConditionalFact] + public async Task Quantization_Binary_EmbeddingsCompareSuccessfully() + { + SkipIfNotEnabled(); + + using IEmbeddingGenerator generator = + new QuantizationEmbeddingGenerator( + CreateEmbeddingGenerator()!); + + var embeddings = await generator.GenerateAsync(["dog", "cat", "fork", "spoon"]); + Assert.Equal(4, embeddings.Count); + + long[,] distances = new long[embeddings.Count, embeddings.Count]; + for (int i = 0; i < embeddings.Count; i++) + { + for (int j = 0; j < embeddings.Count; j++) + { + distances[i, j] = TensorPrimitives.HammingBitDistance(embeddings[i].Bits.Span, embeddings[j].Bits.Span); + } + } + + for (int i = 0; i < embeddings.Count; i++) + { + Assert.Equal(0, distances[i, i]); + } + + Assert.True(distances[0, 1] < distances[0, 2]); + Assert.True(distances[0, 1] < distances[0, 3]); + Assert.True(distances[0, 1] < distances[1, 2]); + Assert.True(distances[0, 1] < distances[1, 3]); + + Assert.True(distances[2, 3] < distances[0, 2]); + Assert.True(distances[2, 3] < distances[0, 3]); + Assert.True(distances[2, 3] < distances[1, 2]); + Assert.True(distances[2, 3] < distances[1, 3]); + } + + [ConditionalFact] + public async Task Quantization_Half_EmbeddingsCompareSuccessfully() + { + SkipIfNotEnabled(); + + using IEmbeddingGenerator> generator = + new QuantizationEmbeddingGenerator( + CreateEmbeddingGenerator()!); + + var embeddings = await generator.GenerateAsync(["dog", "cat", "fork", "spoon"]); + Assert.Equal(4, embeddings.Count); + + var distances = new Half[embeddings.Count, embeddings.Count]; + for (int i = 0; i < embeddings.Count; i++) + { + for (int j = 0; j < embeddings.Count; j++) + { + distances[i, j] = TensorPrimitives.CosineSimilarity(embeddings[i].Vector.Span, embeddings[j].Vector.Span); + } + } + + for (int i = 0; i < embeddings.Count; i++) + { + Assert.Equal(1.0, (double)distances[i, i], 0.001); + } + + Assert.True(distances[0, 1] > distances[0, 2]); + Assert.True(distances[0, 1] > distances[0, 3]); + Assert.True(distances[0, 1] > distances[1, 2]); + Assert.True(distances[0, 1] > distances[1, 3]); + + Assert.True(distances[2, 3] > distances[0, 2]); + Assert.True(distances[2, 3] > distances[0, 3]); + Assert.True(distances[2, 3] > distances[1, 2]); + Assert.True(distances[2, 3] > distances[1, 3]); + } +#endif + + [MemberNotNull(nameof(_embeddingGenerator))] + protected void SkipIfNotEnabled() + { + if (_embeddingGenerator is null) + { + throw new SkipTestException("Generator is not enabled."); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj new file mode 100644 index 00000000000..e38ccd3268b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/Microsoft.Extensions.AI.Integration.Tests.csproj @@ -0,0 +1,37 @@ + + + Microsoft.Extensions.AI + Opt-in integration tests for Microsoft.Extensions.AI. + + + + $(NoWarn);CA1063;CA1861;SA1130;VSTHRD003 + true + + + + true + true + true + + + + + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs new file mode 100644 index 00000000000..150c984ff86 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/PromptBasedFunctionCallingChatClient.cs @@ -0,0 +1,228 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; +using System.Threading; +using System.Threading.Tasks; + +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable S1144 // Unused private types or members should be removed +#pragma warning disable S3459 // Unassigned members should be removed + +namespace Microsoft.Extensions.AI; + +// This isn't a feature we're planning to ship, but demonstrates how custom clients can +// layer in non-trivial functionality. In this case we're able to upgrade non-function-calling models +// to behaving as if they do support function calling. +// +// In practice: +// - For llama3:8b or mistral:7b, this works fairly reliably, at least when it only needs to +// make a single function call with a constrained set of args. +// - For smaller models like phi3:mini, it works only on a more occasional basis (e.g., if there's +// only one function defined, and it takes no arguments, but is very hit-and-miss beyond that). + +internal sealed class PromptBasedFunctionCallingChatClient(IChatClient innerClient) + : DelegatingChatClient(innerClient) +{ + private const string MessageIntro = "You are an AI model with function calling capabilities. Call one or more functions if they are relevant to the user's query."; + + private static readonly JsonSerializerOptions _jsonOptions = new(JsonSerializerDefaults.Web) + { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + + public override async Task CompleteAsync(IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + // Our goal is to convert tools into a prompt describing them, then to detect tool calls in the + // response and convert those into FunctionCallContent. + if (options?.Tools is { Count: > 0 }) + { + AddOrUpdateToolPrompt(chatMessages, options.Tools); + options = options.Clone(); + options.Tools = null; + + options.StopSequences ??= []; + if (!options.StopSequences.Contains("")) + { + options.StopSequences.Add(""); + } + + // Since the point of this client is to avoid relying on the underlying model having + // native tool call support, we have to replace any "tool" or "toolcall" messages with + // "user" or "assistant" ones. + foreach (var message in chatMessages) + { + for (var itemIndex = 0; itemIndex < message.Contents.Count; itemIndex++) + { + if (message.Contents[itemIndex] is FunctionResultContent frc) + { + var toolCallResultJson = JsonSerializer.Serialize(new ToolCallResult { Id = frc.CallId, Result = frc.Result }, _jsonOptions); + message.Role = ChatRole.User; + message.Contents[itemIndex] = new TextContent( + $"{toolCallResultJson}"); + } + else if (message.Contents[itemIndex] is FunctionCallContent fcc) + { + var toolCallJson = JsonSerializer.Serialize(new { fcc.CallId, fcc.Name, fcc.Arguments }, _jsonOptions); + message.Role = ChatRole.Assistant; + message.Contents[itemIndex] = new TextContent( + $"{toolCallJson}"); + } + } + } + } + + var result = await base.CompleteAsync(chatMessages, options, cancellationToken); + + if (result.Choices.FirstOrDefault()?.Text is { } content && content.IndexOf("", StringComparison.Ordinal) is int startPos + && startPos >= 0) + { + var message = result.Choices.First(); + var contentItem = message.Contents.SingleOrDefault(); + content = content.Substring(startPos); + + foreach (var toolCallJson in content.Split([""], StringSplitOptions.None)) + { + var toolCall = toolCallJson.Trim(); + if (toolCall.Length == 0) + { + continue; + } + + var endPos = toolCall.IndexOf(" 0) + { + toolCall = toolCall.Substring(0, endPos); + try + { + var toolCallParsed = JsonSerializer.Deserialize(toolCall, _jsonOptions); + if (!string.IsNullOrEmpty(toolCallParsed?.Name)) + { + if (toolCallParsed!.Arguments is not null) + { + ParseArguments(toolCallParsed.Arguments); + } + + var id = Guid.NewGuid().ToString().Substring(0, 6); + message.Contents.Add(new FunctionCallContent(id, toolCallParsed.Name!, toolCallParsed.Arguments is { } args ? new ReadOnlyDictionary(args) : null)); + + if (contentItem is not null) + { + message.Contents.Remove(contentItem); + } + } + } + catch (JsonException) + { + // Ignore invalid tool calls + } + } + } + } + + return result; + } + + private static void ParseArguments(IDictionary arguments) + { + // This is a simple implementation. A more robust answer is to use other schema information given by + // the AIFunction here, as for example is done in OpenAIChatClient. + foreach (var kvp in arguments.ToArray()) + { + if (kvp.Value is JsonElement jsonElement) + { + arguments[kvp.Key] = jsonElement.ValueKind switch + { + JsonValueKind.String => jsonElement.GetString(), + JsonValueKind.Number => jsonElement.GetDouble(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => jsonElement.ToString() + }; + } + } + } + + private static void AddOrUpdateToolPrompt(IList chatMessages, IList tools) + { + var existingToolPrompt = chatMessages.FirstOrDefault(c => c.Text?.StartsWith(MessageIntro, StringComparison.Ordinal) is true); + if (existingToolPrompt is null) + { + existingToolPrompt = new ChatMessage(ChatRole.System, (string?)null); + chatMessages.Insert(0, existingToolPrompt); + } + + var toolDescriptorsJson = JsonSerializer.Serialize(tools.OfType().Select(ToToolDescriptor), _jsonOptions); + existingToolPrompt.Text = $$""" + {{MessageIntro}} + + For each function call, return a JSON object with the function name and arguments within XML tags + as follows: + + {"name": "tool_name", "arguments": { argname1: argval1, argname2: argval2, ... } } + + Note that the contents of MUST be a valid JSON object, with no other text. + + Once you receive the result as a JSON object within XML tags, use it to + answer the user's question without repeating the same tool call. + + Here are the available tools: + {{toolDescriptorsJson}} + """; + } + + private static ToolDescriptor ToToolDescriptor(AIFunction tool) => new() + { + Name = tool.Metadata.Name, + Description = tool.Metadata.Description, + Arguments = tool.Metadata.Parameters.ToDictionary( + p => p.Name, + p => new ToolParameterDescriptor + { + Type = p.ParameterType?.Name, + Description = p.Description, + Enum = p.ParameterType?.IsEnum == true ? Enum.GetNames(p.ParameterType) : null, + Required = p.IsRequired, + }), + }; + + private sealed class ToolDescriptor + { + public string? Name { get; set; } + public string? Description { get; set; } + public IDictionary? Arguments { get; set; } + } + + private sealed class ToolParameterDescriptor + { + public string? Type { get; set; } + public string? Description { get; set; } + public bool? Required { get; set; } + public string[]? Enum { get; set; } + } + + private sealed class ToolCall + { + public string? Id { get; set; } + public string? Name { get; set; } + public IDictionary? Arguments { get; set; } + } + + private sealed class ToolCallResult + { + public string? Id { get; set; } + public object? Result { get; set; } + } +} + +public static class PromptBasedFunctionCallingChatClientExtensions +{ + public static ChatClientBuilder UsePromptBasedFunctionCalling(this ChatClientBuilder builder) + => builder.Use(innerClient => new PromptBasedFunctionCallingChatClient(innerClient)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs new file mode 100644 index 00000000000..90032f16434 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/QuantizationEmbeddingGenerator.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +#if NET +using System.Numerics.Tensors; +#endif +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.AI; + +internal sealed class QuantizationEmbeddingGenerator : + IEmbeddingGenerator +#if NET + , IEmbeddingGenerator> +#endif +{ + private readonly IEmbeddingGenerator> _floatService; + + public QuantizationEmbeddingGenerator(IEmbeddingGenerator> floatService) + { + _floatService = floatService; + } + + public EmbeddingGeneratorMetadata Metadata => _floatService.Metadata; + + void IDisposable.Dispose() => _floatService.Dispose(); + + public TService? GetService(object? key = null) + where TService : class => + key is null && this is TService ? (TService?)(object)this : + _floatService.GetService(key); + + async Task> IEmbeddingGenerator.GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + var embeddings = await _floatService.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + return new(from e in embeddings select QuantizeToBinary(e)) + { + Usage = embeddings.Usage, + AdditionalProperties = embeddings.AdditionalProperties, + }; + } + + private static BinaryEmbedding QuantizeToBinary(Embedding embedding) + { + ReadOnlySpan vector = embedding.Vector.Span; + + var result = new byte[(int)Math.Ceiling(vector.Length / 8.0)]; + for (int i = 0; i < vector.Length; i++) + { + if (vector[i] > 0) + { + result[i / 8] |= (byte)(1 << (i % 8)); + } + } + + return new(result) + { + CreatedAt = embedding.CreatedAt, + ModelId = embedding.ModelId, + AdditionalProperties = embedding.AdditionalProperties, + }; + } + +#if NET + async Task>> IEmbeddingGenerator>.GenerateAsync( + IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + var embeddings = await _floatService.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false); + return new(from e in embeddings select QuantizeToHalf(e)) + { + Usage = embeddings.Usage, + AdditionalProperties = embeddings.AdditionalProperties, + }; + } + + private static Embedding QuantizeToHalf(Embedding embedding) + { + ReadOnlySpan vector = embedding.Vector.Span; + var result = new Half[vector.Length]; + TensorPrimitives.ConvertToHalf(vector, result); + return new(result) + { + CreatedAt = embedding.CreatedAt, + ModelId = embedding.ModelId, + AdditionalProperties = embedding.AdditionalProperties, + }; + } +#endif +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs new file mode 100644 index 00000000000..0c436f7ccb5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ReducingChatClientTests.cs @@ -0,0 +1,201 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.ML.Tokenizers; +using Microsoft.Shared.Diagnostics; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + +namespace Microsoft.Extensions.AI; + +/// Provides an example of a custom for reducing chat message lists. +public class ReducingChatClientTests +{ + private static readonly Tokenizer _gpt4oTokenizer = TiktokenTokenizer.CreateForModel("gpt-4o"); + + [Fact] + public async Task Reduction_LimitsMessagesBasedOnTokenLimit() + { + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Equal(2, messages.Count); + Assert.Collection(messages, + m => Assert.StartsWith("Golden retrievers are quite active", m.Text, StringComparison.Ordinal), + m => Assert.StartsWith("Are they good with kids?", m.Text, StringComparison.Ordinal)); + return Task.FromResult(new ChatCompletion([])); + } + }; + + using var client = new ChatClientBuilder() + .UseChatReducer(new TokenCountingChatReducer(_gpt4oTokenizer, 40)) + .Use(innerClient); + + List messages = + [ + new ChatMessage(ChatRole.User, "Hi there! Can you tell me about golden retrievers?"), + new ChatMessage(ChatRole.Assistant, "Of course! Golden retrievers are known for their friendly and tolerant attitudes. They're great family pets and are very intelligent and easy to train."), + new ChatMessage(ChatRole.User, "What kind of exercise do they need?"), + new ChatMessage(ChatRole.Assistant, "Golden retrievers are quite active and need regular exercise. Daily walks, playtime, and activities like fetching or swimming are great for them."), + new ChatMessage(ChatRole.User, "Are they good with kids?"), + ]; + + await client.CompleteAsync(messages); + + Assert.Equal(5, messages.Count); + } +} + +/// Provides an example of a chat client for reducing the size of a message list. +public sealed class ReducingChatClient : DelegatingChatClient +{ + private readonly IChatReducer _reducer; + private readonly bool _inPlace; + + /// Initializes a new instance of the class. + /// The inner client. + /// The reducer to be used by this instance. + /// + /// true if the should perform any modifications directly on the supplied list of messages; + /// false if it should instead create a new list when reduction is necessary. + /// + public ReducingChatClient(IChatClient innerClient, IChatReducer reducer, bool inPlace = false) + : base(innerClient) + { + _reducer = Throw.IfNull(reducer); + _inPlace = inPlace; + } + + /// + public override async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + + return await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + chatMessages = await GetChatMessagesToPropagate(chatMessages, cancellationToken).ConfigureAwait(false); + + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + { + yield return update; + } + } + + /// Runs the reducer and gets the chat message list to forward to the inner client. + private async Task> GetChatMessagesToPropagate(IList chatMessages, CancellationToken cancellationToken) => + await _reducer.ReduceAsync(chatMessages, _inPlace, cancellationToken).ConfigureAwait(false) ?? + chatMessages; +} + +/// Represents a reducer capable of shrinking the size of a list of chat messages. +public interface IChatReducer +{ + /// Reduces the size of a list of chat messages. + /// The messages. + /// true if the reducer should modify the provided list; false if a new list should be returned. + /// The to monitor for cancellation requests. The default is . + /// The new list of messages, or null if no reduction need be performed or was true. + Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken); +} + +/// Provides extensions for configuring instances. +public static class ReducingChatClientExtensions +{ + public static ChatClientBuilder UseChatReducer(this ChatClientBuilder builder, IChatReducer reducer, bool inPlace = false) + { + _ = Throw.IfNull(builder); + _ = Throw.IfNull(reducer); + + return builder.Use(innerClient => new ReducingChatClient(innerClient, reducer, inPlace)); + } +} + +/// An that culls the oldest messages once a certain token threshold is reached. +public sealed class TokenCountingChatReducer : IChatReducer +{ + private readonly Tokenizer _tokenizer; + private readonly int _tokenLimit; + + public TokenCountingChatReducer(Tokenizer tokenizer, int tokenLimit) + { + _tokenizer = Throw.IfNull(tokenizer); + _tokenLimit = Throw.IfLessThan(tokenLimit, 1); + } + + public async Task?> ReduceAsync(IList chatMessages, bool inPlace, CancellationToken cancellationToken) + { + _ = Throw.IfNull(chatMessages); + + if (chatMessages.Count > 1) + { + int totalCount = CountTokens(chatMessages[chatMessages.Count - 1]); + + if (inPlace) + { + for (int i = chatMessages.Count - 2; i >= 0; i--) + { + totalCount += CountTokens(chatMessages[i]); + if (totalCount > _tokenLimit) + { + if (chatMessages is List list) + { + list.RemoveRange(0, i + 1); + } + else + { + for (int j = i; j >= 0; j--) + { + chatMessages.RemoveAt(j); + } + } + + break; + } + } + } + else + { + for (int i = chatMessages.Count - 2; i >= 0; i--) + { + totalCount += CountTokens(chatMessages[i]); + if (totalCount > _tokenLimit) + { + return chatMessages.Skip(i + 1).ToList(); + } + } + } + } + + return null; + } + + private int CountTokens(ChatMessage message) + { + int sum = 0; + foreach (AIContent content in message.Contents) + { + if ((content as TextContent)?.Text is string text) + { + sum += _tokenizer.CountTokens(text); + } + } + + return sum; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs new file mode 100644 index 00000000000..14ba68feb7a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/VerbatimHttpHandler.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +/// +/// An that checks the request body against an expected one +/// and sends back an expected response. +/// +public sealed class VerbatimHttpHandler(string expectedInput, string sentOutput) : HttpMessageHandler +{ + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + Assert.NotNull(request.Content); + + string? input = await request.Content +#if NET + .ReadAsStringAsync(cancellationToken).ConfigureAwait(false); +#else + .ReadAsStringAsync().ConfigureAwait(false); +#endif + + Assert.NotNull(input); + Assert.Equal(RemoveWhiteSpace(expectedInput), RemoveWhiteSpace(input)); + + return new() { Content = new StringContent(sentOutput) }; + } + + public static string? RemoveWhiteSpace(string? text) => + text is null ? null : + Regex.Replace(text, @"\s*", string.Empty); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..d25d750ce37 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +#pragma warning disable S125 // Sections of code should not be commented out + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets a to use for testing, or null if the associated tests should be disabled. + public static Uri? GetOllamaUri() + { + // return new Uri("http://localhost:11434"); + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj new file mode 100644 index 00000000000..5db789e3b6b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/Microsoft.Extensions.AI.Ollama.Tests.csproj @@ -0,0 +1,22 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.Ollama + + + + true + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs new file mode 100644 index 00000000000..891378c0e86 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientIntegrationTests.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.TestUtilities; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class OllamaChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ? + new OllamaChatClient(endpoint, "llama3.1") : + null; + + public override Task FunctionInvocation_AutomaticallyInvokeFunction_WithParameters_Streaming() => + throw new SkipTestException("Ollama does not currently support function invocation with streaming."); + + public override Task Logging_LogsFunctionCalls_Streaming() => + throw new SkipTestException("Ollama does not currently support function invocation with streaming."); + + public override Task FunctionInvocation_RequireAny() => + throw new SkipTestException("Ollama does not currently support requiring function invocation."); + + public override Task FunctionInvocation_RequireSpecific() => + throw new SkipTestException("Ollama does not currently support requiring function invocation."); + + [ConditionalFact] + public async Task PromptBasedFunctionCalling_NoArgs() + { + SkipIfNotEnabled(); + + using var chatClient = new ChatClientBuilder() + .UseFunctionInvocation() + .UsePromptBasedFunctionCalling() + .Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient)) + .Use(CreateChatClient()!); + + var secretNumber = 42; + var response = await chatClient.CompleteAsync("What is the current secret number? Answer with digits only.", new ChatOptions + { + ModelId = "llama3:8b", + Tools = [AIFunctionFactory.Create(() => secretNumber, "GetSecretNumber")], + Temperature = 0, + AdditionalProperties = new() { ["seed"] = 0L }, + }); + + Assert.Single(response.Choices); + Assert.Contains(secretNumber.ToString(), response.Message.Text); + } + + [ConditionalFact] + public async Task PromptBasedFunctionCalling_WithArgs() + { + SkipIfNotEnabled(); + + using var chatClient = new ChatClientBuilder() + .UseFunctionInvocation() + .UsePromptBasedFunctionCalling() + .Use(innerClient => new AssertNoToolsDefinedChatClient(innerClient)) + .Use(CreateChatClient()!); + + var stockPriceTool = AIFunctionFactory.Create([Description("Returns the stock price for a given ticker symbol")] ( + [Description("The ticker symbol")] string symbol, + [Description("The currency code such as USD or JPY")] string currency) => + { + Assert.Equal("MSFT", symbol); + Assert.Equal("GBP", currency); + return 999; + }, "GetStockPrice"); + + var didCallIrrelevantTool = false; + var irrelevantTool = AIFunctionFactory.Create(() => { didCallIrrelevantTool = true; return 123; }, "GetSecretNumber"); + + var response = await chatClient.CompleteAsync("What's the stock price for Microsoft in British pounds?", new ChatOptions + { + Tools = [stockPriceTool, irrelevantTool], + Temperature = 0, + AdditionalProperties = new() { ["seed"] = 0L }, + }); + + Assert.Single(response.Choices); + Assert.Contains("999", response.Message.Text); + Assert.False(didCallIrrelevantTool); + } + + private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient) + { + public override Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Assert.Null(options?.Tools); + return base.CompleteAsync(chatMessages, options, cancellationToken); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs new file mode 100644 index 00000000000..b09947337ed --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaChatClientTests.cs @@ -0,0 +1,464 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OllamaChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("endpoint", () => new OllamaChatClient(null!)); + Assert.Throws("modelId", () => new OllamaChatClient(new("http://localhost"), " ")); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + using OllamaChatClient client = new(new("http://localhost")); + + Assert.Same(client, client.GetService()); + Assert.Same(client, client.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(client); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(client, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public void AsChatClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + using IChatClient chatClient = new OllamaChatClient(endpoint, model); + Assert.Equal("ollama", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + { + "model":"llama3.1", + "messages":[{"role":"user","content":"hello"}], + "stream":false, + "options":{"num_predict":10,"temperature":0.5} + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T15:46:10.5248793Z", + "message": { + "role": "assistant", + "content": "Hello! How are you today? Is there something" + }, + "done_reason": "length", + "done": true, + "total_duration": 22186844400, + "load_duration": 17947219100, + "prompt_eval_count": 11, + "prompt_eval_duration": 1953805000, + "eval_count": 10, + "eval_duration": 2277274000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using OllamaChatClient client = new(new("http://localhost:11434"), "llama3.1", httpClient); + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("Hello! How are you today? Is there something", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T15:46:10.5248793Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Length, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(11, response.Usage.InputTokenCount); + Assert.Equal(10, response.Usage.OutputTokenCount); + Assert.Equal(21, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + { + "model":"llama3.1", + "messages":[{"role":"user","content":"hello"}], + "stream":true, + "options":{"num_predict":20,"temperature":0.5} + } + """; + + const string Output = """ + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.4965315Z","message":{"role":"assistant","content":"Hello"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.763058Z","message":{"role":"assistant","content":"!"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:20.9751134Z","message":{"role":"assistant","content":" How"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.1788125Z","message":{"role":"assistant","content":" are"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.3883171Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.5912498Z","message":{"role":"assistant","content":" today"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:21.7968039Z","message":{"role":"assistant","content":"?"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.0034152Z","message":{"role":"assistant","content":" Is"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.1931196Z","message":{"role":"assistant","content":" there"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.3827484Z","message":{"role":"assistant","content":" something"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.5659027Z","message":{"role":"assistant","content":" I"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.7488871Z","message":{"role":"assistant","content":" can"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:22.9339881Z","message":{"role":"assistant","content":" help"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.1201564Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.303447Z","message":{"role":"assistant","content":" with"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.4964909Z","message":{"role":"assistant","content":" or"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.6837816Z","message":{"role":"assistant","content":" would"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:23.8723142Z","message":{"role":"assistant","content":" you"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.064613Z","message":{"role":"assistant","content":" like"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.2504498Z","message":{"role":"assistant","content":" to"},"done":false} + {"model":"llama3.1","created_at":"2024-10-01T16:15:24.2514508Z","message":{"role":"assistant","content":""},"done_reason":"length", "done":true,"total_duration":11912402900,"load_duration":6824559200,"prompt_eval_count":11,"prompt_eval_duration":1329601000,"eval_count":20,"eval_duration":3754262000} + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal(21, updates.Count); + + DateTimeOffset[] createdAts = Regex.Matches(Output, @"2024.*?Z").Cast().Select(m => DateTimeOffset.Parse(m.Value)).ToArray(); + + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal(i < updates.Count - 1 ? 1 : 2, updates[i].Contents.Count); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.All(updates[i].Contents, u => Assert.Equal("llama3.1", u.ModelId)); + Assert.Equal(createdAts[i], updates[i].CreatedAt); + Assert.Equal(i < updates.Count - 1 ? null : ChatFinishReason.Length, updates[i].FinishReason); + } + + Assert.Equal("Hello! How are you today? Is there something I can help you with or would you like to", string.Concat(updates.Select(u => u.Text))); + Assert.Equal(2, updates[updates.Count - 1].Contents.Count); + Assert.IsType(updates[updates.Count - 1].Contents[0]); + UsageContent usage = Assert.IsType(updates[updates.Count - 1].Contents[1]); + Assert.Equal(11, usage.Details.InputTokenCount); + Assert.Equal(20, usage.Details.OutputTokenCount); + Assert.Equal(31, usage.Details.TotalTokenCount); + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "hello!" + }, + { + "role": "assistant", + "content": "hi, how are you?" + }, + { + "role": "user", + "content": "i\u0027m good. how are you?" + } + ], + "stream": false, + "options": { + "frequency_penalty": 0.75, + "presence_penalty": 0.5, + "seed": 42, + "stop": ["great"], + "temperature": 0.25 + } + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T17:18:46.308987Z", + "message": { + "role": "assistant", + "content": "I'm just a computer program, so I don't have feelings or emotions like humans do, but I'm functioning properly and ready to help with any questions or tasks you may have! How about we chat about something in particular or just shoot the breeze? Your choice!" + }, + "done_reason": "stop", + "done": true, + "total_duration": 23229369000, + "load_duration": 7724086300, + "prompt_eval_count": 36, + "prompt_eval_duration": 4245660000, + "eval_count": 55, + "eval_duration": 11256470000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), httpClient: httpClient); + + List messages = + [ + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + ModelId = "llama3.1", + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42 }, + }); + Assert.NotNull(response); + + Assert.Equal( + VerbatimHttpHandler.RemoveWhiteSpace(""" + I'm just a computer program, so I don't have feelings or emotions like humans do, + but I'm functioning properly and ready to help with any questions or tasks you may have! + How about we chat about something in particular or just shoot the breeze ? Your choice! + """), + VerbatimHttpHandler.RemoveWhiteSpace(response.Message.Text)); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T17:18:46.308987Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(36, response.Usage.InputTokenCount); + Assert.Equal(55, response.Usage.OutputTokenCount); + Assert.Equal(91, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + }, + "required": ["personName"] + } + } + } + ] + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T18:48:30.2669578Z", + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "GetPersonAge", + "arguments": { + "personName": "Alice" + } + } + } + ] + }, + "done_reason": "stop", + "done": true, + "total_duration": 27351311300, + "load_duration": 8041538400, + "prompt_eval_count": 170, + "prompt_eval_duration": 16078776000, + "eval_count": 19, + "eval_duration": 3227962000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler) { Timeout = Timeout.InfiniteTimeSpan }; + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient) + { + ToolCallJsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T18:48:30.2669578Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(170, response.Usage.InputTokenCount); + Assert.Equal(19, response.Usage.OutputTokenCount); + Assert.Equal(189, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + } + + [Fact] + public async Task FunctionResultContent_NonStreaming() + { + const string Input = """ + { + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + }, + { + "role": "assistant", + "content": "{\u0022call_id\u0022:\u0022abcd1234\u0022,\u0022name\u0022:\u0022GetPersonAge\u0022,\u0022arguments\u0022:{\u0022personName\u0022:\u0022Alice\u0022}}" + }, + { + "role": "tool", + "content": "{\u0022call_id\u0022:\u0022abcd1234\u0022,\u0022result\u0022:42}" + } + ], + "stream": false, + "tools": [ + { + "type": "function", + "function": { + "name": "GetPersonAge", + "description": "Gets the age of the specified person.", + "parameters": { + "type": "object", + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + }, + "required": ["personName"] + } + } + } + ] + } + """; + + const string Output = """ + { + "model": "llama3.1", + "created_at": "2024-10-01T20:57:20.157266Z", + "message": { + "role": "assistant", + "content": "Alice is 42 years old." + }, + "done_reason": "stop", + "done": true, + "total_duration": 20320666000, + "load_duration": 8159642600, + "prompt_eval_count": 106, + "prompt_eval_duration": 10846727000, + "eval_count": 8, + "eval_duration": 1307842000 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler) { Timeout = Timeout.InfiniteTimeSpan }; + using IChatClient client = new OllamaChatClient(new("http://localhost:11434"), "llama3.1", httpClient) + { + ToolCallJsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var response = await client.CompleteAsync( + [ + new(ChatRole.User, "How old is Alice?"), + new(ChatRole.Assistant, [new FunctionCallContent("abcd1234", "GetPersonAge", new Dictionary { ["personName"] = "Alice" })]), + new(ChatRole.Tool, [new FunctionResultContent("abcd1234", "GetPersonAge", 42)]), + ], + new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Equal("Alice is 42 years old.", response.Message.Text); + Assert.Equal("llama3.1", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.Parse("2024-10-01T20:57:20.157266Z"), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(106, response.Usage.InputTokenCount); + Assert.Equal(8, response.Usage.OutputTokenCount); + Assert.Equal(114, response.Usage.TotalTokenCount); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..4333cbde636 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OllamaEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegrationTests +{ + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ? + new OllamaEmbeddingGenerator(endpoint, "all-minilm") : + null; +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..205398c9a1c --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/OllamaEmbeddingGeneratorTests.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OllamaEmbeddingGeneratorTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("endpoint", () => new OllamaEmbeddingGenerator(null!)); + Assert.Throws("modelId", () => new OllamaEmbeddingGenerator(new("http://localhost"), " ")); + } + + [Fact] + public void GetService_SuccessfullyReturnsUnderlyingClient() + { + using OllamaEmbeddingGenerator generator = new(new("http://localhost")); + + Assert.Same(generator, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(generator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(generator, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public void AsEmbeddingGenerator_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + using IEmbeddingGenerator> chatClient = new OllamaEmbeddingGenerator(endpoint, model); + Assert.Equal("ollama", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public async Task GetEmbeddingsAsync_ExpectedRequestResponse() + { + const string Input = """ + {"model":"all-minilm","input":["hello, world!","red, white, blue"]} + """; + + const string Output = """ + { + "model":"all-minilm", + "embeddings":[ + [-0.038159743,0.032830726,-0.005602915,0.014363416,-0.04031945,-0.11662117,0.031710647,0.0019634133,-0.042558126,0.02925818,0.04254404,0.032178584,0.029820565,0.010947956,-0.05383333,-0.05031401,-0.023460664,0.010746779,-0.13776828,0.003972192,0.029283607,0.06673441,-0.015434976,0.048401773,-0.088160664,-0.012700827,0.04134059,0.0408592,-0.050058633,-0.058048956,0.048720006,0.068883754,0.0588242,0.008813041,-0.016036017,0.08514798,-0.07813561,-0.07740018,0.020856613,0.016228318,0.032506905,-0.053466275,-0.06220645,-0.024293836,0.0073994277,0.02410873,0.006477103,0.051144805,0.072868116,0.03460658,-0.0547553,-0.05937917,-0.007205277,0.020145971,0.035794333,0.005588114,0.010732389,-0.052755248,0.01006711,-0.008716047,-0.062840104,0.038445882,-0.013913384,0.07341423,0.09004691,-0.07995187,-0.016410379,0.044806693,-0.06886798,-0.03302609,-0.015488586,0.0112944925,0.03645402,0.06637969,-0.054364193,0.008732196,0.012049053,-0.038111813,0.006928739,0.05113517,0.07739711,-0.12295967,0.016389083,0.049567502,0.03162499,-0.039604694,0.0016613991,0.009564599,-0.03268798,-0.033994347,-0.13328508,0.0072719813,-0.010261588,0.038570367,-0.093384996,-0.041716397,0.069951184,-0.02632818,-0.149702,0.13445856,0.037486482,0.052814852,0.045044158,0.018727085,0.05445453,0.01727433,-0.032474063,0.046129994,-0.046679277,-0.03058037,-0.0181755,-0.048695795,0.033057086,-0.0038555008,0.050006237,-0.05828653,-0.010029618,0.01062073,-0.040105496,-0.0015263702,0.060846698,-0.04557025,0.049251337,0.026121102,0.019804202,-0.0016694543,0.059516467,-6.525171e-33,0.06351319,0.0030810465,0.028928237,0.17336167,0.0029677018,0.027755935,-0.09513812,-0.031182382,0.026697554,-0.0107956175,0.023849761,0.02378595,-0.03121345,0.049473017,-0.02506533,0.101713106,-0.079133175,-0.0032418896,0.04290832,0.094838716,-0.06652884,0.0062877694,0.02221229,0.0700068,-0.007469806,-0.0017550732,0.027011596,-0.075321496,0.114022695,0.0085597,-0.023766534,-0.04693697,0.014437173,0.01987886,-0.0046902793,0.0013660098,-0.034307938,-0.054156985,-0.09417741,-0.028919358,-0.018871028,0.04574328,0.047602862,-0.0031305805,-0.033291575,-0.0135114025,0.051019657,0.031115327,0.015239397,0.05413997,-0.085031144,0.013366392,-0.04757861,0.07102588,-0.013105953,-0.0023799809,0.050322797,-0.041649505,-0.014187793,0.0324716,0.005401626,0.091307014,0.0044665188,-0.018263677,-0.015284639,-0.04634121,0.038754962,0.014709013,0.052040145,0.0017918312,-0.014979437,0.027103048,0.03117813,0.023749126,-0.004567645,0.03617759,0.06680814,-0.001835277,0.021281,-0.057563916,0.019137124,0.031450257,-0.018432263,-0.040860977,0.10391725,0.011970765,-0.014854915,-0.10521159,-0.012288272,-0.00041675335,-0.09510029,0.058300544,0.042590536,-0.025064372,-0.09454636,4.0064686e-33,0.13224861,0.0053342036,-0.033114634,-0.09096768,-0.031561732,-0.03395822,-0.07202013,0.12591493,-0.08332582,0.052816514,0.001065021,0.022002738,0.1040207,0.013038866,0.04092958,0.018689224,0.1142518,0.024801003,0.014596161,0.006195551,-0.011214642,-0.035760444,-0.037979998,0.011274433,-0.051305123,0.007884909,0.06734877,0.0033462204,-0.09284879,0.037033774,-0.022331867,0.039951596,-0.030730229,-0.011403805,-0.014458028,0.024968812,-0.097553216,-0.03536226,-0.037567392,-0.010149212,-0.06387594,0.025570663,0.02060328,0.037549157,-0.104355134,-0.02837097,-0.052078977,0.0128349,-0.05123587,-0.029060647,-0.09632806,-0.042301137,0.067175224,-0.030890828,-0.010358077,0.027408795,-0.028092034,0.010337195,0.04303845,0.022324203,0.00797792,0.056084383,0.040727936,0.092925824,0.01653155,-0.053750493,0.00046004262,0.050728552,0.04253214,-0.029197674,0.00926312,-0.010662153,-0.037244495,0.002277273,-0.030296732,0.07459592,0.002572513,-0.017561244,0.0028881067,0.03841156,0.007247727,0.045637112,0.039992437,0.014227117,-0.014297474,0.05854321,0.03632371,0.05527864,-0.02007574,-0.08043163,-0.030238612,-0.014929122,0.022335418,0.011954643,-0.06906099,-1.8807288e-8,-0.07850291,0.046684187,-0.023935271,0.063510746,0.024001691,0.0014455577,-0.09078209,-0.066868275,-0.0801402,0.005480386,0.053663295,0.10483363,-0.066864185,0.015531167,0.06711155,0.07081655,-0.031996343,0.020819444,-0.021926524,-0.0073062326,-0.010652819,0.0041180425,0.033138428,-0.0789938,0.03876969,-0.075220205,-0.015715994,0.0059789424,0.005140016,-0.06150612,0.041992374,0.09544083,-0.043187104,0.014401576,-0.10615426,-0.027936764,0.011047429,0.069572434,0.06690283,-0.074798405,-0.07852024,0.04276141,-0.034642085,-0.106051244,-0.03581038,0.051521253,0.06865896,-0.04999753,0.0154549,-0.06452052,-0.07598782,0.02603005,0.074413665,-0.012398757,0.13330704,0.07475513,0.051348723,0.02098748,-0.02679416,0.08896129,0.039944872,-0.041040305,0.031930625,0.018114654], + [0.007228383,-0.021804843,-0.07494023,-0.021707121,-0.021184582,0.09326986,0.10764054,-0.01918113,0.007439991,0.01367952,-0.034187328,-0.044076536,0.016042138,0.007507193,-0.016432272,0.025345335,0.010598066,-0.03832474,-0.14418823,-0.033625234,0.013156937,-0.0048872638,-0.08534306,-0.00003228713,-0.08900276,-0.00008128615,0.010332802,0.053303026,-0.050233904,-0.0879366,-0.064243905,-0.017168961,0.1284308,-0.015268303,-0.049664143,-0.07491954,0.021887481,0.015997978,-0.07967111,0.08744341,-0.039261423,-0.09904984,0.02936398,0.042995434,0.057036504,0.09063012,0.0000012311281,0.06120768,-0.050825767,-0.014443322,0.02879051,-0.002343813,-0.10176559,0.104563184,0.031316753,0.08251861,-0.041213628,-0.0217945,0.0649965,-0.011131547,0.018417398,-0.014460508,-0.05108664,0.11330918,0.01863208,0.006442521,-0.039408617,-0.03609412,-0.009156692,-0.0031261789,-0.010928502,-0.021108521,0.037411734,0.012443921,0.018142054,-0.0362644,0.058286663,-0.02733258,-0.052172586,-0.08320095,-0.07089281,-0.0970049,-0.048587535,0.055343032,0.048351917,0.06892102,-0.039993215,0.06344781,-0.084417015,0.003692423,-0.059397053,0.08186814,0.0029228176,-0.010551637,-0.058019258,0.092128515,0.06862907,-0.06558893,0.021121018,0.079212844,0.09616225,0.0045106052,0.039712362,-0.053576704,0.035097837,-0.04251009,-0.013761404,0.011582285,0.02387105,0.009042205,0.054141942,-0.051263757,-0.07984356,-0.020198742,-0.051623948,-0.0013434993,-0.05825417,-0.0026240738,0.0050159167,-0.06320204,0.07872169,-0.04051374,0.04671058,-0.05804034,-0.07103668,-0.07507343,0.015222599,-3.0948323e-33,0.0076309564,-0.06283016,0.024291662,0.12532257,0.013917241,0.04869009,-0.037988827,-0.035241846,-0.041410565,-0.033772282,0.018835608,0.081035286,-0.049912665,0.044602085,0.030495265,-0.009206943,0.027668765,0.011651487,-0.10254086,0.054472663,-0.06514106,0.12192646,0.048823033,-0.015688669,0.010323047,-0.02821445,-0.030832449,-0.035029083,-0.010604268,0.0014445938,0.08670387,0.01997448,0.0101131955,0.036524937,-0.033489946,-0.026745271,-0.04709222,0.015197909,0.018787097,-0.009976326,-0.0016434817,-0.024719588,-0.09179337,0.09343157,0.029579962,-0.015174558,0.071250066,0.010549244,0.010716396,0.05435638,-0.06391847,-0.031383075,0.007916095,0.012391228,-0.012053197,-0.017409964,0.013742709,0.0594159,-0.033767693,0.04505938,-0.0017214329,0.12797962,0.03223919,-0.054756388,0.025249248,-0.02273578,-0.04701282,-0.018718086,0.009820931,-0.06267794,-0.012644738,0.0068301614,0.093209736,-0.027372226,-0.09436381,0.003861504,0.054960024,-0.058553983,-0.042971537,-0.008994571,-0.08225824,-0.013560626,-0.01880568,0.0995795,-0.040887516,-0.0036491079,-0.010253542,-0.031025425,-0.006957114,-0.038943008,-0.090270124,-0.031345647,0.029613726,-0.099465184,-0.07469079,7.844707e-34,0.024241973,0.03597121,-0.049776066,0.05084303,0.006059542,-0.020719761,0.019962702,0.092246406,0.069408394,0.062306542,0.013837189,0.054749023,0.05090263,0.04100415,-0.02573441,0.09535842,0.036858294,0.059478357,0.0070162765,0.038462427,-0.053635903,0.05912332,-0.037887845,-0.0012995935,-0.068758026,0.0671618,0.029407106,-0.061569903,-0.07481879,-0.01849014,0.014240046,-0.08064838,0.028351007,0.08456427,0.016858438,0.02053254,0.06171099,-0.028964644,-0.047633287,0.08802184,0.0017116248,0.019451816,0.03419083,0.07152118,-0.027244413,-0.04888475,-0.10314279,0.07628554,-0.045991484,-0.023299307,-0.021448445,0.04111079,-0.036342163,-0.010670482,0.01950527,-0.0648448,-0.033299454,0.05782628,0.030278979,0.079154804,-0.03679649,0.031728156,-0.034912236,0.08817754,0.059208114,-0.02319613,-0.027045371,-0.018559752,-0.051946763,-0.010635224,0.048839167,-0.043925915,-0.028300019,-0.0039419765,0.044211324,-0.067469835,-0.027534118,0.005051618,-0.034172326,0.080007285,-0.01931061,-0.005759926,0.08765162,0.08372951,-0.093784876,0.011837292,0.019019455,0.047941882,0.05504541,-0.12475821,0.012822803,0.12833545,0.08005919,0.019278418,-0.025834465,-1.9763878e-8,0.05211108,0.024891146,-0.0015623684,0.0040500895,0.015101377,-0.0031462535,0.014759316,-0.041329216,-0.029255627,0.048599463,0.062482737,0.018376771,-0.066601776,0.014752581,0.07968402,-0.015090815,-0.12100162,-0.0014005995,0.0134423375,-0.0065814927,-0.01188529,-0.01107086,-0.059613306,0.030120188,0.0418596,-0.009260598,0.028435009,0.024893047,0.031339604,0.09501834,0.027570697,0.0636991,-0.056108754,-0.0329521,-0.114633024,-0.00981398,-0.060992315,0.027551433,0.0069592255,-0.059862003,0.0008075791,0.001507554,-0.028574942,-0.011227367,0.0056030746,-0.041190825,-0.09364463,-0.04459479,-0.055058934,-0.029972456,-0.028642913,-0.015199684,0.007875299,-0.034083385,0.02143902,-0.017395096,0.027429376,0.013198211,0.005065835,0.037760753,0.08974973,0.07598824,0.0050444477,0.014734193] + ], + "total_duration":375551700, + "load_duration":354411900, + "prompt_eval_count":9 + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = new OllamaEmbeddingGenerator(new("http://localhost:11434"), "all-minilm", httpClient); + + var response = await generator.GenerateAsync([ + "hello, world!", + "red, white, blue", + ]); + Assert.NotNull(response); + Assert.Equal(2, response.Count); + + Assert.NotNull(response.Usage); + Assert.Equal(9, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.TotalTokenCount); + + foreach (Embedding e in response) + { + Assert.Equal("all-minilm", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(384, e.Vector.Length); + Assert.Contains(e.Vector.ToArray(), f => !f.Equals(0)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..49560a9c451 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Ollama.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSerializable(typeof(string))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(IDictionary))] +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs new file mode 100644 index 00000000000..da60e62061f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/IntegrationTestHelpers.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using Azure.AI.OpenAI; +using OpenAI; + +namespace Microsoft.Extensions.AI; + +/// Shared utility methods for integration tests. +internal static class IntegrationTestHelpers +{ + /// Gets an to use for testing, or null if the associated tests should be disabled. + public static OpenAIClient? GetOpenAIClient() + { + string? apiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + + if (apiKey is not null) + { + if (string.Equals(Environment.GetEnvironmentVariable("OPENAI_MODE"), "AzureOpenAI", StringComparison.OrdinalIgnoreCase)) + { + var endpoint = Environment.GetEnvironmentVariable("OPENAI_ENDPOINT") + ?? throw new InvalidOperationException("To use AzureOpenAI, set a value for OPENAI_ENDPOINT"); + return new AzureOpenAIClient(new Uri(endpoint), new ApiKeyCredential(apiKey)); + } + else + { + return new OpenAIClient(apiKey); + } + } + + return null; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj new file mode 100644 index 00000000000..0ef40e12df3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/Microsoft.Extensions.AI.OpenAI.Tests.csproj @@ -0,0 +1,26 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI.OpenAI + + + + true + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs new file mode 100644 index 00000000000..c82e1abc860 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientIntegrationTests.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OpenAIChatClientIntegrationTests : ChatClientIntegrationTests +{ + protected override IChatClient? CreateChatClient() => + IntegrationTestHelpers.GetOpenAIClient() + ?.AsChatClient(Environment.GetEnvironmentVariable("OPENAI_CHAT_MODEL") ?? "gpt-4o-mini"); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs new file mode 100644 index 00000000000..f19a19f3ce8 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIChatClientTests.cs @@ -0,0 +1,608 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.ComponentModel; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using OpenAI; +using OpenAI.Chat; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OpenAIChatClientTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => new OpenAIChatClient(null!, "model")); + Assert.Throws("chatClient", () => new OpenAIChatClient(null!)); + + OpenAIClient openAIClient = new("key"); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, null!)); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, "")); + Assert.Throws("modelId", () => new OpenAIChatClient(openAIClient, " ")); + } + + [Fact] + public void AsChatClient_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => ((OpenAIClient)null!).AsChatClient("model")); + Assert.Throws("chatClient", () => ((ChatClient)null!).AsChatClient()); + + OpenAIClient client = new("key"); + Assert.Throws("modelId", () => client.AsChatClient(null!)); + Assert.Throws("modelId", () => client.AsChatClient(" ")); + } + + [Fact] + public void AsChatClient_OpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + OpenAIClient client = new(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("openai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + + chatClient = client.GetChatClient(model).AsChatClient(); + Assert.Equal("openai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void AsChatClient_AzureOpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + AzureOpenAIClient client = new(endpoint, new ApiKeyCredential("key")); + + IChatClient chatClient = client.AsChatClient(model); + Assert.Equal("azureopenai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + + chatClient = client.GetChatClient(model).AsChatClient(); + Assert.Equal("azureopenai", chatClient.Metadata.ProviderName); + Assert.Equal(endpoint, chatClient.Metadata.ProviderUri); + Assert.Equal(model, chatClient.Metadata.ModelId); + } + + [Fact] + public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient() + { + OpenAIClient openAIClient = new(new ApiKeyCredential("key")); + IChatClient chatClient = openAIClient.AsChatClient("model"); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(chatClient, chatClient.GetService()); + + Assert.Same(openAIClient, chatClient.GetService()); + + Assert.NotNull(chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public void GetService_ChatClient_SuccessfullyReturnsUnderlyingClient() + { + ChatClient openAIClient = new OpenAIClient(new ApiKeyCredential("key")).GetChatClient("model"); + IChatClient chatClient = openAIClient.AsChatClient(); + + Assert.Same(chatClient, chatClient.GetService()); + Assert.Same(openAIClient, chatClient.GetService()); + + using IChatClient pipeline = new ChatClientBuilder() + .UseFunctionInvocation() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(chatClient); + + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + Assert.NotNull(pipeline.GetService()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType(pipeline.GetService()); + } + + [Fact] + public async Task BasicRequestResponse_NonStreaming() + { + const string Input = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":10,"temperature":0.5} + """; + + const string Output = """ + { + "id": "chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", + "object": "chat.completion", + "created": 1727888631, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 9, + "total_tokens": 17, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("hello", new() + { + MaxOutputTokens = 10, + Temperature = 0.5f, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADx3PvAnCwJg0woha4pYsBTi3ZpOI", response.CompletionId); + Assert.Equal("Hello! How can I assist you today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_888_631), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(8, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.OutputTokenCount); + Assert.Equal(17, response.Usage.TotalTokenCount); + Assert.NotNull(response.Usage.AdditionalProperties); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task BasicRequestResponse_Streaming() + { + const string Input = """ + {"messages":[{"role":"user","content":"hello"}],"model":"gpt-4o-mini","max_completion_tokens":20,"stream":true,"stream_options":{"include_usage":true},"temperature":0.5} + """; + + const string Output = """ + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + + data: {"id":"chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK","object":"chat.completion.chunk","created":1727889370,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":8,"completion_tokens":9,"total_tokens":17,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("hello", new() + { + MaxOutputTokens = 20, + Temperature = 0.5f, + })) + { + updates.Add(update); + } + + Assert.Equal("Hello! How can I assist you today?", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_889_370); + Assert.Equal(12, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADxFKtX6xIwdWRN42QvBj2u1RZpCK", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.NotNull(updates[i].AdditionalProperties); + Assert.Equal("fp_f85bea6784", updates[i].AdditionalProperties![nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + Assert.Equal(i == 10 ? 0 : 1, updates[i].Contents.Count); + Assert.Equal(i < 10 ? null : ChatFinishReason.Stop, updates[i].FinishReason); + } + + UsageContent usage = updates.SelectMany(u => u.Contents).OfType().Single(); + Assert.Equal(8, usage.Details.InputTokenCount); + Assert.Equal(9, usage.Details.OutputTokenCount); + Assert.Equal(17, usage.Details.TotalTokenCount); + Assert.NotNull(usage.Details.AdditionalProperties); + Assert.Equal(new Dictionary { [nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)] = 0 }, usage.Details.AdditionalProperties[nameof(ChatTokenUsage.OutputTokenDetails)]); + } + + [Fact] + public async Task MultipleMessages_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "role": "system", + "content": "You are a really nice friend." + }, + { + "role": "user", + "content": "hello!" + }, + { + "role": "assistant", + "content": "hi, how are you?" + }, + { + "role": "user", + "content": "i\u0027m good. how are you?" + } + ], + "model": "gpt-4o-mini", + "frequency_penalty": 0.75, + "presence_penalty": 0.5, + "seed":42, + "stop": [ + "great" + ], + "temperature": 0.25 + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", + "object": "chat.completion", + "created": 1727894187, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I’m doing well, thank you! What’s on your mind today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 42, + "completion_tokens": 15, + "total_tokens": 57, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List messages = + [ + new(ChatRole.System, "You are a really nice friend."), + new(ChatRole.User, "hello!"), + new(ChatRole.Assistant, "hi, how are you?"), + new(ChatRole.User, "i'm good. how are you?"), + ]; + + var response = await client.CompleteAsync(messages, new() + { + Temperature = 0.25f, + FrequencyPenalty = 0.75f, + PresencePenalty = 0.5f, + StopSequences = ["great"], + AdditionalProperties = new() { ["seed"] = 42 }, + }); + Assert.NotNull(response); + + Assert.Equal("chatcmpl-ADyV17bXeSm5rzUx3n46O7m3M0o3P", response.CompletionId); + Assert.Equal("I’m doing well, thank you! What’s on your mind today?", response.Message.Text); + Assert.Single(response.Message.Contents); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_187), response.CreatedAt); + Assert.Equal(ChatFinishReason.Stop, response.FinishReason); + + Assert.NotNull(response.Usage); + Assert.Equal(42, response.Usage.InputTokenCount); + Assert.Equal(15, response.Usage.OutputTokenCount); + Assert.Equal(57, response.Usage.TotalTokenCount); + Assert.NotNull(response.Usage.AdditionalProperties); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task FunctionCallContent_NonStreaming() + { + const string Input = """ + { + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the age of the specified person.", + "name": "GetPersonAge", + "parameters": { + "type": "object", + "required": [ + "personName" + ], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + }, + "strict": false + } + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + { + "id": "chatcmpl-ADydKhrSKEBWJ8gy0KCIU74rN3Hmk", + "object": "chat.completion", + "created": 1727894702, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_8qbINM045wlmKZt9bVJgwAym", + "type": "function", + "function": { + "name": "GetPersonAge", + "arguments": "{\"personName\":\"Alice\"}" + } + } + ], + "refusal": null + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 61, + "completion_tokens": 16, + "total_tokens": 77, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_f85bea6784" + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + var response = await client.CompleteAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + }); + Assert.NotNull(response); + + Assert.Null(response.Message.Text); + Assert.Equal("gpt-4o-mini-2024-07-18", response.ModelId); + Assert.Equal(ChatRole.Assistant, response.Message.Role); + Assert.Equal(DateTimeOffset.FromUnixTimeSeconds(1_727_894_702), response.CreatedAt); + Assert.Equal(ChatFinishReason.ToolCalls, response.FinishReason); + Assert.NotNull(response.Usage); + Assert.Equal(61, response.Usage.InputTokenCount); + Assert.Equal(16, response.Usage.OutputTokenCount); + Assert.Equal(77, response.Usage.TotalTokenCount); + + Assert.Single(response.Choices); + Assert.Single(response.Message.Contents); + FunctionCallContent fcc = Assert.IsType(response.Message.Contents[0]); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + + Assert.NotNull(response.AdditionalProperties); + Assert.Equal("fp_f85bea6784", response.AdditionalProperties[nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + } + + [Fact] + public async Task FunctionCallContent_Streaming() + { + const string Input = """ + { + "messages": [ + { + "role": "user", + "content": "How old is Alice?" + } + ], + "model": "gpt-4o-mini", + "stream": true, + "stream_options": { + "include_usage": true + }, + "tools": [ + { + "type": "function", + "function": { + "description": "Gets the age of the specified person.", + "name": "GetPersonAge", + "parameters": { + "type": "object", + "required": [ + "personName" + ], + "properties": { + "personName": { + "description": "The person whose age is being requested", + "type": "string" + } + } + }, + "strict": false + } + } + ], + "tool_choice": "auto" + } + """; + + const string Output = """ + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_F9ZaqPWo69u0urxAhVt8meDW","type":"function","function":{"name":"GetPersonAge","arguments":""}}],"refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"person"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Name"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Alice"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}],"usage":null} + + data: {"id":"chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl","object":"chat.completion.chunk","created":1727895263,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_f85bea6784","choices":[],"usage":{"prompt_tokens":61,"completion_tokens":16,"total_tokens":77,"prompt_tokens_details":{"cached_tokens":0},"completion_tokens_details":{"reasoning_tokens":0}}} + + data: [DONE] + + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IChatClient client = CreateChatClient(httpClient, "gpt-4o-mini"); + + List updates = []; + await foreach (var update in client.CompleteStreamingAsync("How old is Alice?", new() + { + Tools = [AIFunctionFactory.Create(([Description("The person whose age is being requested")] string personName) => 42, "GetPersonAge", "Gets the age of the specified person.")], + })) + { + updates.Add(update); + } + + Assert.Equal("", string.Concat(updates.Select(u => u.Text))); + + var createdAt = DateTimeOffset.FromUnixTimeSeconds(1_727_895_263); + Assert.Equal(10, updates.Count); + for (int i = 0; i < updates.Count; i++) + { + Assert.Equal("chatcmpl-ADymNiWWeqCJqHNFXiI1QtRcLuXcl", updates[i].CompletionId); + Assert.Equal(createdAt, updates[i].CreatedAt); + Assert.All(updates[i].Contents, u => Assert.Equal("gpt-4o-mini-2024-07-18", u.ModelId)); + Assert.Equal(ChatRole.Assistant, updates[i].Role); + Assert.NotNull(updates[i].AdditionalProperties); + Assert.Equal("fp_f85bea6784", updates[i].AdditionalProperties![nameof(OpenAI.Chat.ChatCompletion.SystemFingerprint)]); + Assert.Equal(i < 7 ? null : ChatFinishReason.ToolCalls, updates[i].FinishReason); + } + + FunctionCallContent fcc = Assert.IsType(Assert.Single(updates[updates.Count - 1].Contents)); + Assert.Equal("call_F9ZaqPWo69u0urxAhVt8meDW", fcc.CallId); + Assert.Equal("GetPersonAge", fcc.Name); + AssertExtensions.EqualFunctionCallParameters(new Dictionary { ["personName"] = "Alice" }, fcc.Arguments); + + UsageContent usage = updates.SelectMany(u => u.Contents).OfType().Single(); + Assert.Equal(61, usage.Details.InputTokenCount); + Assert.Equal(16, usage.Details.OutputTokenCount); + Assert.Equal(77, usage.Details.TotalTokenCount); + Assert.NotNull(usage.Details.AdditionalProperties); + Assert.Equal(new Dictionary { [nameof(ChatOutputTokenUsageDetails.ReasoningTokenCount)] = 0 }, usage.Details.AdditionalProperties[nameof(ChatTokenUsage.OutputTokenDetails)]); + } + + private static IChatClient CreateChatClient(HttpClient httpClient, string modelId) => + new OpenAIClient(new ApiKeyCredential("apikey"), new OpenAIClientOptions { Transport = new HttpClientPipelineTransport(httpClient) }) + .AsChatClient(modelId); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs new file mode 100644 index 00000000000..38283e2687b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorIntegrationTests.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; + +namespace Microsoft.Extensions.AI; + +public class OpenAIEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegrationTests +{ + protected override IEmbeddingGenerator>? CreateEmbeddingGenerator() => + IntegrationTestHelpers.GetOpenAIClient() + ?.AsEmbeddingGenerator(Environment.GetEnvironmentVariable("OPENAI_EMBEDDING_MODEL") ?? "text-embedding-3-small"); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..d08cf295a4b --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.OpenAI.Tests/OpenAIEmbeddingGeneratorTests.cs @@ -0,0 +1,187 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Net.Http; +using System.Threading.Tasks; +using Azure.AI.OpenAI; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using OpenAI; +using OpenAI.Embeddings; +using Xunit; + +#pragma warning disable S103 // Lines should not be too long + +namespace Microsoft.Extensions.AI; + +public class OpenAIEmbeddingGeneratorTests +{ + [Fact] + public void Ctor_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => new OpenAIEmbeddingGenerator(null!, "model")); + Assert.Throws("embeddingClient", () => new OpenAIEmbeddingGenerator(null!)); + + OpenAIClient openAIClient = new("key"); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, null!)); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, "")); + Assert.Throws("modelId", () => new OpenAIEmbeddingGenerator(openAIClient, " ")); + } + + [Fact] + public void AsEmbeddingGenerator_InvalidArgs_Throws() + { + Assert.Throws("openAIClient", () => ((OpenAIClient)null!).AsEmbeddingGenerator("model")); + Assert.Throws("embeddingClient", () => ((EmbeddingClient)null!).AsEmbeddingGenerator()); + + OpenAIClient client = new("key"); + Assert.Throws("modelId", () => client.AsEmbeddingGenerator(null!)); + Assert.Throws("modelId", () => client.AsEmbeddingGenerator(" ")); + } + + [Fact] + public void AsEmbeddingGenerator_OpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + OpenAIClient client = new(new ApiKeyCredential("key"), new OpenAIClientOptions { Endpoint = endpoint }); + + IEmbeddingGenerator> embeddingGenerator = client.AsEmbeddingGenerator(model); + Assert.Equal("openai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + + embeddingGenerator = client.GetEmbeddingClient(model).AsEmbeddingGenerator(); + Assert.Equal("openai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + } + + [Fact] + public void AsEmbeddingGenerator_AzureOpenAIClient_ProducesExpectedMetadata() + { + Uri endpoint = new("http://localhost/some/endpoint"); + string model = "amazingModel"; + + AzureOpenAIClient client = new(endpoint, new ApiKeyCredential("key")); + + IEmbeddingGenerator> embeddingGenerator = client.AsEmbeddingGenerator(model); + Assert.Equal("azureopenai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + + embeddingGenerator = client.GetEmbeddingClient(model).AsEmbeddingGenerator(); + Assert.Equal("azureopenai", embeddingGenerator.Metadata.ProviderName); + Assert.Equal(endpoint, embeddingGenerator.Metadata.ProviderUri); + Assert.Equal(model, embeddingGenerator.Metadata.ModelId); + } + + [Fact] + public void GetService_OpenAIClient_SuccessfullyReturnsUnderlyingClient() + { + OpenAIClient openAIClient = new(new ApiKeyCredential("key")); + IEmbeddingGenerator> embeddingGenerator = openAIClient.AsEmbeddingGenerator("model"); + + Assert.Same(embeddingGenerator, embeddingGenerator.GetService>>()); + Assert.Same(embeddingGenerator, embeddingGenerator.GetService()); + + Assert.Same(openAIClient, embeddingGenerator.GetService()); + + Assert.NotNull(embeddingGenerator.GetService()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(embeddingGenerator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public void GetService_EmbeddingClient_SuccessfullyReturnsUnderlyingClient() + { + EmbeddingClient openAIClient = new OpenAIClient(new ApiKeyCredential("key")).GetEmbeddingClient("model"); + IEmbeddingGenerator> embeddingGenerator = openAIClient.AsEmbeddingGenerator(); + + Assert.Same(embeddingGenerator, embeddingGenerator.GetService>>()); + Assert.Same(openAIClient, embeddingGenerator.GetService()); + + using IEmbeddingGenerator> pipeline = new EmbeddingGeneratorBuilder>() + .UseOpenTelemetry() + .UseDistributedCache(new MemoryDistributedCache(Options.Options.Create(new MemoryDistributedCacheOptions()))) + .Use(embeddingGenerator); + + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + Assert.NotNull(pipeline.GetService>>()); + + Assert.Same(openAIClient, pipeline.GetService()); + Assert.IsType>>(pipeline.GetService>>()); + } + + [Fact] + public async Task GetEmbeddingsAsync_ExpectedRequestResponse() + { + const string Input = """ + {"input":["hello, world!","red, white, blue"],"model":"text-embedding-3-small","encoding_format":"base64"} + """; + + const string Output = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": "qjH+vMcj07wP1+U7kbwjOv4cwLyL3iy9DkgpvCkBQD0bthW98o6SvMMwmTrQRQa9r7b1uy4tuLzssJs7jZspPe0JG70KJy89ae4fPNLUwjytoHk9BX/1OlXCfTzc07M8JAMIPU7cibsUJiC8pTNGPWUbJztfwW69oNwOPQIQ+rwm60M7oAfOvDMAsTxb+fM77WIaPIverDqcu5S84f+rvFyr8rxqoB686/4cPVnj9ztLHw29mJqaPAhH8Lz/db86qga/PGhnYD1WST28YgWru1AdRTz/db899PIPPBzBE720ie47ujymPbh/Kb0scLs8V1Q7PGIFqzwVMR48xp+UOhNGYTxfwW67CaDvvOeEI7tgc228uQNoPXrLBztd2TI9HRqTvLuVJbytoPm8YVMsOvi6irzweJY7/WpBvI5NKL040ym95ccmPAfj8rxJCZG9bsGYvJkpVzszp7G8wOxcu6/ZN7xXrTo7Q90YvGTtZjz/SgA8RWxVPL/hXjynl8O8ZzGjvHK0Uj0dRVI954QjvaqKfTxmUeS8Abf6O0RhV7tr+R098rnRPAju8DtoiiK95SCmvGV0pjwQMOW9wJPdPPutxDxYivi8NLKvPI3pKj3UDYE9Fg5cvQsyrTz+HEC9uuMmPMEaHbzJ4E8778YXvVDERb2cFBS9tsIsPLU7bT3+R/+8b55WPLhRaTzsgls9Nb2tuhNG4btlzSW9Y7cpvO1iGr0lh0a8u8BkvadJQj24f6k9J51CvbAPdbwCEHq8CicvvIKROr0ESbg7GMvYPE6OCLxS2sG7/WrBPOzbWj3uP1i9TVXKPPJg0rtp7h87TSqLPCmowLxrfdy8XbbwPG06WT33jEo9uxlkvcQN17tAmVy8h72yPEdMFLz4Ewo7BPs2va35eLynScI8WpV2PENW2bwQBSa9lSufu32+wTwl4MU8vohfvRyT07ylCIe8dHHPPPg+ST0Ooag8EsIiO9F7w7ylM0Y7dfgOPADaPLwX7hq7iG8xPDW9Lb1Q8oU98twTPYDUvTomwIQ8akcfvUhXkj3mK6Q8syXxvAMb+DwfMI87bsGYPGUbJ71GHtS8XbbwvFQ+P70f14+7Uq+CPSXgxbvHfFK9icgwPQsEbbwm60O9EpRiPDjTKb3uFJm7p/BCPazDuzxh+iy8Xj2wvBqrl71a7nU9guq5PYNDOb1X2Pk8raD5u+bSpLsMD2u7C9ktPVS6gDzyjhI9vl2gPNO0AT0/vJ68XQTyvMMCWbubYhU9rzK3vLhRaToSlOK6qYIAvQAovrsa1la8CEdwPKOkCT1jEKm8Y7epvOv+HLsoJII704ZBPXbVTDubjVQ8aRnfOvspBr2imYs8MDi2vPFVVDxSrwK9hac2PYverLyxGnO9nqNQvfVLD71UEP+8tDDvurN+8Lzkbqc6tsKsu5WvXTtDKxo72b03PdDshryvXfY81JE/vLYbLL2Fp7Y7JbUGPEQ2GLyagla7fAxDPaVhhrxu7Ne7wzAZPOxXHDx5nUe9s35wPHcOizx1fM26FTGePAsEbbzzQBE9zCQMPW6TWDygucy8zPZLPM2oSjzfmy48EF4lvUttDj3NL4q8WIp4PRoEFzxKFA89uKpou9H3BDvK6009a33cPLq15rzv8VY9AQX8O1gxebzjCqo7EeJjPaA1DrxoZ2C65tIkvS0iOjxln2W8o0sKPMPXGb3Ak908cxhQvR8wDzzN1gq8DnNovMZGFbwUJiA9moJWPBl9VzkVA148TrlHO/nFCL1f7y68xe2VPIROtzvCJRu88YMUvaUzRj1qR5+7e6jFPGyrHL3/SgC9GMtYPJcT27yqMX688YOUO32+QT18iAS9cdeUPFbN+zvlx6a83d6xOzQLL7sZJNi8mSnXOuqan7uqin09CievvPw0hLyuq/c866Udu4T1t7wBXnu7zQFKvE5gyDxhUyw8qzx8vIrTLr0Kq+26TgdJPWmVoDzOiIk8aDwhPVug9Lq6iie9iSEwvOKxqjwMiyy7E59gPepMnjth+iw9ntGQOyDijbw76SW9i96sO7qKJ7ybYhU8R/6Su+GmLLzsgtu7inovPRG3pLwZUpi7YzvoucrAjjwOSKm8uuOmvLbt67wKUu68XCc0vbd0Kz0LXWy8lHmgPAAoPjxRpAS99oHMvOlBoDprUh09teLtOxoEl7z0mRA89tpLvVQQ/zyjdkk9ZZ/lvHLikrw76SW82LI5vXyIBLzVnL06NyGrPPXPzTta7nW8FTEePSVcB73FGFU9SFcSPbzL4rtXrbo84lirvcd8Urw9/yG9+63EvPdhCz2rPPw8PPQjvbXibbuo+0C8oWtLPWVG5juL3qw71Zw9PMUY1Tk3yKu8WWq3vLnYKL25A+i8zH2LvMW/1bxDr1g8Cqvtu3pPRr0FrbU8vVKiO0LSGj1b+fM7Why2ux1FUjwhv0s89lYNPUbFVLzJ4M88t/hpvdpvNj0EzfY7gC29u0HyW7yv2Tc8dSPOvNhZurzrpR28jUIqPM0vijxyDdK8iBYyvZ0fkrxalXa9JeBFPO/GF71dBHK8X8FuPKnY/jpQmQY9S5jNPGBz7TrpQaA87/FWvUHyWzwCEPq78HiWOhfuGr0ltYY9I/iJPamCgLwLBO28jZupu38ivzuIbzG8Cfnuu0dMlLypKQG7BzxyvR5QULwCEHo8k8ehPUXoFjzPvka9MDi2vPsphjwjfMi854QjvcW/VbzO4Yg7Li04vL/h3jsaL9a5iG8xuybrwzz3YYu8Gw8VvVGkBD1UugA99MRPuCjLArzvxhc8XICzPFyrcr0gDU296h7eu8jV0TxNKos8lSufuqT9CD1oDmE8sqGyu2PiaLz6osY5YjBqPBAFJrwIlfG8PlihOBE74zzzQJG8r112vJPHobyrPPw7YawrPb5doLqtzrk7qHcCPVIoQzz5l0i81UM+vFd/eryaVxc9xA3XO/6YgbweJZG7W840PF0Ecj19ZUI8x1GTOtb1vDyDnLg8yxkOvOywGz0kqgg8fTqDvKlUQL3Bnlu992ELvZPHobybCZa82LK5vf2NgzwnnUK8YMzsPKOkiTxDr9g6la/duz3/IbusR/q8lmFcvFbN+zztCRu95nklPVKBwjwEJnY6V9j5PPK50bz6okY7R6UTPPnFiDwCafk8N8grO/gTCr1iiWm8AhB6vHHXlLyV3Z08vtZgPMDsXDsck9O7mdBXvRLCojzkbqe8XxpuvDSyLzu0MO87cxhQvd3eMbxtDxo9JKqIvB8CT72zrDC7s37wPHvWhbuXQZs8UlYDu7ef6rzsV5y8IkYLvUo/Tjz+R/88PrGgujSyrzxsBJy8P7yeO7f46byfKpA8cFDVPLygIzsdGpO77LCbvLSJ7rtgzOy7sA91O0hXkrwhO408XKvyvMUYVT2mPsQ8d+DKu9lkuLy+iF89xZSWPJFjpDwIlfE8bC9bPBE7Y7z/+f08W6B0PAc8crhmquO7RvOUPDybJLwlXAe9cuKSvMPXGbxK5s48sZY0O+4UmT1/Ij+8oNyOvPIH07tNKos8yTnPO2RpKDwRO+O7vl2gvKSvB7xGmpW7nD9TPZpXFzyXQRs9InHKurhR6bwb4VS8iiwuO3pPxrxeD3A8CfluO//OPr0MaOq8r112vAwP6zynHgM9T+cHPJuNVLzLRE07EmkjvWHX6rzBGh285G4nPe6Y17sCafm8//n9PJkpVzv9P4K7IWbMPCtlvTxHKVK8JNXHO/uCBblAFZ48xyPTvGaqY7wXlRs9EDDlPHcOizyNQiq9W3W1O7iq6LxwqdQ69MRPvSJGC7n3CIy8HOxSvSjLAryU0p87QJncvEoUjzsi7Qu9U4xAOwn5brzfm668Wu71uu002rw/Y588o6SJPFfY+Tyfg4+8u5WlPMDBnTzVnD08ljadu3sBxbzfm668n4OPO9VDvrz0mZC8kFimPNiyOT134Mo8vquhvDA4Njyjz0i7zVpJu1rudbwmksQ794xKuhN0ITz/zj68Vvu7unBQ1bv8NAS97FecOyxwOzs1ZC68AIG9PKLyCryvtvU8ntEQPBkkWD2xwfO7QfLbOhqIVTykVog7lSufvKOkiTwpqEA9/RFCvKxHejx3tYu74woqPMS0VzoMtuu8ViZ7PL8PH72+L2C81JE/vN3eMTwoywK9z5OHOx4lkTwGBrW8c5QRu4khMDyvBPc8nR8SvdlkuLw0si+9S8aNvCkBwLsXwFo7Od4nPbo8pryp2P68GfkYPKpfvjrsV5w6zuEIvbHB8zxnMSM9C9mtu1nj97zjYym8XFJzPAiVcTyNm6m7X5YvPJ8qED1l+OS8WTx3vGKJ6bt+F0G9jk2oPAR0dzwIR/A8umdlvNLUwjzI1dE7yuvNvBdnW7zdhTI9xkaVPCVcB70Mtus7G7aVPDchK7xuwRi8oDWOu/SZkLxOuUe8c5QRPLBo9Dz/+f07zS+KvNBFBr1n2CO8TKNLO4ZZNbym5US5HsyRvGi1YTwxnDO71vW8PM3WCr3E4he816e7O7QFML2asBa8jZspPSVcBzvjvCi9ZGmoPHV8zbyyobK830KvOgw9q7xzZtG7R6WTPMpnjzxj4mg8mrAWPS+GN7xoZ2C8tsKsOVMIAj1fli89Zc0lO00qCzz+R/87XKvyvLxy4zy52Cg9YjBqvW9F1zybjVS8mwmWvLvA5DymugU9DOQrPJWvXbvT38C8TrnHvLbt67sgiQ49e32GPPTETzv7goW7cKnUOoOcuLpG85S8CoCuO7ef6rkaqxe90tTCPJ8qkDvuuxk8FFFfPK9ddrtAbh08roC4PAnOrztV8D08jemquwR09ziL3iy7xkaVumVG5rygNQ69CfnuPGBzbTyE9Tc9Z9ijPK8yNzxgoa084woqu1F2RLwN76m7hrI0vf7xgLwaXRY6JmeFO68ytzrrpR29XbZwPYI4uzvkFai8qHcCPRCJ5DxKFI+7dHHPPE65xzxvnta8BPs2vWaq4zwrvjy8tDDvvEq7D7076SU9q+N8PAsyLTxb+XM9xZQWPP7ufzxsXZu6BEk4vGXNJbwBXvu8xA3XO8lcEbuuJzk8GEeavGnun7sMPSs9ITsNu1yr8roj+Ik8To6IvKjQgbwIwzG8wqlZvDfIK7xln2W8B+Pyu1HPw7sBjDs9Ba01PGSU57w/Yx867FecPFdUu7w2b6w7X5avvA8l57ypKQE9oGBNPeyC27vGytM828i1PP9KAD2/4V68eZ1HvDHqtDvR94Q6UwgCPLMlcbz+w0C8HwJPu/I1k7yZ/pe8aLXhPHYDDT28oKO8p2wEvdVDvrxh+qy8WDF5vJBYpjpaR3U8vgQhPNItwrsJoG88UaQEu3e1C7yagtY6HOzSOw9+5ryYTBk9q+N8POMKqrwoywI9DLZrPCN8SDxYivi8b3MXPf/OvruvBHc8M6exvA3vKbxz7RA8Fdieu4rTrrwFVDa8Vvu7PF0Ecjs6N6e8BzzyPP/Ovrv2rww9t59qvEoUDz3HUZO7UJkGPRigmbz/+X28qjH+u3jACbxlzaW7DA9rvFLawbwLBO2547yoO1t1NTr1pI68Vs37PAI+Ojx8s8O8xnHUvPg+yTwLBO26ybUQPfUoTTw76SU8i96sPKWMRbwUqt46pj7EPGX4ZL3ILtG8AV77vM0BSjzKZ488CByxvIWnNjyIFrI83CwzPN2FsjzHUZO8rzK3O+iPIbyGCzQ98NGVuxpdlrxhrKs8hQC2vFWXvjsCaXm8oRJMPHyIBLz+HMA8W/nzvHkZCb0pqMC87m0YPCu+vDsM5Ks8VnR8vG0Pmrt0yk48y3KNvKcegzwGMXS9xZQWPDYWrTxxAtQ7IWZMPU4Hybw89CO8/eaCPPMSUTxuk9i8WAY6vGfYozsQMGW8Li24vI+mJzxKFI88HwJPPFru9btRz8O6L9+2u29F1zwC5bq7RGHXvMtyjbr5bIm7V626uxsPlTv1KE29UB3FPMwkDDupggC8SQkRvH4XQT1cJ7Q8nvzPvKsRvTu9+SI8JbUGuiP4iTx460i99JkQPNF7Qz26Dma8u+4kvHO/0LyzfvA8EIlkPUPdmLpmUWS8uxnku8f4E72ruL27BzxyvKeXwz1plSC8gpG6vEQ2mLvtYho91Zy9vLvA5DtnXGK7sZY0uyu+PLwXlZu8GquXvE2uSb0ezBG8wn6au470KD1Abh28YMzsvPQdT7xKP867Xg/wO81aSb0IarK7SY1PO5EKJTsMi6y8cH4VvcXtlbwdGhM8xTsXPQvZLbxgzOw7Pf8hPRsPlbzDMJm8ZGmoPM1aSb0HEbO8PPQjvX5wwDwQXiW9wlDaO7SJ7jxFE9a8FTEePG5omTvPkwc8vtZgux9bzrmwD3W8U2EBPAVUNj0hlIw7comTPAEF/DvKwI68YKGtPJ78Tz1boHQ9sOS1vHiSSTlVG307HsyRPHEwFDxQmQY8CaBvvB0aE70PfuY8+neHvHOUET3ssBu7+tCGPJl3WDx4wAk9d1yMPOqanzwGBjW8ZialPB7MEby1O+07J0RDu4yQq7xpGV88ZXQmPc3WCruRCqU8Xbbwu+0JG7kXGVq8SY1PvKblxDv/oH68r7Z1OynWgDklh0a8E/hfPBCJZL31/Y08sD21vA9+Zjy6DmY82WQ4PAJp+TxHTJQ8JKoIvUBunbwgDc26BzxyvVUb/bz+w8A8Wu51u8guUbyHZLM8Iu0LvJqCVj3nhKO96kwevVDyBb3UDYG79zNLO7KhMj1IgtE83NOzO0f+krw89CM9z5OHuz+OXj2TxyE8wOzcPP91v7zUZgA8DyVnvILqOTzn3aI8j/+mO8xPyzt1UQ48+R4IvQnOrzt1I067QtKau9vINb1+7AE8sA/1uy7UOLzpQSC8dqoNPSnWgDsJoO+8ANo8vfDRlbwefpC89wgMPI1CKrrYsrm78mBSvFFLBb1Pa0a8s1MxPHbVzLw+WCG9kbyjvNt6tLwfMA+8HwLPvGO3qTyyobK8DcFpPInIsLwXGdq7nBSUPGdc4ryTx6G8T+eHPBxolDvIqhK8rqv3u1fY+Tz3M0s9qNCBO/GDlL2N6Sq9XKtyPFMIgrw0Cy+7Y7epPLJzcrz/+X28la/du8MC2bwTn+C5YSXsvDneJzz/SoC8H9ePvHMY0Lx0nw+9lSsfvS3Jujz/SgC94rEqvQwP67zd3rE83NOzPKvj/DyYmpo8h2SzvF8abjye0ZC8vSRivCKfijs/vJ48NAuvvFIoQzzFGFU9dtVMPa2g+TtpGd88Uv2DO3kZiTwA2rw79f2Nu1ugdDx0nw+8di7MvIrTrjz08g+8j6anvGH6LLxQ8oW8LBc8Pf0/Ajxl+OQ8SQkRPYrTrrzyNRM8GquXu9ItQjz1Sw87C9mtuxXYnrwDl7m87Y1ZO2ChrbyhQIy4EsIiPWpHHz0inwo7teJtPJ0fEroHPPK7fp4APV/B7rwwODa8L4Y3OiaSxLsBBfw7RI8XvP5H/zxVlz68n1VPvEBuHbwTzSA8fOEDvV49sDs2b6y8mf6XPMVm1jvjvCg8ETvjPEQ2GLxK5s47Q92YuxOfYLyod4K8EDDlPHAlFj1zGFC8pWGGPE65R7wBMzy8nJjSvLoO5rwwkbU7Eu3hvLOsMDyyobI6YHNtPKs8fLzXp7s6AV57PV49MLsVMR68+4KFPIkhMLxeaG87mXdYulyAMzzQRQY9ljadu3YDDby7GWS7phOFPEJ5mzq6tea6Eu1hPJjzmTz+R388di5MvJn+F7wi7Qs8K768PFnj9zu5MSi8Gl2WvJfomzxHd1O8vw8fvONjqbxuaBk980ARPSNRiTwLMi272Fk6vDGcs7z60Ia8vX1hOzvppbuKLK48jZspvZkpV7pWJns7G7YVPdPfwLyruL08FFHfu7ZprbwT+N84+1TFPGpHn7y9JOI8xe2Vu08SR7zs29o8/RFCPCbAhDzfQi89OpCmvL194boeJZE8kQqlvES6VjrzEtE7eGeKu2kZX71rfdw8D6wmu6Y+xLzJXJE8DnPovJrbVbvkFai8KX0Bvfr7RbuXbNq8Gw+VPRCJ5LyA1D28uQPoPLygo7xENpi8/RHCvEOv2DwRtyS9o0uKPNshNbvmeSU8IyPJvCedQjy7GWQ8Wkf1vGKJ6bztYho8vHLju5cT2zzKZw+88jWTvFb7uznYCzm8" + }, + { + "object": "embedding", + "index": 1, + "embedding": "eyfbu150UDkC6hQ9ip9oPG7jWDw3AOm8DQlcvFiY5Lt3Z6W8BLPPOV0uOz3FlQk8h5AYvH6Aobv0z/E8nOQRvHI8H7rQA+s8F6X9vPplyDzuZ1u8T2cTvAUeoDt0v0Q9/xx5vOhqlT1EgXu8zfQavTK0CDxRxX08v3MIPAY29bzIpFm8bGAzvQkkazxCciu8mjyxvIK0rDx6mzC7Eqg3O8H2rTz9vo482RNiPUYRB7xaQMU80h8hu8kPqrtyPB+8dvxUvfplSD21bJY8oQ8YPZbCEDvxegw9bTJzvYNlEj0h2q+9mw5xPQ5P8TyWwpA7rmvvO2Go27xw2tO6luNqO2pEfTztTwa7KnbRvAbw37vkEU89uKAhPGfvF7u6I8c8DPGGvB1gjzxU2K48+oqDPLCo/zsskoc8PUclvXCUvjzOpQC9qxaKO1iY5LyT9XS9ZNzmvI74Lr03azk93CYTvFJVCTzd+FK8lwgmvcMzPr00q4O9k46FvEx5HbyIqO083xSJvC7PFzy/lOK7HPW+PF2ikDxeAHu9QnIrvSz59rl/UmG8ZNzmu2b4nD3V31Y5aXK9O/2+jrxljUw8y9jkPGuvTTxX5/48u44XPXFFpDwAiEm8lcuVvX6h+zwe7Lm8SUUSPHmkNTu9Eb08cP8OvYgcw7xU2C49Wm4FPeV8H72AA8c7eH/6vBI0Yj3L2GQ8/0G0PHg5ZTvHjAS9fNhAPcE8wzws2By6RWAhvWTcZjz+1uM8H1eKvHdnJT0TWR29KcVrPdu7wrvMQzW9VhW/Ozo09LvFtuM8OlmvPO5GAT3eHY68zTqwvIhiWLs1w1i9sGJqPaurOb0s2Jy8Z++XOwAU9Lggb988vnyNvVfGpLypKBS8IouVO60NBb26r/G6w+0ovbVslrz+kE68MQOjOxdf6DvoRdo8Z4RHPCvhIT3e7009P4Q1PQ0JXDyD8Ty8/ZnTuhu4Lj3X1lG9sVnlvMxDNb3wySY9cUWkPNZKJ73qyP+8rS7fPNhBojwpxes8kt0fPM7rlbwYEE68zoBFvdrExzsMzEu9BflkvF0uu7zNFfW8UyfJPPSJ3LrEBf68+6JYvef/xDpAe7C8f5h2vPqKA7xUTAS9eDllPVK8eL0+GeW7654gPQuGNr3/+x69YajbPAehRTyc5BE8pfQIPMGwGL2QoA87iGJYPYXoN7s4sc69f1JhPdYEkjxgkIa6uxpCvHtMljtYvR88uCzMPBeEo7wm1/U8GBDOvBkHybwyG3i7aeaSvQzMyzy3e2a9xZUJvVSSmTu7SII8x4yEPKAYHTxUTIQ8lcsVO5x5QT3VDRe963llO4K0rLqI1i07DX0xvQv6CznrniA9nL9WPTvl2Tw6WS+8NcPYvEL+VbzZfrK9NDcuO4wBNL0jXVW980PHvNZKJz1Oti09StG8vIZTiDwu8PE8zP0fO9340juv1j890vFgvMFqAz2kHui7PNxUPQehxTzjGlQ9vcunPL+U4jyfrUw8R+NGPHQF2jtSdmO8mYtLvF50ULyT1Bo9ONaJPC1kx7woznC83xQJvUdv8byEXA29keaku6Qe6Ly+fA29kKAPOxLuzLxjxJG9JnCGur58jTws2Jy8CkmmO3pVm7uwqH87Eu7Mu/SJXL0IUis9MFI9vGnmEr1Oti09Z+8XvH1DkbwcaZS8NDcuvT0BkLyPNT89Haakuza607wv5+w81KLGO80VdT3MiUq8J4hbPHHRzrwr4aG8PSJqvJOOBT3t2zC8eBgLvXchkLymOp66y9jkPDdG/jw2ulO983GHPDvl2Tt+Ooy9NwDpOzZ0Pr3xegw7bhGZvEpd57s5YjS9Gk1evIbfMjxBwcW8NnQ+PMlVPzxR6ji9M8zdPImHk7wQsby8u0gCPXtMFr22YxE9Wm4FPaXPzbygGJ093bK9OuYtBTxyXfk8iYeTvNH65byk/Q29QO+FvKbGyLxCcqs9nL/WvPtcQ72XTjs8kt2fuhaNKDxqRH08KX9WPbmXnDtXDDo96GoVPVw3QL0eeGS8ayOjvAIL7zywQZC9at0NvUMjET1Q8707eTDgvIio7Tv60Jg87kYBOw50LLx7BgE96qclPUXsSz0nQkY5aDUtvQF/RD1bZQC73fjSPHgYCzyPNT+9q315vbMvhjsvodc8tEdbPGcQ8jz8U768cYs5PIwBtL38x5M9PtPPvIex8jzfFIk9vsIivLsaQj2/uZ072y8YvSV5C7uoA9k8JA67PO5nWzvS8eC8av7nuxSWrbybpwE9f5h2vG3sXTmoA1k9sjiLvTBSPbxc8Sq9UpuePB+dHz2/cwg9BWS1vCrqJr2M3Pg86LAqPS/GEj3oRdq8GiyEvACISbuiJ+28FFAYuzBSvTzwDzy8K5uMvE5wmDpd6CW6dkJqPGlyvTwF2Iq9f1JhPSHarzwDdr88JXkLu4ADxzx5pDW7zqUAvdAoJj24wXs8doj/PH46jD2/2vc893fSuyxtTL0YnPg7IWbaPOiwqrxLDk27ZxDyPBpymbwW0z08M/odPTufRL1AVvU849Q+vBGDfD3JDyq6Z6kCPL9OzTz0rpe8FtM9vaDqXLx+W2Y7jHWJPGXT4TwJ3lW9M4bIPPCDkTwoZwE9XH1VOmksqLxLPI08cNrTvCyz4bz+Srm8kiO1vDP6nbvIpNk8MrSIvPe95zoTWR29SYsnPYC9MT2F6De93qm4PCbX9bqqhv47yky6PENE67x/DEw8JdYAvUdvcbywh6W8//ueO8fSmTyjTCi9yky6O/qr3TzvGEE8wqcTPeDmSDyuJVo8ip/ou1HqOLxOtq28y5LPuxk1Cb0Ddr+7c+2EvKQeaL1SVQk8XS47PGTcZjwdpiQ8uFqMO0QaDD1XxqS8mLmLuuSFJDz1xmy8PvgKvJAHf7yC+kE8VapuvetYC7tHCAI8oidtPOiwqjyoSW68xCo5vfzobTzz2HY88/0xPNkT4rty9om8RexLu9SiRrsVaG081gSSO5IjtTsOLpc72sTHPGCQBj0QJRI9BCclPI1sBDzCyO07QHuwvOYthTz4tGK5QHuwvWfvFz2CQNc8PviKPO8YwTuQoA89fjoMPBnBs7zGZ8m8uiPHvMdeRLx+gKE8keaku0wziDzZWfe8I4KQPJ0qpzs4sc47dyEQPEQaDDzVmcE8//uePJcIJjztTwa9ogaTOftcwztU2K48opvCuyz5drzqM1C7iYcTvfDJJjxXxiQ9o0wovO1PBrwqvGa7dSoVPbI4izvnuS88zzGrPH3POzzHXkQ9PSJqOXCUPryW4+o8ELE8PNZKp7z+Sjm8foChPPIGtzyTaUq8JA47vBiceDw3a7m6jWyEOmksKDwH59q5GMo4veALBL0SqDe7IaxvvBD3Ubxn7xc9+dkdPSBOBTxHCAI8mYvLOydCxjw5HB88zTqwvJXs77w9AZA9CxvmvIeQGL2rffm8JXkLPKqGfjyoSe464d1DPPd3UrpO/EK8qxYKvUuCojwhZlq8EPfRPKaAs7xKF9K85i0FvEYRhzyPNT88m6cBvdSiRjxnqQI9uOY2vcBFSLx4OeW7BxUbPCz59rt+W2Y7SWZsPGzUCLzE5KM7sIclvIdr3buoSW47AK0EPImHE7wgToU8IdovO7FZ5bxbzO+8uMF7PGayB7z6ioO8zzErPEcIgrxSm568FJYtvNf7jDyrffm8KaQRPcoGpTwleQu8EWKiPHPthLz44qI8pEOjvWh7QjzpPNU8lcuVPHCUPr3n/8Q8bNQIu0WmNr1Erzs95VfkPCeIW7vT0Aa7656gudH65bxw/w49ZrKHPHsn27sIUiu8mEU2vdUNF7wBf8Q809CGPFtlgDo1fcO85i2FPEcIAjwL+os653OavOu1AL2EN9K8H52fPKzoybuMdYk8T2cTO8lVPzyK5X07iNYtvD74ijzT0IY8RIF7vLLENbyZi8s8KwJ8vAne1TvGZ8k71gSSumJZwTybp4G8656gPG8IFL27SAI9arjSvKVbeDxljcy83fjSuxu4Lr2DZRK9G0TZvLFZ5bxR6ji8NPEYPbI4izyAvTE9riVaPCCUGrw0Ny48f1LhuzIb+DolBTY8UH9ou/4EpLyAvTG9CFIrvCBOBTlkIvy8WJhkvHIXZLkf47Q8GQfJvBpNXr1pcr07c8jJO2nmkrxOcJi8sy8GuzjWibu2Pta8WQO1PFPhs7z7XEO8pEMjvb9OzTz4bs08EWKiu0YyYbzeHQ695D+PPKVbeDzvGEG9B6HFO0uCojws+Xa7JQW2OpRgRbxjCqc8Sw7NPDTxmLwjXVW8sRNQvFPhszzM/Z88rVMavZPUGj06WS+8JpHgO3etursdx369uZccvKplJDws+Xa8fzGHPB1gj7yqZaQ887ecPBNZHbzoi2+7NwDpPMxDtbzfWh49H+O0PO+kaztI2kE8/xz5PImHE73fNWO8T60ovIPxPDvR2Yu8XH3VvMcYr7wfnR+9fUORPIdr3Tyn6wO9nkL8vM2uhTzGIbS66u26vE2/MrxFYKE8iwo5vLSNcLy+wiK9GTUJPK10dLzrniC8qkBpvPxTPrwzQLO8illTvFi9H7yMATS7ayOjO14Ae7z19Cy87dswPKbGyDzujJa93EdtPdsB2LYT5Ue9RhEHPKurubxm+By9+mVIvIy7HrxZj987yOpuvUdv8TvgCwS8TDMIO9xsqLsL+gs8BWS1PFRMBD1yXXm86GoVvK+QqjxRXg46TZHyu2ayhzx7TJa8uKAhPLyFkjsV3MI7niGiPGNQvDxgkIa887ccPUmLJ7yZsIa8KDnBvHgYi7yMR0m82ukCvRuK7junUvO8aeYSPXtt8LqXCKa84kgUPd5jIzxlRze93xQJPNNcMT2v1j889GiCPKRkfbxz7YQ8b06pO8cYL7xg9/U8yQ+qPGlyvbzfNWO8vZ3nPBGD/DtB5gC7yKRZPPTPcbz6q928bleuPI74rrzVDRe9CQORvMmb1Dzv0qs8DBLhu4dr3bta1fQ8aeYSvRD3UTugpMe8CxvmPP9BNDzHjAQ742DpOzXD2Dz4bk28c1T0Onxka7zEBf48uiNHvGayBz1pcj29NcPYvDnu3jz5kwg9WkBFvL58jTx/mHY8wTzDPDZ0Pru/uZ08PQGQPOFRmby4oKE8JktLPIx1iTsppBG9dyGQvHfzT7wzhki44KAzPSOCkDzv0iu8lGBFO2VHNzyKxKM72EEiPYtQzryT9fQ8UDnTPEx5nTzuZ9s8QO8FvG8IlDx7J9s6MUk4O9k4nbx7TBa7G7iuvCzYHDocr6k8/7UJPY2ymTwVIlg8KjC8OvSuFz2iJ+28cCBpvE0qAzw41ok7sgrLvPjiojyG37K6lwimvKcxGTwRHI28y5LPO/mTiDx82MC5VJIZPWkH7TwPusG8YhOsvH1DkbzUx4E8TQXIvO+ka7zKwI+8w+2oPNLxYLzxegy9zEM1PDo0dDxIINc8FdxCO46E2TwPRmw9+ooDvMmb1LwBf0S8CQMRvEXsS7zPvdU80qvLPLfvO7wbuK68iBzDO0cpXL2WndU7dXCqvOTLubytLl88LokCvZj/IDw0q4M8G7guvNkTYrq5UQe7vcunvIrEI7xuERm9RexLvAdbsDwLQCE7uVEHPYjWrbuM3Pi8g2WSO3R5L7x4XiC8vKZsu9Sixros+fa8UH/ouxxpFL3wyaa72sRHu2YZ9zuiJ2274o4pOjkcnzyagka7za4FvYrEozwCMCo7cJQ+vfqKAzzJ4em8fNhAPUB7sLylz80833v4vOU2ir1ty4M8UV4OPXQF2jyu30S9EjRivBVo7TwXX2g70ANrvEJyq7wQJRK99jE9O7c10brUxwE9SUUSPS4VLbzBsJg7FHHyPMz9n7latJo8bleuvBpN3jsF+WS8Ye7wO4nNKL0TWZ08iRM+vOn2v7sB8xm9jY3ePJ/zYbkLG+a7ZvicvGxgM73L2OS761iLPKcxmTrX+ww8J0JGu1MnyTtJZuw7pIm4PJbCED29V1K9PFCqPLBBkLxhYka8hXTiPEB7MDzrniA7h5CYvIR9ZzzARcg7TZHyu4sKOb1in9Y7nL9WO6gD2TxSduO8UaQjPQO81Lxw/w69KwL8O4FJ3D2XTju8SE6XPGDWGz0K1VC8YhMsvObCtDyndy49BCclu68cVbxemYu8sGLqOksOzTzj1L47ISBFvLly4Ttk3Oa8RhGHNwzxBj0v5+y7ogaTPA+6QbxiE6w8ubj2PDixzrstZEe9jbKZPPd30rwqMDw8TQXIPFurlTxx0c68jLsePfSJ3LuXTru8yeHpu6Ewcjx5D4a8BvBfvN8Uibs9R6W8lsIQvaEw8rvVUyw8SJQsPebCNDwu8PE8GMo4OxAlkjwJmMA8KaQRvdYlbDwNNxy9ouHXPDffDrxwZv46AK0EPJqCRrpWz6k8/0E0POAs3rxmsoe7zTqwO5mLyzyP7ym7wTzDvFB/aLx5D4a7doj/O67fxDtsO/g7uq9xvMWViTtC/tU7PhnlvIEogjxxRSQ9SJSsPIJA1zyBKAI9ockCPYC9MbxBTXC83xSJvPFVUb1n75c8uiNHOxdf6Drt27A8/FM+vJOvXz3a6QI8UaQjuvqKgzyOhNm831oevF+xYLxjCic8sn6gPDdrOTs3Rv66cP+Ou5785rycBew8J0JGPJOOBbw9Imq8q335O3MOX7xemQs8PtNPPE1L3Tx5dnU4A+EPPLrdsTzfFIm7LJIHPB4yz7zbAdi8FWjtu1h3Cj0oznA8kv55PKgDWbxIINc8xdsePa8cVbzmlHQ8IJSavAgMlrx4XiA8z3dAu2PEET3xm+a75//EvK2Zr7xbqxU8zP2fvOSFJD1xRSS7k44FvPzHkzz5+ne8+tAYvd5jIz1GMuE8yxSAO3KCNDyRuOS8wzO+vObCNDwzQLO7isQjva1TGrz6ioM79GgCPF66Zbx1KpW8qW6pu4RcDTzcJhO9SJQsO5G45LsAiMm8lRErvJqCxjzQbju7w3nTuTclpDywqP88ysCPvAF/xLxfa0u88cChPBjKODyaPLE8k69fvGFiRrvuRgG9ATmvvJEsOr21+EC9KX/WOrmXnDwDAuo8yky6PI1sBDvztxy8PviKPKInbbzbdS276mGQO2Kf1rwn/DC8ZrIHPBRxcj0z+h264d1DPdG0ULxvTqm5bDt4vToTmjuGJcg7tmMRO9YEEr3oJAC9THmdPKn607vcJhM8Zj6yvHR5r7ywYmq83fjSO5mLyzshIEU8EWKiuu9eVjw75dk7fzGHvNl+sjwJJOs8YllBPAtheztz7QQ92lDyvDEDozzEKrk7KnZRvG8pbjsdYI+7yky6OfWAVzzjYGk7NX3DOzrNhDyeIaI8joTZvFcMOryYRba8G7iuu893QDw9RyW7za6FvDUJ7rva6YK9D7rBPD1o/zxCLJa65TaKvHsGAT2g6ly8+tCYu+wqy7xeAHu8vZ1nPBv+QzwfVwo8CMYAvM+91TzKTDq8Ueo4u2uvzTsBf8Q8p+uDvKofDz12tj+8wP+yOlkDtTwYyji6ZdPhPGv14rwqdtE8YPf1vLIKy7yFLs28ouFXvO1PBj15pDU83xQJPdfWUTz8x5O64kgUPBQKA72eIaK6A3a/OyzYnLoYnPg4XMNqPdxsqLsKSaY7pfSIvBoshLupKJS8G0TZOu/SqzzFcE47cvaJPA19Mb14dQC8sVllvJmwhjycBey8cvaJOmSWUbvRtFC8WtX0O2r+57twIGm8yeFpvFuG2rzCyO08PUelPK5rbzouFS29uCxMPQAUdDqtma88wqeTu5gge7zH8/O7l067PJdOO7uKxCO8/xx5vKt9+TztTwa8OhOaO+Q/Dzw33w49CZhAvSubjDydttG8IdovPIADR7stHrI7ATmvvOAs3rzL2OQ69K4XvNccZ7zlV2S8c+0EPfNDxzydKqc6LLPhO8YhtDyJhxM9H1eKOaNMKLtOcBg9HPU+PTsrbzvT0Ia8BG26PB2mpDp7TJa8wP8yPVvM77t0ea86eTBgvFurFT1C/tW7CkkmvKOSPT2aPDG9lGDFPAhSq7u5UYc8l5TQPFh3ijz9vg68lGBFO4/vKTxViZS7eQ8GPTNAs7xmsoe8o0yoPJfaZbwlvyA8IazvO0XsS717TJY8flvmOgHFWbyWnVW8mdFgvJbCkDynDF68" + } + ], + "model": "text-embedding-3-small", + "usage": { + "prompt_tokens": 9, + "total_tokens": 9 + } + } + """; + + using VerbatimHttpHandler handler = new(Input, Output); + using HttpClient httpClient = new(handler); + using IEmbeddingGenerator> generator = new OpenAIClient(new ApiKeyCredential("apikey"), new OpenAIClientOptions + { + Transport = new HttpClientPipelineTransport(httpClient), + }).AsEmbeddingGenerator("text-embedding-3-small"); + + var response = await generator.GenerateAsync([ + "hello, world!", + "red, white, blue", + ]); + Assert.NotNull(response); + Assert.Equal(2, response.Count); + + Assert.NotNull(response.Usage); + Assert.Equal(9, response.Usage.InputTokenCount); + Assert.Equal(9, response.Usage.TotalTokenCount); + + foreach (Embedding e in response) + { + Assert.Equal("text-embedding-3-small", e.ModelId); + Assert.NotNull(e.CreatedAt); + Assert.Equal(1536, e.Vector.Length); + Assert.Contains(e.Vector.ToArray(), f => !f.Equals(0)); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs new file mode 100644 index 00000000000..ba1c85d700a --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientBuilderTest.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientBuilderTest +{ + [Fact] + public void PassesServiceProviderToFactories() + { + var expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + using TestChatClient expectedResult = new(); + var builder = new ChatClientBuilder(expectedServiceProvider); + + builder.Use((serviceProvider, innerClient) => + { + Assert.Same(expectedServiceProvider, serviceProvider); + return expectedResult; + }); + + using TestChatClient innerClient = new(); + Assert.Equal(expectedResult, builder.Use(innerClient: innerClient)); + } + + [Fact] + public void BuildsPipelineInOrderAdded() + { + // Arrange + using TestChatClient expectedInnerClient = new(); + var builder = new ChatClientBuilder(); + + builder.Use(next => new InnerClientCapturingChatClient("First", next)); + builder.Use(next => new InnerClientCapturingChatClient("Second", next)); + builder.Use(next => new InnerClientCapturingChatClient("Third", next)); + + // Act + var first = (InnerClientCapturingChatClient)builder.Use(expectedInnerClient); + + // Assert + Assert.Equal("First", first.Name); + var second = (InnerClientCapturingChatClient)first.InnerClient; + Assert.Equal("Second", second.Name); + var third = (InnerClientCapturingChatClient)second.InnerClient; + Assert.Equal("Third", third.Name); + Assert.Same(expectedInnerClient, third.InnerClient); + } + + [Fact] + public void DoesNotAcceptNullInnerService() + { + Assert.Throws(() => new ChatClientBuilder().Use((IChatClient)null!)); + } + + [Fact] + public void DoesNotAcceptNullFactories() + { + ChatClientBuilder builder = new(); + Assert.Throws(() => builder.Use((Func)null!)); + Assert.Throws(() => builder.Use((Func)null!)); + } + + [Fact] + public void DoesNotAllowFactoriesToReturnNull() + { + ChatClientBuilder builder = new(); + builder.Use(_ => null!); + var ex = Assert.Throws(() => builder.Use(new TestChatClient())); + Assert.Contains("entry at index 0", ex.Message); + } + + private sealed class InnerClientCapturingChatClient(string name, IChatClient innerClient) : DelegatingChatClient(innerClient) + { +#pragma warning disable S3604 // False positive: Member initializer values should not be redundant + public string Name { get; } = name; +#pragma warning restore S3604 + public new IChatClient InnerClient => base.InnerClient; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs new file mode 100644 index 00000000000..0e776b4fee5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -0,0 +1,256 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Text.Json; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ChatClientStructuredOutputExtensionsTests +{ + [Fact] + public async Task SuccessUsage() + { + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]) + { + CompletionId = "test", + CreatedAt = DateTimeOffset.UtcNow, + ModelId = "someModel", + RawRepresentation = new object(), + Usage = new(), + }; + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Null(responseFormat.Schema); + Assert.Null(responseFormat.SchemaName); + Assert.Null(responseFormat.SchemaDescription); + + // The inner client receives a trailing "system" message with the schema instruction + Assert.Collection(messages, + message => Assert.Equal("Hello", message.Text), + message => + { + Assert.Equal(ChatRole.System, message.Role); + Assert.Contains("Respond with a JSON value", message.Text); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); + foreach (Species v in Enum.GetValues(typeof(Species))) + { + Assert.Contains(v.ToString(), message.Text); // All enum values are described as strings + } + }); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + Assert.Equal(expectedCompletion.CompletionId, response.CompletionId); + Assert.Equal(expectedCompletion.CreatedAt, response.CreatedAt); + Assert.Equal(expectedCompletion.ModelId, response.ModelId); + Assert.Same(expectedCompletion.RawRepresentation, response.RawRepresentation); + Assert.Same(expectedCompletion.Usage, response.Usage); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // Doesn't mutate history (or at least, reverts any changes) + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + + [Fact] + public async Task FailureUsage_InvalidJson() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, "This is not valid JSON")]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Contains("invalid", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task FailureUsage_NullJson() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, "null")]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Equal("The deserialized response is null", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task FailureUsage_NoJsonInResponse() + { + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, [new ImageContent("https://example.com")])]); + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion), + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + var ex = Assert.Throws(() => response.Result); + Assert.Equal("The response did not contain text to be deserialized", ex.Message); + + Assert.False(response.TryGetResult(out var tryGetResult)); + Assert.Null(tryGetResult); + } + + [Fact] + public async Task CanUseNativeStructuredOutput() + { + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + var responseFormat = Assert.IsType(options!.ResponseFormat); + Assert.Equal(nameof(Animal), responseFormat.SchemaName); + Assert.Equal("Some test description", responseFormat.SchemaDescription); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", responseFormat.Schema); + foreach (Species v in Enum.GetValues(typeof(Species))) + { + Assert.Contains(v.ToString(), responseFormat.Schema); // All enum values are described as strings + } + + // The chat history isn't mutated any further, since native structured output is used instead of a prompt + Assert.Equal("Hello", Assert.Single(messages).Text); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory, useNativeJsonSchema: true); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // History remains unmutated + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + + [Fact] + public async Task CanSpecifyCustomJsonSerializationOptions() + { + var jso = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult, jso))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Collection(messages, + message => Assert.Equal("Hello", message.Text), + message => + { + Assert.Equal(ChatRole.System, message.Role); + Assert.Contains("Respond with a JSON value", message.Text); + Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); + Assert.DoesNotContain(nameof(Animal.FullName), message.Text); // The JSO uses snake_case + Assert.Contains("full_name", message.Text); // The JSO uses snake_case + Assert.DoesNotContain(nameof(Species.Tiger), message.Text); // The JSO doesn't use enum-to-string conversion + }); + + return Task.FromResult(expectedCompletion); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory, jso); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + } + + [Fact] + public async Task HandlesBackendReturningMultipleObjects() + { + // A very common failure mode for GPT 3.5 Turbo is that instead of returning a single top-level JSON object, + // it may return multiple, particularly when function calling is involved. + // See https://community.openai.com/t/2-json-objects-returned-when-using-function-calling-and-json-mode/574348 + // Fortunately we can work around this without breaking any cases of valid output. + + var expectedResult = new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger }; + var resultDuplicatedJson = JsonSerializer.Serialize(expectedResult) + Environment.NewLine + JsonSerializer.Serialize(expectedResult); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatCompletion([new ChatMessage(ChatRole.Assistant, resultDuplicatedJson)])); + }, + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory); + + // The completion contains the deserialized result and other completion properties + Assert.Equal(1, response.Result.Id); + Assert.Equal("Tigger", response.Result.FullName); + Assert.Equal(Species.Tiger, response.Result.Species); + } + + [Description("Some test description")] + private class Animal + { + public int Id { get; set; } + public string? FullName { get; set; } + public Species Species { get; set; } + } + + private enum Species + { + Bear, + Tiger, + Walrus, + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs new file mode 100644 index 00000000000..a27761c99ec --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ConfigureOptionsChatClientTests.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class ConfigureOptionsChatClientTests +{ + [Fact] + public void ConfigureOptionsChatClient_InvalidArgs_Throws() + { + Assert.Throws("innerClient", () => new ConfigureOptionsChatClient(null!, _ => new ChatOptions())); + Assert.Throws("configureOptions", () => new ConfigureOptionsChatClient(new TestChatClient(), null!)); + } + + [Fact] + public void UseChatOptions_InvalidArgs_Throws() + { + var builder = new ChatClientBuilder(); + Assert.Throws("configureOptions", () => builder.UseChatOptions(null!)); + } + + [Fact] + public async Task ConfigureOptions_ReturnedInstancePassedToNextClient() + { + ChatOptions providedOptions = new(); + ChatOptions returnedOptions = new(); + ChatCompletion expectedCompletion = new(Array.Empty()); + var expectedUpdates = Enumerable.Range(0, 3).Select(i => new StreamingChatCompletionUpdate()).ToArray(); + using CancellationTokenSource cts = new(); + + using IChatClient innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Same(returnedOptions, options); + Assert.Equal(cts.Token, cancellationToken); + return Task.FromResult(expectedCompletion); + }, + + CompleteStreamingAsyncCallback = (messages, options, cancellationToken) => + { + Assert.Same(returnedOptions, options); + Assert.Equal(cts.Token, cancellationToken); + return YieldUpdates(expectedUpdates); + }, + }; + + using var client = new ChatClientBuilder() + .UseChatOptions(options => + { + Assert.Same(providedOptions, options); + return returnedOptions; + }) + .Use(innerClient); + + var completion = await client.CompleteAsync(Array.Empty(), providedOptions, cts.Token); + Assert.Same(expectedCompletion, completion); + + int i = 0; + await using var e = client.CompleteStreamingAsync(Array.Empty(), providedOptions, cts.Token).GetAsyncEnumerator(); + while (i < expectedUpdates.Length) + { + Assert.True(await e.MoveNextAsync()); + Assert.Same(expectedUpdates[i++], e.Current); + } + + Assert.False(await e.MoveNextAsync()); + + static async IAsyncEnumerable YieldUpdates(StreamingChatCompletionUpdate[] updates) + { + foreach (var update in updates) + { + await Task.Yield(); + yield return update; + } + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs new file mode 100644 index 00000000000..650a8fdd162 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/CustomAIContentJsonContext.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +[JsonSerializable(typeof(DistributedCachingChatClientTest.CustomAIContent1))] +[JsonSerializable(typeof(DistributedCachingChatClientTest.CustomAIContent2))] +internal sealed partial class CustomAIContentJsonContext : JsonSerializerContext; diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs new file mode 100644 index 00000000000..9bbfbea98c3 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DependencyInjectionPatterns.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DependencyInjectionPatterns +{ + private IServiceCollection ServiceCollection { get; } = new ServiceCollection(); + + [Fact] + public void CanRegisterScopedUsingGenericType() + { + // Arrange/Act + ServiceCollection.AddChatClient(builder => builder + .UseScopedMiddleware() + .Use(new TestChatClient())); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance1Copy = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer *AND* inner client + var outer1 = Assert.IsType(instance1); + var outer2 = Assert.IsType(instance2); + var inner1 = Assert.IsType(((ScopedChatClient)instance1).InnerClient); + var inner2 = Assert.IsType(((ScopedChatClient)instance2).InnerClient); + + Assert.NotSame(outer1.Services, outer2.Services); + Assert.NotSame(instance1, instance2); + Assert.NotSame(inner1, inner2); + Assert.Same(instance1, instance1Copy); // From the same scope + } + + [Fact] + public void CanRegisterScopedUsingFactory() + { + // Arrange/Act + ServiceCollection.AddChatClient(builder => + { + builder.UseScopedMiddleware(); + return builder.Use(new TestChatClient { Services = builder.Services }); + }); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer *AND* inner client + var outer1 = Assert.IsType(instance1); + var outer2 = Assert.IsType(instance2); + var inner1 = Assert.IsType(((ScopedChatClient)instance1).InnerClient); + var inner2 = Assert.IsType(((ScopedChatClient)instance2).InnerClient); + + Assert.Same(outer1.Services, inner1.Services); + Assert.Same(outer2.Services, inner2.Services); + Assert.NotSame(outer1.Services, outer2.Services); + } + + [Fact] + public void CanRegisterScopedUsingSharedInstance() + { + // Arrange/Act + using var singleton = new TestChatClient(); + ServiceCollection.AddChatClient(builder => + { + builder.UseScopedMiddleware(); + return builder.Use(singleton); + }); + + // Assert + var services = ServiceCollection.BuildServiceProvider(); + using var scope1 = services.CreateScope(); + using var scope2 = services.CreateScope(); + var instance1 = scope1.ServiceProvider.GetRequiredService(); + var instance2 = scope2.ServiceProvider.GetRequiredService(); + + // Each scope gets a distinct outer instance, but the same inner client + Assert.IsType(instance1); + Assert.IsType(instance2); + Assert.Same(singleton, ((ScopedChatClient)instance1).InnerClient); + Assert.Same(singleton, ((ScopedChatClient)instance2).InnerClient); + } + + public class ScopedChatClient(IServiceProvider services, IChatClient inner) : DelegatingChatClient(inner) + { + public new IChatClient InnerClient => base.InnerClient; + public IServiceProvider Services => services; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs new file mode 100644 index 00000000000..35ced372eb2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/DistributedCachingChatClientTest.cs @@ -0,0 +1,703 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DistributedCachingChatClientTest +{ + private readonly TestInMemoryCacheStorage _storage = new(); + + [Fact] + public async Task CachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var expectedCompletion = new ChatCompletion([ + new(new ChatRole("fakeRole"), "This is some content") + { + AdditionalProperties = new() { ["a"] = "b" }, + Contents = [new FunctionCallContent("someCallId", "functionName", new Dictionary + { + ["arg1"] = "value1", + ["arg2"] = 123, + ["arg3"] = 123.4, + ["arg4"] = true, + ["arg5"] = false, + ["arg6"] = null + })] + } + ]) + { + CompletionId = "someId", + Usage = new() + { + InputTokenCount = 123, + OutputTokenCount = 456, + TotalTokenCount = 99999, + }, + CreatedAt = DateTimeOffset.UtcNow, + ModelId = "someModel", + AdditionalProperties = new() { ["key1"] = "value1", ["key2"] = 123 } + }; + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + return Task.FromResult(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + Assert.Same(expectedCompletion, result1); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + AssertCompletionsEqual(expectedCompletion, result2); + + // Act/Assert 2: Cache misses do not return cached results + await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some modified input")]); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task AllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async delegate + { + innerCallCount++; + await completionTcs.Task; + return new ChatCompletion([new(ChatRole.Assistant, "Hello")]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + var result2 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert 1 + Assert.Equal(2, innerCallCount); + Assert.False(result1.IsCompleted); + Assert.False(result2.IsCompleted); + completionTcs.SetResult(true); + Assert.Equal("Hello", (await result1).Message.Text); + Assert.Equal("Hello", (await result2).Message.Text); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + Assert.Equal(2, innerCallCount); + Assert.Equal("Hello", (await result3).Message.Text); + } + + [Fact] + public async Task DoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + throw new InvalidTimeZoneException("some failure"); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var input = new ChatMessage(ChatRole.User, "abc"); + var ex1 = await Assert.ThrowsAsync(() => outer.CompleteAsync([input])); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var ex2 = await Assert.ThrowsAsync(() => outer.CompleteAsync([input])); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task DoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var resolutionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async delegate + { + innerCallCount++; + if (innerCallCount == 1) + { + await resolutionTcs.Task; + } + + return new ChatCompletion([new(ChatRole.Assistant, "A good result")]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // First call gets cancelled + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteAsync([input]); + Assert.False(result1.IsCompleted); + Assert.Equal(1, innerCallCount); + resolutionTcs.SetCanceled(); + await Assert.ThrowsAsync(() => result1); + Assert.True(result1.IsCanceled); + + // Act/Assert: Second call can succeed + var result2 = await outer.CompleteAsync([input]); + Assert.Equal(2, innerCallCount); + Assert.Equal("A good result", result2.Message.Text); + } + + [Fact] + public async Task StreamingCachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + List expectedCompletion = + [ + new() + { + Role = new ChatRole("fakeRole1"), + ChoiceIndex = 3, + AdditionalProperties = new() { ["a"] = "b" }, + Contents = [new TextContent("Chunk1")] + }, + new() + { + Role = new ChatRole("fakeRole2"), + Text = "Chunk2", + Contents = + [ + new FunctionCallContent("someCallId", "someFn", new Dictionary { ["arg1"] = "value1" }), + new UsageContent(new() { InputTokenCount = 123, OutputTokenCount = 456, TotalTokenCount = 99999 }), + ] + } + ]; + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Make the initial request and do a quick sanity check + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await AssertCompletionsEqualAsync(expectedCompletion, result1); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + await AssertCompletionsEqualAsync(expectedCompletion, result2); + + // Act/Assert 2: Cache misses do not return cached results + await ToListAsync(outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some modified input")])); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingCoalescesConsecutiveTextChunksAsync() + { + // Arrange + List expectedCompletion = + [ + new() { Role = ChatRole.Assistant, Text = "This" }, + new() { Role = ChatRole.Assistant, Text = " becomes one chunk" }, + new() { Role = ChatRole.Assistant, Contents = [new FunctionCallContent("callId1", "separator")] }, + new() { Role = ChatRole.Assistant, Text = "... and this" }, + new() { Role = ChatRole.Assistant, Text = " becomes another" }, + new() { Role = ChatRole.Assistant, Text = " one." }, + ]; + + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate { return ToAsyncEnumerableAsync(expectedCompletion); } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await ToListAsync(result1); + + // Act + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Collection(await ToListAsync(result2), + c => Assert.Equal("This becomes one chunk", c.Text), + c => Assert.IsType(Assert.Single(c.Contents)), + c => Assert.Equal("... and this becomes another one.", c.Text)); + } + + [Fact] + public async Task StreamingAllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + List expectedCompletion = + [ + new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + new() { Role = ChatRole.System, Text = "Chunk 2" }, + ]; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(completionTcs.Task, expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + var result2 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert 1 + Assert.NotSame(result1, result2); + var result1Assertion = AssertCompletionsEqualAsync(expectedCompletion, result1); + var result2Assertion = AssertCompletionsEqualAsync(expectedCompletion, result2); + Assert.False(result1Assertion.IsCompleted); + Assert.False(result2Assertion.IsCompleted); + completionTcs.SetResult(true); + await result1Assertion; + await result2Assertion; + Assert.Equal(2, innerCallCount); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = outer.CompleteStreamingAsync([new ChatMessage(ChatRole.User, "some input")]); + await AssertCompletionsEqualAsync(expectedCompletion, result3); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingDoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync(Task.CompletedTask, + [ + () => new() { Role = ChatRole.Assistant, Text = "Chunk 1" }, + () => throw new InvalidTimeZoneException("some failure"), + ]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteStreamingAsync([input]); + var ex1 = await Assert.ThrowsAsync(() => ToListAsync(result1)); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = outer.CompleteStreamingAsync([input]); + var ex2 = await Assert.ThrowsAsync(() => ToListAsync(result2)); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task StreamingDoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteStreamingAsyncCallback = delegate + { + innerCallCount++; + return ToAsyncEnumerableAsync( + innerCallCount == 1 ? completionTcs.Task : Task.CompletedTask, + [() => new() { Role = ChatRole.Assistant, Text = "A good result" }]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // First call gets cancelled + var input = new ChatMessage(ChatRole.User, "abc"); + var result1 = outer.CompleteStreamingAsync([input]); + var result1Assertion = ToListAsync(result1); + Assert.False(result1Assertion.IsCompleted); + completionTcs.SetCanceled(); + await Assert.ThrowsAsync(() => result1Assertion); + Assert.True(result1Assertion.IsCanceled); + Assert.Equal(1, innerCallCount); + + // Act/Assert: Second call can succeed + var result2 = await ToListAsync(outer.CompleteStreamingAsync([input])); + Assert.Equal("A good result", result2[0].Text); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task CacheKeyDoesNotVaryByChatOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async (_, options, _) => + { + innerCallCount++; + await Task.Yield(); + return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act: Call with two different ChatOptions + var result1 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 1" } } + }); + var result2 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 2" } } + }); + + // Assert: Same result + Assert.Equal(1, innerCallCount); + Assert.Equal("value 1", result1.Message.Text); + Assert.Equal("value 1", result2.Message.Text); + } + + [Fact] + public async Task SubclassCanOverrideCacheKeyToVaryByChatOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = async (_, options, _) => + { + innerCallCount++; + await Task.Yield(); + return new([new(ChatRole.Assistant, options!.AdditionalProperties!["someKey"]!.ToString())]); + } + }; + using var outer = new CachingChatClientWithCustomKey(testClient, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options + }; + + // Act: Call with two different ChatOptions + var result1 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 1" } } + }); + var result2 = await outer.CompleteAsync([], new ChatOptions + { + AdditionalProperties = new() { { "someKey", "value 2" } } + }); + + // Assert: Different results + Assert.Equal(2, innerCallCount); + Assert.Equal("value 1", result1.Message.Text); + Assert.Equal("value 2", result2.Message.Text); + } + + [Fact] + public async Task CanCacheCustomContentTypesAsync() + { + // Arrange + var expectedCompletion = new ChatCompletion([ + new(new ChatRole("fakeRole"), + [ + new CustomAIContent1("Hello", DateTime.Now), + new CustomAIContent2("Goodbye", 42), + ]) + ]); + + var serializerOptions = new JsonSerializerOptions(TestJsonSerializerContext.Default.Options); + serializerOptions.TypeInfoResolver = serializerOptions.TypeInfoResolver!.WithAddedModifier(typeInfo => + { + if (typeInfo.Type == typeof(AIContent)) + { + foreach (var t in new Type[] { typeof(CustomAIContent1), typeof(CustomAIContent2) }) + { + typeInfo.PolymorphismOptions!.DerivedTypes.Add(new JsonDerivedType(t, t.Name)); + } + } + }); + serializerOptions.TypeInfoResolverChain.Add(CustomAIContentJsonContext.Default); + + var innerCallCount = 0; + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + innerCallCount++; + return Task.FromResult(expectedCompletion); + } + }; + using var outer = new DistributedCachingChatClient(testClient, _storage) + { + JsonSerializerOptions = serializerOptions + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + AssertCompletionsEqual(expectedCompletion, result1); + + // Act + var result2 = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.Equal(1, innerCallCount); + AssertCompletionsEqual(expectedCompletion, result2); + Assert.NotSame(result2.Message.Contents[0], expectedCompletion.Message.Contents[0]); + Assert.NotSame(result2.Message.Contents[1], expectedCompletion.Message.Contents[1]); + } + + [Fact] + public async Task CanResolveIDistributedCacheFromDI() + { + // Arrange + var services = new ServiceCollection() + .AddSingleton(_storage) + .BuildServiceProvider(); + using var testClient = new TestChatClient + { + CompleteAsyncCallback = delegate + { + return Task.FromResult(new ChatCompletion([ + new(ChatRole.Assistant, [new TextContent("Hey")])])); + } + }; + using var outer = new ChatClientBuilder(services) + .UseDistributedCache(configure: options => + { + options.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(testClient); + + // Act: Make a request that should populate the cache + Assert.Empty(_storage.Keys); + var result = await outer.CompleteAsync([new ChatMessage(ChatRole.User, "some input")]); + + // Assert + Assert.NotNull(result); + Assert.Single(_storage.Keys); + } + + private static async Task> ToListAsync(IAsyncEnumerable values) + { + var result = new List(); + await foreach (var v in values) + { + result.Add(v); + } + + return result; + } + + private static IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) + => ToAsyncEnumerableAsync(Task.CompletedTask, values); + + private static IAsyncEnumerable ToAsyncEnumerableAsync(Task preTask, IEnumerable valueFactories) + => ToAsyncEnumerableAsync(preTask, valueFactories.Select>(v => () => v)); + + private static async IAsyncEnumerable ToAsyncEnumerableAsync(Task preTask, IEnumerable> values) + { + await preTask; + + foreach (var value in values) + { + await Task.Yield(); + yield return value(); + } + } + + private static void AssertCompletionsEqual(ChatCompletion expected, ChatCompletion actual) + { + Assert.Equal(expected.CompletionId, actual.CompletionId); + Assert.Equal(expected.Usage?.InputTokenCount, actual.Usage?.InputTokenCount); + Assert.Equal(expected.Usage?.OutputTokenCount, actual.Usage?.OutputTokenCount); + Assert.Equal(expected.Usage?.TotalTokenCount, actual.Usage?.TotalTokenCount); + Assert.Equal(expected.CreatedAt, actual.CreatedAt); + Assert.Equal(expected.ModelId, actual.ModelId); + Assert.Equal( + JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); + Assert.Equal(expected.Choices.Count, actual.Choices.Count); + + for (var i = 0; i < expected.Choices.Count; i++) + { + Assert.IsType(expected.Choices[i].GetType(), actual.Choices[i]); + Assert.Equal(expected.Choices[i].Role, actual.Choices[i].Role); + Assert.Equal(expected.Choices[i].Text, actual.Choices[i].Text); + Assert.Equal(expected.Choices[i].Contents.Count, actual.Choices[i].Contents.Count); + + for (var itemIndex = 0; itemIndex < expected.Choices[i].Contents.Count; itemIndex++) + { + var expectedItem = expected.Choices[i].Contents[itemIndex]; + var actualItem = actual.Choices[i].Contents[itemIndex]; + Assert.Equal(expectedItem.ModelId, actualItem.ModelId); + Assert.IsType(expectedItem.GetType(), actualItem); + + if (expectedItem is FunctionCallContent expectedFcc) + { + var actualFcc = (FunctionCallContent)actualItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); + } + } + } + } + + private static async Task AssertCompletionsEqualAsync(IReadOnlyList expected, IAsyncEnumerable actual) + { + var actualEnumerator = actual.GetAsyncEnumerator(); + + foreach (var expectedItem in expected) + { + Assert.True(await actualEnumerator.MoveNextAsync()); + + var actualItem = actualEnumerator.Current; + Assert.Equal(expectedItem.Text, actualItem.Text); + Assert.Equal(expectedItem.ChoiceIndex, actualItem.ChoiceIndex); + Assert.Equal(expectedItem.Role, actualItem.Role); + Assert.Equal(expectedItem.Contents.Count, actualItem.Contents.Count); + + for (var itemIndex = 0; itemIndex < expectedItem.Contents.Count; itemIndex++) + { + var expectedItemItem = expectedItem.Contents[itemIndex]; + var actualItemItem = actualItem.Contents[itemIndex]; + Assert.IsType(expectedItemItem.GetType(), actualItemItem); + + if (expectedItemItem is FunctionCallContent expectedFcc) + { + var actualFcc = (FunctionCallContent)actualItemItem; + Assert.Equal(expectedFcc.Name, actualFcc.Name); + Assert.Equal(expectedFcc.CallId, actualFcc.CallId); + + // The correct JSON-round-tripping of AIContent/AIContent is not + // the responsibility of CachingChatClient, so not testing that here. + Assert.Equal( + JsonSerializer.Serialize(expectedFcc.Arguments, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actualFcc.Arguments, TestJsonSerializerContext.Default.Options)); + } + else if (expectedItemItem is UsageContent expectedUsage) + { + var actualUsage = (UsageContent)actualItemItem; + Assert.Equal(expectedUsage.Details.InputTokenCount, actualUsage.Details.InputTokenCount); + Assert.Equal(expectedUsage.Details.OutputTokenCount, actualUsage.Details.OutputTokenCount); + Assert.Equal(expectedUsage.Details.TotalTokenCount, actualUsage.Details.TotalTokenCount); + } + } + } + + Assert.False(await actualEnumerator.MoveNextAsync()); + } + + private sealed class CachingChatClientWithCustomKey(IChatClient innerClient, IDistributedCache storage) + : DistributedCachingChatClient(innerClient, storage) + { + protected override string GetCacheKey(bool streaming, IList chatMessages, ChatOptions? options) + { + var baseKey = base.GetCacheKey(streaming, chatMessages, options); + return baseKey + options?.AdditionalProperties?["someKey"]?.ToString(); + } + } + + public class CustomAIContent1(string text, DateTime date) : AIContent + { + public string Text => text; + public DateTime Date => date; + } + + public class CustomAIContent2(string text, int number) : AIContent + { + public string Text => text; + public int Number => number; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs new file mode 100644 index 00000000000..8ad0c6d7944 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -0,0 +1,352 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class FunctionInvokingChatClientTests +{ + [Fact] + public async Task SupportsSingleFunctionCallPerRequestAsync() + { + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentInvocation) + { + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create((int i) => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + ] + }; + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func1"), + new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 34 } }), + new FunctionCallContent("callId3", "Func2", arguments: new Dictionary { { "i", 56 } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func1", result: "Result 1"), + new FunctionResultContent("callId2", "Func2", result: "Result 2: 34"), + new FunctionResultContent("callId3", "Func2", result: "Result 2: 56"), + ]), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId4", "Func2", arguments: new Dictionary { { "i", 78 } }), + new FunctionCallContent("callId5", "Func1")]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId4", "Func2", result: "Result 2: 78"), + new FunctionResultContent("callId5", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation })); + } + + [Fact] + public async Task ParallelFunctionCallsInvokedConcurrentlyByDefaultAsync() + { + using var barrier = new Barrier(2); + + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create((string arg) => + { + barrier.SignalAndWait(); + return arg + arg; + }, "Func"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), + new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func", result: "hellohello"), + new FunctionResultContent("callId2", "Func", result: "worldworld"), + ]), + new ChatMessage(ChatRole.Assistant, "done"), + ]); + } + + [Fact] + public async Task ConcurrentInvocationOfParallelCallsCanBeDisabledAsync() + { + int activeCount = 0; + + var options = new ChatOptions + { + Tools = [ + AIFunctionFactory.Create(async (string arg) => + { + Interlocked.Increment(ref activeCount); + await Task.Delay(100); + Assert.Equal(1, activeCount); + Interlocked.Decrement(ref activeCount); + return arg + arg; + }, "Func"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [ + new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), + new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), + ]), + new ChatMessage(ChatRole.Tool, [ + new FunctionResultContent("callId1", "Func", result: "hellohello"), + new FunctionResultContent("callId2", "Func", result: "worldworld"), + ]), + new ChatMessage(ChatRole.Assistant, "done"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = false })); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunctionCallingMessages) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + +#pragma warning disable SA1118 // Parameter should not span multiple lines + var finalChat = await InvokeAndAssertAsync( + options, + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], + expected: keepFunctionCallingMessages ? + null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "world") + ], + configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); +#pragma warning restore SA1118 + + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunctionCallingMessages) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(() => "Result 1", "Func1"), + AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), + AIFunctionFactory.Create((int i) => { }, "VoidReturn"), + ] + }; + +#pragma warning disable SA1118 // Parameter should not span multiple lines + var finalChat = await InvokeAndAssertAsync(options, + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], + expected: keepFunctionCallingMessages ? + null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Assistant, "more"), + new ChatMessage(ChatRole.Assistant, "world"), + ], + configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); +#pragma warning restore SA1118 + + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedErrors) + { + var options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.Create(string () => throw new InvalidOperationException("Oh no!"), "Func1"), + ] + }; + + await InvokeAndAssertAsync(options, [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors })); + } + + [Fact] + public async Task RejectsMultipleChoicesAsync() + { + var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); + var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); + + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = async (chatContents, options, cancellationToken) => + { + await Task.Yield(); + + return new ChatCompletion( + [ + new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), + new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), + ]); + } + }; + + IChatClient service = new ChatClientBuilder().UseFunctionInvocation().Use(innerClient); + + List chat = [new ChatMessage(ChatRole.User, "hello")]; + var ex = await Assert.ThrowsAsync( + () => service.CompleteAsync(chat, new ChatOptions { Tools = [func1, func2] })); + + Assert.Contains("only accepts a single choice", ex.Message); + Assert.Single(chat); // It didn't add anything to the chat history + } + + private static async Task> InvokeAndAssertAsync( + ChatOptions options, + List plan, + List? expected = null, + Func? configurePipeline = null) + { + Assert.NotEmpty(plan); + + configurePipeline ??= static b => b.UseFunctionInvocation(); + + using CancellationTokenSource cts = new(); + List chat = [plan[0]]; + int i = 0; + + using var innerClient = new TestChatClient + { + CompleteAsyncCallback = async (contents, actualOptions, actualCancellationToken) => + { + Assert.Same(chat, contents); + Assert.Equal(cts.Token, actualCancellationToken); + + await Task.Yield(); + + return new ChatCompletion([plan[contents.Count]]); + } + }; + + IChatClient service = configurePipeline(new ChatClientBuilder()).Use(innerClient); + + var result = await service.CompleteAsync(chat, options, cts.Token); + chat.Add(result.Message); + + expected ??= plan; + Assert.NotNull(result); + Assert.Equal(expected.Count, chat.Count); + for (; i < expected.Count; i++) + { + var expectedMessage = expected[i]; + var chatMessage = chat[i]; + + Assert.Equal(expectedMessage.Role, chatMessage.Role); + Assert.Equal(expectedMessage.Text, chatMessage.Text); + Assert.Equal(expectedMessage.GetType(), chatMessage.GetType()); + + Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count); + for (int j = 0; j < expectedMessage.Contents.Count; j++) + { + var expectedItem = expectedMessage.Contents[j]; + var chatItem = chatMessage.Contents[j]; + + Assert.Equal(expectedItem.GetType(), chatItem.GetType()); + Assert.Equal(expectedItem.ToString(), chatItem.ToString()); + if (expectedItem is FunctionCallContent expectedFunctionCall) + { + var chatFunctionCall = (FunctionCallContent)chatItem; + Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name); + AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments); + } + else if (expectedItem is FunctionResultContent expectedFunctionResult) + { + var chatFunctionResult = (FunctionResultContent)chatItem; + AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result); + } + } + } + + return chat; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs new file mode 100644 index 00000000000..feb91ac925e --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/LoggingChatClientTests.cs @@ -0,0 +1,121 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class LoggingChatClientTests +{ + [Fact] + public void LoggingChatClient_InvalidArgs_Throws() + { + Assert.Throws("innerClient", () => new LoggingChatClient(null!, NullLogger.Instance)); + Assert.Throws("logger", () => new LoggingChatClient(new TestChatClient(), null!)); + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level) + { + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using IChatClient innerClient = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + return Task.FromResult(new ChatCompletion([new(ChatRole.Assistant, "blue whale")])); + }, + }; + + using IChatClient client = new ChatClientBuilder(services) + .UseLogging() + .Use(innerClient); + + await client.CompleteAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions { FrequencyPenalty = 3.0f }); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteAsync invoked:") && entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteAsync completed:") && entry.Message.Contains("blue whale"))); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteAsync invoked.") && !entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteAsync completed.") && !entry.Message.Contains("blue whale"))); + } + else + { + Assert.Empty(clp.Logger.Entries); + } + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteStreamAsync_LogsStartUpdateCompletion(LogLevel level) + { + CapturingLogger logger = new(level); + + using IChatClient innerClient = new TestChatClient + { + CompleteStreamingAsyncCallback = (messages, options, cancellationToken) => GetUpdatesAsync() + }; + + static async IAsyncEnumerable GetUpdatesAsync() + { + await Task.Yield(); + yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "blue " }; + yield return new StreamingChatCompletionUpdate { Role = ChatRole.Assistant, Text = "whale" }; + } + + using IChatClient client = new ChatClientBuilder() + .UseLogging(logger) + .Use(innerClient); + + await foreach (var update in client.CompleteStreamingAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions { FrequencyPenalty = 3.0f })) + { + // nop + } + + if (level is LogLevel.Trace) + { + Assert.Collection(logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync invoked:") && entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update:") && entry.Message.Contains("blue")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update:") && entry.Message.Contains("whale")), + entry => Assert.Contains("CompleteStreamingAsync completed.", entry.Message)); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(logger.Entries, + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync invoked.") && !entry.Message.Contains("biggest animal")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update.") && !entry.Message.Contains("blue")), + entry => Assert.True(entry.Message.Contains("CompleteStreamingAsync received update.") && !entry.Message.Contains("whale")), + entry => Assert.Contains("CompleteStreamingAsync completed.", entry.Message)); + } + else + { + Assert.Empty(logger.Entries); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs new file mode 100644 index 00000000000..d0056b21b91 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/OpenTelemetryChatClientTests.cs @@ -0,0 +1,220 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using OpenTelemetry.Trace; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class OpenTelemetryChatClientTests +{ + [Fact] + public async Task ExpectedInformationLogged_NonStreaming_Async() + { + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + using var innerClient = new TestChatClient + { + Metadata = new("testservice", new Uri("http://localhost:12345/something"), "amazingmodel"), + CompleteAsyncCallback = async (messages, options, cancellationToken) => + { + await Task.Yield(); + return new ChatCompletion([new ChatMessage(ChatRole.Assistant, "blue whale")]) + { + CompletionId = "id123", + FinishReason = ChatFinishReason.Stop, + Usage = new UsageDetails + { + InputTokenCount = 10, + OutputTokenCount = 20, + TotalTokenCount = 42, + }, + }; + } + }; + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => + { + instance.EnableSensitiveData = true; + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(innerClient); + + await chatClient.CompleteAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions + { + FrequencyPenalty = 3.0f, + MaxOutputTokens = 123, + ModelId = "replacementmodel", + TopP = 4.0f, + PresencePenalty = 5.0f, + ResponseFormat = ChatResponseFormat.Json, + Temperature = 6.0f, + StopSequences = ["hello", "world"], + AdditionalProperties = new() { ["top_k"] = 7.0f }, + }); + + var activity = Assert.Single(activities); + + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + + Assert.Equal("http://localhost:12345/something", activity.GetTagItem("server.address")); + Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); + + Assert.Equal("chat.completions replacementmodel", activity.DisplayName); + Assert.Equal("testservice", activity.GetTagItem("gen_ai.system")); + + Assert.Equal("replacementmodel", activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(3.0f, activity.GetTagItem("gen_ai.request.frequency_penalty")); + Assert.Equal(4.0f, activity.GetTagItem("gen_ai.request.top_p")); + Assert.Equal(5.0f, activity.GetTagItem("gen_ai.request.presence_penalty")); + Assert.Equal(6.0f, activity.GetTagItem("gen_ai.request.temperature")); + Assert.Equal(7.0, activity.GetTagItem("gen_ai.request.top_k")); + Assert.Equal(123, activity.GetTagItem("gen_ai.request.max_tokens")); + Assert.Equal("""["hello", "world"]""", activity.GetTagItem("gen_ai.request.stop_sequences")); + + Assert.Equal("id123", activity.GetTagItem("gen_ai.response.id")); + Assert.Equal("""["stop"]""", activity.GetTagItem("gen_ai.response.finish_reasons")); + Assert.Equal(10, activity.GetTagItem("gen_ai.response.input_tokens")); + Assert.Equal(20, activity.GetTagItem("gen_ai.response.output_tokens")); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } + + [Fact] + public async Task ExpectedInformationLogged_Streaming_Async() + { + var sourceName = Guid.NewGuid().ToString(); + var activities = new List(); + using var tracerProvider = OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build(); + + async static IAsyncEnumerable CallbackAsync( + IList messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.Yield(); + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = "blue ", + CompletionId = "id123", + }; + await Task.Yield(); + yield return new StreamingChatCompletionUpdate + { + Role = ChatRole.Assistant, + Text = "whale", + FinishReason = ChatFinishReason.Stop, + }; + yield return new StreamingChatCompletionUpdate + { + Contents = [new UsageContent(new() + { + InputTokenCount = 10, + OutputTokenCount = 20, + TotalTokenCount = 42, + })], + }; + } + + using var innerClient = new TestChatClient + { + Metadata = new("testservice", new Uri("http://localhost:12345/something"), "amazingmodel"), + CompleteStreamingAsyncCallback = CallbackAsync, + }; + + var chatClient = new ChatClientBuilder() + .UseOpenTelemetry(sourceName, instance => + { + instance.EnableSensitiveData = true; + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(innerClient); + + await foreach (var update in chatClient.CompleteStreamingAsync( + [new(ChatRole.User, "What's the biggest animal?")], + new ChatOptions + { + FrequencyPenalty = 3.0f, + MaxOutputTokens = 123, + ModelId = "replacementmodel", + TopP = 4.0f, + PresencePenalty = 5.0f, + ResponseFormat = ChatResponseFormat.Json, + Temperature = 6.0f, + StopSequences = ["hello", "world"], + AdditionalProperties = new() { ["top_k"] = 7.0 }, + })) + { + // Drain the stream. + } + + var activity = Assert.Single(activities); + + Assert.NotNull(activity.Id); + Assert.NotEmpty(activity.Id); + + Assert.Equal("http://localhost:12345/something", activity.GetTagItem("server.address")); + Assert.Equal(12345, (int)activity.GetTagItem("server.port")!); + + Assert.Equal("chat.completions replacementmodel", activity.DisplayName); + Assert.Equal("testservice", activity.GetTagItem("gen_ai.system")); + + Assert.Equal("replacementmodel", activity.GetTagItem("gen_ai.request.model")); + Assert.Equal(3.0f, activity.GetTagItem("gen_ai.request.frequency_penalty")); + Assert.Equal(4.0f, activity.GetTagItem("gen_ai.request.top_p")); + Assert.Equal(5.0f, activity.GetTagItem("gen_ai.request.presence_penalty")); + Assert.Equal(6.0f, activity.GetTagItem("gen_ai.request.temperature")); + Assert.Equal(7.0, activity.GetTagItem("gen_ai.request.top_k")); + Assert.Equal(123, activity.GetTagItem("gen_ai.request.max_tokens")); + Assert.Equal("""["hello", "world"]""", activity.GetTagItem("gen_ai.request.stop_sequences")); + + Assert.Equal("id123", activity.GetTagItem("gen_ai.response.id")); + Assert.Equal("""["stop"]""", activity.GetTagItem("gen_ai.response.finish_reasons")); + Assert.Equal(10, activity.GetTagItem("gen_ai.response.input_tokens")); + Assert.Equal(20, activity.GetTagItem("gen_ai.response.output_tokens")); + + Assert.Collection(activity.Events, + evt => + { + Assert.Equal("gen_ai.content.prompt", evt.Name); + Assert.Equal("""[{"role": "user", "content": "What\u0027s the biggest animal?"}]""", evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.prompt").Value); + }, + evt => + { + Assert.Equal("gen_ai.content.completion", evt.Name); + Assert.Contains("whale", (string)evt.Tags.FirstOrDefault(t => t.Key == "gen_ai.completion").Value!); + }); + + Assert.True(activity.Duration.TotalMilliseconds > 0); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs new file mode 100644 index 00000000000..d9ad92dc266 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ScopedChatClientExtensions.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.Extensions.AI; + +public static class ScopedChatClientExtensions +{ + public static ChatClientBuilder UseScopedMiddleware(this ChatClientBuilder builder) + => builder.Use((services, inner) + => new DependencyInjectionPatterns.ScopedChatClient(services, inner)); +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs new file mode 100644 index 00000000000..2b4370222c6 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/DistributedCachingEmbeddingGeneratorTest.cs @@ -0,0 +1,348 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class DistributedCachingEmbeddingGeneratorTest +{ + private readonly TestInMemoryCacheStorage _storage = new(); + private readonly Embedding _expectedEmbedding = new(new float[] { 1.0f, 2.0f, 3.0f }) + { + CreatedAt = DateTimeOffset.Parse("2024-08-01T00:00:00Z"), + ModelId = "someModel", + AdditionalProperties = new() { ["a"] = "b" }, + }; + + [Fact] + public async Task CachesSuccessResultsAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var innerCallCount = 0; + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + innerCallCount++; + return Task.FromResult>>([_expectedEmbedding]); + }, + }; + using var outer = new DistributedCachingEmbeddingGenerator>(testGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Make the initial request and do a quick sanity check + var result1 = await outer.GenerateAsync("abc"); + Assert.Single(result1); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + Assert.Equal(1, innerCallCount); + + // Act + var result2 = await outer.GenerateAsync("abc"); + + // Assert + Assert.Single(result2); + Assert.Equal(1, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + + // Act/Assert 2: Cache misses do not return cached results + await outer.GenerateAsync(["def"]); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task SupportsPartiallyCachedBatchesAsync() + { + // Arrange + + // Verify that all the expected properties will round-trip through the cache, + // even if this involves serialization + var innerCallCount = 0; + Embedding[] expected = Enumerable.Range(0, 10).Select(i => + new Embedding(new[] { 1.0f, 2.0f, 3.0f }) + { + CreatedAt = DateTimeOffset.Parse("2024-08-01T00:00:00Z") + TimeSpan.FromHours(i), + ModelId = $"someModel{i}", + AdditionalProperties = new() { [$"a{i}"] = $"b{i}" }, + }).ToArray(); + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + innerCallCount++; + Assert.Equal(innerCallCount == 1 ? 4 : 6, values.Count()); + return Task.FromResult>>(new(values.Select(i => expected[int.Parse(i)]))); + }, + }; + using var outer = new DistributedCachingEmbeddingGenerator>(testGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Make initial requests for some of the values + var results = await outer.GenerateAsync(["0", "4", "5", "8"]); + Assert.Equal(1, innerCallCount); + Assert.Equal(4, results.Count); + AssertEmbeddingsEqual(expected[0], results[0]); + AssertEmbeddingsEqual(expected[4], results[1]); + AssertEmbeddingsEqual(expected[5], results[2]); + AssertEmbeddingsEqual(expected[8], results[3]); + + // Act/Assert + results = await outer.GenerateAsync(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]); + Assert.Equal(2, innerCallCount); + for (int i = 0; i < 10; i++) + { + AssertEmbeddingsEqual(expected[i], results[i]); + } + + results = await outer.GenerateAsync(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]); + Assert.Equal(2, innerCallCount); + for (int i = 0; i < 10; i++) + { + AssertEmbeddingsEqual(expected[i], results[i]); + } + } + + [Fact] + public async Task AllowsConcurrentCallsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await completionTcs.Task; + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act 1: Concurrent calls before resolution are passed into the inner client + var result1 = outer.GenerateAsync("abc"); + var result2 = outer.GenerateAsync("abc"); + + // Assert 1 + Assert.Equal(2, innerCallCount); + Assert.False(result1.IsCompleted); + Assert.False(result2.IsCompleted); + completionTcs.SetResult(true); + AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); + AssertEmbeddingsEqual(_expectedEmbedding, (await result2)[0]); + + // Act 2: Subsequent calls after completion are resolved from the cache + var result3 = await outer.GenerateAsync("abc"); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, (await result1)[0]); + } + + [Fact] + public async Task DoesNotCacheExceptionResultsAsync() + { + // Arrange + var innerCallCount = 0; + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (value, options, cancellationToken) => + { + innerCallCount++; + throw new InvalidTimeZoneException("some failure"); + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + var ex1 = await Assert.ThrowsAsync(() => outer.GenerateAsync("abc")); + Assert.Equal("some failure", ex1.Message); + Assert.Equal(1, innerCallCount); + + // Act + var ex2 = await Assert.ThrowsAsync(() => outer.GenerateAsync("abc")); + + // Assert + Assert.NotSame(ex1, ex2); + Assert.Equal("some failure", ex2.Message); + Assert.Equal(2, innerCallCount); + } + + [Fact] + public async Task DoesNotCacheCanceledResultsAsync() + { + // Arrange + var innerCallCount = 0; + var resolutionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + if (innerCallCount == 1) + { + await resolutionTcs.Task; + } + + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // First call gets cancelled + var result1 = outer.GenerateAsync("abc"); + Assert.False(result1.IsCompleted); + Assert.Equal(1, innerCallCount); + resolutionTcs.SetCanceled(); + await Assert.ThrowsAnyAsync(() => result1); + Assert.True(result1.IsCanceled); + + // Act/Assert: Second call can succeed + var result2 = await outer.GenerateAsync("abc"); + Assert.Single(result2); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task CacheKeyDoesNotVaryByEmbeddingOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await Task.Yield(); + return [_expectedEmbedding]; + } + }; + using var outer = new DistributedCachingEmbeddingGenerator>(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act: Call with two different options + var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 1" } + }); + var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 2" } + }); + + // Assert: Same result + Assert.Single(result1); + Assert.Single(result2); + Assert.Equal(1, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task SubclassCanOverrideCacheKeyToVaryByOptionsAsync() + { + // Arrange + var innerCallCount = 0; + var completionTcs = new TaskCompletionSource(); + using var innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = async (value, options, cancellationToken) => + { + innerCallCount++; + await Task.Yield(); + return [_expectedEmbedding]; + } + }; + using var outer = new CachingEmbeddingGeneratorWithCustomKey(innerGenerator, _storage) + { + JsonSerializerOptions = TestJsonSerializerContext.Default.Options, + }; + + // Act: Call with two different options + var result1 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 1" } + }); + var result2 = await outer.GenerateAsync("abc", new EmbeddingGenerationOptions + { + AdditionalProperties = new() { ["someKey"] = "value 2" } + }); + + // Assert: Different results + Assert.Single(result1); + Assert.Single(result2); + Assert.Equal(2, innerCallCount); + AssertEmbeddingsEqual(_expectedEmbedding, result1[0]); + AssertEmbeddingsEqual(_expectedEmbedding, result2[0]); + } + + [Fact] + public async Task CanResolveIDistributedCacheFromDI() + { + // Arrange + var services = new ServiceCollection() + .AddSingleton(_storage) + .BuildServiceProvider(); + using var testGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + return Task.FromResult>>([_expectedEmbedding]); + }, + }; + using var outer = new EmbeddingGeneratorBuilder>(services) + .UseDistributedCache(configure: instance => + { + instance.JsonSerializerOptions = TestJsonSerializerContext.Default.Options; + }) + .Use(testGenerator); + + // Act: Make a request that should populate the cache + Assert.Empty(_storage.Keys); + var result = await outer.GenerateAsync("abc"); + + // Assert + Assert.NotNull(result); + Assert.Single(_storage.Keys); + } + + private static void AssertEmbeddingsEqual(Embedding expected, Embedding actual) + { + Assert.Equal(expected.CreatedAt, actual.CreatedAt); + Assert.Equal(expected.ModelId, actual.ModelId); + Assert.Equal(expected.Vector.ToArray(), actual.Vector.ToArray()); + Assert.Equal( + JsonSerializer.Serialize(expected.AdditionalProperties, TestJsonSerializerContext.Default.Options), + JsonSerializer.Serialize(actual.AdditionalProperties, TestJsonSerializerContext.Default.Options)); + } + + private sealed class CachingEmbeddingGeneratorWithCustomKey(IEmbeddingGenerator> innerGenerator, IDistributedCache storage) + : DistributedCachingEmbeddingGenerator>(innerGenerator, storage) + { + protected override string GetCacheKey(string value, EmbeddingGenerationOptions? options) => + base.GetCacheKey(value, options) + options?.AdditionalProperties?["someKey"]?.ToString(); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs new file mode 100644 index 00000000000..357168c3b65 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/EmbeddingGeneratorBuilderTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class EmbeddingGeneratorBuilderTests +{ + [Fact] + public void PassesServiceProviderToFactories() + { + var expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + using var expectedResult = new TestEmbeddingGenerator(); + var builder = new EmbeddingGeneratorBuilder>(expectedServiceProvider); + + builder.Use((serviceProvider, innerClient) => + { + Assert.Same(expectedServiceProvider, serviceProvider); + return expectedResult; + }); + + using var innerGenerator = new TestEmbeddingGenerator(); + Assert.Equal(expectedResult, builder.Use(innerGenerator)); + } + + [Fact] + public void BuildsPipelineInOrderAdded() + { + // Arrange + using var expectedInnerService = new TestEmbeddingGenerator(); + var builder = new EmbeddingGeneratorBuilder>(); + + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("First", next)); + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("Second", next)); + builder.Use(next => new InnerServiceCapturingEmbeddingGenerator("Third", next)); + + // Act + var first = (InnerServiceCapturingEmbeddingGenerator)builder.Use(expectedInnerService); + + // Assert + Assert.Equal("First", first.Name); + var second = (InnerServiceCapturingEmbeddingGenerator)first.InnerGenerator; + Assert.Equal("Second", second.Name); + var third = (InnerServiceCapturingEmbeddingGenerator)second.InnerGenerator; + Assert.Equal("Third", third.Name); + Assert.Same(expectedInnerService, third.InnerGenerator); + } + + [Fact] + public void DoesNotAcceptNullInnerService() + { + Assert.Throws(() => new EmbeddingGeneratorBuilder>().Use((IEmbeddingGenerator>)null!)); + } + + [Fact] + public void DoesNotAcceptNullFactories() + { + var builder = new EmbeddingGeneratorBuilder>(); + Assert.Throws(() => builder.Use((Func>, IEmbeddingGenerator>>)null!)); + Assert.Throws(() => builder.Use((Func>, IEmbeddingGenerator>>)null!)); + } + + [Fact] + public void DoesNotAllowFactoriesToReturnNull() + { + var builder = new EmbeddingGeneratorBuilder>(); + builder.Use(_ => null!); + var ex = Assert.Throws(() => builder.Use(new TestEmbeddingGenerator())); + Assert.Contains("entry at index 0", ex.Message); + } + + private sealed class InnerServiceCapturingEmbeddingGenerator(string name, IEmbeddingGenerator> innerGenerator) : + DelegatingEmbeddingGenerator>(innerGenerator) + { +#pragma warning disable S3604 // False positive: Member initializer values should not be redundant + public string Name { get; } = name; +#pragma warning restore S3604 + public new IEmbeddingGenerator> InnerGenerator => base.InnerGenerator; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs new file mode 100644 index 00000000000..e231e8995fe --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Embeddings/LoggingEmbeddingGeneratorTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class LoggingEmbeddingGeneratorTests +{ + [Fact] + public void LoggingEmbeddingGenerator_InvalidArgs_Throws() + { + Assert.Throws("innerGenerator", () => new LoggingEmbeddingGenerator>(null!, NullLogger.Instance)); + Assert.Throws("logger", () => new LoggingEmbeddingGenerator>(new TestEmbeddingGenerator(), null!)); + } + + [Theory] + [InlineData(LogLevel.Trace)] + [InlineData(LogLevel.Debug)] + [InlineData(LogLevel.Information)] + public async Task CompleteAsync_LogsStartAndCompletion(LogLevel level) + { + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + var services = c.BuildServiceProvider(); + + using IEmbeddingGenerator> innerGenerator = new TestEmbeddingGenerator + { + GenerateAsyncCallback = (values, options, cancellationToken) => + { + return Task.FromResult(new GeneratedEmbeddings>([new Embedding(new float[] { 1f, 2f, 3f })])); + }, + }; + + using IEmbeddingGenerator> generator = new EmbeddingGeneratorBuilder>(services) + .UseLogging() + .Use(innerGenerator); + + await generator.GenerateAsync("Blue whale"); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("GenerateAsync invoked:") && entry.Message.Contains("Blue whale")), + entry => Assert.Contains("GenerateAsync generated 1 embedding(s).", entry.Message)); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("GenerateAsync invoked.") && !entry.Message.Contains("Blue whale")), + entry => Assert.Contains("GenerateAsync generated 1 embedding(s).", entry.Message)); + } + else + { + Assert.Empty(clp.Logger.Entries); + } + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs new file mode 100644 index 00000000000..41ed51cd2a2 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs @@ -0,0 +1,186 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Extensions.AI; + +public class AIFunctionFactoryTest +{ + [Fact] + public void InvalidArguments_Throw() + { + Delegate nullDelegate = null!; + Assert.Throws(() => AIFunctionFactory.Create(nullDelegate)); + Assert.Throws(() => AIFunctionFactory.Create((MethodInfo)null!)); + Assert.Throws(() => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, null)); + Assert.Throws(() => AIFunctionFactory.Create(typeof(List<>).GetMethod("Add")!, new List())); + } + + [Fact] + public async Task Parameters_MappedByName_Async() + { + AIFunction func; + + func = AIFunctionFactory.Create((string a) => a + " " + a); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")])); + + func = AIFunctionFactory.Create((string a, string b) => b + " " + a); + AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")])); + + func = AIFunctionFactory.Create((int a, long b) => a + b); + AssertExtensions.EqualFunctionCallResults(3L, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + } + + [Fact] + public async Task Parameters_DefaultValuesAreUsedButOverridable_Async() + { + AIFunction func = AIFunctionFactory.Create((string a = "test") => a + " " + a); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync()); + AssertExtensions.EqualFunctionCallResults("hello hello", await func.InvokeAsync([new KeyValuePair("a", "hello")])); + } + + [Fact] + public async Task Parameters_AIFunctionContextMappedByType_Async() + { + using var cts = new CancellationTokenSource(); + CancellationToken written; + AIFunction func; + + // As the only parameter + written = default; + func = AIFunctionFactory.Create((AIFunctionContext ctx) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(cancellationToken: cts.Token)); + Assert.Equal(cts.Token, written); + + // As the last + written = default; + func = AIFunctionFactory.Create((int somethingFirst, AIFunctionContext ctx) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(new Dictionary { ["somethingFirst"] = 1, ["ctx"] = new AIFunctionContext() }, cts.Token)); + Assert.Equal(cts.Token, written); + + // As the first + written = default; + func = AIFunctionFactory.Create((AIFunctionContext ctx, int somethingAfter = 0) => + { + Assert.NotNull(ctx); + written = ctx.CancellationToken; + }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync(cancellationToken: cts.Token)); + Assert.Equal(cts.Token, written); + } + + [Fact] + public async Task Returns_AsyncReturnTypesSupported_Async() + { + AIFunction func; + + func = AIFunctionFactory.Create(Task (string a) => Task.FromResult(a + " " + a)); + AssertExtensions.EqualFunctionCallResults("test test", await func.InvokeAsync([new KeyValuePair("a", "test")])); + + func = AIFunctionFactory.Create(ValueTask (string a, string b) => new ValueTask(b + " " + a)); + AssertExtensions.EqualFunctionCallResults("hello world", await func.InvokeAsync([new KeyValuePair("b", "hello"), new KeyValuePair("a", "world")])); + + long result = 0; + func = AIFunctionFactory.Create(async Task (int a, long b) => { result = a + b; await Task.Yield(); }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + Assert.Equal(3, result); + + result = 0; + func = AIFunctionFactory.Create(async ValueTask (int a, long b) => { result = a + b; await Task.Yield(); }); + AssertExtensions.EqualFunctionCallResults(null, await func.InvokeAsync([new KeyValuePair("a", 1), new KeyValuePair("b", 2L)])); + Assert.Equal(3, result); + + func = AIFunctionFactory.Create((int count) => SimpleIAsyncEnumerable(count)); + AssertExtensions.EqualFunctionCallResults(new int[] { 0, 1, 2, 3, 4 }, await func.InvokeAsync([new("count", 5)])); + + static async IAsyncEnumerable SimpleIAsyncEnumerable(int count) + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + yield return i; + } + } + + func = AIFunctionFactory.Create(() => (IAsyncEnumerable)new ThrowingAsyncEnumerable()); + await Assert.ThrowsAsync(() => func.InvokeAsync()); + } + + private sealed class ThrowingAsyncEnumerable : IAsyncEnumerable + { +#pragma warning disable S3717 // Track use of "NotImplementedException" + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => throw new NotImplementedException(); +#pragma warning restore S3717 // Track use of "NotImplementedException" + } + + [Fact] + public void Metadata_DerivedFromLambda() + { + AIFunction func; + + func = AIFunctionFactory.Create(() => "test"); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Empty(func.Metadata.Description); + Assert.Empty(func.Metadata.Parameters); + Assert.Equal(typeof(string), func.Metadata.ReturnParameter.ParameterType); + + func = AIFunctionFactory.Create((string a) => a + " " + a); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Empty(func.Metadata.Description); + Assert.Single(func.Metadata.Parameters); + + func = AIFunctionFactory.Create( + [Description("This is a test function")] ([Description("This is A")] string a, [Description("This is B")] string b) => b + " " + a); + Assert.Contains("Metadata_DerivedFromLambda", func.Metadata.Name); + Assert.Equal("This is a test function", func.Metadata.Description); + Assert.Collection(func.Metadata.Parameters, + p => Assert.Equal("This is A", p.Description), + p => Assert.Equal("This is B", p.Description)); + } + + [Fact] + public void AIFunctionFactoryCreateOptions_ValuesPropagateToAIFunction() + { + IReadOnlyList parameterMetadata = [new AIFunctionParameterMetadata("a")]; + AIFunctionReturnParameterMetadata returnParameterMetadata = new() { ParameterType = typeof(string) }; + IReadOnlyDictionary metadata = new Dictionary { ["a"] = "b" }; + + var options = new AIFunctionFactoryCreateOptions + { + Name = "test name", + Description = "test description", + Parameters = parameterMetadata, + ReturnParameter = returnParameterMetadata, + AdditionalProperties = metadata, + }; + + Assert.Equal("test name", options.Name); + Assert.Equal("test description", options.Description); + Assert.Same(parameterMetadata, options.Parameters); + Assert.Same(returnParameterMetadata, options.ReturnParameter); + Assert.Same(metadata, options.AdditionalProperties); + + AIFunction func = AIFunctionFactory.Create(() => { }, options); + + Assert.Equal("test name", func.Metadata.Name); + Assert.Equal("test description", func.Metadata.Description); + Assert.Equal(parameterMetadata, func.Metadata.Parameters); + Assert.Equal(returnParameterMetadata, func.Metadata.ReturnParameter); + Assert.Equal(metadata, func.Metadata.AdditionalProperties); + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj new file mode 100644 index 00000000000..b3d5e8048f5 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/Microsoft.Extensions.AI.Tests.csproj @@ -0,0 +1,32 @@ + + + Microsoft.Extensions.AI + Unit tests for Microsoft.Extensions.AI. + + + + $(NoWarn);CA1063;CA1861;SA1130;VSTHRD003 + true + + + + true + + + + + + + + + + + + + + + + + + + diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs new file mode 100644 index 00000000000..8ab2cd0cbb0 --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestInMemoryCacheStorage.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Distributed; + +namespace Microsoft.Extensions.AI; + +internal sealed class TestInMemoryCacheStorage : IDistributedCache +{ + private readonly ConcurrentDictionary _storage = new(); + + public ICollection Keys => _storage.Keys; + + public byte[]? Get(string key) + => _storage.TryGetValue(key, out var value) ? value : null; + + public Task GetAsync(string key, CancellationToken token = default) + => Task.FromResult(Get(key)); + + public void Refresh(string key) + { + // In memory, nothing to refresh + } + + public Task RefreshAsync(string key, CancellationToken token = default) + => Task.CompletedTask; + + public void Remove(string key) + => _storage.TryRemove(key, out _); + + public Task RemoveAsync(string key, CancellationToken token = default) + { + Remove(key); + return Task.CompletedTask; + } + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + _storage[key] = value; + } + + public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + { + Set(key, value, options); + return Task.CompletedTask; + } +} diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs new file mode 100644 index 00000000000..e376da86dad --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/TestJsonSerializerContext.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; + +namespace Microsoft.Extensions.AI; + +// These types are directly serialized by DistributedCachingChatClient +[JsonSerializable(typeof(ChatCompletion))] +[JsonSerializable(typeof(IList))] +[JsonSerializable(typeof(IReadOnlyList))] + +// These types are specific to the tests in this project +[JsonSerializable(typeof(bool))] +[JsonSerializable(typeof(double))] +[JsonSerializable(typeof(JsonElement))] +[JsonSerializable(typeof(Embedding))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(DayOfWeek[]))] +[JsonSerializable(typeof(Guid))] +internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs b/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs index a27876703e7..e007d95860a 100644 --- a/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs +++ b/test/TestUtilities/XUnit/ConditionalFactDiscoverer.cs @@ -24,6 +24,7 @@ protected override IXunitTestCase CreateTestCase(ITestFrameworkDiscoveryOptions var skipReason = testMethod.EvaluateSkipConditions(); return skipReason != null ? new SkippedTestCase(skipReason, _diagnosticMessageSink, discoveryOptions.MethodDisplayOrDefault(), TestMethodDisplayOptions.None, testMethod) - : base.CreateTestCase(discoveryOptions, testMethod, factAttribute); + : new SkippedFactTestCase(DiagnosticMessageSink, discoveryOptions.MethodDisplayOrDefault(), + discoveryOptions.MethodDisplayOptionsOrDefault(), testMethod); // Test case skippable at runtime. } } diff --git a/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs b/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs index 846038f8786..b1e53b8ed77 100644 --- a/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs +++ b/test/TestUtilities/XUnit/ConditionalTheoryDiscoverer.cs @@ -3,7 +3,6 @@ // Borrowed from https://github.com/dotnet/aspnetcore/blob/95ed45c67/src/Testing/src/xunit/ -using System; using System.Collections.Generic; using Xunit.Abstractions; using Xunit.Sdk; diff --git a/test/TestUtilities/XUnit/SkipTestException.cs b/test/TestUtilities/XUnit/SkipTestException.cs new file mode 100644 index 00000000000..70f7d53c7d8 --- /dev/null +++ b/test/TestUtilities/XUnit/SkipTestException.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// Borrowed from https://github.com/dotnet/aspnetcore/blob/95ed45c67/src/Testing/src/xunit/ + +using System; + +namespace Microsoft.TestUtilities; + +public class SkipTestException : Exception +{ + public SkipTestException(string reason) + : base(reason) + { + } +} diff --git a/test/TestUtilities/XUnit/SkippedFactTestCase.cs b/test/TestUtilities/XUnit/SkippedFactTestCase.cs new file mode 100644 index 00000000000..79ace15ea6e --- /dev/null +++ b/test/TestUtilities/XUnit/SkippedFactTestCase.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Microsoft.TestUtilities; + +public class SkippedFactTestCase : XunitTestCase +{ + [Obsolete("Called by the de-serializer; should only be called by deriving classes for de-serialization purposes", error: true)] + public SkippedFactTestCase() + { + } + + public SkippedFactTestCase( + IMessageSink diagnosticMessageSink, TestMethodDisplay defaultMethodDisplay, TestMethodDisplayOptions defaultMethodDisplayOptions, + ITestMethod testMethod, object[]? testMethodArguments = null) + : base(diagnosticMessageSink, defaultMethodDisplay, defaultMethodDisplayOptions, testMethod, testMethodArguments) + { + } + + public override async Task RunAsync(IMessageSink diagnosticMessageSink, + IMessageBus messageBus, + object[] constructorArguments, + ExceptionAggregator aggregator, + CancellationTokenSource cancellationTokenSource) + { + using SkippedTestMessageBus skipMessageBus = new(messageBus); + var result = await base.RunAsync(diagnosticMessageSink, skipMessageBus, constructorArguments, aggregator, cancellationTokenSource); + if (skipMessageBus.SkippedTestCount > 0) + { + result.Failed -= skipMessageBus.SkippedTestCount; + result.Skipped += skipMessageBus.SkippedTestCount; + } + + return result; + } +} diff --git a/test/TestUtilities/XUnit/SkippedTestMessageBus.cs b/test/TestUtilities/XUnit/SkippedTestMessageBus.cs new file mode 100644 index 00000000000..230586852b8 --- /dev/null +++ b/test/TestUtilities/XUnit/SkippedTestMessageBus.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Xunit.Abstractions; +using Xunit.Sdk; + +namespace Microsoft.TestUtilities; + +/// Implements message bus to communicate tests skipped via SkipTestException. +public sealed class SkippedTestMessageBus : IMessageBus +{ + private readonly IMessageBus _innerBus; + + public SkippedTestMessageBus(IMessageBus innerBus) + { + _innerBus = innerBus; + } + + public int SkippedTestCount { get; private set; } + + public void Dispose() + { + // nothing to dispose + } + + public bool QueueMessage(IMessageSinkMessage message) + { + var testFailed = message as ITestFailed; + + if (testFailed != null) + { + var exceptionType = testFailed.ExceptionTypes.FirstOrDefault(); + if (exceptionType == typeof(SkipTestException).FullName) + { + SkippedTestCount++; + return _innerBus.QueueMessage(new TestSkipped(testFailed.Test, testFailed.Messages.FirstOrDefault())); + } + } + + // Nothing we care about, send it on its way + return _innerBus.QueueMessage(message); + } +} From e5bbd336e678188d587be1cbae051d872492f2e6 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 8 Oct 2024 13:00:44 -0400 Subject: [PATCH 3/3] Temporarily work around trimming-related warnings --- .../ChatClientStructuredOutputExtensions.cs | 2 ++ .../Functions/AIFunctionFactory.cs | 10 ++++++++++ src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs | 4 ++-- .../Microsoft.Extensions.AI.csproj | 5 +++++ 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 2a8b794c50e..5d16440a8fa 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -68,6 +68,8 @@ public static Task> CompleteAsync( /// The type of structured output to request. [RequiresDynamicCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. " + "Use System.Text.Json source generation for native AOT applications.")] + [RequiresUnreferencedCode("JSON serialization and deserialization might require types that cannot be statically analyzed and might need runtime code generation. " + + "Use System.Text.Json source generation for native AOT applications.")] public static Task> CompleteAsync( this IChatClient chatClient, string chatMessage, diff --git a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs index 0fff0cd64fa..c562db8ca3a 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs @@ -44,6 +44,8 @@ class AIFunctionFactory /// The method to be represented via the created . /// Metadata to use to override defaults inferred from . /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] public static AIFunction Create(Delegate method, AIFunctionFactoryCreateOptions options) { _ = Throw.IfNull(method); @@ -67,6 +69,8 @@ public static AIFunction Create(Delegate method, string? name, string? descripti /// The name to use for the . /// The description to use for the . /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] public static AIFunction Create(Delegate method, JsonSerializerOptions options, string? name = null, string? description = null) { _ = Throw.IfNull(method); @@ -100,6 +104,8 @@ public static AIFunction Create(MethodInfo method, object? target = null) /// /// Metadata to use to override defaults inferred from . /// The created for invoking . + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] public static AIFunction Create(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) { _ = Throw.IfNull(method); @@ -130,6 +136,8 @@ class ReflectionAIFunction : AIFunction /// This should be if and only if is a static method. /// /// Function creation options. + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] public ReflectionAIFunction(MethodInfo method, object? target, AIFunctionFactoryCreateOptions options) { _ = Throw.IfNull(method); @@ -376,6 +384,8 @@ static bool IsAsyncMethod(MethodInfo method) /// /// Gets a delegate for handling the result value of a method, converting it into the to return from the invocation. /// + [RequiresUnreferencedCode("Reflection is used to access types from the supplied MethodInfo.")] + [RequiresDynamicCode("Reflection is used to access types from the supplied MethodInfo.")] private static Type GetReturnMarshaler(MethodInfo method, out Func> marshaler) { // Handle each known return type for the method diff --git a/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs index 71edc9404b6..06317f570a2 100644 --- a/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs +++ b/src/Libraries/Microsoft.Extensions.AI/JsonDefaults.cs @@ -28,9 +28,9 @@ private static JsonSerializerOptions CreateDefaultOptions() var options = new JsonSerializerOptions(JsonSerializerDefaults.Web) { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, -#pragma warning disable IL3050 +#pragma warning disable IL3050, IL2026 // only used when reflection-based serialization is enabled TypeInfoResolver = new DefaultJsonTypeInfoResolver(), -#pragma warning restore IL3050 +#pragma warning restore IL3050, IL2026 }; options.MakeReadOnly(); diff --git a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj index 8e389b61652..39b33458d0c 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj +++ b/src/Libraries/Microsoft.Extensions.AI/Microsoft.Extensions.AI.csproj @@ -18,6 +18,11 @@ true + + + $(NoWarn);IL2026 + + true true