Skip to content

Commit

Permalink
feat: Quantized models
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Apr 16, 2024
1 parent 9bad443 commit 49edfa3
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 54 deletions.
12 changes: 6 additions & 6 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import shutil
import tarfile
from pathlib import Path
from typing import List, Optional, Dict, Any
from typing import List, Literal, Optional, Dict, Any, Tuple

import requests
from huggingface_hub import snapshot_download
from huggingface_hub.utils import RepositoryNotFoundError
from tqdm import tqdm
from loguru import logger

SOURCE = Literal["hf", "gcs"]


def locate_model_file(model_dir: Path, file_names: List[str]) -> Path:
"""
Expand Down Expand Up @@ -118,8 +120,6 @@ def download_files_from_huggingface(
return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=[
"*.onnx",
"*.onnx_data",
"config.json",
"tokenizer.json",
"tokenizer_config.json",
Expand Down Expand Up @@ -200,7 +200,7 @@ def retrieve_model_gcs(cls, model_name: str, source_url: str, cache_dir: str) ->
return model_dir

@classmethod
def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
def download_repo_files(cls, model: Dict[str, Any], cache_dir: Path) -> Tuple[Path, SOURCE]:
"""
Downloads a model from HuggingFace Hub or Google Cloud Storage.
Expand Down Expand Up @@ -232,14 +232,14 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
try:
return Path(
cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))
)
), "hf"
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
f"Could not download model from HuggingFace: {e}"
"Falling back to other sources."
)

if url_source:
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir))
return cls.retrieve_model_gcs(model["model"], url_source, str(cache_dir)), "gcs"

raise ValueError(f"Could not download model {model['model']} from any source.")
35 changes: 31 additions & 4 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import numpy as np
import onnxruntime as ort

from fastembed.common.model_management import locate_model_file
from fastembed.common.model_management import SOURCE, locate_model_file
from fastembed.common.models import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool, Worker

from huggingface_hub import hf_hub_download

# Holds type of the embedding result
T = TypeVar("T")

Expand All @@ -34,8 +36,33 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,
"""
return onnx_input

def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: int) -> None:
model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"])
def load_onnx_model(
self,
model_dir: Path,
threads: Optional[int],
cache_dir: Path,
model_description: dict,
source: SOURCE,
) -> None:
if source == "gcs":
model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"])
elif source == "hf":
# For HuggingFace sources, the model file is conditionally downloaded
repo_id = model_description["sources"]["hf"]
model_file = model_description["model_file"]

# Some models require additional repo files.
# For eg: intfloat/multilingual-e5-large requires the model.onnx_data file.
# These can be specified within the "additional_files" option when describing the model properties
if additional_files := model_description.get("additional_files"):
for file in additional_files:
hf_hub_download(repo_id=repo_id, filename=file, cache_dir=str(cache_dir))

model_path = hf_hub_download(
repo_id=repo_id, filename=model_file, cache_dir=str(cache_dir)
)
else:
raise ValueError(f"Unknown source: {source}")

# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
onnx_providers = ["CPUExecutionProvider"]
Expand All @@ -50,7 +77,7 @@ def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: i
so.intra_op_num_threads = threads
so.inter_op_num_threads = threads

self.tokenizer = load_tokenizer(model_dir=model_dir, max_length=max_length)
self.tokenizer = load_tokenizer(model_dir=model_dir)
self.model = ort.InferenceSession(
str(model_path), providers=onnx_providers, sess_options=so
)
Expand Down
26 changes: 16 additions & 10 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
{
"model": "prithivida/Splade_PP_en_v1",
Expand All @@ -24,6 +25,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
]

Expand Down Expand Up @@ -76,15 +78,19 @@ def __init__(
"""

super().__init__(model_name, cache_dir, threads, **kwargs)

self.model_name = model_name
self._model_description = self._get_model_description(model_name)

self._cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(self._model_description, self._cache_dir)
self._max_length = 512

self.load_onnx_model(self._model_dir, self.threads, self._max_length)

model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)

model_dir, source = self.download_repo_files(model_description, cache_dir)

self.load_onnx_model(
model_dir,
threads,
cache_dir,
model_description,
source,
)

def embed(
self,
Expand All @@ -110,7 +116,7 @@ def embed(
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self._cache_dir),
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
Expand Down
3 changes: 3 additions & 0 deletions fastembed/text/e5_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
"hf": "qdrant/multilingual-e5-large-onnx",
},
"model_file": "model.onnx",
"additional_files": ["model.onnx_data"],
},
{
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
Expand All @@ -24,6 +26,7 @@
"sources": {
"hf": "xenova/paraphrase-multilingual-mpnet-base-v2",
},
"model_file": "onnx/model.onnx",
},
]

Expand Down
2 changes: 2 additions & 0 deletions fastembed/text/jina_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.52,
"sources": {"hf": "xenova/jina-embeddings-v2-base-en"},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.12,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
"model_file": "onnx/model.onnx",
},
]

Expand Down
70 changes: 36 additions & 34 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
},
"model_file": "model.onnx",
},
{
"model": "BAAI/bge-base-en-v1.5",
Expand All @@ -26,6 +27,7 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-large-en-v1.5",
Expand All @@ -35,6 +37,7 @@
"sources": {
"hf": "qdrant/bge-large-en-v1.5-onnx",
},
"model_file": "model.onnx",
},
{
"model": "BAAI/bge-small-en",
Expand All @@ -44,18 +47,8 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
},
"model_file": "onnx/model.onnx",
},
# {
# "model": "BAAI/bge-small-en",
# "dim": 384,
# "description": "Fast English model",
# "size_in_GB": 0.2,
# "hf_sources": [],
# "compressed_url_sources": [
# "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz",
# "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz"
# ]
# },
{
"model": "BAAI/bge-small-en-v1.5",
"dim": 384,
Expand All @@ -64,6 +57,7 @@
"sources": {
"hf": "qdrant/bge-small-en-v1.5-onnx-q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-small-zh-v1.5",
Expand All @@ -73,6 +67,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
},
"model_file": "onnx/model.onnx",
},
{
"model": "sentence-transformers/all-MiniLM-L6-v2",
Expand All @@ -83,6 +78,7 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
"hf": "qdrant/all-MiniLM-L6-v2-onnx",
},
"model_file": "model.onnx",
},
{
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
Expand All @@ -92,6 +88,7 @@
"sources": {
"hf": "qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1",
Expand All @@ -101,6 +98,7 @@
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5",
Expand All @@ -110,6 +108,17 @@
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
"dim": 768,
"description": "Quantized 8192 context length english model",
"size_in_GB": 0.13,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model_quantized.onnx",
},
{
"model": "thenlper/gte-large",
Expand All @@ -119,20 +128,8 @@
"sources": {
"hf": "qdrant/gte-large-onnx",
},
"model_file": "model_optimized.onnx",
},
# {
# "model": "sentence-transformers/all-MiniLM-L6-v2",
# "dim": 384,
# "description": "Sentence Transformer model, MiniLM-L6-v2",
# "size_in_GB": 0.09,
# "hf_sources": [
# "qdrant/all-MiniLM-L6-v2-onnx"
# ],
# "compressed_url_sources": [
# "https://storage.googleapis.com/qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz",
# "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz"
# ]
# }
{
"model": "mixedbread-ai/mxbai-embed-large-v1",
"dim": 1024,
Expand All @@ -141,6 +138,7 @@
"sources": {
"hf": "mixedbread-ai/mxbai-embed-large-v1",
},
"model_file": "onnx/model.onnx",
},
]

Expand Down Expand Up @@ -178,15 +176,19 @@ def __init__(
"""

super().__init__(model_name, cache_dir, threads, **kwargs)

self.model_name = model_name
self._model_description = self._get_model_description(model_name)

self._cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(self._model_description, self._cache_dir)
self._max_length = 512

self.load_onnx_model(self._model_dir, self.threads, self._max_length)

model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)

model_dir, source = self.download_repo_files(model_description, cache_dir)

self.load_onnx_model(
model_dir,
threads,
cache_dir,
model_description,
source,
)

def embed(
self,
Expand All @@ -212,7 +214,7 @@ def embed(
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self._cache_dir),
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
),
"thenlper/gte-large": np.array([-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]),
"mixedbread-ai/mxbai-embed-large-v1": np.array([0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]),
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
[-0.04514299, -0.00462366, -0.18909897, -0.0071826, 0.00678478]
),
}


Expand Down

0 comments on commit 49edfa3

Please sign in to comment.