Skip to content

Commit

Permalink
Use low level JSON API to manipulate the wrapper node
Browse files Browse the repository at this point in the history
Rather than relying on the type system, since a source-generated serializer options would not be able to deal with it.
  • Loading branch information
kzu committed Oct 16, 2024
1 parent 779b573 commit 4db07a5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

serializerOptions.MakeReadOnly();

var exporterOptions = new JsonSchemaExporterOptions
var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), new()
{
TreatNullObliviousAsNonNullable = true,
TransformSchemaNode = static (context, node) =>
Expand All @@ -132,9 +132,8 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(

return node;
},
};
});

var schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(T), exporterOptions);
var isObject = schemaNode.TryGetPropertyValue("type", out var schemaType) &&
schemaType?.GetValueKind() == JsonValueKind.String &&
schemaType.GetValue<string>() is { } type &&
Expand All @@ -145,7 +144,11 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
// We wrap regardless of native structured output, since it also applies to Azure Inference
if (!isObject)
{
schemaNode = (JsonObject)serializerOptions.GetJsonSchemaAsNode(typeof(Payload<T>), exporterOptions);
schemaNode = new JsonObject
{
{ "type", "object" },
{ "properties", new JsonObject { { "data", schemaNode } } },
};
wrapped = true;
}

Expand Down Expand Up @@ -205,6 +208,4 @@ public static async Task<ChatCompletion<T>> CompleteAsync<T>(
}
}
}

private sealed record Payload<TValue>(TValue Data);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
#pragma warning restore CA1031 // Do not catch general exception types
}

private static TValue? DeserializeFirstTopLevelObject<TValue>(string json, JsonTypeInfo<TValue> typeInfo)
private static T? DeserializeFirstTopLevelObject(string json, JsonTypeInfo<T> typeInfo)
{
// We need to deserialize only the first top-level object as a workaround for a common LLM backend
// issue. GPT 3.5 Turbo commonly returns multiple top-level objects after doing a function call.
Expand Down Expand Up @@ -132,10 +132,10 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)

if (wrapped)
{
var result = DeserializeFirstTopLevelObject(json!, (JsonTypeInfo<Payload<T>>)_serializerOptions.GetTypeInfo(typeof(Payload<T>)));
if (result != null)
var doc = JsonDocument.Parse(json!);
if (doc.RootElement.TryGetProperty("data", out var data))
{
deserialized = result.Data;
deserialized = DeserializeFirstTopLevelObject(data.GetRawText(), (JsonTypeInfo<T>)_serializerOptions.GetTypeInfo(typeof(T)));
}
}
else
Expand All @@ -155,8 +155,6 @@ public bool TryGetResult([NotNullWhen(true)] out T? result)
return deserialized;
}

private sealed record Payload<TValue>(TValue Data);

private enum FailureReason
{
ResultDidNotContainJson,
Expand Down

0 comments on commit 4db07a5

Please sign in to comment.