Skip to content

Commit

Permalink
add image embedding to azure index (#271)
Browse files Browse the repository at this point in the history
## Purpose
<!-- Describe the intention of the changes being proposed. What problem
does it solve or functionality does it add? -->
* ...

## Does this introduce a breaking change?
<!-- Mark one with an "x". -->
```
[ ] Yes
[ ] No
```

## Pull Request Type
What kind of change does this Pull Request introduce?

<!-- Please check the one that applies to this PR using "x". -->
```
[ ] Bugfix
[ ] Feature
[ ] Code style update (formatting, local variables)
[ ] Refactoring (no functional changes, no api changes)
[ ] Documentation content changes
[ ] Other... Please describe:
```

## How to Test
*  Get the code

```
git clone [repo-address]
cd [repo-name]
git checkout [branch-name]
npm install
```

* Test the code
<!-- Add steps to run the tests suite and/or manually test -->
```
```

## What to Check
Verify that the following are valid
* ...

## Other Information
<!-- Add any other helpful information that may be needed here. -->

---------

Co-authored-by: David Pine <[email protected]>
  • Loading branch information
LittleLittleCloud and IEvangelist authored Feb 6, 2024
1 parent 9543b54 commit f5b168f
Show file tree
Hide file tree
Showing 37 changed files with 739 additions and 100 deletions.
4 changes: 2 additions & 2 deletions app/backend/Extensions/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
return sp.GetRequiredService<BlobServiceClient>().GetBlobContainerClient(azureStorageContainer);
});

services.AddSingleton<IDocumentService, AzureDocumentService>(sp =>
services.AddSingleton<ISearchService, AzureSearchService>(sp =>
{
var config = sp.GetRequiredService<IConfiguration>();
var azureSearchServiceEndpoint = config["AzureSearchServiceEndpoint"];
Expand All @@ -39,7 +39,7 @@ internal static IServiceCollection AddAzureServices(this IServiceCollection serv
var searchClient = new SearchClient(
new Uri(azureSearchServiceEndpoint), azureSearchIndex, s_azureCredential);

return new AzureDocumentService(searchClient);
return new AzureSearchService(searchClient);
});

services.AddSingleton<DocumentAnalysisClient>(sp =>
Expand Down
5 changes: 3 additions & 2 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ namespace MinimalApi.Services;

public class ReadRetrieveReadChatService
{
private readonly IDocumentService _searchClient;
private readonly ISearchService _searchClient;
private readonly IKernel _kernel;
private readonly IConfiguration _configuration;

public ReadRetrieveReadChatService(
IDocumentService searchClient,
ISearchService searchClient,
OpenAIClient client,
IConfiguration configuration)
{
Expand Down Expand Up @@ -158,6 +158,7 @@ Return the follow-up question as a json string list.
}
return new ApproachResponse(
DataPoints: documentContentList,
Images: null,
Answer: ans,
Thoughts: thoughts,
CitationBaseUrl: _configuration.ToCitationBaseUrl());
Expand Down
1 change: 1 addition & 0 deletions app/frontend/Services/ApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ private async Task<AnswerResult<TRequest>> PostRequestAsync<TRequest>(
$"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "☹️ Unknown error..."}",
null,
[],
null,
"Unable to retrieve valid response from the server.");

return result with
Expand Down
12 changes: 11 additions & 1 deletion app/functions/EmbedFunctions/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,17 @@ uri is not null
var documentClient = provider.GetRequiredService<DocumentAnalysisClient>();
var logger = provider.GetRequiredService<ILogger<AzureSearchEmbedService>>();

return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, logger);
return new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingModelName,
searchClient: searchClient,
searchIndexName: searchIndexName,
searchIndexClient: searchIndexClient,
documentAnalysisClient: documentClient,
corpusContainerClient: blobContainerClient,
computerVisionService: null,
includeImageEmbeddingsField: false,
logger: logger);
});
})
.ConfigureFunctionsWorkerDefaults()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ internal async Task EmbedBlobAsync(Stream blobStream, string blobName)
var embeddingType = GetEmbeddingType();
var embedService = embedServiceFactory.GetEmbedService(embeddingType);

var result = await embedService.EmbedBlobAsync(blobStream, blobName);
var result = await embedService.EmbedPDFBlobAsync(blobStream, blobName);

var status = result switch
{
Expand Down
19 changes: 0 additions & 19 deletions app/functions/EmbedFunctions/Services/IEmbedService.cs

This file was deleted.

18 changes: 17 additions & 1 deletion app/functions/EmbedFunctions/Services/MilvusEmbedService.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
// Copyright (c) Microsoft. All rights reserved.


namespace EmbedFunctions.Services;

internal sealed class MilvusEmbedService : IEmbedService
{
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();

public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}
}
18 changes: 17 additions & 1 deletion app/functions/EmbedFunctions/Services/PineconeEmbedService.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
// Copyright (c) Microsoft. All rights reserved.


namespace EmbedFunctions.Services;

internal sealed class PineconeEmbedService : IEmbedService
{
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException(
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException(
"Pinecone embedding isn't implemented.");

public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}
}
18 changes: 17 additions & 1 deletion app/functions/EmbedFunctions/Services/QdrantEmbedService.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
// Copyright (c) Microsoft. All rights reserved.


namespace EmbedFunctions.Services;

internal sealed class QdrantEmbedService : IEmbedService
{
public Task<bool> EmbedBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();
public Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedImageBlobAsync(Stream imageStream, string imageName, CancellationToken ct = default)
{
throw new NotImplementedException();
}

public Task<bool> EmbedPDFBlobAsync(Stream blobStream, string blobName) => throw new NotImplementedException();

public Task EnsureSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
throw new NotImplementedException();
}
}
1 change: 1 addition & 0 deletions app/prepdocs/PrepareDocs/AppOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal record class AppOptions(
bool Remove,
bool RemoveAll,
string? FormRecognizerServiceEndpoint,
string? ComputerVisionServiceEndpoint,
bool Verbose,
IConsole Console) : AppConsole(Console);

Expand Down
2 changes: 1 addition & 1 deletion app/prepdocs/PrepareDocs/PrepareDocs.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\functions\EmbedFunctions\EmbedFunctions.csproj" />
<ProjectReference Include="..\..\shared\Shared\Shared.csproj" />
</ItemGroup>

</Project>
32 changes: 27 additions & 5 deletions app/prepdocs/PrepareDocs/Program.Clients.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.


using EmbedFunctions.Services;

internal static partial class Program
{
private static BlobContainerClient? s_corpusContainerClient;
Expand Down Expand Up @@ -30,8 +27,19 @@ private static Task<AzureSearchEmbedService> GetAzureSearchEmbedService(AppOptio
var openAIClient = await GetAzureOpenAIClientAsync(o);
var embeddingModelName = o.EmbeddingModelName ?? throw new ArgumentNullException(nameof(o.EmbeddingModelName));
var searchIndexName = o.SearchIndexName ?? throw new ArgumentNullException(nameof(o.SearchIndexName));

return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, null);
var computerVisionService = await GetComputerVisionServiceAsync(o);

return new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingModelName,
searchClient: searchClient,
searchIndexName: searchIndexName,
searchIndexClient: searchIndexClient,
documentAnalysisClient: documentClient,
corpusContainerClient: blobContainerClient,
computerVisionService: computerVisionService,
includeImageEmbeddingsField: computerVisionService != null,
logger: null);
});

private static Task<BlobContainerClient> GetCorpusBlobContainerClientAsync(AppOptions options) =>
Expand Down Expand Up @@ -139,6 +147,20 @@ private static Task<SearchClient> GetSearchClientAsync(AppOptions options) =>
return s_searchClient;
});

private static Task<IComputerVisionService?> GetComputerVisionServiceAsync(AppOptions options) =>
GetLazyClientAsync<IComputerVisionService?>(options, s_openAILock, async o =>
{
await Task.CompletedTask;
var endpoint = o.ComputerVisionServiceEndpoint;

if (string.IsNullOrEmpty(endpoint))
{
return null;
}

return new AzureComputerVisionService(new HttpClient(), endpoint, DefaultCredential);
});

private static Task<OpenAIClient> GetAzureOpenAIClientAsync(AppOptions options) =>
GetLazyClientAsync<OpenAIClient>(options, s_openAILock, async o =>
{
Expand Down
26 changes: 21 additions & 5 deletions app/prepdocs/PrepareDocs/Program.Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ internal static partial class Program
private static readonly Option<string> s_formRecognizerServiceEndpoint =
new(name: "--formrecognizerendpoint", description: "Optional. The Azure Form Recognizer service endpoint which will be used to extract text, tables and layout from the documents (must exist already)");

private static readonly Option<string> s_computerVisionServiceEndpoint =
new(name: "--computervisionendpoint", description: "Optional. The Azure Computer Vision service endpoint which will be used to vectorize image and query");

private static readonly Option<bool> s_verbose =
new(aliases: new[] { "--verbose", "-v" }, description: "Verbose output");

Expand All @@ -49,11 +52,23 @@ internal static partial class Program
Prepare documents by extracting content from PDFs, splitting content into sections,
uploading to blob storage, and indexing in a search index.
""")
{
s_files, s_category, s_skipBlobs, s_storageEndpoint,
s_container, s_tenantId, s_searchService, s_searchIndexName, s_azureOpenAIService, s_embeddingModelName,
s_remove, s_removeAll, s_formRecognizerServiceEndpoint, s_verbose
};
{
s_files,
s_category,
s_skipBlobs,
s_storageEndpoint,
s_container,
s_tenantId,
s_searchService,
s_searchIndexName,
s_azureOpenAIService,
s_embeddingModelName,
s_remove,
s_removeAll,
s_formRecognizerServiceEndpoint,
s_computerVisionServiceEndpoint,
s_verbose,
};

private static AppOptions GetParsedAppOptions(InvocationContext context) => new(
Files: context.ParseResult.GetValueForArgument(s_files),
Expand All @@ -69,6 +84,7 @@ internal static partial class Program
Remove: context.ParseResult.GetValueForOption(s_remove),
RemoveAll: context.ParseResult.GetValueForOption(s_removeAll),
FormRecognizerServiceEndpoint: context.ParseResult.GetValueForOption(s_formRecognizerServiceEndpoint),
ComputerVisionServiceEndpoint: context.ParseResult.GetValueForOption(s_computerVisionServiceEndpoint),
Verbose: context.ParseResult.GetValueForOption(s_verbose),
Console: context.Console);
}
14 changes: 10 additions & 4 deletions app/prepdocs/PrepareDocs/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using EmbedFunctions.Services;
using System.Diagnostics;

s_rootCommand.SetHandler(
async (context) =>
Expand All @@ -15,7 +15,6 @@
{
var searchIndexName = options.SearchIndexName ?? throw new ArgumentNullException(nameof(options.SearchIndexName));
var embedService = await GetAzureSearchEmbedService(options);

await embedService.EnsureSearchIndexAsync(options.SearchIndexName);

Matcher matcher = new();
Expand Down Expand Up @@ -190,19 +189,26 @@ static async ValueTask UploadBlobsAndCreateIndexAsync(
// revert stream position
stream.Position = 0;

await embeddingService.EmbedBlobAsync(stream, documentName);
await embeddingService.EmbedPDFBlobAsync(stream, documentName);
}
finally
{
File.Delete(tempFileName);
}
}
}
// if it's an img (end with .png/.jpg/.jpeg), upload it to blob storage and embed it.
else if (Path.GetExtension(fileName).Equals(".png", StringComparison.OrdinalIgnoreCase) ||
Path.GetExtension(fileName).Equals(".jpg", StringComparison.OrdinalIgnoreCase) ||
Path.GetExtension(fileName).Equals(".jpeg", StringComparison.OrdinalIgnoreCase))
{
await embeddingService.EmbedImageBlobAsync(File.OpenRead(fileName), fileName);
}
else
{
var blobName = BlobNameFromFilePage(fileName);
await UploadBlobAsync(fileName, blobName, container);
await embeddingService.EmbedBlobAsync(File.OpenRead(fileName), blobName);
await embeddingService.EmbedPDFBlobAsync(File.OpenRead(fileName), blobName);
}
}

Expand Down
6 changes: 5 additions & 1 deletion app/shared/Shared/Models/ApproachResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
namespace Shared.Models;

public record SupportingContentRecord(string Title, string Content);

public record SupportingImageRecord(string Title, string Url);

public record ApproachResponse(
string Answer,
string? Thoughts,
SupportingContentRecord[] DataPoints, // title, content
SupportingContentRecord[]? DataPoints, // title, content
SupportingImageRecord[]? Images, // title, url
string CitationBaseUrl,
string? Error = null);
4 changes: 3 additions & 1 deletion app/shared/Shared/Services/AzureComputerVisionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
using System.Text.Json;
using Azure.Core;

public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential)
public class AzureComputerVisionService(HttpClient client, string endPoint, TokenCredential tokenCredential) : IComputerVisionService
{
public int Dimension => 1024;

// add virtual keyword to make it mockable
public async Task<ImageEmbeddingResponse> VectorizeImageAsync(string imagePathOrUrl, CancellationToken ct = default)
{
Expand Down
Loading

0 comments on commit f5b168f

Please sign in to comment.