Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net Simplify configuration by ServiceId on Multi Model Scenarios. #6416

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public async Task RunAsync()

await RunByServiceIdAsync(kernel, "AzureOpenAIChat");
await RunByModelIdAsync(kernel, TestConfiguration.OpenAI.ChatModelId);
await RunByFirstModelIdAsync(kernel, "gpt-4-1106-preview", TestConfiguration.AzureOpenAI.ChatModelId, TestConfiguration.OpenAI.ChatModelId);
await RunByFirstModelIdAsync(kernel, ["gpt-4-1106-preview", TestConfiguration.AzureOpenAI.ChatModelId, TestConfiguration.OpenAI.ChatModelId]);
await RunByFirstServiceIdAsync(kernel, ["NotFound", "AzureOpenAIChat", "OpenAIChat"]);
}

private async Task RunByServiceIdAsync(Kernel kernel, string serviceId)
Expand All @@ -37,12 +38,21 @@ private async Task RunByServiceIdAsync(Kernel kernel, string serviceId)

var prompt = "Hello AI, what can you do for me?";

KernelArguments arguments = [];
arguments.ExecutionSettings = new Dictionary<string, PromptExecutionSettings>()
{
{ serviceId, new PromptExecutionSettings() }
};
var result = await kernel.InvokePromptAsync(prompt, arguments);
var result = await kernel.InvokePromptAsync(prompt, new(new PromptExecutionSettings { ServiceId = serviceId }));

Console.WriteLine(result.GetValue<string>());
}

private async Task RunByFirstServiceIdAsync(Kernel kernel, string[] serviceIds)
{
Console.WriteLine($"======== Service Ids: {string.Join(", ", serviceIds)} ========");

var prompt = "Hello AI, what can you do for me?";

var function = kernel.CreateFunctionFromPrompt(prompt, serviceIds.Select(serviceId => new PromptExecutionSettings { ServiceId = serviceId }));

var result = await kernel.InvokeAsync(function);

Console.WriteLine(result.GetValue<string>());
}

Expand All @@ -61,20 +71,13 @@ private async Task RunByModelIdAsync(Kernel kernel, string modelId)
Console.WriteLine(result.GetValue<string>());
}

private async Task RunByFirstModelIdAsync(Kernel kernel, params string[] modelIds)
private async Task RunByFirstModelIdAsync(Kernel kernel, string[] modelIds)
{
Console.WriteLine($"======== Model Ids: {string.Join(", ", modelIds)} ========");

var prompt = "Hello AI, what can you do for me?";

var modelSettings = new Dictionary<string, PromptExecutionSettings>();
foreach (var modelId in modelIds)
{
modelSettings.Add(modelId, new PromptExecutionSettings() { ModelId = modelId });
}
var promptConfig = new PromptTemplateConfig(prompt) { Name = "HelloAI", ExecutionSettings = modelSettings };

var function = kernel.CreateFunctionFromPrompt(promptConfig);
var function = kernel.CreateFunctionFromPrompt(prompt, modelIds.Select((modelId, index) => new PromptExecutionSettings { ServiceId = $"service-{index}", ModelId = modelId }));

var result = await kernel.InvokeAsync(function);
Console.WriteLine(result.GetValue<string>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ public async Task<FunctionResult> ExecuteFlowAsync(
}

var executor = new FlowExecutor(this._kernelBuilder, this._flowStatusProvider, this._globalPluginCollection, this._config);
return await executor.ExecuteFlowAsync(flow, sessionId, input, kernelArguments ?? new KernelArguments(null)).ConfigureAwait(false);
return await executor.ExecuteFlowAsync(flow, sessionId, input, kernelArguments ?? new KernelArguments()).ConfigureAwait(false);
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ public class PromptExecutionSettings
/// </remarks>
public static string DefaultServiceId => "default";

/// <summary>
/// Service identifier.
/// This identifies the service these settings are configured for e.g., openai, ollama, huggingface, etc.
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
[JsonPropertyName("service_id")]
public string? ServiceId
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
{
get => this._serviceId;

set
{
this.ThrowIfFrozen();
this._serviceId = value;
}
}

/// <summary>
/// Model identifier.
/// This identifies the AI model these settings are configured for e.g., gpt-4, gpt-3.5-turbo
Expand Down Expand Up @@ -93,6 +109,7 @@ public virtual PromptExecutionSettings Clone()
return new()
{
ModelId = this.ModelId,
ServiceId = this.ServiceId,
ExtensionData = this.ExtensionData is not null ? new Dictionary<string, object>(this.ExtensionData) : null
};
}
Expand All @@ -113,6 +130,7 @@ protected void ThrowIfFrozen()

private string? _modelId;
private IDictionary<string, object>? _extensionData;
private string? _serviceId;

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public sealed class KernelArguments : IDictionary<string, object?>, IReadOnlyDic
{
/// <summary>Dictionary of name/values for all the arguments in the instance.</summary>
private readonly Dictionary<string, object?> _arguments;
private IReadOnlyDictionary<string, PromptExecutionSettings>? _executionSettings;

/// <summary>
/// Initializes a new instance of the <see cref="KernelArguments"/> class with the specified AI execution settings.
Expand All @@ -36,12 +37,32 @@ public KernelArguments()
/// </summary>
/// <param name="executionSettings">The prompt execution settings.</param>
public KernelArguments(PromptExecutionSettings? executionSettings)
: this(executionSettings is null ? null : [executionSettings])
{
this._arguments = new(StringComparer.OrdinalIgnoreCase);
}

/// <summary>
/// Initializes a new instance of the <see cref="KernelArguments"/> class with the specified AI execution settings.
/// </summary>
/// <param name="executionSettings">The prompt execution settings.</param>
public KernelArguments(IReadOnlyCollection<PromptExecutionSettings>? executionSettings)
{
this._arguments = new(StringComparer.OrdinalIgnoreCase);
if (executionSettings is not null)
{
this.ExecutionSettings = new Dictionary<string, PromptExecutionSettings>() { { PromptExecutionSettings.DefaultServiceId, executionSettings } };
var newExecutionSettings = new Dictionary<string, PromptExecutionSettings>(executionSettings.Count);
foreach (var settings in executionSettings)
{
var targetServiceId = settings.ServiceId ?? PromptExecutionSettings.DefaultServiceId;
if (newExecutionSettings.ContainsKey(targetServiceId))
{
throw new ArgumentException("When adding multiple execution settings, the service id needs to be provided and be unique for each.");
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
}

newExecutionSettings[targetServiceId] = settings;
}

this.ExecutionSettings = newExecutionSettings;
}
}

Expand All @@ -65,7 +86,34 @@ public KernelArguments(IDictionary<string, object?> source, Dictionary<string, P
/// <summary>
/// Gets or sets the prompt execution settings.
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
public IReadOnlyDictionary<string, PromptExecutionSettings>? ExecutionSettings { get; set; }
public IReadOnlyDictionary<string, PromptExecutionSettings>? ExecutionSettings
{
get => this._executionSettings;
set
{
this._executionSettings = value;

if (this._executionSettings is null ||
this._executionSettings.Count == 0)
{
return;
}

foreach (var kv in this._executionSettings)
{
// Ensures that if a service id is not specified and is not default, it is set to the current service id.
if (kv.Key != kv.Value.ServiceId)
{
if (!string.IsNullOrWhiteSpace(kv.Value.ServiceId))
{
throw new ArgumentException($"Service id '{kv.Value.ServiceId}' must match the key '{kv.Key}'.", nameof(this.ExecutionSettings));
}

kv.Value.ServiceId = kv.Key;
}
}
RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// <summary>
/// Gets the number of arguments contained in the <see cref="KernelArguments"/>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,25 @@ public Dictionary<string, PromptExecutionSettings> ExecutionSettings
{
Verify.NotNull(value);
this._executionSettings = value;

RogerBarreto marked this conversation as resolved.
Show resolved Hide resolved
if (value.Count == 0)
{
return;
}

foreach (var kv in value)
{
// Ensures that if a service id is not specified and is not default, it is set to the current service id.
if (kv.Key != kv.Value.ServiceId)
{
if (!string.IsNullOrWhiteSpace(kv.Value.ServiceId))
{
throw new ArgumentException($"Service id '{kv.Value.ServiceId}' must match the key '{kv.Key}'.", nameof(this.ExecutionSettings));
}

kv.Value.ServiceId = kv.Key;
}
}
}
}

Expand Down Expand Up @@ -224,13 +243,19 @@ public void AddExecutionSettings(PromptExecutionSettings settings, string? servi
{
Verify.NotNull(settings);

var key = serviceId ?? PromptExecutionSettings.DefaultServiceId;
var key = serviceId ?? settings.ServiceId ?? PromptExecutionSettings.DefaultServiceId;

// To avoid any reference changes to the settings object, clone it before changing service id.
var clonedSettings = settings.Clone();

// Overwrite the service id if provided in the method.
clonedSettings.ServiceId = key;
if (this.ExecutionSettings.ContainsKey(key))
{
throw new ArgumentException($"Execution settings for service id '{key}' already exists.", nameof(serviceId));
}

this.ExecutionSettings[key] = settings;
this.ExecutionSettings[key] = clonedSettings;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -107,6 +108,37 @@ public static KernelFunction CreateFromPrompt(
string? templateFormat = null,
IPromptTemplateFactory? promptTemplateFactory = null,
ILoggerFactory? loggerFactory = null) =>
KernelFunctionFromPrompt.Create(
promptTemplate,
CreateSettingsDictionary(executionSettings is null ? null : [executionSettings]),
functionName,
description,
templateFormat,
promptTemplateFactory,
loggerFactory);

/// <summary>
/// Creates a <see cref="KernelFunction"/> instance for a prompt specified via a prompt template.
/// </summary>
/// <param name="promptTemplate">Prompt template for the function.</param>
/// <param name="executionSettings">Default execution settings to use when invoking this prompt function.</param>
/// <param name="functionName">The name to use for the function. If null, it will default to a randomly generated name.</param>
/// <param name="description">The description to use for the function.</param>
/// <param name="templateFormat">The template format of <paramref name="promptTemplate"/>. This must be provided if <paramref name="promptTemplateFactory"/> is not null.</param>
/// <param name="promptTemplateFactory">
/// The <see cref="IPromptTemplateFactory"/> to use when interpreting the <paramref name="promptTemplate"/> into a <see cref="IPromptTemplate"/>.
/// If null, a default factory will be used.
/// </param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <returns>The created <see cref="KernelFunction"/> for invoking the prompt.</returns>
public static KernelFunction CreateFromPrompt(
string promptTemplate,
IEnumerable<PromptExecutionSettings>? executionSettings,
string? functionName = null,
string? description = null,
string? templateFormat = null,
IPromptTemplateFactory? promptTemplateFactory = null,
ILoggerFactory? loggerFactory = null) =>
KernelFunctionFromPrompt.Create(promptTemplate, CreateSettingsDictionary(executionSettings), functionName, description, templateFormat, promptTemplateFactory, loggerFactory);

/// <summary>
Expand Down Expand Up @@ -141,10 +173,6 @@ public static KernelFunction CreateFromPrompt(
/// Wraps the specified settings into a dictionary with the default service ID as the key.
/// </summary>
[return: NotNullIfNotNull(nameof(settings))]
private static Dictionary<string, PromptExecutionSettings>? CreateSettingsDictionary(PromptExecutionSettings? settings) =>
settings is null ? null :
new Dictionary<string, PromptExecutionSettings>(1)
{
{ PromptExecutionSettings.DefaultServiceId, settings },
};
private static Dictionary<string, PromptExecutionSettings>? CreateSettingsDictionary(IEnumerable<PromptExecutionSettings>? settings) =>
settings?.ToDictionary(s => s.ServiceId ?? PromptExecutionSettings.DefaultServiceId, s => s);
}
36 changes: 36 additions & 0 deletions dotnet/src/SemanticKernel.Core/KernelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,42 @@ public static KernelFunction CreateFunctionFromPrompt(
kernel.LoggerFactory);
}

/// <summary>
/// Creates a <see cref="KernelFunction"/> instance for a prompt specified via a prompt template.
/// </summary>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="promptTemplate">Prompt template for the function.</param>
/// <param name="executionSettings">List of execution settings to use when invoking this prompt function.</param>
/// <param name="functionName">The name to use for the function. If null, it will default to a randomly generated name.</param>
/// <param name="description">The description to use for the function.</param>
/// <param name="templateFormat">The template format of <paramref name="promptTemplate"/>. This must be provided if <paramref name="promptTemplateFactory"/> is not null.</param>
/// <param name="promptTemplateFactory">
/// The <see cref="IPromptTemplateFactory"/> to use when interpreting the <paramref name="promptTemplate"/> into a <see cref="IPromptTemplate"/>.
/// If null, a default factory will be used.
/// </param>
/// <returns>The created <see cref="KernelFunction"/> for invoking the prompt.</returns>
public static KernelFunction CreateFunctionFromPrompt(
this Kernel kernel,
string promptTemplate,
IEnumerable<PromptExecutionSettings>? executionSettings,
string? functionName = null,
string? description = null,
string? templateFormat = null,
IPromptTemplateFactory? promptTemplateFactory = null)
{
Verify.NotNull(kernel);
Verify.NotNull(promptTemplate);

return KernelFunctionFactory.CreateFromPrompt(
promptTemplate,
executionSettings,
functionName,
description,
templateFormat,
promptTemplateFactory,
kernel.LoggerFactory);
}

/// <summary>
/// Creates a <see cref="KernelFunction"/> instance for a prompt specified via a prompt template configuration.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ public void PromptExecutionSettingsCloneWorksAsExpected()
// Arrange
string configPayload = """
{
"model_id": "gpt-3",
"service_id": "service-1",
"max_tokens": 60,
"temperature": 0.5,
"top_p": 0.0,
Expand All @@ -30,6 +32,36 @@ public void PromptExecutionSettingsCloneWorksAsExpected()
Assert.NotNull(clone);
Assert.Equal(executionSettings.ModelId, clone.ModelId);
Assert.Equivalent(executionSettings.ExtensionData, clone.ExtensionData);
Assert.Equal(executionSettings.ServiceId, clone.ServiceId);
}

[Fact]
public void PromptExecutionSettingsSerializationWorksAsExpected()
{
// Arrange
string configPayload = """
{
"model_id": "gpt-3",
"service_id": "service-1",
"max_tokens": 60,
"temperature": 0.5,
"top_p": 0.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0
}
""";

// Act
var executionSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(configPayload);

// Assert
Assert.NotNull(executionSettings);
Assert.Equal("gpt-3", executionSettings.ModelId);
Assert.Equal("service-1", executionSettings.ServiceId);
Assert.Equal(60, ((JsonElement)executionSettings.ExtensionData!["max_tokens"]).GetInt32());
Assert.Equal(0.5, ((JsonElement)executionSettings.ExtensionData!["temperature"]).GetDouble());
Assert.Equal(0.0, ((JsonElement)executionSettings.ExtensionData!["top_p"]).GetDouble());
Assert.Equal(0.0, ((JsonElement)executionSettings.ExtensionData!["presence_penalty"]).GetDouble());
}

[Fact]
Expand Down
Loading
Loading