Skip to content

Commit

Permalink
Structured output improvements (continuation of PR 5522) (#5560)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveSandersonMS authored Oct 24, 2024
1 parent cb16d5d commit 46d5e57
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -44,8 +46,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, chatMessages, AIJsonUtilities.DefaultOptions, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -65,8 +66,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
string chatMessage,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -88,8 +88,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
JsonSerializerOptions serializerOptions,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends chat messages to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -116,20 +115,40 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class
{
_ = Throw.IfNull(chatClient);
_ = Throw.IfNull(chatMessages);
_ = Throw.IfNull(serializerOptions);

serializerOptions.MakeReadOnly();

var schemaNode = AIJsonUtilities.CreateJsonSchema(
var schemaElement = AIJsonUtilities.CreateJsonSchema(
type: typeof(T),
serializerOptions: serializerOptions,
inferenceOptions: _inferenceOptions);

var schema = JsonSerializer.Serialize(schemaNode, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)));
bool isWrappedInObject;
string schema;
if (SchemaRepresentsObject(schemaElement))
{
// For object-representing schemas, we can use them as-is
isWrappedInObject = false;
schema = JsonSerializer.Serialize(schemaElement, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonElement)));
}
else
{
// For non-object-representing schemas, we wrap them in an object schema, because all
// the real LLM providers today require an object schema as the root. This is currently
// true even for providers that support native structured output.
isWrappedInObject = true;
schema = JsonSerializer.Serialize(new JsonObject
{
{ "$schema", "https://json-schema.org/draft/2020-12/schema" },
{ "type", "object" },
{ "properties", new JsonObject { { "data", JsonElementToJsonNode(schemaElement) } } },
{ "additionalProperties", false },
}, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonObject)));
}

ChatMessage? promptAugmentation = null;
options = (options ?? new()).Clone();
Expand All @@ -152,7 +171,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

// When not using native structured output, augment the chat messages with a schema prompt
#pragma warning disable SA1118 // Parameter should not span multiple lines
promptAugmentation = new ChatMessage(ChatRole.System, $$"""
promptAugmentation = new ChatMessage(ChatRole.User, $$"""
Respond with a JSON value conforming to the following schema:
```
{{schema}}
Expand All @@ -166,7 +185,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
try
{
var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
return new ChatCompletion<T>(result, serializerOptions);
return new ChatCompletion<T>(result, serializerOptions) { IsWrappedInObject = isWrappedInObject };
}
finally
{
Expand All @@ -176,4 +195,32 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
}
}
}

private static bool SchemaRepresentsObject(JsonElement schemaElement)
{
if (schemaElement.ValueKind is JsonValueKind.Object)
{
foreach (var property in schemaElement.EnumerateObject())
{
if (property.NameEquals("type"u8))
{
return property.Value.ValueKind == JsonValueKind.String
&& property.Value.ValueEquals("object"u8);
}
}
}

return false;
}

private static JsonNode? JsonElementToJsonNode(JsonElement element)
{
return element.ValueKind switch
{
JsonValueKind.Null => null,
JsonValueKind.Array => JsonArray.Create(element),
JsonValueKind.Object => JsonObject.Create(element),
_ => JsonValue.Create(element)
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public T Result
{
FailureReason.ResultDidNotContainJson => throw new InvalidOperationException("The response did not contain text to be deserialized"),
FailureReason.DeserializationProducedNull => throw new InvalidOperationException("The deserialized response is null"),
FailureReason.ResultDidNotContainDataProperty => throw new InvalidOperationException("The response did not contain the expected 'data' property"),
_ => result!,
};
}
Expand Down Expand Up @@ -103,6 +104,12 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
}
}

/// <summary>
/// Gets or sets a value indicating whether the JSON schema has an extra object wrapper.
/// This is required for any non-JSON-object-typed values such as numbers, enum values, or arrays.
/// </summary>
internal bool IsWrappedInObject { get; set; }

private string? GetResultAsJson()
{
var choice = Choices.Count == 1 ? Choices[0] : null;
Expand All @@ -125,8 +132,25 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
return default;
}

T? deserialized = default;

// If there's an exception here, we want it to propagate, since the Result property is meant to throw directly
var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));

if (IsWrappedInObject)
{
if (JsonDocument.Parse(json!).RootElement.TryGetProperty("data", out var data))
{
json = data.GetRawText();
}
else
{
failureReason = FailureReason.ResultDidNotContainDataProperty;
return default;
}
}

deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));

if (deserialized is null)
{
failureReason = FailureReason.DeserializationProducedNull;
Expand All @@ -143,5 +167,6 @@ private enum FailureReason
{
ResultDidNotContainJson,
DeserializationProducedNull,
ResultDidNotContainDataProperty,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)
const string DescriptionPropertyName = "description";
const string NotPropertyName = "not";
const string TypePropertyName = "type";
const string PatternPropertyName = "pattern";
const string EnumPropertyName = "enum";
const string PropertiesPropertyName = "properties";
const string AdditionalPropertiesPropertyName = "additionalProperties";
Expand Down Expand Up @@ -281,7 +282,20 @@ JsonNode TransformSchemaNode(JsonSchemaExporterContext ctx, JsonNode schema)

if (ctx.Path.IsEmpty)
{
// We are at the root-level schema node, append parameter-specific metadata
// We are at the root-level schema node, update/append parameter-specific metadata

// Some consumers of the JSON schema, including Ollama as of v0.3.13, don't understand
// schemas with "type": [...], and only understand "type" being a single value.
// STJ represents .NET integer types as ["string", "integer"], which will then lead to an error.
if (TypeIsArrayContainingInteger(schema))
{
// We don't want to emit any array for "type". In this case we know it contains "integer"
// so reduce the type to that alone, assuming it's the most specific type.
// This makes schemas for Int32 (etc) work with Ollama
JsonObject obj = ConvertSchemaToObject(ref schema);
obj[TypePropertyName] = "integer";
_ = obj.Remove(PatternPropertyName);
}

if (!string.IsNullOrWhiteSpace(key.Description))
{
Expand Down Expand Up @@ -340,6 +354,22 @@ static JsonObject ConvertSchemaToObject(ref JsonNode schema)
}
}

private static bool TypeIsArrayContainingInteger(JsonNode schema)
{
if (schema["type"] is JsonArray typeArray)
{
foreach (var entry in typeArray)
{
if (entry?.GetValueKind() == JsonValueKind.String && entry.GetValue<string>() == "integer")
{
return true;
}
}
}

return false;
}

private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
{
Utf8JsonReader reader = new(utf8Json);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,26 @@ public static void ResolveParameterJsonSchema_ReturnsExpectedValue()
JsonElement resolvedSchema;
resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options);
Assert.True(JsonElement.DeepEquals(generatedSchema, resolvedSchema));
}

options = new(options) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
resolvedSchema = AIJsonUtilities.ResolveParameterJsonSchema(param, metadata, options);
Assert.False(JsonElement.DeepEquals(generatedSchema, resolvedSchema));
[Fact]
public static void ResolveParameterJsonSchema_TreatsIntegralTypesAsInteger_EvenWithAllowReadingFromString()
{
JsonElement expected = JsonDocument.Parse("""
{
"type": "integer"
}
""").RootElement;

JsonSerializerOptions options = new(JsonSerializerOptions.Default) { NumberHandling = JsonNumberHandling.AllowReadingFromString };
AIFunction func = AIFunctionFactory.Create((int a, int? b, long c, short d) => { }, serializerOptions: options);

AIFunctionMetadata metadata = func.Metadata;
foreach (var param in metadata.Parameters)
{
JsonElement actualSchema = Assert.IsType<JsonElement>(param.Schema);
Assert.True(JsonElement.DeepEquals(expected, actualSchema));
}
}

[Description("The type")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
Expand Down Expand Up @@ -569,7 +570,7 @@ public virtual async Task CompleteAsync_StructuredOutput()

var response = await _chatClient.CompleteAsync<Person>("""
Who is described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Jimbo Smith is a 35-year-old programmer from Cardiff, Wales.
""");

Assert.Equal("Jimbo Smith", response.Result.FullName);
Expand All @@ -578,6 +579,86 @@ Who is described in the following sentence?
Assert.Equal(JobType.Programmer, response.Result.Job);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputArray()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Person[]>("""
Who are described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Josh Simpson is a 25-year-old software developer from Newport, Wales.
""");

Assert.Equal(2, response.Result.Length);
Assert.Contains(response.Result, x => x.FullName == "Jimbo Smith");
Assert.Contains(response.Result, x => x.FullName == "Josh Simpson");
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputInteger()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<int>("""
There were 14 abstractions for AI programming, which was too many.
To fix this we added another one. How many are there now?
""");

Assert.Equal(15, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputString()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<string>("""
The software developer, Jimbo Smith, is a 35-year-old from Cardiff, Wales.
What's his full name?
""");

Assert.Equal("Jimbo Smith", response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputBool_True()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<bool>("""
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Is there at least one software developer from Cardiff?
""");

Assert.True(response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputBool_False()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<bool>("""
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Can we be sure that he is a medical doctor?
""");

Assert.False(response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputEnum()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Architecture>("""
I'm using a Macbook Pro with an M2 chip. What architecture am I using?
""");

Assert.Equal(Architecture.Arm64, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutput_WithFunctions()
{
Expand Down
Loading

0 comments on commit 46d5e57

Please sign in to comment.