From b3ba505ebdffc7c8adfff061031d5d605cf9056b Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 17 Apr 2024 01:46:50 +0530 Subject: [PATCH 1/5] feat: Quantized models --- fastembed/common/model_management.py | 12 ++--- fastembed/common/onnx_model.py | 35 ++++++++++++-- fastembed/sparse/splade_pp.py | 26 ++++++---- fastembed/text/e5_onnx_embedding.py | 3 ++ fastembed/text/jina_onnx_embedding.py | 2 + fastembed/text/onnx_embedding.py | 70 ++++++++++++++------------- tests/test_text_onnx_embeddings.py | 3 ++ 7 files changed, 97 insertions(+), 54 deletions(-) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 6ccce324..6b21b46f 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -2,7 +2,7 @@ import shutil import tarfile from pathlib import Path -from typing import List, Optional, Dict, Any +from typing import List, Literal, Optional, Dict, Any, Tuple import requests from huggingface_hub import snapshot_download @@ -10,6 +10,8 @@ from tqdm import tqdm from loguru import logger +SOURCE = Literal["hf", "gcs"] + def locate_model_file(model_dir: Path, file_names: List[str]) -> Path: """ @@ -118,8 +120,6 @@ def download_files_from_huggingface( return snapshot_download( repo_id=hf_source_repo, allow_patterns=[ - "*.onnx", - "*.onnx_data", "config.json", "tokenizer.json", "tokenizer_config.json", @@ -200,7 +200,7 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> return model_dir @classmethod - def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: + def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, SOURCE]: """ Downloads a model from HuggingFace Hub or Google Cloud Storage. @@ -232,7 +232,7 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: try: return Path( cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)) - ) + ), "hf" except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: logger.error( f"Could not download model from HuggingFace: {e}" @@ -240,6 +240,6 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: ) if url_source: - return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)) + return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), "gcs" raise ValueError(f"Could not download model {model['model']} from any source.") diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 190393ff..0483bd81 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -6,11 +6,13 @@ import numpy as np import onnxruntime as ort -from fastembed.common.model_management import locate_model_file +from fastembed.common.model_management import SOURCE, locate_model_file from fastembed.common.models import load_tokenizer from fastembed.common.utils import iter_batch from fastembed.parallel_processor import ParallelWorkerPool, Worker +from huggingface_hub import hf_hub_download + # Holds type of the embedding result T = TypeVar("T") @@ -34,8 +36,33 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, """ return onnx_input - def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: int) -> None: - model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"]) + def load_onnx_model( + self, + model_dir: Path, + threads: Optional[int], + cache_dir: Path, + model_description: dict, + source: SOURCE, + ) -> None: + if source == "gcs": + model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"]) + elif source == "hf": + # For HuggingFace sources, the model file is conditionally downloaded + repo_id = model_description["sources"]["hf"] + model_file = model_description["model_file"] + + # Some models require additional repo files. + # For eg: intfloat/multilingual-e5-large requires the model.onnx_data file. + # These can be specified within the "additional_files" option when describing the model properties + if additional_files := model_description.get("additional_files"): + for file in additional_files: + hf_hub_download(repo_id=repo_id, filename=file, cache_dir=str(cache_dir)) + + model_path = hf_hub_download( + repo_id=repo_id, filename=model_file, cache_dir=str(cache_dir) + ) + else: + raise ValueError(f"Unknown source: {source}") # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers onnx_providers = ["CPUExecutionProvider"] @@ -50,7 +77,7 @@ def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: i so.intra_op_num_threads = threads so.inter_op_num_threads = threads - self.tokenizer = load_tokenizer(model_dir=model_dir, max_length=max_length) + self.tokenizer = load_tokenizer(model_dir=model_dir) self.model = ort.InferenceSession( str(model_path), providers=onnx_providers, sess_options=so ) diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 92f3da88..892597cb 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -15,6 +15,7 @@ "sources": { "hf": "Qdrant/SPLADE_PP_en_v1", }, + "model_file": "model.onnx", }, { "model": "prithivida/Splade_PP_en_v1", @@ -24,6 +25,7 @@ "sources": { "hf": "Qdrant/SPLADE_PP_en_v1", }, + "model_file": "model.onnx", }, ] @@ -76,15 +78,19 @@ def __init__( """ super().__init__(model_name, cache_dir, threads, **kwargs) - - self.model_name = model_name - self._model_description = self._get_model_description(model_name) - - self._cache_dir = define_cache_dir(cache_dir) - self._model_dir = self.download_model(self._model_description, self._cache_dir) - self._max_length = 512 - - self.load_onnx_model(self._model_dir, self.threads, self._max_length) + + model_description = self._get_model_description(model_name) + cache_dir = define_cache_dir(cache_dir) + + model_dir, source = self.download_repo_files(model_description, cache_dir) + + self.load_onnx_model( + model_dir, + threads, + cache_dir, + model_description, + source, + ) def embed( self, @@ -110,7 +116,7 @@ def embed( """ yield from self._embed_documents( model_name=self.model_name, - cache_dir=str(self._cache_dir), + cache_dir=str(self.cache_dir), documents=documents, batch_size=batch_size, parallel=parallel, diff --git a/fastembed/text/e5_onnx_embedding.py b/fastembed/text/e5_onnx_embedding.py index 9c37174c..8df199d6 100644 --- a/fastembed/text/e5_onnx_embedding.py +++ b/fastembed/text/e5_onnx_embedding.py @@ -15,6 +15,8 @@ "url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", "hf": "qdrant/multilingual-e5-large-onnx", }, + "model_file": "model.onnx", + "additional_files": ["model.onnx_data"], }, { "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", @@ -24,6 +26,7 @@ "sources": { "hf": "xenova/paraphrase-multilingual-mpnet-base-v2", }, + "model_file": "onnx/model.onnx", }, ] diff --git a/fastembed/text/jina_onnx_embedding.py b/fastembed/text/jina_onnx_embedding.py index 4d2ed96e..c0e7acf6 100644 --- a/fastembed/text/jina_onnx_embedding.py +++ b/fastembed/text/jina_onnx_embedding.py @@ -13,6 +13,7 @@ "description": "English embedding model supporting 8192 sequence length", "size_in_GB": 0.52, "sources": {"hf": "xenova/jina-embeddings-v2-base-en"}, + "model_file": "onnx/model.onnx", }, { "model": "jinaai/jina-embeddings-v2-small-en", @@ -20,6 +21,7 @@ "description": "English embedding model supporting 8192 sequence length", "size_in_GB": 0.12, "sources": {"hf": "xenova/jina-embeddings-v2-small-en"}, + "model_file": "onnx/model.onnx", }, ] diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index c82bdfe5..a7be011c 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -16,6 +16,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", }, + "model_file": "model.onnx", }, { "model": "BAAI/bge-base-en-v1.5", @@ -26,6 +27,7 @@ "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", "hf": "qdrant/bge-base-en-v1.5-onnx-q", }, + "model_file": "model_optimized.onnx", }, { "model": "BAAI/bge-large-en-v1.5", @@ -35,6 +37,7 @@ "sources": { "hf": "qdrant/bge-large-en-v1.5-onnx", }, + "model_file": "model.onnx", }, { "model": "BAAI/bge-small-en", @@ -44,18 +47,8 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", }, + "model_file": "onnx/model.onnx", }, - # { - # "model": "BAAI/bge-small-en", - # "dim": 384, - # "description": "Fast English model", - # "size_in_GB": 0.2, - # "hf_sources": [], - # "compressed_url_sources": [ - # "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz", - # "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz" - # ] - # }, { "model": "BAAI/bge-small-en-v1.5", "dim": 384, @@ -64,6 +57,7 @@ "sources": { "hf": "qdrant/bge-small-en-v1.5-onnx-q", }, + "model_file": "model_optimized.onnx", }, { "model": "BAAI/bge-small-zh-v1.5", @@ -73,6 +67,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", }, + "model_file": "onnx/model.onnx", }, { "model": "sentence-transformers/all-MiniLM-L6-v2", @@ -83,6 +78,7 @@ "url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", "hf": "qdrant/all-MiniLM-L6-v2-onnx", }, + "model_file": "model.onnx", }, { "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", @@ -92,6 +88,7 @@ "sources": { "hf": "qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q", }, + "model_file": "model_optimized.onnx", }, { "model": "nomic-ai/nomic-embed-text-v1", @@ -101,6 +98,7 @@ "sources": { "hf": "nomic-ai/nomic-embed-text-v1", }, + "model_file": "onnx/model.onnx", }, { "model": "nomic-ai/nomic-embed-text-v1.5", @@ -110,6 +108,17 @@ "sources": { "hf": "nomic-ai/nomic-embed-text-v1.5", }, + "model_file": "onnx/model.onnx", + }, + { + "model": "nomic-ai/nomic-embed-text-v1.5-Q", + "dim": 768, + "description": "Quantized 8192 context length english model", + "size_in_GB": 0.13, + "sources": { + "hf": "nomic-ai/nomic-embed-text-v1.5", + }, + "model_file": "onnx/model_quantized.onnx", }, { "model": "thenlper/gte-large", @@ -119,20 +128,8 @@ "sources": { "hf": "qdrant/gte-large-onnx", }, + "model_file": "model.onnx", }, - # { - # "model": "sentence-transformers/all-MiniLM-L6-v2", - # "dim": 384, - # "description": "Sentence Transformer model, MiniLM-L6-v2", - # "size_in_GB": 0.09, - # "hf_sources": [ - # "qdrant/all-MiniLM-L6-v2-onnx" - # ], - # "compressed_url_sources": [ - # "https://storage.googleapis.com/qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz", - # "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz" - # ] - # } { "model": "mixedbread-ai/mxbai-embed-large-v1", "dim": 1024, @@ -141,6 +138,7 @@ "sources": { "hf": "mixedbread-ai/mxbai-embed-large-v1", }, + "model_file": "onnx/model.onnx", }, ] @@ -178,15 +176,19 @@ def __init__( """ super().__init__(model_name, cache_dir, threads, **kwargs) - - self.model_name = model_name - self._model_description = self._get_model_description(model_name) - - self._cache_dir = define_cache_dir(cache_dir) - self._model_dir = self.download_model(self._model_description, self._cache_dir) - self._max_length = 512 - - self.load_onnx_model(self._model_dir, self.threads, self._max_length) + + model_description = self._get_model_description(model_name) + cache_dir = define_cache_dir(cache_dir) + + model_dir, source = self.download_repo_files(model_description, cache_dir) + + self.load_onnx_model( + model_dir, + threads, + cache_dir, + model_description, + source, + ) def embed( self, @@ -212,7 +214,7 @@ def embed( """ yield from self._embed_documents( model_name=self.model_name, - cache_dir=str(self._cache_dir), + cache_dir=str(self.cache_dir), documents=documents, batch_size=batch_size, parallel=parallel, diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index be710525..994b09ea 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -28,6 +28,9 @@ ), "thenlper/gte-large": np.array([-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]), "mixedbread-ai/mxbai-embed-large-v1": np.array([0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]), + "nomic-ai/nomic-embed-text-v1.5-Q": np.array( + [-0.01554983, 0.0129992 , -0.17909265, -0.01062993, 0.00512859] + ), } From 1f52bf9767020fadafc703781710218f3ca1e802 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 17 Apr 2024 14:18:26 +0530 Subject: [PATCH 2/5] refactor: use model_file for GCS --- fastembed/common/model_management.py | 28 ++++++---------------------- fastembed/common/onnx_model.py | 13 +++++++------ fastembed/text/onnx_embedding.py | 8 ++++---- 3 files changed, 17 insertions(+), 32 deletions(-) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 6b21b46f..7adcb81b 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -1,8 +1,9 @@ +from enum import Enum import os import shutil import tarfile from pathlib import Path -from typing import List, Literal, Optional, Dict, Any, Tuple +from typing import List, Optional, Dict, Any, Tuple import requests from huggingface_hub import snapshot_download @@ -10,24 +11,7 @@ from tqdm import tqdm from loguru import logger -SOURCE = Literal["hf", "gcs"] - - -def locate_model_file(model_dir: Path, file_names: List[str]) -> Path: - """ - Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used - by Optimum and Qdrant - """ - if not model_dir.is_dir(): - raise ValueError(f"Provided model path '{model_dir}' is not a directory.") - - for file_name in file_names: - file_paths = [path for path in model_dir.rglob(file_name) if path.is_file()] - - if file_paths: - return file_paths[0] - - raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}") +Source = Enum("Source", ["HF", "GCS"]) class ModelManagement: @@ -200,7 +184,7 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> return model_dir @classmethod - def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, SOURCE]: + def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, Source]: """ Downloads a model from HuggingFace Hub or Google Cloud Storage. @@ -232,7 +216,7 @@ def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Pa try: return Path( cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)) - ), "hf" + ), Source.HF except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: logger.error( f"Could not download model from HuggingFace: {e}" @@ -240,6 +224,6 @@ def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Pa ) if url_source: - return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), "gcs" + return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), Source.GCS raise ValueError(f"Could not download model {model['model']} from any source.") diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 0483bd81..9462ec0b 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -6,7 +6,7 @@ import numpy as np import onnxruntime as ort -from fastembed.common.model_management import SOURCE, locate_model_file +from fastembed.common.model_management import Source from fastembed.common.models import load_tokenizer from fastembed.common.utils import iter_batch from fastembed.parallel_processor import ParallelWorkerPool, Worker @@ -42,14 +42,15 @@ def load_onnx_model( threads: Optional[int], cache_dir: Path, model_description: dict, - source: SOURCE, + source: Source, ) -> None: - if source == "gcs": - model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"]) - elif source == "hf": + model_file = model_description["model_file"] + + if source == Source.GCS: + model_path = model_dir.joinpath(model_file) + elif source == Source.HF: # For HuggingFace sources, the model file is conditionally downloaded repo_id = model_description["sources"]["hf"] - model_file = model_description["model_file"] # Some models require additional repo files. # For eg: intfloat/multilingual-e5-large requires the model.onnx_data file. diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index a7be011c..5fc447bb 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -16,7 +16,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", }, - "model_file": "model.onnx", + "model_file": "model_optimized.onnx", }, { "model": "BAAI/bge-base-en-v1.5", @@ -47,7 +47,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz", }, - "model_file": "onnx/model.onnx", + "model_file": "model_optimized.onnx", }, { "model": "BAAI/bge-small-en-v1.5", @@ -67,7 +67,7 @@ "sources": { "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", }, - "model_file": "onnx/model.onnx", + "model_file": "model_optimized.onnx", }, { "model": "sentence-transformers/all-MiniLM-L6-v2", @@ -176,7 +176,7 @@ def __init__( """ super().__init__(model_name, cache_dir, threads, **kwargs) - + model_description = self._get_model_description(model_name) cache_dir = define_cache_dir(cache_dir) From 792794248a4e71197ab12dfe56f069f6d16dcb73 Mon Sep 17 00:00:00 2001 From: George Date: Fri, 19 Apr 2024 18:47:16 +0200 Subject: [PATCH 3/5] refactoring: refactor model downloading (#209) * refactoring: refactor model downloading * refactor: update docstring Co-authored-by: Anush --- fastembed/common/model_management.py | 41 ++++++++++++++++++---------- fastembed/common/onnx_model.py | 27 ++---------------- fastembed/sparse/splade_pp.py | 12 ++++---- fastembed/text/onnx_embedding.py | 11 +++----- 4 files changed, 37 insertions(+), 54 deletions(-) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 7adcb81b..1ff7492f 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -1,9 +1,8 @@ -from enum import Enum import os import shutil import tarfile from pathlib import Path -from typing import List, Optional, Dict, Any, Tuple +from typing import List, Optional, Dict, Any import requests from huggingface_hub import snapshot_download @@ -11,8 +10,6 @@ from tqdm import tqdm from loguru import logger -Source = Enum("Source", ["HF", "GCS"]) - class ModelManagement: @classmethod @@ -90,25 +87,33 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool @classmethod def download_files_from_huggingface( - cls, hf_source_repo: str, cache_dir: Optional[str] = None + cls, + hf_source_repo: str, + cache_dir: Optional[str] = None, + extra_patterns: Optional[List[str]] = None, ) -> str: """ Downloads a model from HuggingFace Hub. Args: hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx". cache_dir (Optional[str]): The path to the cache directory. + extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically + includes the required model files. Returns: Path: The path to the model directory. """ + allow_patterns = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + ] + if extra_patterns is not None: + allow_patterns.extend(extra_patterns) return snapshot_download( repo_id=hf_source_repo, - allow_patterns=[ - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - ], + allow_patterns=allow_patterns, cache_dir=cache_dir, ) @@ -184,7 +189,7 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) -> return model_dir @classmethod - def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, Source]: + def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: """ Downloads a model from HuggingFace Hub or Google Cloud Storage. @@ -213,10 +218,16 @@ def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Pa url_source = model.get("sources", {}).get("url") if hf_source: + extra_patterns = [] + extra_patterns.extend([model["model_file"]]) + extra_patterns.extend(model.get("additional_files", [])) + try: return Path( - cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir)) - ), Source.HF + cls.download_files_from_huggingface( + hf_source, cache_dir=str(cache_dir), extra_patterns=extra_patterns + ) + ) except (EnvironmentError, RepositoryNotFoundError, ValueError) as e: logger.error( f"Could not download model from HuggingFace: {e}" @@ -224,6 +235,6 @@ def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Pa ) if url_source: - return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), Source.GCS + return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)) raise ValueError(f"Could not download model {model['model']} from any source.") diff --git a/fastembed/common/onnx_model.py b/fastembed/common/onnx_model.py index 9462ec0b..9ddd50be 100644 --- a/fastembed/common/onnx_model.py +++ b/fastembed/common/onnx_model.py @@ -6,12 +6,10 @@ import numpy as np import onnxruntime as ort -from fastembed.common.model_management import Source from fastembed.common.models import load_tokenizer from fastembed.common.utils import iter_batch from fastembed.parallel_processor import ParallelWorkerPool, Worker -from huggingface_hub import hf_hub_download # Holds type of the embedding result T = TypeVar("T") @@ -39,31 +37,10 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str, def load_onnx_model( self, model_dir: Path, + model_file: str, threads: Optional[int], - cache_dir: Path, - model_description: dict, - source: Source, ) -> None: - model_file = model_description["model_file"] - - if source == Source.GCS: - model_path = model_dir.joinpath(model_file) - elif source == Source.HF: - # For HuggingFace sources, the model file is conditionally downloaded - repo_id = model_description["sources"]["hf"] - - # Some models require additional repo files. - # For eg: intfloat/multilingual-e5-large requires the model.onnx_data file. - # These can be specified within the "additional_files" option when describing the model properties - if additional_files := model_description.get("additional_files"): - for file in additional_files: - hf_hub_download(repo_id=repo_id, filename=file, cache_dir=str(cache_dir)) - - model_path = hf_hub_download( - repo_id=repo_id, filename=model_file, cache_dir=str(cache_dir) - ) - else: - raise ValueError(f"Unknown source: {source}") + model_path = model_dir / model_file # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers onnx_providers = ["CPUExecutionProvider"] diff --git a/fastembed/sparse/splade_pp.py b/fastembed/sparse/splade_pp.py index 892597cb..f59114ca 100644 --- a/fastembed/sparse/splade_pp.py +++ b/fastembed/sparse/splade_pp.py @@ -78,18 +78,16 @@ def __init__( """ super().__init__(model_name, cache_dir, threads, **kwargs) - + model_description = self._get_model_description(model_name) cache_dir = define_cache_dir(cache_dir) - model_dir, source = self.download_repo_files(model_description, cache_dir) + model_dir = self.download_model(model_description, cache_dir) self.load_onnx_model( - model_dir, - threads, - cache_dir, - model_description, - source, + model_dir=model_dir, + model_file=model_description["model_file"], + threads=threads, ) def embed( diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 5fc447bb..b44ce74c 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -179,15 +179,12 @@ def __init__( model_description = self._get_model_description(model_name) cache_dir = define_cache_dir(cache_dir) - - model_dir, source = self.download_repo_files(model_description, cache_dir) + model_dir = self.download_model(model_description, cache_dir) self.load_onnx_model( - model_dir, - threads, - cache_dir, - model_description, - source, + model_dir=model_dir, + model_file=model_description["model_file"], + threads=threads, ) def embed( From 778adaaaf2bde5e0c1860d57dc17e6b27d087737 Mon Sep 17 00:00:00 2001 From: Anush Date: Fri, 19 Apr 2024 22:20:29 +0530 Subject: [PATCH 4/5] Update fastembed/common/model_management.py Co-authored-by: George --- fastembed/common/model_management.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fastembed/common/model_management.py b/fastembed/common/model_management.py index 1ff7492f..f7a75be3 100644 --- a/fastembed/common/model_management.py +++ b/fastembed/common/model_management.py @@ -218,8 +218,7 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path: url_source = model.get("sources", {}).get("url") if hf_source: - extra_patterns = [] - extra_patterns.extend([model["model_file"]]) + extra_patterns = [model["model_file"]] extra_patterns.extend(model.get("additional_files", [])) try: From 0b7ff9ed355fdbf71cb5000a3a9009182d411575 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 19 Apr 2024 22:25:07 +0530 Subject: [PATCH 5/5] fix: model_file for Snowflake models --- fastembed/text/onnx_embedding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index c5836494..059cce20 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -148,6 +148,7 @@ "sources": { "hf": "snowflake/snowflake-arctic-embed-xs", }, + "model_file": "onnx/model.onnx", }, { "model": "snowflake/snowflake-arctic-embed-s", @@ -157,6 +158,7 @@ "sources": { "hf": "snowflake/snowflake-arctic-embed-s", }, + "model_file": "onnx/model.onnx", }, { "model": "snowflake/snowflake-arctic-embed-m", @@ -166,6 +168,7 @@ "sources": { "hf": "Snowflake/snowflake-arctic-embed-m", }, + "model_file": "onnx/model.onnx", }, { "model": "snowflake/snowflake-arctic-embed-m-long", @@ -175,6 +178,7 @@ "sources": { "hf": "snowflake/snowflake-arctic-embed-m-long", }, + "model_file": "onnx/model.onnx", }, { "model": "snowflake/snowflake-arctic-embed-l", @@ -184,6 +188,7 @@ "sources": { "hf": "snowflake/snowflake-arctic-embed-l", }, + "model_file": "onnx/model.onnx", }, ]