Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add CachedContent Property to GeminiPromptExecutionSettings #10268

Merged
merged 7 commits into from
Jan 28, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests
  • Loading branch information
davidpene committed Jan 28, 2025
commit ba7164dda300f17417a989cbd1120170bfbe22e8
Original file line number Diff line number Diff line change
@@ -470,6 +470,44 @@ public void AddChatMessageToRequest()
c => Equals(message.Role, c.Role));
}

[Fact]
public void CachedContentFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
CachedContent = "xyz/abc"
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.CachedContent, request.CachedContent);
}

[Fact]
public void CachedContentFromChatHistoryReturnsAsExpected()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
var executionSettings = new GeminiPromptExecutionSettings
{
CachedContent = "xyz/abc"
};

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Equal(executionSettings.CachedContent, request.CachedContent);
}

private sealed class DummyContent(object? innerContent, string? modelId = null, IReadOnlyDictionary<string, object?>? metadata = null) :
KernelContent(innerContent, modelId, metadata);
}
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
/// Range is 0.0 to 1.0.
/// </summary>
[JsonPropertyName("temperature")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? Temperature
{
get => this._temperature;
@@ -57,6 +58,7 @@ public double? Temperature
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? TopP
{
get => this._topP;
@@ -72,6 +74,7 @@ public double? TopP
/// The TopK property represents the maximum value of a collection or dataset.
/// </summary>
[JsonPropertyName("top_k")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? TopK
{
get => this._topK;
@@ -86,6 +89,7 @@ public int? TopK
/// The maximum number of tokens to generate in the completion.
/// </summary>
[JsonPropertyName("max_tokens")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? MaxTokens
{
get => this._maxTokens;
@@ -100,6 +104,7 @@ public int? MaxTokens
/// The count of candidates. Possible values range from 1 to 8.
/// </summary>
[JsonPropertyName("candidate_count")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? CandidateCount
{
get => this._candidateCount;
@@ -115,6 +120,7 @@ public int? CandidateCount
/// Maximum number of stop sequences is 5.
/// </summary>
[JsonPropertyName("stop_sequences")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<string>? StopSequences
{
get => this._stopSequences;
@@ -129,6 +135,7 @@ public IList<string>? StopSequences
/// Represents a list of safety settings.
/// </summary>
[JsonPropertyName("safety_settings")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<GeminiSafetySetting>? SafetySettings
{
get => this._safetySettings;
@@ -181,6 +188,7 @@ public GeminiToolCallBehavior? ToolCallBehavior
/// if enabled, audio timestamp will be included in the request to the model.
/// </summary>
[JsonPropertyName("audio_timestamp")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? AudioTimestamp
{
get => this._audioTimestamp;
@@ -199,6 +207,7 @@ public bool? AudioTimestamp
/// 3. text/x.enum: For classification tasks, output an enum value as defined in the response schema.
/// </summary>
[JsonPropertyName("response_mimetype")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ResponseMimeType
{
get => this._responseMimeType;
@@ -241,6 +250,7 @@ public object? ResponseSchema
/// Format: projects/{project}/locations/{location}/cachedContents/{cachedContent}
/// </summary>
[JsonPropertyName("cached_content")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? CachedContent
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
{
get => this._cachedContent;
Original file line number Diff line number Diff line change
@@ -3,11 +3,15 @@
using System;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Net.Http.Json;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using Newtonsoft.Json.Linq;
using xRetry;
using Xunit;
using Xunit.Abstractions;
@@ -135,6 +139,57 @@ public async Task ChatGenerationWithSystemMessagesAsync(ServiceType serviceType)
Assert.Contains("Roger", response.Content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.VertexAI)]
public async Task ChatGenerationWithCachedContentAsync(ServiceType serviceType)
{
// Arrange
var chatHistory = new ChatHistory();
chatHistory.AddUserMessage("Finish this sentence: He knew the sea’s...");

// Setup initial cached content
var cachedContentJson = File.ReadAllText(Path.Combine("Resources", "gemini_cached_content.json"));
var cachedContentName = string.Empty;

using (var httpClient = new HttpClient()
{
DefaultRequestHeaders = { Authorization = new("Bearer", this.VertexAIGetBearerKey()) }
})
{
using (var content = new StringContent(cachedContentJson, Encoding.UTF8, "application/json"))
{
using (var httpResponse = await httpClient.PostAsync(
new Uri($"https://{this.VertexAIGetLocation()}-aiplatform.googleapis.com/v1beta1/projects/{this.VertexAIGetProjectId()}/locations/{this.VertexAIGetLocation()}/cachedContents"),
content))
{
httpResponse.EnsureSuccessStatusCode();

var responseString = await httpResponse.Content.ReadAsStringAsync();
var responseJson = JObject.Parse(responseString);

cachedContentName = responseJson?["name"]?.ToString();

Assert.NotNull(cachedContentName);
}
}
}

var sut = this.GetChatService(serviceType, isBeta: true);

// Act
var response = await sut.GetChatMessageContentAsync(
chatHistory,
new GeminiPromptExecutionSettings
{
CachedContent = cachedContentName
});

// Assert
Assert.NotNull(response.Content);
this.Output.WriteLine(response.Content);
Assert.Contains("capriciousness", response.Content, StringComparison.OrdinalIgnoreCase);
}

[RetryTheory]
[InlineData(ServiceType.GoogleAI, Skip = "This test is for manual verification.")]
[InlineData(ServiceType.VertexAI, Skip = "This test is for manual verification.")]
16 changes: 10 additions & 6 deletions dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
@@ -20,16 +22,18 @@ public abstract class TestsBase(ITestOutputHelper output)

protected ITestOutputHelper Output { get; } = output;

protected IChatCompletionService GetChatService(ServiceType serviceType) => serviceType switch
protected IChatCompletionService GetChatService(ServiceType serviceType, bool isBeta = false) => serviceType switch
{
ServiceType.GoogleAI => new GoogleAIGeminiChatCompletionService(
this.GoogleAIGetGeminiModel(),
this.GoogleAIGetApiKey()),
this.GoogleAIGetApiKey(),
isBeta ? GoogleAIVersion.V1_Beta : GoogleAIVersion.V1),
ServiceType.VertexAI => new VertexAIGeminiChatCompletionService(
modelId: this.VertexAIGetGeminiModel(),
bearerKey: this.VertexAIGetBearerKey(),
location: this.VertexAIGetLocation(),
projectId: this.VertexAIGetProjectId()),
projectId: this.VertexAIGetProjectId(),
isBeta ? VertexAIVersion.V1_Beta : VertexAIVersion.V1),
_ => throw new ArgumentOutOfRangeException(nameof(serviceType), serviceType, null)
};

@@ -72,7 +76,7 @@ public enum ServiceType
private string VertexAIGetGeminiModel() => this._configuration.GetSection("VertexAI:Gemini:ModelId").Get<string>()!;
private string VertexAIGetGeminiVisionModel() => this._configuration.GetSection("VertexAI:Gemini:VisionModelId").Get<string>()!;
private string VertexAIGetEmbeddingModel() => this._configuration.GetSection("VertexAI:EmbeddingModelId").Get<string>()!;
private string VertexAIGetBearerKey() => this._configuration.GetSection("VertexAI:BearerKey").Get<string>()!;
private string VertexAIGetLocation() => this._configuration.GetSection("VertexAI:Location").Get<string>()!;
private string VertexAIGetProjectId() => this._configuration.GetSection("VertexAI:ProjectId").Get<string>()!;
internal string VertexAIGetBearerKey() => this._configuration.GetSection("VertexAI:BearerKey").Get<string>()!;
internal string VertexAIGetLocation() => this._configuration.GetSection("VertexAI:Location").Get<string>()!;
internal string VertexAIGetProjectId() => this._configuration.GetSection("VertexAI:ProjectId").Get<string>()!;
}
6 changes: 6 additions & 0 deletions dotnet/src/IntegrationTests/IntegrationTests.csproj
Original file line number Diff line number Diff line change
@@ -195,4 +195,10 @@
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>

<ItemGroup>
<EmbeddedResource Include="Resources/gemini_cached_content.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</EmbeddedResource>
</ItemGroup>
</Project>
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

Loading