Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check blob hash #942

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/data_ingestion.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ If needed, you can modify the chunking algorithm in `scripts/prepdocslib/textspl

To upload more PDFs, put them in the data/ folder and run `./scripts/prepdocs.sh` or `./scripts/prepdocs.ps1`.

A [recent change](https://github.com/Azure-Samples/azure-search-openai-demo/pull/835) added checks to see what's been uploaded before. The prepdocs script now writes an .md5 file with an MD5 hash of each file that gets uploaded. Whenever the prepdocs script is re-run, that hash is checked against the current hash and the file is skipped if it hasn't changed.
The script checks existing docs by comparing the hash of the local file to the hash of the file in the blob storage. If the hash is different, it will upload the new file to blob storage and update the index. If the hash is the same, it will skip the file.

### Removing documents

Expand Down
1 change: 1 addition & 0 deletions scripts/prepdocs.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ if ($env:AZURE_USE_AUTHENTICATION) {
if ($env:AZURE_SEARCH_ANALYZER_NAME) {
$searchAnalyzerNameArg = "--searchanalyzername $env:AZURE_SEARCH_ANALYZER_NAME"
}

if ($env:AZURE_VISION_ENDPOINT) {
$visionEndpointArg = "--visionendpoint $env:AZURE_VISION_ENDPOINT"
}
Expand Down
12 changes: 12 additions & 0 deletions scripts/prepdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from prepdocslib.jsonparser import JsonParser
from prepdocslib.listfilestrategy import (
ADLSGen2ListFileStrategy,
BlobListFileStrategy,
ListFileStrategy,
LocalListFileStrategy,
)
Expand Down Expand Up @@ -132,6 +133,11 @@ async def setup_file_strategy(credential: AsyncTokenCredential, args: Any) -> St
credential=adls_gen2_creds,
verbose=args.verbose,
)
elif args.blobstoragehashcheck:
print("Using Blob Storage Account files to get hashes of existing files")
list_file_strategy = BlobListFileStrategy(
path_pattern=args.files, blob_manager=blob_manager, verbose=args.verbose
)
else:
print(f"Using local files in {args.files}")
list_file_strategy = LocalListFileStrategy(path_pattern=args.files, verbose=args.verbose)
Expand Down Expand Up @@ -253,6 +259,12 @@ async def main(strategy: Strategy, credential: AsyncTokenCredential, args: Any):
parser.add_argument(
"--datalakestorageaccount", required=False, help="Optional. Azure Data Lake Storage Gen2 Account name"
)
parser.add_argument(
"--blobstoragehashcheck",
action="store_true",
required=False,
help="Optional. Use files from this Azure Blob Storage account for hash comparisons, rather than using local files.",
)
parser.add_argument(
"--datalakefilesystem",
required=False,
Expand Down
15 changes: 14 additions & 1 deletion scripts/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import binascii
import datetime
import io
import os
Expand All @@ -15,7 +16,7 @@
from PIL import Image, ImageDraw, ImageFont
from pypdf import PdfReader

from .listfilestrategy import File
from .file import File


class BlobManager:
Expand Down Expand Up @@ -158,6 +159,18 @@ async def remove_blob(self, path: Optional[str] = None):
print(f"\tRemoving blob {blob_path}")
await container_client.delete_blob(blob_path)

async def get_blob_hash(self, blob_name: str):
async with BlobServiceClient(
account_url=self.endpoint, credential=self.credential
) as service_client, service_client.get_blob_client(self.container, blob_name) as blob_client:
if not await blob_client.exists():
return None

blob_properties = await blob_client.get_blob_properties()
blob_hash_raw_bytes = blob_properties.content_settings.content_md5
hex_hash = binascii.hexlify(blob_hash_raw_bytes)
return hex_hash.decode("utf-8")

@classmethod
def sourcepage_from_file_page(cls, filename, page=0) -> str:
if os.path.splitext(filename)[1].lower() == ".pdf":
Expand Down
6 changes: 4 additions & 2 deletions scripts/prepdocslib/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def create_client(self) -> AsyncOpenAI:

def before_retry_sleep(self, retry_state):
if self.verbose:
print("Rate limited on the OpenAI embeddings API, sleeping before retrying...")
print("\tRate limited on the OpenAI embeddings API, sleeping before retrying...")

def calculate_token_length(self, text: str):
encoding = tiktoken.encoding_for_model(self.open_ai_model_name)
Expand Down Expand Up @@ -83,6 +83,8 @@ def split_text_into_batches(self, texts: List[str]) -> List[EmbeddingBatch]:
return batches

async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]:
if self.verbose:
print(f"Generating embeddings for {len(texts)} sections in batches...")
batches = self.split_text_into_batches(texts)
embeddings = []
client = await self.create_client()
Expand All @@ -97,7 +99,7 @@ async def create_embedding_batch(self, texts: List[str]) -> List[List[float]]:
emb_response = await client.embeddings.create(model=self.open_ai_model_name, input=batch.texts)
embeddings.extend([data.embedding for data in emb_response.data])
if self.verbose:
print(f"Batch Completed. Batch size {len(batch.texts)} Token count {batch.token_length}")
print(f"\tBatch Completed. Batch size: {len(batch.texts)} Token count: {batch.token_length}")

return embeddings

Expand Down
33 changes: 33 additions & 0 deletions scripts/prepdocslib/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import base64
import os
import re
from typing import IO, Optional


class File:
"""
Represents a file stored either locally or in a data lake storage account
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None):
self.content = content
self.acls = acls or {}

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

def filename(self):
return os.path.basename(self.content.name)

def filename_to_id(self):
filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename())
filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii")
return f"file-{filename_ascii}-{filename_hash}"

def close(self):
if self.content:
self.content.close()
79 changes: 51 additions & 28 deletions scripts/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,16 @@
import base64
import hashlib
import os
import re
import tempfile
from abc import ABC
from glob import glob
from typing import IO, AsyncGenerator, Dict, List, Optional, Union
from typing import AsyncGenerator, Dict, List, Union

from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.filedatalake.aio import (
DataLakeServiceClient,
)


class File:
"""
Represents a file stored either locally or in a data lake storage account
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None):
self.content = content
self.acls = acls or {}

def filename(self):
return os.path.basename(self.content.name)

def file_extension(self):
return os.path.splitext(self.content.name)[1]

def filename_to_id(self):
filename_ascii = re.sub("[^0-9a-zA-Z_-]", "_", self.filename())
filename_hash = base64.b16encode(self.filename().encode("utf-8")).decode("ascii")
return f"file-{filename_ascii}-{filename_hash}"

def close(self):
if self.content:
self.content.close()
from .blobmanager import BlobManager


class ListFileStrategy(ABC):
Expand Down Expand Up @@ -106,6 +80,55 @@ def check_md5(self, path: str) -> bool:
return False


class BlobListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing remote files that are located in a blob storage account
"""

def __init__(self, path_pattern: str, blob_manager: BlobManager, verbose: bool = False):
self.path_pattern = path_pattern
self.blob_manager = blob_manager
self.verbose = verbose

async def list_paths(self) -> AsyncGenerator[str, None]:
async for p in self._list_paths(self.path_pattern):
yield p

async def _list_paths(self, path_pattern: str) -> AsyncGenerator[str, None]:
for path in glob(path_pattern):
if os.path.isdir(path):
async for p in self._list_paths(f"{path}/*"):
yield p
else:
# Only list files, not directories
yield path

async def list(self) -> AsyncGenerator[File, None]:
async for path in self.list_paths():
if not await self.check_md5(path):
yield File(content=open(path, mode="rb"))

async def check_md5(self, path: str) -> bool:
# if filename ends in .md5 skip
if path.endswith(".md5"):
return True

# get hash from local file
with open(path, "rb") as file:
existing_hash = hashlib.md5(file.read()).hexdigest()

# get hash from blob storage
blob_hash = await self.blob_manager.get_blob_hash(os.path.basename(path))

# compare hashes from local and blob storage
if blob_hash and blob_hash.strip() == existing_hash.strip():
if self.verbose:
print(f"Skipping {path}, no changes detected.")
return True

return False


class ADLSGen2ListFileStrategy(ListFileStrategy):
"""
Concrete strategy for listing files that are located in a data lake storage account
Expand Down
23 changes: 17 additions & 6 deletions scripts/prepdocslib/searchmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] = None):
if self.search_info.verbose:
print(f"Ensuring search index {self.search_info.index_name} exists")
print(f"Ensuring search index '{self.search_info.index_name}' exists...")

async with self.search_info.create_search_index_client() as search_index_client:
fields = [
Expand Down Expand Up @@ -174,11 +174,11 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]]
)
if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]:
if self.search_info.verbose:
print(f"Creating {self.search_info.index_name} search index")
print(f"\tCreating '{self.search_info.index_name}' search index")
await search_index_client.create_index(index)
else:
if self.search_info.verbose:
print(f"Search index {self.search_info.index_name} already exists")
print(f"\tSearch index '{self.search_info.index_name}' already exists")

async def update_content(
self,
Expand Down Expand Up @@ -221,11 +221,21 @@ async def update_content(
for i, (document, section) in enumerate(zip(documents, batch)):
document["imageEmbedding"] = image_embeddings[section.split_page.page_num]

# Remove any existing documents with the same sourcefile before uploading new ones
# that ensures we don't have outdated documents in the index
await self.remove_content(path=batch[0].content.filename())
if self.search_info.verbose:
print(
f"Uploading {len(documents)} sections from '{batch[0].content.filename()}' to search index '{self.search_info.index_name}'"
)
await search_client.upload_documents(documents)

async def remove_content(self, path: Optional[str] = None):
if self.search_info.verbose:
print(f"Removing sections from '{path or '<all>'}' from search index '{self.search_info.index_name}'")
print(
f"Potentially removing sections from '{path or '<all>'}' from search index '{self.search_info.index_name}'..."
)
total_removed = 0
async with self.search_info.create_search_client() as search_client:
while True:
filter = None if path is None else f"sourcefile eq '{os.path.basename(path)}'"
Expand All @@ -235,7 +245,8 @@ async def remove_content(self, path: Optional[str] = None):
removed_docs = await search_client.delete_documents(
documents=[{"id": document["id"]} async for document in result]
)
if self.search_info.verbose:
print(f"\tRemoved {len(removed_docs)} sections from index")
total_removed += len(removed_docs)
# It can take a few seconds for search results to reflect changes, so wait a bit
await asyncio.sleep(2)
if self.search_info.verbose:
print(f"\tRemoved {total_removed} sections from index")
29 changes: 29 additions & 0 deletions tests/test_blob_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,35 @@ async def mock_delete_blob(self, name, *args, **kwargs):
await blob_manager.remove_blob()


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_get_blob_hash(monkeypatch, mock_env, blob_manager):
blob_name = "test_blob"

# Set up mocks used by get_blob_hash
async def mock_exists(*args, **kwargs):
return True

monkeypatch.setattr("azure.storage.blob.aio.BlobClient.exists", mock_exists)

async def mock_get_blob_properties(*args, **kwargs):
class MockBlobProperties:
class MockContentSettings:
content_md5 = b"\x14\x0c\xdd\x8f\xd2\x74\x3d\x3b\xf1\xd1\xe2\x43\x01\xe4\xa0\x11"

content_settings = MockContentSettings()

return MockBlobProperties()

monkeypatch.setattr("azure.storage.blob.aio.BlobClient.get_blob_properties", mock_get_blob_properties)

blob_hash = await blob_manager.get_blob_hash(blob_name)

# The expected hash is the hex encoding of the mock content MD5
expected_hash = "140cdd8fd2743d3bf1d1e24301e4a011"
assert blob_hash == expected_hash


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info.minor < 10, reason="requires Python 3.10 or higher")
async def test_create_container_upon_upload(monkeypatch, mock_env, blob_manager):
Expand Down
Loading