-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip: init image embeddings * new: add clip * new: fix clip text embedding * fix: fix image parallel * fix: add test images * fix: add PIL * fix: fix generics * fix: fix sparse worker * fix: fix image test path * fix: replace models repo * new: follow-up for onnx providers and local_files_only option * fix: add types, refactor a bit * refactoring: move onnxprovider type alias to types * fix: fix type alias import
- Loading branch information
Showing
28 changed files
with
920 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,15 @@ | ||
import importlib.metadata | ||
|
||
from fastembed.image import ImageEmbedding | ||
from fastembed.text import TextEmbedding | ||
from fastembed.sparse import SparseTextEmbedding, SparseEmbedding | ||
|
||
|
||
try: | ||
version = importlib.metadata.version("fastembed") | ||
except importlib.metadata.PackageNotFoundError as _: | ||
version = importlib.metadata.version("fastembed-gpu") | ||
|
||
__version__ = version | ||
__all__ = ["TextEmbedding", "SparseTextEmbedding", "SparseEmbedding"] | ||
__all__ = ["TextEmbedding", "SparseTextEmbedding", "SparseEmbedding", "ImageEmbedding"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from fastembed.common.onnx_model import OnnxProvider | ||
from fastembed.common.types import OnnxProvider, ImageInput, PathInput | ||
|
||
__all__ = ["OnnxProvider"] | ||
__all__ = ["OnnxProvider", "ImageInput", "PathInput"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
import sys | ||
from typing import Union, Iterable, Tuple, Dict, Any | ||
|
||
if sys.version_info >= (3, 10): | ||
from typing import TypeAlias | ||
else: | ||
from typing_extensions import TypeAlias | ||
|
||
|
||
PathInput: TypeAlias = Union[str, os.PathLike] | ||
ImageInput: TypeAlias = Union[PathInput, Iterable[PathInput]] | ||
|
||
OnnxProvider: TypeAlias = Union[str, Tuple[str, Dict[Any, Any]]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from fastembed.image.image_embedding import ImageEmbedding | ||
|
||
|
||
__all__ = ["ImageEmbedding"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Any, Dict, Iterable, List, Optional, Type, Sequence | ||
|
||
import numpy as np | ||
|
||
from fastembed.common import ImageInput, OnnxProvider | ||
from fastembed.image.image_embedding_base import ImageEmbeddingBase | ||
from fastembed.image.onnx_embedding import OnnxImageEmbedding | ||
|
||
|
||
class ImageEmbedding(ImageEmbeddingBase): | ||
EMBEDDINGS_REGISTRY: List[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding] | ||
|
||
@classmethod | ||
def list_supported_models(cls) -> List[Dict[str, Any]]: | ||
""" | ||
Lists the supported models. | ||
Returns: | ||
List[Dict[str, Any]]: A list of dictionaries containing the model information. | ||
Example: | ||
``` | ||
[ | ||
{ | ||
"model": "Qdrant/clip-ViT-B-32-vision", | ||
"dim": 512, | ||
"description": "CLIP vision encoder based on ViT-B/32", | ||
"size_in_GB": 0.33, | ||
"sources": { | ||
"hf": "Qdrant/clip-ViT-B-32-vision", | ||
}, | ||
"model_file": "model.onnx", | ||
} | ||
] | ||
``` | ||
""" | ||
result = [] | ||
for embedding in cls.EMBEDDINGS_REGISTRY: | ||
result.extend(embedding.list_supported_models()) | ||
return result | ||
|
||
def __init__( | ||
self, | ||
model_name: str, | ||
cache_dir: Optional[str] = None, | ||
threads: Optional[int] = None, | ||
providers: Optional[Sequence[OnnxProvider]] = None, | ||
**kwargs, | ||
): | ||
super().__init__(model_name, cache_dir, threads, **kwargs) | ||
|
||
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: | ||
supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() | ||
if any(model_name.lower() == model["model"].lower() for model in supported_models): | ||
self.model = EMBEDDING_MODEL_TYPE( | ||
model_name, cache_dir, threads, providers=providers, **kwargs | ||
) | ||
return | ||
|
||
raise ValueError( | ||
f"Model {model_name} is not supported in TextEmbedding." | ||
"Please check the supported models using `TextEmbedding.list_supported_models()`" | ||
) | ||
|
||
def embed( | ||
self, | ||
images: ImageInput, | ||
batch_size: int = 16, | ||
parallel: Optional[int] = None, | ||
**kwargs, | ||
) -> Iterable[np.ndarray]: | ||
""" | ||
Encode a list of documents into list of embeddings. | ||
We use mean pooling with attention so that the model can handle variable-length inputs. | ||
Args: | ||
images: Iterator of image paths or single image path to embed | ||
batch_size: Batch size for encoding -- higher values will use more memory, but be faster | ||
parallel: | ||
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. | ||
If 0, use all available cores. | ||
If None, don't use data-parallel processing, use default onnxruntime threading instead. | ||
Returns: | ||
List of embeddings, one per document | ||
""" | ||
yield from self.model.embed(images, batch_size, parallel, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Iterable, Optional | ||
|
||
import numpy as np | ||
|
||
from fastembed.common.model_management import ModelManagement | ||
from fastembed.common.types import ImageInput | ||
|
||
|
||
class ImageEmbeddingBase(ModelManagement): | ||
def __init__( | ||
self, | ||
model_name: str, | ||
cache_dir: Optional[str] = None, | ||
threads: Optional[int] = None, | ||
**kwargs, | ||
): | ||
self.model_name = model_name | ||
self.cache_dir = cache_dir | ||
self.threads = threads | ||
self._local_files_only = kwargs.pop("local_files_only", False) | ||
|
||
def embed( | ||
self, | ||
images: ImageInput, | ||
batch_size: int = 16, | ||
parallel: Optional[int] = None, | ||
**kwargs, | ||
) -> Iterable[np.ndarray]: | ||
""" | ||
Embeds a list of images into a list of embeddings. | ||
Args: | ||
images - The list of image paths to preprocess and embed. | ||
**kwargs: Additional keyword argument to pass to the embed method. | ||
Yields: | ||
Iterable[np.ndarray]: The embeddings. | ||
""" | ||
raise NotImplementedError() |
Oops, something went wrong.