Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arrays, enums and primitive types #5522

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
IList<ChatMessage> chatMessages,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless of whether we take this change, I'd be in favor of removing this constraint since it isn't a good predictor of the JSON shape of the type. We can instead just fail at runtime depending on the value of the corresponding JsonTypeInfo.Kind property.

CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, chatMessages, JsonDefaults.Options, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -61,8 +60,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
string chatMessage,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends a user chat text message to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -84,8 +82,7 @@ public static Task<ChatCompletion<T>> CompleteAsync<T>(
JsonSerializerOptions serializerOptions,
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class =>
CancellationToken cancellationToken = default) =>
CompleteAsync<T>(chatClient, [new ChatMessage(ChatRole.User, chatMessage)], serializerOptions, options, useNativeJsonSchema, cancellationToken);

/// <summary>Sends chat messages to the model, requesting a response matching the type <typeparamref name="T"/>.</summary>
Expand All @@ -112,15 +109,14 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
ChatOptions? options = null,
bool? useNativeJsonSchema = null,
CancellationToken cancellationToken = default)
where T : class
{
_ = Throw.IfNull(chatClient);
_ = Throw.IfNull(chatMessages);
_ = Throw.IfNull(serializerOptions);

serializerOptions.MakeReadOnly();

var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), new()
var exporterOptions = new JsonSchemaExporterOptions
{
TreatNullObliviousAsNonNullable = true,
TransformSchemaNode = static (context, node) =>
Expand All @@ -136,7 +132,23 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

return node;
},
});
};

var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), exporterOptions);
var isObject = schemaNode.TryGetPropertyValue("type", out var schemaType) &&
schemaType?.GetValueKind() == JsonValueKind.String &&
schemaType.GetValue<string>() is { } type &&
type.Equals("object", System.StringComparison.Ordinal);

var wrapped = false;

// We wrap regardless of native structured output, since it also applies to Azure Inference
if (!isObject)
{
schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(Payload<T>), exporterOptions);
wrapped = true;
}

schemaNode.Insert(0, "$schema", "https://json-schema.org/draft/2020-12/schema");
schemaNode.Add("additionalProperties", false);
var schema = JsonSerializer.Serialize(schemaNode, JsonDefaults.Options.GetTypeInfo(typeof(JsonNode)));
Expand Down Expand Up @@ -176,6 +188,13 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
try
{
var result = await chatClient.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
if (wrapped)
{
// We don't initialize the dictionary unless we need it, to avoid unnecessary allocations.
result.AdditionalProperties ??= [];
result.AdditionalProperties["$wrapped"] = true;
}

return new ChatCompletion<T>(result, serializerOptions);
}
finally
Expand All @@ -186,4 +205,6 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
}
}
}

private sealed record Payload<TValue>(TValue Data);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
#pragma warning restore CA1031 // Do not catch general exception types
}

private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo<T> typeInfo)
private static TValue? DeserializeFirstTopLevelObject<TValue>(string json, JsonTypeInfo<TValue> typeInfo)
{
// We need to deserialize only the first top-level object as a workaround for a common LLM backend
// issue. GPT 3.5 Turbo commonly returns multiple top-level objects after doing a function call.
Expand Down Expand Up @@ -125,8 +125,24 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
return default;
}

T? deserialized = default;
var wrapped = AdditionalProperties?.TryGetValue("$wrapped", out var isWrappedValue) == true && isWrappedValue is bool isWrappedBool && isWrappedBool;

// If there's an exception here, we want it to propagate, since the Result property is meant to throw directly
var deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));

if (wrapped)
{
var result = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<Payload<T>>)_serializerOptions.GetTypeInfo(typeof(Payload<T>)));
kzu marked this conversation as resolved.
Show resolved Hide resolved
if (result != null)
{
deserialized = result.Data;
}
}
else
{
deserialized = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));
}

if (deserialized is null)
{
failureReason = FailureReason.DeserializationProducedNull;
Expand All @@ -139,6 +155,8 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
return deserialized;
}

private sealed record Payload<TValue>(TValue Data);

private enum FailureReason
{
ResultDidNotContainJson,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
Expand Down Expand Up @@ -579,17 +580,91 @@ public virtual async Task CompleteAsync_StructuredOutput()
{
SkipIfNotEnabled();

// For openai, we can use the native JSON schema support
var response = await _chatClient.CompleteAsync<Person>("""
Who is described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
""");
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.Equal("Jimbo Smith", response.Result.FullName);
Assert.Equal(35, response.Result.AgeInYears);
Assert.Contains("Cardiff", response.Result.HomeTown);
Assert.Equal(JobType.Programmer, response.Result.Job);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputArray()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Person[]>("""
Who are described in the following sentence?
Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Josh Simpson is a 25-year-old software developer from Newport, Wales.
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.Equal(2, response.Result.Length);
Assert.Contains(response.Result, x => x.FullName == "Jimbo Smith");
Assert.Contains(response.Result, x => x.FullName == "Josh Simpson");
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputInteger()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<int>("""
As of today (october 2024), Jimbo Smith is a 35-year-old software developer from Cardiff, Wales.
Which year was he born in?
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.Equal(1989, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputString()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<string>("""
The software developer, Jimbo Smith, is a 35-year-old software developer from Cardiff, Wales.
What's his full name?
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.Equal("Jimbo Smith", response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputBool()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<bool>("""
The software developer, Jimbo Smith, is a 35-year-old software developer from Cardiff, Wales.
Is he a medical doctor?
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.False(response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutputEnum()
{
SkipIfNotEnabled();

var response = await _chatClient.CompleteAsync<Architecture>("""
I'm using a Macbook Pro with an M2 chip. What architecture am I using?
""",
useNativeJsonSchema: _chatClient.Metadata.ProviderName == "openai");

Assert.Equal(Architecture.Arm64, response.Result);
}

[ConditionalFact]
public virtual async Task CompleteAsync_StructuredOutput_WithFunctions()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,34 @@ public async Task CanUseNativeStructuredOutputWithSanitizedTypeName()
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanUseNativeStructuredOutputWithArray()
{
var expectedResult = new[] { new Animal { Id = 1, FullName = "Tigger", Species = Species.Tiger } };
var payload = new { data = expectedResult };
var expectedCompletion = new ChatCompletion([new ChatMessage(ChatRole.Assistant, JsonSerializer.Serialize(payload))]);

using var client = new TestChatClient
{
CompleteAsyncCallback = (messages, options, cancellationToken) => Task.FromResult(expectedCompletion)
};

var chatHistory = new List<ChatMessage> { new(ChatRole.User, "Hello") };
var response = await client.CompleteAsync<Animal[]>(chatHistory, useNativeJsonSchema: true);

// The completion contains the deserialized result and other completion properties
Assert.Single(response.Result!);
Assert.Equal("Tigger", response.Result[0].FullName);
Assert.Equal(Species.Tiger, response.Result[0].Species);

// TryGetResult returns the same value
Assert.True(response.TryGetResult(out var tryGetResultOutput));
Assert.Same(response.Result, tryGetResultOutput);

// History remains unmutated
Assert.Equal("Hello", Assert.Single(chatHistory).Text);
}

[Fact]
public async Task CanSpecifyCustomJsonSerializationOptions()
{
Expand Down
Loading