diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 70da2a22..380782f7 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -96,7 +96,7 @@ def normalize( def resize( - image: Image, + image: Image.Image, size: Union[int, tuple[int, int]], resample: Image.Resampling = Image.Resampling.BILINEAR, ) -> Image: diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 2b943dbb..494fa0d0 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -116,7 +116,7 @@ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_resize(transforms: list[Transform], config: dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor": if config.get("do_resize", False): size = config["size"] if "shortest_edge" in size: @@ -161,7 +161,7 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): @staticmethod def _get_center_crop(transforms: list[Transform], config: dict[str, Any]): mode = config.get("image_processor_type", "CLIPImageProcessor") - if mode == "CLIPImageProcessor": + if mode == "CLIPImageProcessor" or mode == "SiglipImageProcessor": if config.get("do_center_crop", False): crop_size = config["crop_size"] if isinstance(crop_size, int): diff --git a/fastembed/late_interaction/colbert.py b/fastembed/late_interaction/colbert.py index 4d65fc29..27de10ad 100644 --- a/fastembed/late_interaction/colbert.py +++ b/fastembed/late_interaction/colbert.py @@ -7,7 +7,7 @@ from fastembed.common import OnnxProvider from fastembed.common.onnx_model import OnnxOutputContext from fastembed.common.utils import define_cache_dir -from fastembed.late_interaction.late_interaction_embedding_base import ( +from fastembed.late_interaction.late_interaction_text_embedding_base import ( LateInteractionTextEmbeddingBase, ) from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker diff --git a/fastembed/late_interaction/colpali.py b/fastembed/late_interaction/colpali.py new file mode 100644 index 00000000..63c5eb90 --- /dev/null +++ b/fastembed/late_interaction/colpali.py @@ -0,0 +1,364 @@ +from typing import Any, Iterable, Optional, Sequence, Union, List, Dict, Type + +import numpy as np +from sys import maxsize +from fastembed.common import OnnxProvider +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.image.onnx_image_model import OnnxImageModel +from fastembed.late_interaction.late_interaction_image_embedding_base import ( + LateInteractionImageEmbeddingBase, +) +from PIL import Image +from fastembed.text.onnx_text_model import OnnxTextModel, TextEmbeddingWorker +import contextlib +from fastembed.common import ImageInput +from fastembed.common.preprocessor_utils import load_preprocessor + + +supported_colpali_models = [ + { + "model": "akshayballal/colpali-v1.2-merged", + "dim": 128, + "description": "Text embeddings, Unimodal (text), Aligned to image latent space, ColBERT-compatible, 512 tokens max, 2024.", + "license": "mit", + "size_in_GB": 6.08, + "sources": { + "hf": "akshayballal/colpali-v1.2-merged-onnx", + }, + "additional_files": [ + "model.onnx_data", + "tokenizer.json", + "tokenizer_config.json", + "config.json", + ], + "model_file": "model.onnx", + } +] + + +class ColPali( + LateInteractionImageEmbeddingBase, OnnxTextModel[np.ndarray], OnnxImageModel[np.array] +): + DOCUMENT_MARKER_TOKEN_ID = 2 + QUERY_PREFIX = "Query: " + BOS_TOKEN = "" + PAD_TOKEN = "" + QUERY_MARKER_TOKEN_ID = [2, 9413] + image_placeholder_size = (3, 448, 448) + EMPTY_TEXT_PLACEHOLDER = np.array([257152] * 1024 + [2, 50721, 573, 2416, 235265, 108]) + EVEN_ATTENTION_MASK = np.array([1] * 1030) + + def _post_process_onnx_output( + self, + output: OnnxOutputContext, + ) -> Iterable[np.ndarray]: + """ + Post-process the ONNX model output to convert it into a usable format. + + Args: + output (OnnxOutputContext): The raw output from the ONNX model. + + Returns: + Iterable[np.ndarray]: Post-processed output as NumPy arrays. + """ + return output.model_output.astype(np.float32) + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """ + Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + """ + return supported_colpali_models + + def _preprocess_queries(self, documents: list[str]) -> list[str]: + """ + Preprocess the input text queries by adding special tokens and padding. + + Args: + documents (list[str]): List of text queries. + + Returns: + list[str]: Preprocessed text queries. + """ + texts_query: list[str] = [] + + for query in documents: + query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10 + query += "\n" + + texts_query.append(query) + return texts_query + + def _preprocess_images_input( + self, inputs: list[Union[ImageInput]], **kwargs: Any + ) -> dict[str, np.ndarray]: + """ + Preprocess the input images for ONNX model inference. + + Args: + inputs (list[Union[ImageInput]]): List of image inputs. + **kwargs: Additional preprocessing arguments. + + Returns: + dict[str, np.ndarray]: Preprocessed image inputs as a dictionary. + """ + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in inputs + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_input(encoded) + onnx_input = self._preprocess_image_input(onnx_input, **kwargs) + return onnx_input + + def embed( + self, + inputs: Union[ImageInput, str], + batch_size: int = 16, + parallel: Optional[int] = None, + is_doc: bool = False, + **kwargs, + ) -> OnnxOutputContext: + """ + Generate embeddings for the given input, either images or text. + + Args: + inputs (Union[ImageInput, str]): Input data (images or text). + batch_size (int, optional): Batch size for embedding. Defaults to 16. + parallel (Optional[int], optional): Number of parallel threads. Defaults to None. + is_doc (bool, optional): Indicates if input is a document. Defaults to False. + **kwargs: Additional arguments for embedding. + + Yields: + OnnxOutputContext: Embedding output context. + """ + if is_doc: + yield from self._embed_documents( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + documents=inputs, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + else: + # onnx_input = self._preprocess_images_input(inputs, **kwargs) + yield from self._embed_images( + model_name=self.model_name, + cache_dir=str(self.cache_dir), + images=inputs, + batch_size=batch_size, + parallel=parallel, + providers=self.providers, + cuda=self.cuda, + device_ids=self.device_ids, + **kwargs, + ) + + def onnx_embed(self, inputs: Union[ImageInput, str], **kwargs) -> OnnxOutputContext: + """ + Embed inputs using the ONNX model. + + Args: + inputs (Union[ImageInput, str]): Input data (images or text). + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context. + """ + if isinstance(inputs[0], str): + return self.onnx_embed_text(inputs, **kwargs) + else: + return self.onnx_embed_image(inputs, **kwargs) + + def onnx_embed_image(self, images: List[ImageInput], **kwargs) -> OnnxOutputContext: + """ + Embed images using the ONNX model. + + Args: + images (List[ImageInput]): List of image inputs. + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context for images. + """ + with contextlib.ExitStack(): + image_files = [ + Image.open(image) if not isinstance(image, Image.Image) else image + for image in images + ] + encoded = self.processor(image_files) + onnx_input = self._build_onnx_input(encoded) + onnx_input = self._preprocess_onnx_image_input(onnx_input) + model_output = self.model.run(None, onnx_input) + embeddings = model_output[0].reshape(len(images), -1, self.model_description["dim"]) + return OnnxOutputContext(model_output=embeddings) + + def onnx_embed_text( + self, + documents: List[str], + **kwargs, + ) -> OnnxOutputContext: + """ + Embed text using the ONNX model. + + Args: + documents (List[str]): List of text documents. + **kwargs: Additional arguments for embedding. + + Returns: + OnnxOutputContext: Embedding output context for text. + """ + documents = self._preprocess_queries(documents) + encoded = self.tokenize(documents, **kwargs) + input_ids = np.array([self.QUERY_MARKER_TOKEN_ID + e.ids[2:] for e in encoded]) + + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = {"input_ids": np.array(input_ids, dtype=np.int64)} + onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs) + onnx_input["attention_mask"] = attention_mask + model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) + return OnnxOutputContext( + model_output=model_output[0], + attention_mask=onnx_input.get("attention_mask", attention_mask), + input_ids=onnx_input.get("input_ids", input_ids), + ) + + def _preprocess_onnx_image_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + """ + Add placeholders for text input when processing image data for ONNX. + + Args: + onnx_input (Dict[str, np.ndarray]): Preprocessed image inputs. + **kwargs: Additional arguments. + + Returns: + Dict[str, np.ndarray]: ONNX input with text placeholders. + """ + onnx_input["input_ids"] = np.array( + [self.EMPTY_TEXT_PLACEHOLDER for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array( + [self.EVEN_ATTENTION_MASK for _ in onnx_input["input_ids"]] + ) + return onnx_input + + def _preprocess_onnx_text_input( + self, onnx_input: Dict[str, np.ndarray], **kwargs + ) -> Dict[str, np.ndarray]: + """ + Add placeholders for image input when processing text data for ONNX. + + Args: + onnx_input (Dict[str, np.ndarray]): Preprocessed text inputs. + **kwargs: Additional arguments. + + Returns: + Dict[str, np.ndarray]: ONNX input with image placeholders. + """ + empty_image_placeholder = np.zeros(self.image_placeholder_size, dtype=np.float32) + onnx_input["pixel_values"] = np.array( + [empty_image_placeholder for _ in onnx_input["input_ids"]] + ) + onnx_input["attention_mask"] = np.array([[1] for _ in onnx_input["input_ids"]]) + return onnx_input + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + device_id: Optional[int] = None, + **kwargs, + ): + """ + Initialize the ColPali model. + + Args: + model_name (str): Name of the model to use. + cache_dir (Optional[str], optional): Directory for caching model files. Defaults to None. + threads (Optional[int], optional): Number of threads for inference. Defaults to None. + providers (Optional[Sequence[OnnxProvider]], optional): ONNX providers for model execution. Defaults to None. + cuda (bool, optional): Whether to use CUDA for inference. Defaults to False. + device_ids (Optional[list[int]], optional): List of CUDA device IDs. Defaults to None. + lazy_load (bool, optional): Whether to lazily load the model. Defaults to False. + device_id (Optional[int], optional): Specific device ID to use. Defaults to None. + **kwargs: Additional arguments for model initialization. + """ + super().__init__(model_name, cache_dir, threads, **kwargs) + self.model_description = self._get_model_description(model_name) + self._model_dir = self.download_model( + self.model_description, self.cache_dir, local_files_only=self._local_files_only + ) + self.providers = providers + self.lazy_load = lazy_load + self.cuda = cuda + self.device_ids = device_ids + if device_id is not None: + self.device_id = device_id + elif self.device_ids is not None: + self.device_id = self.device_ids[0] + else: + self.device_id = None + if not self.lazy_load: + self.load_onnx_model() + + self.processor = load_preprocessor(model_dir=self._model_dir) + + def load_onnx_model(self) -> None: + """ + Load the ONNX model for inference. + """ + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description["model_file"], + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + ) + self.tokenizer.enable_truncation(max_length=maxsize) + + @classmethod + def _get_worker_class(cls) -> Type[TextEmbeddingWorker]: + """ + Get the worker class for text/image embedding. + + Returns: + Type[TextEmbeddingWorker]: The worker class. + """ + return ColPaliEmbeddingWorker + + +class ColPaliEmbeddingWorker(TextEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs) -> ColPali: + """ + Initialize the ColPali embedding worker. + + Args: + model_name (str): Name of the model to use. + cache_dir (str): Directory for caching model files. + **kwargs: Additional arguments for initialization. + + Returns: + ColPali: Initialized ColPali model instance. + """ + return ColPali( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/fastembed/late_interaction/late_interaction_image_embedding.py b/fastembed/late_interaction/late_interaction_image_embedding.py new file mode 100644 index 00000000..748d861e --- /dev/null +++ b/fastembed/late_interaction/late_interaction_image_embedding.py @@ -0,0 +1,113 @@ +from typing import Any, Iterable, Optional, Sequence, Type, Union + +import numpy as np + +from fastembed.common import OnnxProvider +from fastembed.late_interaction.colpali import ColPali +from fastembed.late_interaction.late_interaction_image_embedding_base import ( + LateInteractionImageEmbeddingBase, +) + + +class LateInteractionImageEmbedding(LateInteractionImageEmbeddingBase): + EMBEDDINGS_REGISTRY: list[Type[LateInteractionImageEmbeddingBase]] = [ColPali] + + @classmethod + def list_supported_models(cls) -> list[dict[str, Any]]: + """ + Lists the supported models. + + Returns: + list[dict[str, Any]]: A list of dictionaries containing the model information. + + Example: + ``` + [ + { + "model": "colbert-ir/colbertv2.0", + "dim": 128, + "description": "Late interaction model", + "license": "mit", + "size_in_GB": 0.44, + "sources": { + "hf": "colbert-ir/colbertv2.0", + }, + "model_file": "model.onnx", + }, + ] + ``` + """ + result = [] + for embedding in cls.EMBEDDINGS_REGISTRY: + result.extend(embedding.list_supported_models()) + return result + + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + providers: Optional[Sequence[OnnxProvider]] = None, + cuda: bool = False, + device_ids: Optional[list[int]] = None, + lazy_load: bool = False, + **kwargs, + ): + super().__init__(model_name, cache_dir, threads, **kwargs) + for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY: + supported_models = EMBEDDING_MODEL_TYPE.list_supported_models() + if any(model_name.lower() == model["model"].lower() for model in supported_models): + self.model = EMBEDDING_MODEL_TYPE( + model_name, + cache_dir, + threads=threads, + providers=providers, + cuda=cuda, + device_ids=device_ids, + lazy_load=lazy_load, + **kwargs, + ) + return + + raise ValueError( + f"Model {model_name} is not supported in LateInteractionTextEmbedding." + "Please check the supported models using `LateInteractionTextEmbedding.list_supported_models()`" + ) + + def embed( + self, + documents: Union[str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + **kwargs, + ) -> Iterable[np.ndarray]: + """ + Encode a list of documents into list of embeddings. + We use mean pooling with attention so that the model can handle variable-length inputs. + + Args: + documents: Iterator of documents or single document to embed + batch_size: Batch size for encoding -- higher values will use more memory, but be faster + parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + + Returns: + List of embeddings, one per document + """ + yield from self.model.embed(documents, batch_size, parallel, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + yield from self.model.query_embed(query, **kwargs) diff --git a/fastembed/late_interaction/late_interaction_image_embedding_base.py b/fastembed/late_interaction/late_interaction_image_embedding_base.py new file mode 100644 index 00000000..e4ae8996 --- /dev/null +++ b/fastembed/late_interaction/late_interaction_image_embedding_base.py @@ -0,0 +1,62 @@ +from typing import Iterable, Optional, Union + +import numpy as np + +from fastembed.common.model_management import ModelManagement +from fastembed.common.types import ImageInput + + +class LateInteractionImageEmbeddingBase(ModelManagement): + def __init__( + self, + model_name: str, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + self.cache_dir = cache_dir + self.threads = threads + self._local_files_only = kwargs.pop("local_files_only", False) + + def embed( + self, + images: Union[ImageInput, Iterable[ImageInput], str, Iterable[str]], + batch_size: int = 256, + parallel: Optional[int] = None, + is_doc: bool = False, + **kwargs, + ) -> Iterable[np.ndarray]: + raise NotImplementedError() + + def image_embed(self, images: Iterable[ImageInput], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds a list of image passages into a list of embeddings. + + Args: + images (Iterable[str]): The list of images to embed. + **kwargs: Additional keyword argument to pass to the embed method. + + Yields: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + yield from self.embed(images, **kwargs) + + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: + """ + Embeds queries + + Args: + query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries. + + Returns: + Iterable[np.ndarray]: The embeddings. + """ + + # This is model-specific, so that different models can have specialized implementations + if isinstance(query, str): + yield from self.embed([query], is_doc=True, **kwargs) + if isinstance(query, Iterable): + yield from self.embed(query, is_doc=True, **kwargs) diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 58c88411..dc7a719c 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -5,7 +5,7 @@ from fastembed.common import OnnxProvider from fastembed.late_interaction.colbert import Colbert from fastembed.late_interaction.jina_colbert import JinaColbert -from fastembed.late_interaction.late_interaction_embedding_base import ( +from fastembed.late_interaction.late_interaction_text_embedding_base import ( LateInteractionTextEmbeddingBase, ) diff --git a/fastembed/late_interaction/late_interaction_embedding_base.py b/fastembed/late_interaction/late_interaction_text_embedding_base.py similarity index 94% rename from fastembed/late_interaction/late_interaction_embedding_base.py rename to fastembed/late_interaction/late_interaction_text_embedding_base.py index 64fba498..2a587e01 100644 --- a/fastembed/late_interaction/late_interaction_embedding_base.py +++ b/fastembed/late_interaction/late_interaction_text_embedding_base.py @@ -42,9 +42,7 @@ def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]: # This is model-specific, so that different models can have specialized implementations yield from self.embed(texts, **kwargs) - def query_embed( - self, query: Union[str, Iterable[str]], **kwargs - ) -> Iterable[np.ndarray]: + def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]: """ Embeds queries diff --git a/fastembed/text/onnx_text_model.py b/fastembed/text/onnx_text_model.py index 95301985..ba3e1516 100644 --- a/fastembed/text/onnx_text_model.py +++ b/fastembed/text/onnx_text_model.py @@ -79,7 +79,6 @@ def onnx_embed( onnx_input["token_type_ids"] = np.array( [np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64 ) - onnx_input = self._preprocess_onnx_input(onnx_input, **kwargs) model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) diff --git a/tests/test_image_onnx_embeddings.py b/tests/test_image_onnx_embeddings.py index 78194caf..5fbf8e3f 100644 --- a/tests/test_image_onnx_embeddings.py +++ b/tests/test_image_onnx_embeddings.py @@ -44,7 +44,6 @@ def test_embedding(): embeddings = list(model.embed(images)) embeddings = np.stack(embeddings, axis=0) assert embeddings.shape == (len(images), dim) - canonical_vector = CANONICAL_VECTOR_VALUES[model_desc["model"]] assert np.allclose( diff --git a/tests/test_late_interaction_image_embeddings.py b/tests/test_late_interaction_image_embeddings.py new file mode 100644 index 00000000..7d048e36 --- /dev/null +++ b/tests/test_late_interaction_image_embeddings.py @@ -0,0 +1,156 @@ +import os + +import numpy as np +import pytest + +from fastembed.late_interaction.late_interaction_image_embedding import ( + LateInteractionImageEmbedding, +) +from tests.utils import delete_model_cache +from tests.config import TEST_MISC_DIR +from PIL import Image + +# vectors are abridged and rounded for brevity +CANONICAL_COLUMN_VALUES = { + "akshayballal/colpali-v1.2-merged": np.array( + [ + [ + [0.015, 0.051, 0.059, 0.026, -0.061, -0.027, -0.014], + [-0.22, -0.111, 0.046, 0.081, -0.048, -0.052, -0.086], + [-0.184, -0.131, 0.004, 0.062, -0.038, -0.059, -0.127], + [-0.209, -0.113, 0.015, 0.059, -0.035, -0.035, -0.072], + [-0.031, -0.044, 0.092, -0.005, 0.006, -0.057, -0.061], + [-0.18, -0.039, 0.031, 0.003, 0.083, -0.041, 0.088], + [-0.091, 0.023, 0.116, -0.02, 0.039, -0.064, -0.026], + ] + ] + ), +} + +CANONICAL_QUERY_VALUES = { + "akshayballal/colpali-v1.2-merged": np.array( + [ + [0.158, -0.02, 0.1, -0.023, 0.045, 0.031, 0.071], + [-0.074, -0.111, 0.065, -0.0, -0.089, -0.003, -0.099], + [-0.034, -0.014, 0.174, -0.063, -0.09, -0.036, 0.064], + [-0.07, -0.014, 0.186, -0.013, -0.021, -0.062, 0.107], + [-0.085, 0.025, 0.179, -0.101, 0.036, -0.089, 0.098], + [-0.058, 0.031, 0.18, -0.078, 0.023, -0.119, 0.131], + [-0.067, 0.038, 0.188, -0.079, -0.001, -0.123, 0.127], + [-0.063, 0.037, 0.204, -0.069, 0.003, -0.118, 0.134], + [-0.054, 0.036, 0.212, -0.072, -0.001, -0.117, 0.133], + [-0.044, 0.03, 0.218, -0.077, -0.003, -0.107, 0.139], + [-0.037, 0.033, 0.22, -0.088, 0.0, -0.095, 0.146], + [-0.031, 0.041, 0.213, -0.092, 0.001, -0.088, 0.147], + [-0.026, 0.047, 0.204, -0.089, -0.002, -0.084, 0.144], + [-0.027, 0.051, 0.199, -0.084, -0.007, -0.083, 0.14], + [-0.031, 0.056, 0.19, -0.082, -0.011, -0.086, 0.135], + [-0.008, 0.108, 0.144, -0.095, -0.018, -0.086, 0.085], + ] + ), +} + +queries = ["hello world", "flag embedding"] +images = [ + TEST_MISC_DIR / "image.jpeg", + str(TEST_MISC_DIR / "small_image.jpeg"), + Image.open((TEST_MISC_DIR / "small_image.jpeg")), +] + + +def test_batch_embedding(): + is_ci = os.getenv("CI") + docs_to_embed = images + + for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = list(model.embed(docs_to_embed, batch_size=2)) + + for value in result: + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(value[:token_num, :abridged_dim], expected_result, atol=1e-3) + break + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding(): + is_ci = os.getenv("CI") + + docs_to_embed = images + + for model_name, expected_result in CANONICAL_COLUMN_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = next(iter(model.embed(docs_to_embed, batch_size=6))) + batch_size, token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_single_embedding_query(): + is_ci = os.getenv("CI") + + queries_to_embed = queries + + for model_name, expected_result in CANONICAL_QUERY_VALUES.items(): + print("evaluating", model_name) + model = LateInteractionImageEmbedding(model_name=model_name) + result = next(iter(model.query_embed(queries_to_embed))) + token_num, abridged_dim = expected_result.shape + assert np.allclose(result[:token_num, :abridged_dim], expected_result, atol=2e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_parallel_processing(): + is_ci = os.getenv("CI") + + model = LateInteractionImageEmbedding(model_name="akshayballal/colpali-v1.2-merged") + + token_dim = 128 + docs = ["hello world", "flag embedding"] * 100 + embeddings = list(model.embed(docs, batch_size=10, parallel=2)) + embeddings = np.stack(embeddings, axis=0) + + embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None)) + embeddings_2 = np.stack(embeddings_2, axis=0) + + embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0)) + embeddings_3 = np.stack(embeddings_3, axis=0) + + assert embeddings.shape[0] == len(docs) and embeddings.shape[-1] == token_dim + assert np.allclose(embeddings, embeddings_2, atol=1e-3) + assert np.allclose(embeddings, embeddings_3, atol=1e-3) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +@pytest.mark.parametrize( + "model_name", + ["akshayballal/colpali-v1.2-merged"], +) +def test_lazy_load(model_name): + is_ci = os.getenv("CI") + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + assert not hasattr(model.model, "model") + + docs = ["hello world", "flag embedding"] + list(model.embed(docs)) + assert hasattr(model.model, "model") + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + list(model.query_embed(docs)) + + model = LateInteractionImageEmbedding(model_name=model_name, lazy_load=True) + list(model.embed(docs)) + + if is_ci: + delete_model_cache(model.model._model_dir) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_text_embeddings.py similarity index 100% rename from tests/test_late_interaction_embeddings.py rename to tests/test_late_interaction_text_embeddings.py