diff --git a/docs/data_ingestion.md b/docs/data_ingestion.md index dd934fa30a..16bc60f2fd 100644 --- a/docs/data_ingestion.md +++ b/docs/data_ingestion.md @@ -36,7 +36,7 @@ If needed, you can modify the chunking algorithm in `scripts/prepdocslib/textspl To upload more PDFs, put them in the data/ folder and run `./scripts/prepdocs.sh` or `./scripts/prepdocs.ps1`. -A [recent change](https://github.com/Azure-Samples/azure-search-openai-demo/pull/835) added checks to see what's been uploaded before. The prepdocs script now writes an .md5 file with an MD5 hash of each file that gets uploaded. Whenever the prepdocs script is re-run, that hash is checked against the current hash and the file is skipped if it hasn't changed. +The script checks existing docs by comparing the hash of the local file to the hash of the file in the blob storage. If the hash is different, it will upload the new file to blob storage and update the index. If the hash is the same, it will skip the file. ### Removing documents diff --git a/scripts/prepdocs.ps1 b/scripts/prepdocs.ps1 index dad1b0c090..edc7934182 100755 --- a/scripts/prepdocs.ps1 +++ b/scripts/prepdocs.ps1 @@ -27,6 +27,7 @@ if ($env:AZURE_USE_AUTHENTICATION) { if ($env:AZURE_SEARCH_ANALYZER_NAME) { $searchAnalyzerNameArg = "--searchanalyzername $env:AZURE_SEARCH_ANALYZER_NAME" } + if ($env:AZURE_VISION_ENDPOINT) { $visionEndpointArg = "--visionendpoint $env:AZURE_VISION_ENDPOINT" } diff --git a/scripts/prepdocs.py b/scripts/prepdocs.py index 1bccfffa0a..504ac1f1bd 100644 --- a/scripts/prepdocs.py +++ b/scripts/prepdocs.py @@ -23,6 +23,7 @@ from prepdocslib.jsonparser import JsonParser from prepdocslib.listfilestrategy import ( ADLSGen2ListFileStrategy, + BlobListFileStrategy, ListFileStrategy, LocalListFileStrategy, ) @@ -132,6 +133,11 @@ async def setup_file_strategy(credential: AsyncTokenCredential, args: Any) -> St credential=adls_gen2_creds, verbose=args.verbose, ) + elif args.blobstoragehashcheck: + print("Using Blob Storage Account files to get hashes of existing files") + list_file_strategy = BlobListFileStrategy( + path_pattern=args.files, blob_manager=blob_manager, verbose=args.verbose + ) else: print(f"Using local files in {args.files}") list_file_strategy = LocalListFileStrategy(path_pattern=args.files, verbose=args.verbose) @@ -253,6 +259,12 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any): parser.add_argument( "--datalakestorageaccount", required=False, help="Optional. Azure Data Lake Storage Gen2 Account name" ) + parser.add_argument( + "--blobstoragehashcheck", + action="store_true", + required=False, + help="Optional. Use files from this Azure Blob Storage account for hash comparisons, rather than using local files.", + ) parser.add_argument( "--datalakefilesystem", required=False, diff --git a/scripts/prepdocslib/blobmanager.py b/scripts/prepdocslib/blobmanager.py index 2521b62eac..3ba8f37549 100644 --- a/scripts/prepdocslib/blobmanager.py +++ b/scripts/prepdocslib/blobmanager.py @@ -1,3 +1,4 @@ +import binascii import datetime import io import os @@ -15,7 +16,7 @@ from PIL import Image, ImageDraw, ImageFont from pypdf import PdfReader -from .listfilestrategy import File +from .file import File class BlobManager: @@ -158,6 +159,18 @@ async def remove_blob(self, path: Optional[str] = None): print(f"\tRemoving blob {blob_path}") await container_client.delete_blob(blob_path) + async def get_blob_hash(self, blob_name: str): + async with BlobServiceClient( + account_url=self.endpoint, credential=self.credential + ) as service_client, service_client.get_blob_client(self.container, blob_name) as blob_client: + if not await blob_client.exists(): + return None + + blob_properties = await blob_client.get_blob_properties() + blob_hash_raw_bytes = blob_properties.content_settings.content_md5 + hex_hash = binascii.hexlify(blob_hash_raw_bytes) + return hex_hash.decode("utf-8") + @classmethod def sourcepage_from_file_page(cls, filename, page=0) -> str: if os.path.splitext(filename)[1].lower() == ".pdf": diff --git a/scripts/prepdocslib/embeddings.py b/scripts/prepdocslib/embeddings.py index d6193d1ce7..f917c1c17c 100644 --- a/scripts/prepdocslib/embeddings.py +++ b/scripts/prepdocslib/embeddings.py @@ -45,7 +45,7 @@ async def create_client(self) -> AsyncOpenAI: def before_retry_sleep(self, retry_state): if self.verbose: - print("Rate limited on the OpenAI embeddings API, sleeping before retrying...") + print("\tRate limited on the OpenAI embeddings API, sleeping before retrying...") def calculate_token_length(self, text: str): encoding = tiktoken.encoding_for_model(self.open_ai_model_name) @@ -83,6 +83,8 @@ def split_text_into_batches(self, texts: List[str]) -> List[EmbeddingBatch]: return batches async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]: + if self.verbose: + print(f"Generating embeddings for {len(texts)} sections in batches...") batches = self.split_text_into_batches(texts) embeddings = [] client = await self.create_client() @@ -97,7 +99,7 @@ async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]: emb_response = await client.embeddings.create(model=self.open_ai_model_name, input=batch.texts) embeddings.extend([data.embedding for data in emb_response.data]) if self.verbose: - print(f"Batch Completed. Batch size {len(batch.texts)} Token count {batch.token_length}") + print(f"\tBatch Completed. Batch size: {len(batch.texts)} Token count: {batch.token_length}") return embeddings diff --git a/scripts/prepdocslib/file.py b/scripts/prepdocslib/file.py new file mode 100644 index 0000000000..92b2380e62 --- /dev/null +++ b/scripts/prepdocslib/file.py @@ -0,0 +1,33 @@ +import base64 +import os +import re +from typing import IO, Optional + + +class File: + """ + Represents a file stored either locally or in a data lake storage account + This file might contain access control information about which users or groups can access it + """ + + def __init__(self, content: IO, acls: Optional[dict[str, list]] = None): + self.content = content + self.acls = acls or {} + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def filename(self): + return os.path.basename(self.content.name) + + def filename_to_id(self): + filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename()) + filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii") + return f"file-{filename_ascii}-{filename_hash}" + + def close(self): + if self.content: + self.content.close() diff --git a/scripts/prepdocslib/listfilestrategy.py b/scripts/prepdocslib/listfilestrategy.py index d0b24876f1..a291919eb8 100644 --- a/scripts/prepdocslib/listfilestrategy.py +++ b/scripts/prepdocslib/listfilestrategy.py @@ -1,42 +1,16 @@ -import base64 import hashlib import os -import re import tempfile from abc import ABC from glob import glob -from typing import IO, AsyncGenerator, Dict, List, Optional, Union +from typing import AsyncGenerator, Dict, List, Union from azure.core.credentials_async import AsyncTokenCredential from azure.storage.filedatalake.aio import ( DataLakeServiceClient, ) - -class File: - """ - Represents a file stored either locally or in a data lake storage account - This file might contain access control information about which users or groups can access it - """ - - def __init__(self, content: IO, acls: Optional[dict[str, list]] = None): - self.content = content - self.acls = acls or {} - - def filename(self): - return os.path.basename(self.content.name) - - def file_extension(self): - return os.path.splitext(self.content.name)[1] - - def filename_to_id(self): - filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename()) - filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii") - return f"file-{filename_ascii}-{filename_hash}" - - def close(self): - if self.content: - self.content.close() +from .blobmanager import BlobManager class ListFileStrategy(ABC): @@ -106,6 +80,55 @@ def check_md5(self, path: str) -> bool: return False +class BlobListFileStrategy(ListFileStrategy): + """ + Concrete strategy for listing remote files that are located in a blob storage account + """ + + def __init__(self, path_pattern: str, blob_manager: BlobManager, verbose: bool = False): + self.path_pattern = path_pattern + self.blob_manager = blob_manager + self.verbose = verbose + + async def list_paths(self) -> AsyncGenerator[str, None]: + async for p in self._list_paths(self.path_pattern): + yield p + + async def _list_paths(self, path_pattern: str) -> AsyncGenerator[str, None]: + for path in glob(path_pattern): + if os.path.isdir(path): + async for p in self._list_paths(f"{path}/*"): + yield p + else: + # Only list files, not directories + yield path + + async def list(self) -> AsyncGenerator[File, None]: + async for path in self.list_paths(): + if not await self.check_md5(path): + yield File(content=open(path, mode="rb")) + + async def check_md5(self, path: str) -> bool: + # if filename ends in .md5 skip + if path.endswith(".md5"): + return True + + # get hash from local file + with open(path, "rb") as file: + existing_hash = hashlib.md5(file.read()).hexdigest() + + # get hash from blob storage + blob_hash = await self.blob_manager.get_blob_hash(os.path.basename(path)) + + # compare hashes from local and blob storage + if blob_hash and blob_hash.strip() == existing_hash.strip(): + if self.verbose: + print(f"Skipping {path}, no changes detected.") + return True + + return False + + class ADLSGen2ListFileStrategy(ListFileStrategy): """ Concrete strategy for listing files that are located in a data lake storage account diff --git a/scripts/prepdocslib/searchmanager.py b/scripts/prepdocslib/searchmanager.py index 4f9e14bdc3..3558a676cc 100644 --- a/scripts/prepdocslib/searchmanager.py +++ b/scripts/prepdocslib/searchmanager.py @@ -61,7 +61,7 @@ def __init__( async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None): if self.search_info.verbose: - print(f"Ensuring search index {self.search_info.index_name} exists") + print(f"Ensuring search index '{self.search_info.index_name}' exists...") async with self.search_info.create_search_index_client() as search_index_client: fields = [ @@ -174,11 +174,11 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] ) if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]: if self.search_info.verbose: - print(f"Creating {self.search_info.index_name} search index") + print(f"\tCreating '{self.search_info.index_name}' search index") await search_index_client.create_index(index) else: if self.search_info.verbose: - print(f"Search index {self.search_info.index_name} already exists") + print(f"\tSearch index '{self.search_info.index_name}' already exists") async def update_content( self, @@ -221,11 +221,21 @@ async def update_content( for i, (document, section) in enumerate(zip(documents, batch)): document["imageEmbedding"] = image_embeddings[section.split_page.page_num] + # Remove any existing documents with the same sourcefile before uploading new ones + # that ensures we don't have outdated documents in the index + await self.remove_content(path=batch[0].content.filename()) + if self.search_info.verbose: + print( + f"Uploading {len(documents)} sections from '{batch[0].content.filename()}' to search index '{self.search_info.index_name}'" + ) await search_client.upload_documents(documents) async def remove_content(self, path: Optional[str] = None): if self.search_info.verbose: - print(f"Removing sections from '{path or ''}' from search index '{self.search_info.index_name}'") + print( + f"Potentially removing sections from '{path or ''}' from search index '{self.search_info.index_name}'..." + ) + total_removed = 0 async with self.search_info.create_search_client() as search_client: while True: filter = None if path is None else f"sourcefile eq '{os.path.basename(path)}'" @@ -235,7 +245,8 @@ async def remove_content(self, path: Optional[str] = None): removed_docs = await search_client.delete_documents( documents=[{"id": document["id"]} async for document in result] ) - if self.search_info.verbose: - print(f"\tRemoved {len(removed_docs)} sections from index") + total_removed += len(removed_docs) # It can take a few seconds for search results to reflect changes, so wait a bit await asyncio.sleep(2) + if self.search_info.verbose: + print(f"\tRemoved {total_removed} sections from index") diff --git a/tests/test_blob_manager.py b/tests/test_blob_manager.py index 218856f5d2..474c294a5e 100644 --- a/tests/test_blob_manager.py +++ b/tests/test_blob_manager.py @@ -123,6 +123,35 @@ async def mock_delete_blob(self, name, *args, **kwargs): await blob_manager.remove_blob() +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher") +async def test_get_blob_hash(monkeypatch, mock_env, blob_manager): + blob_name = "test_blob" + + # Set up mocks used by get_blob_hash + async def mock_exists(*args, **kwargs): + return True + + monkeypatch.setattr("azure.storage.blob.aio.BlobClient.exists", mock_exists) + + async def mock_get_blob_properties(*args, **kwargs): + class MockBlobProperties: + class MockContentSettings: + content_md5 = b"\x14\x0c\xdd\x8f\xd2\x74\x3d\x3b\xf1\xd1\xe2\x43\x01\xe4\xa0\x11" + + content_settings = MockContentSettings() + + return MockBlobProperties() + + monkeypatch.setattr("azure.storage.blob.aio.BlobClient.get_blob_properties", mock_get_blob_properties) + + blob_hash = await blob_manager.get_blob_hash(blob_name) + + # The expected hash is the hex encoding of the mock content MD5 + expected_hash = "140cdd8fd2743d3bf1d1e24301e4a011" + assert blob_hash == expected_hash + + @pytest.mark.asyncio @pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher") async def test_create_container_upon_upload(monkeypatch, mock_env, blob_manager):