Skip to content

Commit

Permalink
new: allow users to override providers (#214)
Browse files Browse the repository at this point in the history
* new: add gpu support, allow users to override providers

* fix: update poetry.lock

* fix: fix type hint for 3.8

* [readme] Remove similar work

* [README] Add GPU support for FastEmbed library

* [README]  Add device check

* fix: revert changes to pyproject and lock, update readme

* Update poetry.lock

* new: add type alias for providers, add explicit providers to embeddings

---------

Co-authored-by: Nirant Kasliwal <[email protected]>
Co-authored-by: Nirant <[email protected]>
  • Loading branch information
3 people authored May 3, 2024
1 parent da603b8 commit d8c5920
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 17 deletions.
39 changes: 32 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ The default text embedding (`TextEmbedding`) model is Flag Embedding, presented

## 🚀 Installation

To install the FastEmbed library, pip works:
To install the FastEmbed library, pip works best. You can install it with or without GPU support:

```bash
pip install fastembed
```

### ⚡️ With GPU

```bash
pip install fastembed-gpu
```

## 📖 Quickstart

```python
Expand All @@ -42,6 +48,23 @@ embeddings_list = list(embedding_model.embed(documents))
len(embeddings_list[0]) # Vector of 384 dimensions
```

### ⚡️ FastEmbed on a GPU

FastEmbed supports running on GPU devices. It requires installation of the `fastembed-gpu` package.
Make sure not to have the `fastembed` package installed, as it might interfere with the `fastembed-gpu` package.

```bash
pip install fastembed-gpu
```

```python
from fastembed import TextEmbedding

embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5", providers=["CUDAExecutionProvider"])
print("The model BAAI/bge-small-en-v1.5 is ready to use on a GPU.")

```

## Usage with Qdrant

Installation with Qdrant Client in Python:
Expand All @@ -50,7 +73,13 @@ Installation with Qdrant Client in Python:
pip install qdrant-client[fastembed]
```

You might have to use ```pip install 'qdrant-client[fastembed]'``` on zsh.
or

```bash
pip install qdrant-client[fastembed-gpu]
```

You might have to use quotes ```pip install 'qdrant-client[fastembed]'``` on zsh.

```python
from qdrant_client import QdrantClient
Expand Down Expand Up @@ -85,8 +114,4 @@ search_result = client.query(
query_text="This is a query document"
)
print(search_result)
```

#### Similar Work

Ilyas M. wrote about using [FlagEmbeddings with Optimum](https://twitter.com/IlysMoutawwakil/status/1705215192425288017) over CUDA.
```
3 changes: 3 additions & 0 deletions fastembed/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fastembed.common.onnx_model import OnnxProvider

__all__ = ["OnnxProvider"]
28 changes: 26 additions & 2 deletions fastembed/common/onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import os
from multiprocessing import get_all_start_methods
from pathlib import Path
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Sequence,
)

import numpy as np
import onnxruntime as ort
Expand All @@ -14,6 +26,8 @@
# Holds type of the embedding result
T = TypeVar("T")

OnnxProvider = Union[str, Tuple[str, Dict[Any, Any]]]


class OnnxModel(Generic[T]):
@classmethod
Expand All @@ -39,11 +53,21 @@ def load_onnx_model(
model_dir: Path,
model_file: str,
threads: Optional[int],
providers: Optional[Sequence[OnnxProvider]] = None,
) -> None:
model_path = model_dir / model_file

# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
onnx_providers = ["CPUExecutionProvider"]

onnx_providers = ["CPUExecutionProvider"] if providers is None else list(providers)
available_providers = ort.get_available_providers()
for provider in onnx_providers:
# check providers available
provider_name = provider if isinstance(provider, str) else provider[0]
if provider_name not in available_providers:
raise ValueError(
f"Provider {provider_name} is not available. Available providers: {available_providers}"
)

so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
Expand Down
8 changes: 6 additions & 2 deletions fastembed/sparse/sparse_text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Type, Dict, Any, Union, Iterable, Optional
from typing import List, Type, Dict, Any, Union, Iterable, Optional, Sequence

from fastembed.common import OnnxProvider
from fastembed.sparse.sparse_embedding_base import SparseTextEmbeddingBase, SparseEmbedding
from fastembed.sparse.splade_pp import SpladePP

Expand Down Expand Up @@ -42,14 +43,17 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs)
self.model = EMBEDDING_MODEL_TYPE(
model_name, cache_dir, threads, providers=providers, **kwargs
)
return

raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions fastembed/sparse/splade_pp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type, Sequence

import numpy as np

from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxProvider
from fastembed.common.utils import define_cache_dir
from fastembed.sparse.sparse_embedding_base import SparseEmbedding, SparseTextEmbeddingBase

Expand Down Expand Up @@ -63,6 +63,7 @@ def __init__(
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
)

def embed(
Expand Down
6 changes: 4 additions & 2 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any
from typing import Dict, Optional, Tuple, Union, Iterable, Type, List, Any, Sequence

import numpy as np

from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker
from fastembed.common.onnx_model import OnnxModel, EmbeddingWorker, OnnxProvider
from fastembed.common.models import normalize
from fastembed.common.utils import define_cache_dir
from fastembed.text.text_embedding_base import TextEmbeddingBase
Expand Down Expand Up @@ -211,6 +211,7 @@ def __init__(
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
"""
Expand All @@ -237,6 +238,7 @@ def __init__(
model_dir=model_dir,
model_file=model_description["model_file"],
threads=threads,
providers=providers,
)

def embed(
Expand Down
8 changes: 6 additions & 2 deletions fastembed/text/text_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Type, Union, Sequence

import numpy as np

from fastembed.common import OnnxProvider
from fastembed.text.e5_onnx_embedding import E5OnnxEmbedding
from fastembed.text.jina_onnx_embedding import JinaOnnxEmbedding
from fastembed.text.onnx_embedding import OnnxTextEmbedding
Expand Down Expand Up @@ -49,14 +50,17 @@ def __init__(
model_name: str = "BAAI/bge-small-en-v1.5",
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence[OnnxProvider]] = None,
**kwargs,
):
super().__init__(model_name, cache_dir, threads, **kwargs)

for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models()
if any(model_name.lower() == model["model"].lower() for model in supported_models):
self.model = EMBEDDING_MODEL_TYPE(model_name, cache_dir, threads, **kwargs)
self.model = EMBEDDING_MODEL_TYPE(
model_name, cache_dir, threads, providers=providers, **kwargs
)
return

raise ValueError(
Expand Down

0 comments on commit d8c5920

Please sign in to comment.