Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure the Ollama clients validate HTTP status codes. #5821

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages,
JsonContext.Default.OllamaChatRequest,
cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

var response = (await httpResponse.Content.ReadFromJsonAsync(
JsonContext.Default.OllamaChatResponse,
cancellationToken).ConfigureAwait(false))!;
Expand Down Expand Up @@ -117,6 +122,12 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest)
};
using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

using var httpResponseStream = await httpResponse.Content
#if NET
.ReadAsStreamAsync(cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
JsonContext.Default.OllamaEmbeddingRequest,
cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

var response = (await httpResponse.Content.ReadFromJsonAsync(
JsonContext.Default.OllamaEmbeddingResponse,
cancellationToken).ConfigureAwait(false))!;
Expand Down
38 changes: 38 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -31,4 +35,38 @@ public static void TransferNanosecondsTime<TResponse>(TResponse response, Func<T
}
}
}

[DoesNotReturn]
public static async ValueTask ThrowUnsuccessfulOllamaResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
Debug.Assert(!response.IsSuccessStatusCode, "must only be invoked for unsuccessful responses.");

// Read the entire response content into a string.
string errorContent =
#if NET
await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
#else
await response.Content.ReadAsStringAsync().ConfigureAwait(false);
#endif

// The response content *could* be JSON formatted, try to extract the error field.

#pragma warning disable CA1031 // Do not catch general exception types
try
{
using JsonDocument document = JsonDocument.Parse(errorContent);
if (document.RootElement.TryGetProperty("error", out JsonElement errorElement) &&
errorElement.ValueKind is JsonValueKind.String)
{
errorContent = errorElement.GetString()!;
}
}
catch
{
// Ignore JSON parsing errors.
}
#pragma warning restore CA1031 // Do not catch general exception types

throw new InvalidOperationException($"Ollama error: {errorContent}");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ public async Task PromptBasedFunctionCalling_WithArgs()
Assert.False(didCallIrrelevantTool);
}

[ConditionalFact]
public async Task InvalidModelParameter_ThrowsInvalidOperationException()
{
SkipIfNotEnabled();

var endpoint = IntegrationTestHelpers.GetOllamaUri();
Assert.NotNull(endpoint);

using var chatClient = new OllamaChatClient(endpoint, modelId: "inexistent-model");

InvalidOperationException ex;
ex = await Assert.ThrowsAsync<InvalidOperationException>(() => chatClient.CompleteAsync("Hello, world!"));
Assert.Contains("inexistent-model", ex.Message);

ex = await Assert.ThrowsAsync<InvalidOperationException>(() => chatClient.CompleteStreamingAsync("Hello, world!").ToChatCompletionAsync());
Assert.Contains("inexistent-model", ex.Message);
}

private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient)
{
public override Task<ChatCompletion> CompleteAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading.Tasks;
using Microsoft.TestUtilities;
using Xunit;

namespace Microsoft.Extensions.AI;

Expand All @@ -11,4 +14,19 @@ public class OllamaEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegr
IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ?
new OllamaEmbeddingGenerator(endpoint, "all-minilm") :
null;

[ConditionalFact]
public async Task InvalidModelParameter_ThrowsInvalidOperationException()
{
SkipIfNotEnabled();

var endpoint = IntegrationTestHelpers.GetOllamaUri();
Assert.NotNull(endpoint);

using var generator = new OllamaEmbeddingGenerator(endpoint, modelId: "inexistent-model");

InvalidOperationException ex;
ex = await Assert.ThrowsAsync<InvalidOperationException>(() => generator.GenerateAsync(["Hello, world!"]));
Assert.Contains("inexistent-model", ex.Message);
}
}