Skip to content

Commit

Permalink
.Net: Add CachedContent Property to GeminiPromptExecutionSettings (#1…
Browse files Browse the repository at this point in the history
…0268)

### Motivation and Context

**Why is this change required?**  
This change introduces a new `CachedContent` field to the
`GeminiPromptExecutionSettings` class. The addition enables context
caching, a feature that optimizes the handling of repeated static
content in requests with high input token counts.

**What problem does it solve?**  
Repeatedly processing substantial, static context data (e.g., lengthy
documents, audio files, or video files) in requests can be
resource-intensive and costly. Context caching addresses this by
allowing the reuse of shared context across multiple requests, improving
performance and cost-efficiency.

**What scenario does it contribute to?**  
Context caching is particularly well-suited for scenarios where a
substantial initial context is repeatedly referenced by shorter
requests. Use cases include:
- **Chatbots with extensive system instructions**: Allows reuse of
complex system configurations across multiple user interactions.
- **Repetitive analysis of lengthy video files**: Enables efficient
analysis by caching video-related metadata or transcriptions.
- **Recurring queries against large document sets**: Improves efficiency
for workflows requiring repeated access to large document collections.
- **Frequent code repository analysis or bug fixing**: Facilitates the
reuse of large codebases as cached context for debugging or analysis
tasks.

*No open issues are linked to this change.*

---

### Description

This PR adds a `CachedContent` field to the
`GeminiPromptExecutionSettings` class. The field allows users to
reference cached context items such as text blocks, audio files, or
video files in prompt requests.

**Key features:**  
- The minimum size of a context cache is 32,768 tokens.  
- Cached content is retained for a default duration of 60 minutes, with
the option to configure expiration.
- Cached content is billed efficiently: the initial creation call is
charged at the standard rate, while subsequent references to the cache
are billed at a reduced rate.

By enabling context caching, users can optimize cost and resource usage
for workflows requiring substantial and repeat context data.

---

### Contribution Checklist

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

---

### Demo

<img width="1465" alt="Screenshot 2025-01-23 at 11 46 17 PM"
src="https://github.com/user-attachments/assets/9d070c69-a482-4a40-a912-2b17ddf4c7ab"
/>
<img width="841" alt="Screenshot 2025-01-23 at 11 50 29 PM"
src="https://github.com/user-attachments/assets/aac961ec-edf8-4458-a815-57e46f2a5789"
/>


---

### Notes

1. The CachedContent is only available via `v1beta1` endpoint
2. [Overview of context caching in
Gemini](https://cloud.google.com/vertex-ai/generative-ai/docs/context-cache/context-cache-overview?hl=en)
3. [cachedContent field in the
body](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.endpoints/generateContent)
for the `v1beta1` endpoint

---------

Co-authored-by: Roger Barreto <[email protected]>
  • Loading branch information
davidpene and RogerBarreto authored Jan 28, 2025
1 parent 471d9a8 commit 6fbbb44
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,44 @@ public void AddChatMessageToRequest()
c => Equals(message.Role, c.Role));
}

[Fact]
public void CachedContentFromPromptReturnsAsExpected()
{
// Arrange
var prompt = "prompt-example";
var executionSettings = new GeminiPromptExecutionSettings
{
CachedContent = "xyz/abc"
};

// Act
var request = GeminiRequest.FromPromptAndExecutionSettings(prompt, executionSettings);

// Assert
Assert.NotNull(request.Configuration);
Assert.Equal(executionSettings.CachedContent, request.CachedContent);
}

[Fact]
public void CachedContentFromChatHistoryReturnsAsExpected()
{
// Arrange
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("user-message");
chatHistory.AddAssistantMessage("assist-message");
chatHistory.AddUserMessage("user-message2");
var executionSettings = new GeminiPromptExecutionSettings
{
CachedContent = "xyz/abc"
};

// Act
var request = GeminiRequest.FromChatHistoryAndExecutionSettings(chatHistory, executionSettings);

// Assert
Assert.Equal(executionSettings.CachedContent, request.CachedContent);
}

private sealed class DummyContent(object? innerContent, string? modelId = null, IReadOnlyDictionary<string, object?>? metadata = null) :
KernelContent(innerContent, modelId, metadata);
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using Microsoft.SemanticKernel.Services;
using Xunit;

namespace SemanticKernel.Connectors.Google.UnitTests.Services;

public sealed class GoogleAIGeminiChatCompletionServiceTests
public sealed class GoogleAIGeminiChatCompletionServiceTests : IDisposable
{
private readonly HttpMessageHandlerStub _messageHandlerStub;
private readonly HttpClient _httpClient;

public GoogleAIGeminiChatCompletionServiceTests()
{
this._messageHandlerStub = new()
{
ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("./TestData/completion_one_response.json"))
}
};
this._httpClient = new HttpClient(this._messageHandlerStub, false);
}

[Fact]
public void AttributesShouldContainModelId()
{
Expand All @@ -18,4 +39,39 @@ public void AttributesShouldContainModelId()
// Assert
Assert.Equal(model, service.Attributes[AIServiceExtensions.ModelIdKey]);
}

[Theory]
[InlineData(null)]
[InlineData("content")]
[InlineData("")]
public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent)
{
// Arrange
string model = "fake-model";
var sut = new GoogleAIGeminiChatCompletionService(model, "key", httpClient: this._httpClient);

// Act
var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { CachedContent = cachedContent });

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (cachedContent is not null)
{
Assert.Contains($"\"cachedContent\":\"{cachedContent}\"", requestBody);
}
else
{
// Then no quality is provided, it should not be included in the request body
Assert.DoesNotContain("cachedContent", requestBody);
}
}

public void Dispose()
{
this._httpClient.Dispose();
this._messageHandlerStub.Dispose();
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.IO;
using System.Net.Http;
using System.Text;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Google;
using Microsoft.SemanticKernel.Services;
using Xunit;

namespace SemanticKernel.Connectors.Google.UnitTests.Services;

public sealed class VertexAIGeminiChatCompletionServiceTests
public sealed class VertexAIGeminiChatCompletionServiceTests : IDisposable
{
private readonly HttpMessageHandlerStub _messageHandlerStub;
private readonly HttpClient _httpClient;

public VertexAIGeminiChatCompletionServiceTests()
{
this._messageHandlerStub = new()
{
ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("./TestData/completion_one_response.json"))
}
};
this._httpClient = new HttpClient(this._messageHandlerStub, false);
}

[Fact]
public void AttributesShouldContainModelIdBearerAsString()
{
Expand All @@ -30,4 +50,39 @@ public void AttributesShouldContainModelIdBearerAsFunc()
// Assert
Assert.Equal(model, service.Attributes[AIServiceExtensions.ModelIdKey]);
}

[Theory]
[InlineData(null)]
[InlineData("content")]
[InlineData("")]
public async Task RequestCachedContentWorksCorrectlyAsync(string? cachedContent)
{
// Arrange
string model = "fake-model";
var sut = new VertexAIGeminiChatCompletionService(model, () => new ValueTask<string>("key"), "location", "project", httpClient: this._httpClient);

// Act
var result = await sut.GetChatMessageContentAsync("my prompt", new GeminiPromptExecutionSettings { CachedContent = cachedContent });

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (cachedContent is not null)
{
Assert.Contains($"\"cachedContent\":\"{cachedContent}\"", requestBody);
}
else
{
// Then no quality is provided, it should not be included in the request body
Assert.DoesNotContain("cachedContent", requestBody);
}
}

public void Dispose()
{
this._httpClient.Dispose();
this._messageHandlerStub.Dispose();
}
}
1 change: 1 addition & 0 deletions dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ protected static string GetApiVersionSubLink(VertexAIVersion apiVersion)
=> apiVersion switch
{
VertexAIVersion.V1 => "v1",
VertexAIVersion.V1_Beta => "v1beta1",
_ => throw new NotSupportedException($"Vertex API version {apiVersion} is not supported.")
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ internal sealed class GeminiRequest
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public GeminiContent? SystemInstruction { get; set; }

[JsonPropertyName("cachedContent")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? CachedContent { get; set; }

public void AddFunction(GeminiFunction function)
{
// NOTE: Currently Gemini only supports one tool i.e. function calling.
Expand All @@ -67,6 +71,7 @@ public static GeminiRequest FromPromptAndExecutionSettings(
GeminiRequest obj = CreateGeminiRequest(prompt);
AddSafetySettings(executionSettings, obj);
AddConfiguration(executionSettings, obj);
AddAdditionalBodyFields(executionSettings, obj);
return obj;
}

Expand All @@ -83,6 +88,7 @@ public static GeminiRequest FromChatHistoryAndExecutionSettings(
GeminiRequest obj = CreateGeminiRequest(chatHistory);
AddSafetySettings(executionSettings, obj);
AddConfiguration(executionSettings, obj);
AddAdditionalBodyFields(executionSettings, obj);
return obj;
}

Expand Down Expand Up @@ -318,6 +324,11 @@ private static void AddSafetySettings(GeminiPromptExecutionSettings executionSet
=> new GeminiSafetySetting(s.Category, s.Threshold)).ToList();
}

private static void AddAdditionalBodyFields(GeminiPromptExecutionSettings executionSettings, GeminiRequest request)
{
request.CachedContent = executionSettings.CachedContent;
}

internal sealed class ConfigurationElement
{
[JsonPropertyName("temperature")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
private bool? _audioTimestamp;
private string? _responseMimeType;
private object? _responseSchema;
private string? _cachedContent;
private IList<GeminiSafetySetting>? _safetySettings;
private GeminiToolCallBehavior? _toolCallBehavior;

Expand All @@ -41,6 +42,7 @@ public sealed class GeminiPromptExecutionSettings : PromptExecutionSettings
/// Range is 0.0 to 1.0.
/// </summary>
[JsonPropertyName("temperature")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? Temperature
{
get => this._temperature;
Expand All @@ -56,6 +58,7 @@ public double? Temperature
/// The higher the TopP, the more diverse the completion.
/// </summary>
[JsonPropertyName("top_p")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public double? TopP
{
get => this._topP;
Expand All @@ -71,6 +74,7 @@ public double? TopP
/// The TopK property represents the maximum value of a collection or dataset.
/// </summary>
[JsonPropertyName("top_k")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? TopK
{
get => this._topK;
Expand All @@ -85,6 +89,7 @@ public int? TopK
/// The maximum number of tokens to generate in the completion.
/// </summary>
[JsonPropertyName("max_tokens")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? MaxTokens
{
get => this._maxTokens;
Expand All @@ -99,6 +104,7 @@ public int? MaxTokens
/// The count of candidates. Possible values range from 1 to 8.
/// </summary>
[JsonPropertyName("candidate_count")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public int? CandidateCount
{
get => this._candidateCount;
Expand All @@ -114,6 +120,7 @@ public int? CandidateCount
/// Maximum number of stop sequences is 5.
/// </summary>
[JsonPropertyName("stop_sequences")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<string>? StopSequences
{
get => this._stopSequences;
Expand All @@ -128,6 +135,7 @@ public IList<string>? StopSequences
/// Represents a list of safety settings.
/// </summary>
[JsonPropertyName("safety_settings")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public IList<GeminiSafetySetting>? SafetySettings
{
get => this._safetySettings;
Expand Down Expand Up @@ -180,6 +188,7 @@ public GeminiToolCallBehavior? ToolCallBehavior
/// if enabled, audio timestamp will be included in the request to the model.
/// </summary>
[JsonPropertyName("audio_timestamp")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public bool? AudioTimestamp
{
get => this._audioTimestamp;
Expand All @@ -198,6 +207,7 @@ public bool? AudioTimestamp
/// 3. text/x.enum: For classification tasks, output an enum value as defined in the response schema.
/// </summary>
[JsonPropertyName("response_mimetype")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ResponseMimeType
{
get => this._responseMimeType;
Expand Down Expand Up @@ -234,6 +244,23 @@ public object? ResponseSchema
}
}

/// <summary>
/// Optional. The name of the cached content used as context to serve the prediction.
/// Note: only used in explicit caching, where users can have control over caching (e.g. what content to cache) and enjoy guaranteed cost savings.
/// Format: projects/{project}/locations/{location}/cachedContents/{cachedContent}
/// </summary>
[JsonPropertyName("cached_content")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? CachedContent
{
get => this._cachedContent;
set
{
this.ThrowIfFrozen();
this._cachedContent = value;
}
}

/// <inheritdoc />
public override void Freeze()
{
Expand Down
7 changes: 6 additions & 1 deletion dotnet/src/Connectors/Connectors.Google/VertexAIVersion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,10 @@ public enum VertexAIVersion
/// <summary>
/// Represents the V1 version of the Vertex AI API.
/// </summary>
V1
V1,

/// <summary>
/// Represents the V1-beta version of the Vertex AI API.
/// </summary>
V1_Beta
}
Loading

0 comments on commit 6fbbb44

Please sign in to comment.