Skip to content

Commit

Permalink
Ensure the Ollama clients validate HTTP status codes. (#5821)
Browse files Browse the repository at this point in the history
  • Loading branch information
eiriktsarpalis authored Jan 28, 2025
1 parent fc84cf9 commit 67bed79
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages,
JsonContext.Default.OllamaChatRequest,
cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

var response = (await httpResponse.Content.ReadFromJsonAsync(
JsonContext.Default.OllamaChatResponse,
cancellationToken).ConfigureAwait(false))!;
Expand Down Expand Up @@ -117,6 +122,12 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
Content = JsonContent.Create(ToOllamaChatRequest(chatMessages, options, stream: true), JsonContext.Default.OllamaChatRequest)
};
using var httpResponse = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

using var httpResponseStream = await httpResponse.Content
#if NET
.ReadAsStreamAsync(cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
JsonContext.Default.OllamaEmbeddingRequest,
cancellationToken).ConfigureAwait(false);

if (!httpResponse.IsSuccessStatusCode)
{
await OllamaUtilities.ThrowUnsuccessfulOllamaResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
}

var response = (await httpResponse.Content.ReadFromJsonAsync(
JsonContext.Default.OllamaEmbeddingResponse,
cancellationToken).ConfigureAwait(false))!;
Expand Down
38 changes: 38 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.Ollama/OllamaUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Net.Http;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Extensions.AI;

Expand Down Expand Up @@ -31,4 +35,38 @@ public static void TransferNanosecondsTime<TResponse>(TResponse response, Func<T
}
}
}

[DoesNotReturn]
public static async ValueTask ThrowUnsuccessfulOllamaResponseAsync(HttpResponseMessage response, CancellationToken cancellationToken)
{
Debug.Assert(!response.IsSuccessStatusCode, "must only be invoked for unsuccessful responses.");

// Read the entire response content into a string.
string errorContent =
#if NET
await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
#else
await response.Content.ReadAsStringAsync().ConfigureAwait(false);
#endif

// The response content *could* be JSON formatted, try to extract the error field.

#pragma warning disable CA1031 // Do not catch general exception types
try
{
using JsonDocument document = JsonDocument.Parse(errorContent);
if (document.RootElement.TryGetProperty("error", out JsonElement errorElement) &&
errorElement.ValueKind is JsonValueKind.String)
{
errorContent = errorElement.GetString()!;
}
}
catch
{
// Ignore JSON parsing errors.
}
#pragma warning restore CA1031 // Do not catch general exception types

throw new InvalidOperationException($"Ollama error: {errorContent}");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ public async Task PromptBasedFunctionCalling_WithArgs()
Assert.False(didCallIrrelevantTool);
}

[ConditionalFact]
public async Task InvalidModelParameter_ThrowsInvalidOperationException()
{
SkipIfNotEnabled();

var endpoint = IntegrationTestHelpers.GetOllamaUri();
Assert.NotNull(endpoint);

using var chatClient = new OllamaChatClient(endpoint, modelId: "inexistent-model");

InvalidOperationException ex;
ex = await Assert.ThrowsAsync<InvalidOperationException>(() => chatClient.CompleteAsync("Hello, world!"));
Assert.Contains("inexistent-model", ex.Message);

ex = await Assert.ThrowsAsync<InvalidOperationException>(() => chatClient.CompleteStreamingAsync("Hello, world!").ToChatCompletionAsync());
Assert.Contains("inexistent-model", ex.Message);
}

private sealed class AssertNoToolsDefinedChatClient(IChatClient innerClient) : DelegatingChatClient(innerClient)
{
public override Task<ChatCompletion> CompleteAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Threading.Tasks;
using Microsoft.TestUtilities;
using Xunit;

namespace Microsoft.Extensions.AI;

Expand All @@ -11,4 +14,19 @@ public class OllamaEmbeddingGeneratorIntegrationTests : EmbeddingGeneratorIntegr
IntegrationTestHelpers.GetOllamaUri() is Uri endpoint ?
new OllamaEmbeddingGenerator(endpoint, "all-minilm") :
null;

[ConditionalFact]
public async Task InvalidModelParameter_ThrowsInvalidOperationException()
{
SkipIfNotEnabled();

var endpoint = IntegrationTestHelpers.GetOllamaUri();
Assert.NotNull(endpoint);

using var generator = new OllamaEmbeddingGenerator(endpoint, modelId: "inexistent-model");

InvalidOperationException ex;
ex = await Assert.ThrowsAsync<InvalidOperationException>(() => generator.GenerateAsync(["Hello, world!"]));
Assert.Contains("inexistent-model", ex.Message);
}
}

0 comments on commit 67bed79

Please sign in to comment.