diff --git a/.gitignore b/.gitignore index 42967a7ed..d8337200a 100644 --- a/.gitignore +++ b/.gitignore @@ -77,4 +77,3 @@ venv.bak/ # Other *.log *.swp -.DS_Store \ No newline at end of file diff --git a/src/distilabel/models/embeddings/__init__.py b/src/distilabel/models/embeddings/__init__.py index 917729874..573ba7226 100644 --- a/src/distilabel/models/embeddings/__init__.py +++ b/src/distilabel/models/embeddings/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from distilabel.models.embeddings.base import Embeddings +from distilabel.models.embeddings.llamacpp import LlamaCppEmbeddings from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) @@ -22,4 +23,5 @@ "Embeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", + "LlamaCppEmbeddings", ] diff --git a/src/distilabel/models/embeddings/llamacpp.py b/src/distilabel/models/embeddings/llamacpp.py new file mode 100644 index 000000000..6596bb45e --- /dev/null +++ b/src/distilabel/models/embeddings/llamacpp.py @@ -0,0 +1,237 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.embeddings.base import Embeddings +from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin + +if TYPE_CHECKING: + from llama_cpp import Llama + + +class LlamaCppEmbeddings(Embeddings, CudaDevicePlacementMixin): + """`LlamaCpp` library implementation for embedding generation. + + Attributes: + model_name: contains the name of the GGUF quantized model, compatible with the + installed version of the `llama.cpp` Python bindings. + model_path: contains the path to the GGUF quantized model, compatible with the + installed version of the `llama.cpp` Python bindings. + repo_id: the Hugging Face Hub repository id. + verbose: whether to print verbose output. Defaults to `False`. + n_gpu_layers: number of layers to run on the GPU. Defaults to `-1` (use the GPU if available). + disable_cuda_device_placement: whether to disable CUDA device placement. Defaults to `True`. + normalize_embeddings: whether to normalize the embeddings. Defaults to `False`. + seed: RNG seed, -1 for random + n_ctx: Text context, 0 = from model + n_batch: Prompt processing maximum batch size + extra_kwargs: additional dictionary of keyword arguments that will be passed to the + `Llama` class of `llama_cpp` library. Defaults to `{}`. + + Runtime parameters: + - `n_gpu_layers`: the number of layers to use for the GPU. Defaults to `-1`. + - `verbose`: whether to print verbose output. Defaults to `False`. + - `normalize_embeddings`: whether to normalize the embeddings. Defaults to `False`. + - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to the + `Llama` class of `llama_cpp` library. Defaults to `{}`. + + References: + - [Offline inference embeddings](https://llama-cpp-python.readthedocs.io/en/stable/#embeddings) + + Examples: + Generate sentence embeddings using a local model: + + ```python + from pathlib import Path + from distilabel.models.embeddings import LlamaCppEmbeddings + + # You can follow along this example downloading the following model running the following + # command in the terminal, that will download the model to the `Downloads` folder: + # curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf + + model_path = "Downloads/" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings( + model=model, + model_path=str(Path.home() / model_path), + ) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + ``` + + Generate sentence embeddings using a HuggingFace Hub model: + + ```python + from distilabel.models.embeddings import LlamaCppEmbeddings + # You need to set environment variable to download private model to the local machine + + repo_id = "second-state/All-MiniLM-L6-v2-Embedding-GGUF" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings(model=model,repo_id=repo_id) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + # [ + # [-0.05447685346007347, -0.01623094454407692, ...], + # [4.4889533455716446e-05, 0.044016145169734955, ...], + # ] + ``` + + Generate sentence embeddings with cpu: + + ```python + from pathlib import Path + from distilabel.models.embeddings import LlamaCppEmbeddings + + # You can follow along this example downloading the following model running the following + # command in the terminal, that will download the model to the `Downloads` folder: + # curl -L -o ~/Downloads/all-MiniLM-L6-v2-Q2_K.gguf https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-Q2_K.gguf + + model_path = "Downloads/" + model = "all-MiniLM-L6-v2-Q2_K.gguf" + embeddings = LlamaCppEmbeddings( + model=model, + model_path=str(Path.home() / model_path), + n_gpu_layers=0, + disable_cuda_device_placement=True, + ) + + embeddings.load() + + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + embeddings.unload() + # [ + # [-0.05447685346007347, -0.01623094454407692, ...], + # [4.4889533455716446e-05, 0.044016145169734955, ...], + # ] + ``` + + + """ + + model: str = Field( + description="The name of the model to use for embeddings.", + ) + + model_path: RuntimeParameter[str] = Field( + default=None, + description="The path to the GGUF quantized model, compatible with the installed version of the `llama.cpp` Python bindings.", + ) + + repo_id: RuntimeParameter[str] = Field( + default=None, description="The Hugging Face Hub repository id.", exclude=True + ) + + n_gpu_layers: RuntimeParameter[int] = Field( + default=-1, + description="The number of layers that will be loaded in the GPU.", + ) + + n_ctx: int = 512 + n_batch: int = 512 + seed: int = 4294967295 + + normalize_embeddings: RuntimeParameter[bool] = Field( + default=False, + description="Whether to normalize the embeddings.", + ) + verbose: RuntimeParameter[bool] = Field( + default=False, + description="Whether to print verbose output from llama.cpp library.", + ) + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `Llama` class of `llama_cpp` library. See all the supported arguments at: " + "https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__", + ) + _model: Optional["Llama"] = PrivateAttr(...) + + def load(self) -> None: + """Loads the `gguf` model using either the path or the Hugging Face Hub repository id.""" + super().load() + CudaDevicePlacementMixin.load(self) + + try: + from llama_cpp import Llama + except ImportError as ie: + raise ImportError( + "`llama-cpp-python` package is not installed. Please install it using" + " `pip install llama-cpp-python`." + ) from ie + + if self.repo_id is not None: + # use repo_id to download the model + from huggingface_hub.utils import validate_repo_id + + validate_repo_id(self.repo_id) + self._model = Llama.from_pretrained( + repo_id=self.repo_id, + filename=self.model, + n_gpu_layers=self.n_gpu_layers, + seed=self.seed, + n_ctx=self.n_ctx, + n_batch=self.n_batch, + verbose=self.verbose, + embedding=True, + kwargs=self.extra_kwargs, + ) + elif self.model_path is not None: + self._model = Llama( + model_path=str(Path(self.model_path) / self.model), + n_gpu_layers=self.n_gpu_layers, + seed=self.seed, + n_ctx=self.n_ctx, + n_batch=self.n_batch, + verbose=self.verbose, + embedding=True, + kwargs=self.extra_kwargs, + ) + else: + raise ValueError("Either 'model_path' or 'repo_id' must be provided") + + def unload(self) -> None: + """Unloads the `gguf` model.""" + CudaDevicePlacementMixin.unload(self) + self._model.close() + super().unload() + + @property + def model_name(self) -> str: + """Returns the name of the model.""" + return self.model + + def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]: + """Generates embeddings for the provided inputs. + + Args: + inputs: a list of texts for which an embedding has to be generated. + + Returns: + The generated embeddings. + """ + return self._model.embed(inputs, normalize=self.normalize_embeddings) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9aa4ea336..32f70133a 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit +import os from typing import TYPE_CHECKING, Any, Dict, List, Union +from urllib.request import urlretrieve import pytest from pydantic import PrivateAttr @@ -126,3 +129,35 @@ class DummyTaskOfflineBatchGeneration(DummyTask): @pytest.fixture def dummy_llm() -> AsyncLLM: return DummyAsyncLLM() + + +@pytest.fixture(scope="session") +def local_llamacpp_model_path(tmp_path_factory): + """ + Session-scoped fixture that provides the local model path for LlamaCpp testing. + + Download a small test model to a temporary directory. + The model is downloaded once per test session and cleaned up after all tests. + + Args: + tmp_path_factory: Pytest fixture providing a temporary directory factory. + + Returns: + str: The path to the local LlamaCpp model file. + """ + model_name = "all-MiniLM-L6-v2-Q2_K.gguf" + model_url = f"https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/{model_name}" + tmp_path = tmp_path_factory.getbasetemp() + model_path = tmp_path / model_name + + if not model_path.exists(): + urlretrieve(model_url, model_path) + + def cleanup(): + if model_path.exists(): + os.remove(model_path) + + # Register the cleanup function to be called at exit + atexit.register(cleanup) + + return str(tmp_path) diff --git a/tests/unit/models/embeddings/test_llamacpp.py b/tests/unit/models/embeddings/test_llamacpp.py new file mode 100644 index 000000000..b219ac779 --- /dev/null +++ b/tests/unit/models/embeddings/test_llamacpp.py @@ -0,0 +1,185 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pytest + +from distilabel.models.embeddings import LlamaCppEmbeddings + + +class TestLlamaCppEmbeddings: + @pytest.fixture(autouse=True) + def setup_embeddings(self, local_llamacpp_model_path): + """ + Fixture to set up embeddings for each test, considering CPU usage. + """ + self.model_name = "all-MiniLM-L6-v2-Q2_K.gguf" + self.repo_id = "second-state/All-MiniLM-L6-v2-Embedding-GGUF" + self.disable_cuda_device_placement = True + self.n_gpu_layers = 0 + self.embeddings = LlamaCppEmbeddings( + model=self.model_name, + model_path=local_llamacpp_model_path, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + + self.embeddings.load() + + @pytest.fixture + def test_inputs(self): + """ + Fixture that provides a list of test input strings. + + Returns: + list: A list of strings to be used as test inputs for embeddings. + """ + return [ + "Hello, how are you?", + "What a nice day!", + "I hear that llamas are very popular now.", + ] + + def test_model_name(self) -> None: + """ + Test if the model name is correctly set. + """ + assert self.embeddings.model_name == self.model_name + + def test_encode(self, test_inputs) -> None: + """ + Test if the model can generate embeddings. + """ + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_load_model_from_local(self, test_inputs): + """ + Test if the model can be loaded from a local file and generate embeddings. + + Args: + local_llamacpp_model_path (str): Fixture providing the local model path. + """ + + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_load_model_from_repo(self, test_inputs): + """ + Test if the model can be loaded from a Hugging Face repository. + """ + embeddings = LlamaCppEmbeddings( + repo_id=self.repo_id, + model=self.model_name, + normalize_embeddings=True, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + embeddings.load() + results = embeddings.encode(inputs=test_inputs) + + for result in results: + assert len(result) == 384 + + def test_normalize_embeddings(self, test_inputs): + """ + Test if embeddings are normalized when normalize_embeddings is True. + """ + + embeddings = LlamaCppEmbeddings( + repo_id=self.repo_id, + model=self.model_name, + normalize_embeddings=True, + n_gpu_layers=self.n_gpu_layers, + disable_cuda_device_placement=self.disable_cuda_device_placement, + ) + embeddings.load() + results = embeddings.encode(inputs=test_inputs) + + for result in results: + # Check if the embedding is normalized (L2 norm should be close to 1) + norm = np.linalg.norm(result) + assert np.isclose( + norm, 1.0, atol=1e-6 + ), f"Norm is {norm}, expected close to 1.0" + + def test_normalize_embeddings_false(self, test_inputs): + """ + Test if embeddings are not normalized when normalize_embeddings is False. + """ + + results = self.embeddings.encode(inputs=test_inputs) + + for result in results: + # Check if the embedding is not normalized (L2 norm should not be close to 1) + norm = np.linalg.norm(result) + assert not np.isclose( + norm, 1.0, atol=1e-6 + ), f"Norm is {norm}, expected not close to 1.0" + + # Additional check: ensure that at least one embedding has a norm significantly different from 1 + norms = [np.linalg.norm(result) for result in results] + assert any( + not np.isclose(norm, 1.0, atol=0.1) for norm in norms + ), "Expected at least one embedding with norm not close to 1.0" + + def test_encode_batch(self) -> None: + """ + Test if the model can generate embeddings for batches of inputs. + """ + # Test with different batch sizes + batch_sizes = [1, 2, 5, 10] + for batch_size in batch_sizes: + inputs = [f"This is test sentence {i}" for i in range(batch_size)] + results = self.embeddings.encode(inputs=inputs) + + assert ( + len(results) == batch_size + ), f"Expected {batch_size} results, got {len(results)}" + for result in results: + assert ( + len(result) == 384 + ), f"Expected embedding dimension 384, got {len(result)}" + + # Test with a large batch to ensure it doesn't cause issues + large_batch = ["Large batch test" for _ in range(100)] + large_results = self.embeddings.encode(inputs=large_batch) + assert ( + len(large_results) == 100 + ), f"Expected 100 results for large batch, got {len(large_results)}" + + def test_encode_batch_consistency(self) -> None: + """ + Test if the model produces consistent embeddings for the same input in different batch sizes. + + Args: + local_llamacpp_model_path (str): Fixture providing the local model path. + """ + input_text = "This is a test sentence for consistency" + + # Generate embedding individually + single_result = self.embeddings.encode([input_text])[0] + + # Generate embedding as part of a batch + batch_result = self.embeddings.encode([input_text, "Another sentence"])[0] + + # Compare the embeddings + assert np.allclose( + single_result, batch_result, atol=1e-5 + ), "Embeddings are not consistent between single and batch processing"