Skip to content

Commit

Permalink
Fix streaming function calling (#5718)
Browse files Browse the repository at this point in the history
* Fix streaming function calling

* Rename test
  • Loading branch information
SteveSandersonMS authored Dec 3, 2024
1 parent e8efa1f commit 6734c8f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
// If there were any, remove them from the update. We do this before yielding the update so
// that we're not modifying an instance already provided back to the caller.
int addedFccs = functionCallContents.Count - preFccCount;
if (addedFccs > preFccCount)
if (addedFccs > 0)
{
update.Contents = addedFccs == update.Contents.Count ?
[] : update.Contents.Where(c => c is not FunctionCallContent).ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,58 @@ async Task InvokeAsync(Func<Task> work)
}
}

[Fact]
public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls()
{
var options = new ChatOptions
{
Tools = [AIFunctionFactory.Create((string text) => $"Result for {text}", "Func1")]
};

var messages = new List<ChatMessage>
{
new(ChatRole.User, "Hello"),
};

using var innerClient = new TestChatClient
{
CompleteStreamingAsyncCallback = (chatContents, chatOptions, cancellationToken) =>
{
// If the conversation is just starting, issue two consecutive updates with function calls
// Otherwise just end the conversation
return chatContents.Last().Text == "Hello"
? YieldAsync(
new StreamingChatCompletionUpdate { Contents = [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["text"] = "Input 1" })] },
new StreamingChatCompletionUpdate { Contents = [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["text"] = "Input 2" })] })
: YieldAsync(
new StreamingChatCompletionUpdate { Contents = [new TextContent("OK bye")] });
}
};

using var client = new FunctionInvokingChatClient(innerClient);

var updates = new List<StreamingChatCompletionUpdate>();
await foreach (var update in client.CompleteStreamingAsync(messages, options, CancellationToken.None))
{
updates.Add(update);
}

// Message history should now include the FCCs and FRCs
Assert.Collection(messages,
m => Assert.Equal("Hello", Assert.IsType<TextContent>(Assert.Single(m.Contents)).Text),
m => Assert.Collection(m.Contents,
c => Assert.Equal("Input 1", Assert.IsType<FunctionCallContent>(c).Arguments!["text"]),
c => Assert.Equal("Input 2", Assert.IsType<FunctionCallContent>(c).Arguments!["text"])),
m => Assert.Collection(m.Contents,
c => Assert.Equal("Result for Input 1", Assert.IsType<FunctionResultContent>(c).Result?.ToString()),
c => Assert.Equal("Result for Input 2", Assert.IsType<FunctionResultContent>(c).Result?.ToString())));

// The returned updates should *not* include the FCCs and FRCs
var allUpdateContents = updates.SelectMany(updates => updates.Contents).ToList();
var singleUpdateContent = Assert.IsType<TextContent>(Assert.Single(allUpdateContents));
Assert.Equal("OK bye", singleUpdateContent.Text);
}

private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
ChatOptions options,
List<ChatMessage> plan,
Expand Down

0 comments on commit 6734c8f

Please sign in to comment.