Skip to content

Commit

Permalink
refactor: use model_file for GCS
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Apr 17, 2024
1 parent b3ba505 commit 1f52bf9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 32 deletions.
28 changes: 6 additions & 22 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
from enum import Enum
import os
import shutil
import tarfile
from pathlib import Path
from typing import List, Literal, Optional, Dict, Any, Tuple
from typing import List, Optional, Dict, Any, Tuple

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from tqdm import tqdm
from loguru import logger

SOURCE = Literal["hf", "gcs"]


def locate_model_file(model_dir: Path, file_names: List[str]) -> Path:
"""
Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used
by Optimum and Qdrant
"""
if not model_dir.is_dir():
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")

for file_name in file_names:
file_paths = [path for path in model_dir.rglob(file_name) if path.is_file()]

if file_paths:
return file_paths[0]

raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}")
Source = Enum("Source", ["HF", "GCS"])


class ModelManagement:
Expand Down Expand Up @@ -200,7 +184,7 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) ->
return model_dir

@classmethod
def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, SOURCE]:
def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, Source]:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.
Expand Down Expand Up @@ -232,14 +216,14 @@ def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Pa
try:
return Path(
cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))
), "hf"
), Source.HF
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e}"
"Falling back to other sources."
)

if url_source:
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), "gcs"
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), Source.GCS

raise ValueError(f"Could not download model {model['model']} from any source.")
13 changes: 7 additions & 6 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import onnxruntime as ort

from fastembed.common.model_management import SOURCE, locate_model_file
from fastembed.common.model_management import Source
from fastembed.common.models import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool, Worker
Expand Down Expand Up @@ -42,14 +42,15 @@ def load_onnx_model(
threads: Optional[int],
cache_dir: Path,
model_description: dict,
source: SOURCE,
source: Source,
) -> None:
if source == "gcs":
model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"])
elif source == "hf":
model_file = model_description["model_file"]

if source == Source.GCS:
model_path = model_dir.joinpath(model_file)
elif source == Source.HF:
# For HuggingFace sources, the model file is conditionally downloaded
repo_id = model_description["sources"]["hf"]
model_file = model_description["model_file"]

# Some models require additional repo files.
# For eg: intfloat/multilingual-e5-large requires the model.onnx_data file.
Expand Down
8 changes: 4 additions & 4 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
},
"model_file": "model.onnx",
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-base-en-v1.5",
Expand Down Expand Up @@ -47,7 +47,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
},
"model_file": "onnx/model.onnx",
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-small-en-v1.5",
Expand All @@ -67,7 +67,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
},
"model_file": "onnx/model.onnx",
"model_file": "model_optimized.onnx",
},
{
"model": "sentence-transformers/all-MiniLM-L6-v2",
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(
"""

super().__init__(model_name, cache_dir, threads, **kwargs)

model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)

Expand Down

0 comments on commit 1f52bf9

Please sign in to comment.