Skip to content

Commit

Permalink
feat: Quantized models (#201)
Browse files Browse the repository at this point in the history
* feat: Quantized models

* refactor: use model_file for GCS

* refactoring: refactor model downloading (#209)

* refactoring: refactor model downloading

* refactor: update docstring

Co-authored-by: Anush <[email protected]>

* Update fastembed/common/model_management.py

Co-authored-by: George <[email protected]>

* fix: model_file for Snowflake models

---------

Co-authored-by: George <[email protected]>
  • Loading branch information
Anush008 and joein authored Apr 26, 2024
1 parent 466886a commit ab7a99a
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 70 deletions.
48 changes: 21 additions & 27 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
from loguru import logger


def locate_model_file(model_dir: Path, file_names: List[str]) -> Path:
"""
Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used
by Optimum and Qdrant
"""
if not model_dir.is_dir():
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")

for file_name in file_names:
file_paths = [path for path in model_dir.rglob(file_name) if path.is_file()]

if file_paths:
return file_paths[0]

raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}")


class ModelManagement:
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -104,27 +87,33 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool

@classmethod
def download_files_from_huggingface(
cls, hf_source_repo: str, cache_dir: Optional[str] = None
cls,
hf_source_repo: str,
cache_dir: Optional[str] = None,
extra_patterns: Optional[List[str]] = None,
) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically
includes the required model files.
Returns:
Path: The path to the model directory.
"""
allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=[
"*.onnx",
"*.onnx_data",
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
],
allow_patterns=allow_patterns,
cache_dir=cache_dir,
)

Expand Down Expand Up @@ -229,9 +218,14 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
url_source = model.get("sources", {}).get("url")

if hf_source:
extra_patterns = [model["model_file"]]
extra_patterns.extend(model.get("additional_files", []))

try:
return Path(
cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))
cls.download_files_from_huggingface(
hf_source, cache_dir=str(cache_dir), extra_patterns=extra_patterns
)
)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
Expand Down
13 changes: 9 additions & 4 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import numpy as np
import onnxruntime as ort

from fastembed.common.model_management import locate_model_file
from fastembed.common.models import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool, Worker


# Holds type of the embedding result
T = TypeVar("T")

Expand All @@ -34,8 +34,13 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,
"""
return onnx_input

def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: int) -> None:
model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"])
def load_onnx_model(
self,
model_dir: Path,
model_file: str,
threads: Optional[int],
) -> None:
model_path = model_dir / model_file

# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
onnx_providers = ["CPUExecutionProvider"]
Expand All @@ -50,7 +55,7 @@ def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: i
so.intra_op_num_threads = threads
so.inter_op_num_threads = threads

self.tokenizer = load_tokenizer(model_dir=model_dir, max_length=max_length)
self.tokenizer = load_tokenizer(model_dir=model_dir)
self.model = ort.InferenceSession(
str(model_path), providers=onnx_providers, sess_options=so
)
Expand Down
18 changes: 11 additions & 7 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
{
"model": "prithivida/Splade_PP_en_v1",
Expand All @@ -24,6 +25,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
]

Expand Down Expand Up @@ -77,14 +79,16 @@ def __init__(

super().__init__(model_name, cache_dir, threads, **kwargs)

self.model_name = model_name
self._model_description = self._get_model_description(model_name)
model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)

self._cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(self._model_description, self._cache_dir)
self._max_length = 512
model_dir = self.download_model(model_description, cache_dir)

self.load_onnx_model(self._model_dir, self.threads, self._max_length)
self.load_onnx_model(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
)

def embed(
self,
Expand All @@ -110,7 +114,7 @@ def embed(
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self._cache_dir),
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
Expand Down
3 changes: 3 additions & 0 deletions fastembed/text/e5_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
"hf": "qdrant/multilingual-e5-large-onnx",
},
"model_file": "model.onnx",
"additional_files": ["model.onnx_data"],
},
{
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
Expand All @@ -24,6 +26,7 @@
"sources": {
"hf": "xenova/paraphrase-multilingual-mpnet-base-v2",
},
"model_file": "onnx/model.onnx",
},
]

Expand Down
2 changes: 2 additions & 0 deletions fastembed/text/jina_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.52,
"sources": {"hf": "xenova/jina-embeddings-v2-base-en"},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.12,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
"model_file": "onnx/model.onnx",
},
]

Expand Down
Loading

0 comments on commit ab7a99a

Please sign in to comment.