Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Quantized models #201

Merged
merged 7 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions fastembed/common/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
from loguru import logger


def locate_model_file(model_dir: Path, file_names: List[str]) -> Path:
"""
Find model path for both TransformerJS style `onnx` subdirectory structure and direct model weights structure used
by Optimum and Qdrant
"""
if not model_dir.is_dir():
raise ValueError(f"Provided model path '{model_dir}' is not a directory.")

for file_name in file_names:
file_paths = [path for path in model_dir.rglob(file_name) if path.is_file()]

if file_paths:
return file_paths[0]

raise ValueError(f"Could not find either of {', '.join(file_names)} in {model_dir}")


class ModelManagement:
@classmethod
def list_supported_models(cls) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -104,27 +87,33 @@ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool

@classmethod
def download_files_from_huggingface(
cls, hf_source_repo: str, cache_dir: Optional[str] = None
cls,
hf_source_repo: str,
cache_dir: Optional[str] = None,
extra_patterns: Optional[List[str]] = None,
) -> str:
"""
Downloads a model from HuggingFace Hub.
Args:
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
cache_dir (Optional[str]): The path to the cache directory.
extra_patterns (Optional[List[str]]): extra patterns to allow in the snapshot download, typically
includes the required model files.
Returns:
Path: The path to the model directory.
"""
allow_patterns = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
]
if extra_patterns is not None:
allow_patterns.extend(extra_patterns)

return snapshot_download(
repo_id=hf_source_repo,
allow_patterns=[
"*.onnx",
"*.onnx_data",
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
],
allow_patterns=allow_patterns,
cache_dir=cache_dir,
)

Expand Down Expand Up @@ -229,9 +218,15 @@ def download_model(cls, model: Dict[str, Any], cache_dir: Path) -> Path:
url_source = model.get("sources", {}).get("url")

if hf_source:
extra_patterns = []
extra_patterns.extend([model["model_file"]])
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
extra_patterns.extend(model.get("additional_files", []))

try:
return Path(
cls.download_files_from_huggingface(hf_source, cache_dir=str(cache_dir))
cls.download_files_from_huggingface(
hf_source, cache_dir=str(cache_dir), extra_patterns=extra_patterns
)
)
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
logger.error(
Expand Down
13 changes: 9 additions & 4 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import numpy as np
import onnxruntime as ort

from fastembed.common.model_management import locate_model_file
from fastembed.common.models import load_tokenizer
from fastembed.common.utils import iter_batch
from fastembed.parallel_processor import ParallelWorkerPool, Worker


# Holds type of the embedding result
T = TypeVar("T")

Expand All @@ -34,8 +34,13 @@ def _preprocess_onnx_input(self, onnx_input: Dict[str, np.ndarray]) -> Dict[str,
"""
return onnx_input

def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: int) -> None:
model_path = locate_model_file(model_dir, ["model.onnx", "model_optimized.onnx"])
def load_onnx_model(
self,
model_dir: Path,
model_file: str,
threads: Optional[int],
) -> None:
model_path = model_dir / model_file

# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
onnx_providers = ["CPUExecutionProvider"]
Expand All @@ -50,7 +55,7 @@ def load_onnx_model(self, model_dir: Path, threads: Optional[int], max_length: i
so.intra_op_num_threads = threads
so.inter_op_num_threads = threads

self.tokenizer = load_tokenizer(model_dir=model_dir, max_length=max_length)
self.tokenizer = load_tokenizer(model_dir=model_dir)
self.model = ort.InferenceSession(
str(model_path), providers=onnx_providers, sess_options=so
)
Expand Down
18 changes: 11 additions & 7 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
{
"model": "prithivida/Splade_PP_en_v1",
Expand All @@ -24,6 +25,7 @@
"sources": {
"hf": "Qdrant/SPLADE_PP_en_v1",
},
"model_file": "model.onnx",
},
]

Expand Down Expand Up @@ -77,14 +79,16 @@ def __init__(

super().__init__(model_name, cache_dir, threads, **kwargs)

self.model_name = model_name
self._model_description = self._get_model_description(model_name)
model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)

self._cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(self._model_description, self._cache_dir)
self._max_length = 512
model_dir = self.download_model(model_description, cache_dir)

self.load_onnx_model(self._model_dir, self.threads, self._max_length)
self.load_onnx_model(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
)

def embed(
self,
Expand All @@ -110,7 +114,7 @@ def embed(
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self._cache_dir),
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
Expand Down
3 changes: 3 additions & 0 deletions fastembed/text/e5_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz",
"hf": "qdrant/multilingual-e5-large-onnx",
},
"model_file": "model.onnx",
"additional_files": ["model.onnx_data"],
},
{
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
Expand All @@ -24,6 +26,7 @@
"sources": {
"hf": "xenova/paraphrase-multilingual-mpnet-base-v2",
},
"model_file": "onnx/model.onnx",
},
]

Expand Down
2 changes: 2 additions & 0 deletions fastembed/text/jina_onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.52,
"sources": {"hf": "xenova/jina-embeddings-v2-base-en"},
"model_file": "onnx/model.onnx",
},
{
"model": "jinaai/jina-embeddings-v2-small-en",
"dim": 512,
"description": "English embedding model supporting 8192 sequence length",
"size_in_GB": 0.12,
"sources": {"hf": "xenova/jina-embeddings-v2-small-en"},
"model_file": "onnx/model.onnx",
},
]

Expand Down
63 changes: 31 additions & 32 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz",
},
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-base-en-v1.5",
Expand All @@ -26,6 +27,7 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-large-en-v1.5",
Expand All @@ -35,6 +37,7 @@
"sources": {
"hf": "qdrant/bge-large-en-v1.5-onnx",
},
"model_file": "model.onnx",
},
{
"model": "BAAI/bge-small-en",
Expand All @@ -44,18 +47,8 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz",
},
"model_file": "model_optimized.onnx",
},
# {
# "model": "BAAI/bge-small-en",
# "dim": 384,
# "description": "Fast English model",
# "size_in_GB": 0.2,
# "hf_sources": [],
# "compressed_url_sources": [
# "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz",
# "https://storage.googleapis.com/qdrant-fastembed/BAAI-bge-small-en.tar.gz"
# ]
# },
{
"model": "BAAI/bge-small-en-v1.5",
"dim": 384,
Expand All @@ -64,6 +57,7 @@
"sources": {
"hf": "qdrant/bge-small-en-v1.5-onnx-q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "BAAI/bge-small-zh-v1.5",
Expand All @@ -73,6 +67,7 @@
"sources": {
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz",
},
"model_file": "model_optimized.onnx",
},
{
"model": "sentence-transformers/all-MiniLM-L6-v2",
Expand All @@ -83,6 +78,7 @@
"url": "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
"hf": "qdrant/all-MiniLM-L6-v2-onnx",
},
"model_file": "model.onnx",
},
{
"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
Expand All @@ -92,6 +88,7 @@
"sources": {
"hf": "qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q",
},
"model_file": "model_optimized.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1",
Expand All @@ -101,6 +98,7 @@
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5",
Expand All @@ -110,6 +108,17 @@
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model.onnx",
},
{
"model": "nomic-ai/nomic-embed-text-v1.5-Q",
"dim": 768,
"description": "Quantized 8192 context length english model",
"size_in_GB": 0.13,
"sources": {
"hf": "nomic-ai/nomic-embed-text-v1.5",
},
"model_file": "onnx/model_quantized.onnx",
},
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
{
"model": "thenlper/gte-large",
Expand All @@ -119,20 +128,8 @@
"sources": {
"hf": "qdrant/gte-large-onnx",
},
"model_file": "model.onnx",
},
# {
# "model": "sentence-transformers/all-MiniLM-L6-v2",
# "dim": 384,
# "description": "Sentence Transformer model, MiniLM-L6-v2",
# "size_in_GB": 0.09,
# "hf_sources": [
# "qdrant/all-MiniLM-L6-v2-onnx"
# ],
# "compressed_url_sources": [
# "https://storage.googleapis.com/qdrant-fastembed/fast-all-MiniLM-L6-v2.tar.gz",
# "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz"
# ]
# }
{
"model": "mixedbread-ai/mxbai-embed-large-v1",
"dim": 1024,
Expand All @@ -141,6 +138,7 @@
"sources": {
"hf": "mixedbread-ai/mxbai-embed-large-v1",
},
"model_file": "onnx/model.onnx",
},
]

Expand Down Expand Up @@ -179,14 +177,15 @@ def __init__(

super().__init__(model_name, cache_dir, threads, **kwargs)

self.model_name = model_name
self._model_description = self._get_model_description(model_name)
model_description = self._get_model_description(model_name)
cache_dir = define_cache_dir(cache_dir)
model_dir = self.download_model(model_description, cache_dir)

self._cache_dir = define_cache_dir(cache_dir)
self._model_dir = self.download_model(self._model_description, self._cache_dir)
self._max_length = 512

self.load_onnx_model(self._model_dir, self.threads, self._max_length)
self.load_onnx_model(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
)

def embed(
self,
Expand All @@ -212,7 +211,7 @@ def embed(
"""
yield from self._embed_documents(
model_name=self.model_name,
cache_dir=str(self._cache_dir),
cache_dir=str(self.cache_dir),
documents=documents,
batch_size=batch_size,
parallel=parallel,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
),
"thenlper/gte-large": np.array([-0.01920587, 0.00113156, -0.00708992, -0.00632304, -0.04025577]),
"mixedbread-ai/mxbai-embed-large-v1": np.array([0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]),
"nomic-ai/nomic-embed-text-v1.5-Q": np.array(
[-0.01554983, 0.0129992 , -0.17909265, -0.01062993, 0.00512859]
),
}


Expand Down