Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structured output improvements (continuation of PR 5522) #5560

Merged
merged 14 commits into from
Oct 24, 2024
Merged
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -44,8 +46,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, chatMessages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -65,8 +66,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
string chatMessage,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -88,8 +88,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
JsonSerializerOptions serializerOptions,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends chat messages to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -116,20 +115,35 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class
{
_ = Throw.IfNull(chatClient);
_ = Throw.IfNull(chatMessages);
_ = Throw.IfNull(serializerOptions);

serializerOptions.MakeReadOnly();

var schemaNode = AIJsonUtilities.CreateJsonSchema(
var schemaElement = AIJsonUtilities.CreateJsonSchema(
type: typeof(T),
serializerOptions: serializerOptions,
inferenceOptions: _inferenceOptions);

var schema = JsonSerializer.Serialize(schemaNode, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)));
var isWrappedInObject = false;
if (!SchemaRepresentsObject(schemaElement))
{
// For non-object-representing schemas, we wrap them in an object schema, because all
// the real LLM providers today require an object schema as the root. This is currently
// true even for providers that support native structured output.
isWrappedInObject = true;
schemaElement = JsonSerializer.SerializeToElement(new JsonObject
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
{
{ "$schema", "https://json-schema.org/draft/2020-12/schema" },
{ "type", "object" },
{ "properties", new JsonObject { { "data", JsonElementToJsonNode(schemaElement) } } },
{ "additionalProperties", false },
}, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonObject)))!;
}

var schema = JsonSerializer.Serialize(schemaElement, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)));

ChatMessage? promptAugmentation = null;
options = (options ?? new()).Clone();
Expand All @@ -152,7 +166,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

// When not using native structured output, augment the chat messages with a schema prompt
#pragma warning disable SA1118 // Parameter should not span multiple lines
promptAugmentation = new ChatMessage(ChatRole.System, $$"""
promptAugmentation = new ChatMessage(ChatRole.User, $$"""
Respond with a JSON value conforming to the following schema:
```
{{schema}}
Expand All @@ -166,7 +180,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
try
{
var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
return new ChatCompletion<T>(result, serializerOptions);
return new ChatCompletion<T>(result, serializerOptions) { IsWrappedInObject = isWrappedInObject };
}
finally
{
Expand All @@ -176,4 +190,31 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
}
}
}

private static bool SchemaRepresentsObject(JsonElement schemaElement)
{
if (schemaElement.ValueKind is JsonValueKind.Object)
{
foreach (var property in schemaElement.EnumerateObject())
{
if (property.NameEquals("type"u8))
{
return property.Value.ValueKind == JsonValueKind.String
&& property.Value.ValueEquals("object"u8);
}
}
}

return false;
}

private static JsonNode? JsonElementToJsonNode(JsonElement element)
{
return element.ValueKind switch
{
JsonValueKind.Array => JsonArray.Create(element),
SteveSandersonMS marked this conversation as resolved.
Show resolved Hide resolved
JsonValueKind.Object => JsonObject.Create(element),
_ => JsonValue.Create(element)
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
}
}

/// <summary>
/// Gets or sets a value indicating whether the JSON schema has an extra object wrapper.
/// This is required for any non-JSON-object-typed values such as numbers, enum values, or arrays.
/// </summary>
internal bool IsWrappedInObject { get; set; }

private string? GetResultAsJson()
{
var choice = Choices.Count == 1 ? Choices[0] : null;
Expand All @@ -125,8 +131,23 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
return default;
}

T? deserialized = default;

// If there's an exception here, we want it to propagate, since the Result property is meant to throw directly
var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));

if (IsWrappedInObject)
{
var doc = JsonDocument.Parse(json!);
if (doc.RootElement.TryGetProperty("data", out var data))
{
deserialized = DeserializeFirstTopLevelObject(data.GetRawText(), (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));
}
}
else
{
deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));
}

if (deserialized is null)
{
failureReason = FailureReason.DeserializationProducedNull;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
const string DescriptionPropertyName = "description";
const string NotPropertyName = "not";
const string TypePropertyName = "type";
const string PatternPropertyName = "pattern";
const string EnumPropertyName = "enum";
const string PropertiesPropertyName = "properties";
const string AdditionalPropertiesPropertyName = "additionalProperties";
Expand Down Expand Up @@ -281,7 +282,20 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)

if (ctx.Path.IsEmpty)
{
// We are at the root-level schema node, append parameter-specific metadata
// We are at the root-level schema node, update/append parameter-specific metadata

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
if (TypeIsArrayContainingInteger(schema))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
_ = obj.Remove(PatternPropertyName);
}

if (!string.IsNullOrWhiteSpace(key.Description))
{
Expand Down Expand Up @@ -340,6 +354,22 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
}
}

private static bool TypeIsArrayContainingInteger(JsonNode schema)
{
if (schema["type"] is JsonArray typeArray)
{
foreach (var entry in typeArray)
{
if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue<string>() == "integer")
{
return true;
}
}
}

return false;
}

private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,26 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue()
JsonElement resolvedSchema;
resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options);
Assert.True(JsonElement.DeepEquals(generatedSchema, resolvedSchema));
}

options = new(options) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options);
Assert.False(JsonElement.DeepEquals(generatedSchema, resolvedSchema));
[Fact]
public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{
JsonElement expected = JsonDocument.Parse("""
{
"type": "integer"
}
""").RootElement;

JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d) => { }, serializerOptions: options);

AIFunctionMetadata metadata = func.Metadata;
foreach (var param in metadata.Parameters)
{
JsonElement actualSchema = Assert.IsType<JsonElement>(param.Schema);
Assert.True(JsonElement.DeepEquals(expected, actualSchema));
}
}

[Description("The type")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
Expand Down Expand Up @@ -569,7 +570,7 @@ public virtual async Task CompleteAsync_StructuredOutput()

var response = await _chatClient.CompleteAsync<Person>("""
Who is described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Jimbo Smith is a 35-year-old programmer from Cardiff, Wales.
""");

Assert.Equal("Jimbo Smith", response.Result.FullName);
Expand All @@ -578,6 +579,86 @@ Who is described in the following sentence?
Assert.Equal(JobType.Programmer, response.Result.Job);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputArray()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Person[]>("""
Who are described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Josh Simpson is a 25-year-old software developer from Newport, Wales.
""");

Assert.Equal(2, response.Result.Length);
Assert.Contains(response.Result, x => x.FullName == "Jimbo Smith");
Assert.Contains(response.Result, x => x.FullName == "Josh Simpson");
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputInteger()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<int>("""
There were 14 abstractions for AI programming, which was too many.
To fix this we added another one. How many are there now?
""");

Assert.Equal(15, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputString()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<string>("""
The software developer, Jimbo Smith, is a 35-year-old from Cardiff, Wales.
What's his full name?
""");

Assert.Equal("Jimbo Smith", response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputBool_True()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<bool>("""
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Is there at least one software developer from Cardiff?
""");

Assert.True(response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputBool_False()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<bool>("""
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Can we be sure that he is a medical doctor?
""");

Assert.False(response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputEnum()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Architecture>("""
I'm using a Macbook Pro with an M2 chip. What architecture am I using?
""");

Assert.Equal(Architecture.Arm64, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutput_WithFunctions()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@ public async Task CanUseNativeStructuredOutputWithSanitizedTypeName()
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanUseNativeStructuredOutputWithArray()
{
var expectedResult = new[] { new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } };
var payload = new { data = expectedResult };
var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(payload))]);

using var client = new TestChatClient
{
CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion)
};

var chatHistory = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var response = await client.CompleteAsync<Animal[]>(chatHistory, useNativeJsonSchema: true);

// The completion contains the deserialized result and other completion properties
Assert.Single(response.Result!);
Assert.Equal("Tigger", response.Result[0].FullName);
Assert.Equal(Species.Tiger, response.Result[0].Species);

// TryGetResult returns the same value
Assert.True(response.TryGetResult(out var tryGetResultOutput));
Assert.Same(response.Result, tryGetResultOutput);

// History remains unmutated
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanSpecifyCustomJsonSerializationOptions()
{
Expand Down