Skip to content

Commit

Permalink
feat: llmclient test
Browse files Browse the repository at this point in the history
  • Loading branch information
swatDong committed Jan 5, 2024
1 parent 3999d88 commit 4084915
Showing 1 changed file with 93 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.Teams.AI.AI.Validators;
using Microsoft.Teams.AI.Exceptions;
using Microsoft.Teams.AI.State;
using Microsoft.Teams.AI.Tests.TestUtils;
using Moq;

namespace Microsoft.Teams.AI.Tests.AITests
Expand Down Expand Up @@ -146,6 +147,28 @@ public async Task Test_CompletePromptAsync_PromptResponse_Success()
Assert.Equal(2, ((List<ChatMessage>)memory.Values[options.HistoryVariable]).Count);
}

[Fact]
public async Task Test_CompletePromptAsync_PromptResponse_Exception()
{
// Arrange
var promptCompletionModelMock = new Mock<IPromptCompletionModel>();
var promptTemplate = new PromptTemplate(
"prompt",
new(new() { })
);
LLMClientOptions<object> options = new(promptCompletionModelMock.Object, promptTemplate);
LLMClient<object> client = new(options, null);
TestMemory memory = new();

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");

// Assert
Assert.NotNull(response);
Assert.Equal(PromptResponseStatus.Error, response.Status);
Assert.NotNull(response.Error);
}

[Fact]
public async Task Test_CompletePromptAsync_PromptResponse_Repair()
{
Expand All @@ -158,9 +181,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair()
var validator = new TestValidator();
LLMClientOptions<object> options = new(promptCompletionModel, promptTemplate)
{
LogRepairs = true,
Validator = validator
};
LLMClient<object> client = new(options, null);
LLMClient<object> client = new(options, new TestLoggerFactory());
TestMemory memory = new();
promptCompletionModel.Results.Enqueue(new()
{
Expand Down Expand Up @@ -202,6 +226,69 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair()
Assert.Equal(2, ((List<ChatMessage>)memory.Values[options.HistoryVariable]).Count);
}

[Fact]
public async Task Test_CompletePromptAsync_PromptResponse_RepairNotSuccess()
{
// Arrange
var promptCompletionModel = new TestPromptCompletionModel();
var promptTemplate = new PromptTemplate(
"prompt",
new(new() { })
);
var validator = new TestValidator();
LLMClientOptions<object> options = new(promptCompletionModel, promptTemplate)
{
LogRepairs = true,
Validator = validator
};
LLMClient<object> client = new(options, new TestLoggerFactory());
TestMemory memory = new();
promptCompletionModel.Results.Enqueue(new()
{
Status = PromptResponseStatus.Success,
Message = new(ChatRole.Assistant)
{
Content = "welcome"
}
});
promptCompletionModel.Results.Enqueue(new()
{
Status = PromptResponseStatus.Success,
Message = new(ChatRole.Assistant)
{
Content = "welcome-repair"
}
});
promptCompletionModel.Results.Enqueue(new()
{
Status = PromptResponseStatus.Error,
Error = new("test")
});
validator.Results.Enqueue(new()
{
Valid = false
});
validator.Results.Enqueue(new()
{
Valid = false
});
validator.Results.Enqueue(new()
{
Valid = true
});

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");

// Assert
Assert.NotNull(response);
Assert.Equal(PromptResponseStatus.Error, response.Status);
Assert.NotNull(response.Error);
Assert.Equal("test", response.Error.Message);
Assert.Equal(1, memory.Values.Count);
Assert.Equal("hello", memory.Values[options.InputVariable]);
}

[Fact]
public async Task Test_CompletePromptAsync_PromptResponse_Repair_ExceedMaxRepairAttempts()
{
Expand All @@ -214,10 +301,11 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair_ExceedMaxRepair
var validator = new TestValidator();
LLMClientOptions<object> options = new(promptCompletionModel, promptTemplate)
{
LogRepairs = true,
Validator = validator,
MaxRepairAttempts = 1
};
LLMClient<object> client = new(options, null);
LLMClient<object> client = new(options, new TestLoggerFactory());
TestMemory memory = new();
promptCompletionModel.Results.Enqueue(new()
{
Expand Down Expand Up @@ -268,7 +356,7 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair_ExceedMaxRepair
Assert.Equal("hello", memory.Values[options.InputVariable]);
}

private class TestMemory : IMemory
private sealed class TestMemory : IMemory
{
public Dictionary<string, object> Values { get; set; } = new Dictionary<string, object>();

Expand All @@ -293,7 +381,7 @@ public void SetValue(string path, object value)
}
}

private class TestPromptCompletionModel : IPromptCompletionModel
private sealed class TestPromptCompletionModel : IPromptCompletionModel
{
public Queue<PromptResponse> Results { get; set; } = new Queue<PromptResponse>();

Expand All @@ -303,7 +391,7 @@ public Task<PromptResponse> CompletePromptAsync(ITurnContext turnContext, IMemor
}
}

private class TestValidator : IPromptResponseValidator
private sealed class TestValidator : IPromptResponseValidator
{

public Queue<Validation> Results { get; set; } = new Queue<Validation>();
Expand Down

0 comments on commit 4084915

Please sign in to comment.