diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs index 8b76682f8c8..0f847dbb296 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatClientStructuredOutputExtensions.cs @@ -1,10 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Collections.Generic; using System.ComponentModel; using System.Reflection; using System.Text.Json; +using System.Text.Json.Nodes; using System.Threading; using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; @@ -44,8 +46,7 @@ public static Task> CompleteAsync( IList chatMessages, ChatOptions? options = null, bool? useNativeJsonSchema = null, - CancellationToken cancellationToken = default) - where T : class => + CancellationToken cancellationToken = default) => CompleteAsync(chatClient, chatMessages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken); /// Sends a user chat text message to the model, requesting a response matching the type . @@ -65,8 +66,7 @@ public static Task> CompleteAsync( string chatMessage, ChatOptions? options = null, bool? useNativeJsonSchema = null, - CancellationToken cancellationToken = default) - where T : class => + CancellationToken cancellationToken = default) => CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken); /// Sends a user chat text message to the model, requesting a response matching the type . @@ -88,8 +88,7 @@ public static Task> CompleteAsync( JsonSerializerOptions serializerOptions, ChatOptions? options = null, bool? useNativeJsonSchema = null, - CancellationToken cancellationToken = default) - where T : class => + CancellationToken cancellationToken = default) => CompleteAsync(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken); /// Sends chat messages to the model, requesting a response matching the type . @@ -116,7 +115,6 @@ public static async Task> CompleteAsync( ChatOptions? options = null, bool? useNativeJsonSchema = null, CancellationToken cancellationToken = default) - where T : class { _ = Throw.IfNull(chatClient); _ = Throw.IfNull(chatMessages); @@ -124,12 +122,33 @@ public static async Task> CompleteAsync( serializerOptions.MakeReadOnly(); - var schemaNode = AIJsonUtilities.CreateJsonSchema( + var schemaElement = AIJsonUtilities.CreateJsonSchema( type: typeof(T), serializerOptions: serializerOptions, inferenceOptions: _inferenceOptions); - var schema = JsonSerializer.Serialize(schemaNode, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement))); + bool isWrappedInObject; + string schema; + if (SchemaRepresentsObject(schemaElement)) + { + // For object-representing schemas, we can use them as-is + isWrappedInObject = false; + schema = JsonSerializer.Serialize(schemaElement, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement))); + } + else + { + // For non-object-representing schemas, we wrap them in an object schema, because all + // the real LLM providers today require an object schema as the root. This is currently + // true even for providers that support native structured output. + isWrappedInObject = true; + schema = JsonSerializer.Serialize(new JsonObject + { + { "$schema", "https://json-schema.org/draft/2020-12/schema" }, + { "type", "object" }, + { "properties", new JsonObject { { "data", JsonElementToJsonNode(schemaElement) } } }, + { "additionalProperties", false }, + }, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonObject))); + } ChatMessage? promptAugmentation = null; options = (options ?? new()).Clone(); @@ -152,7 +171,7 @@ public static async Task> CompleteAsync( // When not using native structured output, augment the chat messages with a schema prompt #pragma warning disable SA1118 // Parameter should not span multiple lines - promptAugmentation = new ChatMessage(ChatRole.System, $$""" + promptAugmentation = new ChatMessage(ChatRole.User, $$""" Respond with a JSON value conforming to the following schema: ``` {{schema}} @@ -166,7 +185,7 @@ public static async Task> CompleteAsync( try { var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false); - return new ChatCompletion(result, serializerOptions); + return new ChatCompletion(result, serializerOptions) { IsWrappedInObject = isWrappedInObject }; } finally { @@ -176,4 +195,32 @@ public static async Task> CompleteAsync( } } } + + private static bool SchemaRepresentsObject(JsonElement schemaElement) + { + if (schemaElement.ValueKind is JsonValueKind.Object) + { + foreach (var property in schemaElement.EnumerateObject()) + { + if (property.NameEquals("type"u8)) + { + return property.Value.ValueKind == JsonValueKind.String + && property.Value.ValueEquals("object"u8); + } + } + } + + return false; + } + + private static JsonNode? JsonElementToJsonNode(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.Array => JsonArray.Create(element), + JsonValueKind.Object => JsonObject.Create(element), + _ => JsonValue.Create(element) + }; + } } diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs index 344a01d2c22..7166f04e744 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/ChatCompletion{T}.cs @@ -57,6 +57,7 @@ public T Result { FailureReason.ResultDidNotContainJson => throw new InvalidOperationException("The response did not contain text to be deserialized"), FailureReason.DeserializationProducedNull => throw new InvalidOperationException("The deserialized response is null"), + FailureReason.ResultDidNotContainDataProperty => throw new InvalidOperationException("The response did not contain the expected 'data' property"), _ => result!, }; } @@ -103,6 +104,12 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) } } + /// + /// Gets or sets a value indicating whether the JSON schema has an extra object wrapper. + /// This is required for any non-JSON-object-typed values such as numbers, enum values, or arrays. + /// + internal bool IsWrappedInObject { get; set; } + private string? GetResultAsJson() { var choice = Choices.Count == 1 ? Choices[0] : null; @@ -125,8 +132,25 @@ public bool TryGetResult([NotNullWhen(true)] out T? result) return default; } + T? deserialized = default; + // If there's an exception here, we want it to propagate, since the Result property is meant to throw directly - var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo)_serializerOptions.GetTypeInfo(typeof(T))); + + if (IsWrappedInObject) + { + if (JsonDocument.Parse(json!).RootElement.TryGetProperty("data", out var data)) + { + json = data.GetRawText(); + } + else + { + failureReason = FailureReason.ResultDidNotContainDataProperty; + return default; + } + } + + deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo)_serializerOptions.GetTypeInfo(typeof(T))); + if (deserialized is null) { failureReason = FailureReason.DeserializationProducedNull; @@ -143,5 +167,6 @@ private enum FailureReason { ResultDidNotContainJson, DeserializationProducedNull, + ResultDidNotContainDataProperty, } } diff --git a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs b/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs index 4ad0603d311..46fe45342f2 100644 --- a/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs +++ b/src/Libraries/Microsoft.Extensions.AI/Utilities/AIJsonUtilities.Schema.cs @@ -231,6 +231,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) const string DescriptionPropertyName = "description"; const string NotPropertyName = "not"; const string TypePropertyName = "type"; + const string PatternPropertyName = "pattern"; const string EnumPropertyName = "enum"; const string PropertiesPropertyName = "properties"; const string AdditionalPropertiesPropertyName = "additionalProperties"; @@ -281,7 +282,20 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema) if (ctx.Path.IsEmpty) { - // We are at the root-level schema node, append parameter-specific metadata + // We are at the root-level schema node, update/append parameter-specific metadata + + // Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand + // schemas with "type": [...], and only understand "type" being a single value. + // STJ represents .NET integer types as ["string", "integer"], which will then lead to an error. + if (TypeIsArrayContainingInteger(schema)) + { + // We don't want to emit any array for "type". In this case we know it contains "integer" + // so reduce the type to that alone, assuming it's the most specific type. + // This makes schemas for Int32 (etc) work with Ollama + JsonObject obj = ConvertSchemaToObject(ref schema); + obj[TypePropertyName] = "integer"; + _ = obj.Remove(PatternPropertyName); + } if (!string.IsNullOrWhiteSpace(key.Description)) { @@ -340,6 +354,22 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema) } } + private static bool TypeIsArrayContainingInteger(JsonNode schema) + { + if (schema["type"] is JsonArray typeArray) + { + foreach (var entry in typeArray) + { + if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue() == "integer") + { + return true; + } + } + } + + return false; + } + private static JsonElement ParseJsonElement(ReadOnlySpan utf8Json) { Utf8JsonReader reader = new(utf8Json); diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs index 266f7ec45e9..db482d26804 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/AIJsonUtilitiesTests.cs @@ -127,10 +127,26 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue() JsonElement resolvedSchema; resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options); Assert.True(JsonElement.DeepEquals(generatedSchema, resolvedSchema)); + } - options = new(options) { NumberHandling = JsonNumberHandling.AllowReadingFromString }; - resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options); - Assert.False(JsonElement.DeepEquals(generatedSchema, resolvedSchema)); + [Fact] + public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString() + { + JsonElement expected = JsonDocument.Parse(""" + { + "type": "integer" + } + """).RootElement; + + JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString }; + AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d) => { }, serializerOptions: options); + + AIFunctionMetadata metadata = func.Metadata; + foreach (var param in metadata.Parameters) + { + JsonElement actualSchema = Assert.IsType(param.Schema); + Assert.True(JsonElement.DeepEquals(expected, actualSchema)); + } } [Description("The type")] diff --git a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs index 634e4a19f9e..3f5ce32fc37 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Integration.Tests/ChatClientIntegrationTests.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.InteropServices; using System.Text; using System.Text.RegularExpressions; using System.Threading.Tasks; @@ -569,7 +570,7 @@ public virtual async Task CompleteAsync_StructuredOutput() var response = await _chatClient.CompleteAsync(""" Who is described in the following sentence? - Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + Jimbo Smith is a 35-year-old programmer from Cardiff, Wales. """); Assert.Equal("Jimbo Smith", response.Result.FullName); @@ -578,6 +579,86 @@ Who is described in the following sentence? Assert.Equal(JobType.Programmer, response.Result.Job); } + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputArray() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + Who are described in the following sentence? + Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + Josh Simpson is a 25-year-old software developer from Newport, Wales. + """); + + Assert.Equal(2, response.Result.Length); + Assert.Contains(response.Result, x => x.FullName == "Jimbo Smith"); + Assert.Contains(response.Result, x => x.FullName == "Josh Simpson"); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputInteger() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + There were 14 abstractions for AI programming, which was too many. + To fix this we added another one. How many are there now? + """); + + Assert.Equal(15, response.Result); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputString() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + The software developer, Jimbo Smith, is a 35-year-old from Cardiff, Wales. + What's his full name? + """); + + Assert.Equal("Jimbo Smith", response.Result); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputBool_True() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + Is there at least one software developer from Cardiff? + """); + + Assert.True(response.Result); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputBool_False() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + Jimbo Smith is a 35-year-old software developer from Cardiff, Wales. + Can we be sure that he is a medical doctor? + """); + + Assert.False(response.Result); + } + + [ConditionalFact] + public virtual async Task CompleteAsync_StructuredOutputEnum() + { + SkipIfNotEnabled(); + + var response = await _chatClient.CompleteAsync(""" + I'm using a Macbook Pro with an M2 chip. What architecture am I using? + """); + + Assert.Equal(Architecture.Arm64, response.Result); + } + [ConditionalFact] public virtual async Task CompleteAsync_StructuredOutput_WithFunctions() { diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs index eea22abfacb..acb6142935e 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/ChatClientStructuredOutputExtensionsTests.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.ComponentModel; using System.Text.Json; +using System.Text.RegularExpressions; using System.Threading.Tasks; using Xunit; @@ -34,12 +35,12 @@ public async Task SuccessUsage() Assert.Null(responseFormat.SchemaName); Assert.Null(responseFormat.SchemaDescription); - // The inner client receives a trailing "system" message with the schema instruction + // The inner client receives a trailing "user" message with the schema instruction Assert.Collection(messages, message => Assert.Equal("Hello", message.Text), message => { - Assert.Equal(ChatRole.System, message.Role); + Assert.Equal(ChatRole.User, message.Role); Assert.Contains("Respond with a JSON value", message.Text); Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); foreach (Species v in Enum.GetValues(typeof(Species))) @@ -73,6 +74,39 @@ public async Task SuccessUsage() Assert.Equal("Hello", Assert.Single(chatHistory).Text); } + [Fact] + public async Task WrapsNonObjectValuesInDataProperty() + { + var expectedResult = new { data = 123 }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(expectedResult))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => + { + var suppliedSchemaMatch = Regex.Match(messages[1].Text!, "```(.*?)```", RegexOptions.Singleline); + Assert.True(suppliedSchemaMatch.Success); + Assert.Equal(""" + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "data": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "integer" + } + }, + "additionalProperties": false + } + """, suppliedSchemaMatch.Groups[1].Value.Trim()); + return Task.FromResult(expectedCompletion); + }, + }; + + var response = await client.CompleteAsync("Hello"); + Assert.Equal(123, response.Result); + } + [Fact] public async Task FailureUsage_InvalidJson() { @@ -206,6 +240,34 @@ public async Task CanUseNativeStructuredOutputWithSanitizedTypeName() Assert.Equal("Hello", Assert.Single(chatHistory).Text); } + [Fact] + public async Task CanUseNativeStructuredOutputWithArray() + { + var expectedResult = new[] { new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } }; + var payload = new { data = expectedResult }; + var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(payload))]); + + using var client = new TestChatClient + { + CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion) + }; + + var chatHistory = new List { new(ChatRole.User, "Hello") }; + var response = await client.CompleteAsync(chatHistory, useNativeJsonSchema: true); + + // The completion contains the deserialized result and other completion properties + Assert.Single(response.Result!); + Assert.Equal("Tigger", response.Result[0].FullName); + Assert.Equal(Species.Tiger, response.Result[0].Species); + + // TryGetResult returns the same value + Assert.True(response.TryGetResult(out var tryGetResultOutput)); + Assert.Same(response.Result, tryGetResultOutput); + + // History remains unmutated + Assert.Equal("Hello", Assert.Single(chatHistory).Text); + } + [Fact] public async Task CanSpecifyCustomJsonSerializationOptions() { @@ -224,7 +286,7 @@ public async Task CanSpecifyCustomJsonSerializationOptions() message => Assert.Equal("Hello", message.Text), message => { - Assert.Equal(ChatRole.System, message.Role); + Assert.Equal(ChatRole.User, message.Role); Assert.Contains("Respond with a JSON value", message.Text); Assert.Contains("https://json-schema.org/draft/2020-12/schema", message.Text); Assert.DoesNotContain(nameof(Animal.FullName), message.Text); // The JSO uses snake_case