From 8eb9eccdeeec1e16fb2edc71778b3a1be86c8651 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 14 Aug 2024 16:18:42 +0200 Subject: [PATCH 1/2] Update mistralai client to version 1.*.* --- pyproject.toml | 2 +- src/distilabel/llms/mistral.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4617885f..da50920dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] llama-cpp = ["llama-cpp-python >= 0.2.0"] -mistralai = ["mistralai >= 0.1.0"] +mistralai = ["mistralai >= 1.0.0"] ollama = ["ollama >= 0.1.7"] openai = ["openai >= 1.0.0"] outlines = ["outlines >= 0.0.40"] diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index ed1c3af7d..73b5fc13a 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from mistralai.async_client import MistralAsyncClient + from mistralai import Mistral _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" @@ -50,7 +50,7 @@ class MistralLLM(AsyncLLM): `InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`. _api_key_env_var: the name of the environment variable to use for the API key. It is meant to be used internally. - _aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally. + _aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally. Set in the `load` method. Runtime parameters: @@ -126,14 +126,14 @@ class User(BaseModel): _num_generations_param_supported = False _api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME) - _aclient: Optional["MistralAsyncClient"] = PrivateAttr(...) + _aclient: Optional["Mistral"] = PrivateAttr(...) def load(self) -> None: - """Loads the `MistralAsyncClient` client to benefit from async requests.""" + """Loads the `Mistral` client to benefit from async requests.""" super().load() try: - from mistralai.async_client import MistralAsyncClient + from mistralai import Mistral except ImportError as ie: raise ImportError( "MistralAI Python client is not installed. Please install it using" @@ -146,7 +146,7 @@ def load(self) -> None: f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." ) - self._aclient = MistralAsyncClient( + self._aclient = Mistral( api_key=self.api_key.get_secret_value(), endpoint=self.endpoint, max_retries=self.max_retries, # type: ignore @@ -218,7 +218,8 @@ async def agenerate( # type: ignore # We need to check instructor and see if we can create a PR. completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore else: - completion = await self._aclient.chat(**kwargs) # type: ignore + # completion = await self._aclient.chat(**kwargs) # type: ignore + completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore if structured_output: generations.append(completion.model_dump_json()) From 06d8dff01167b7cf25ebb7e8ac020f201fa645f3 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 14 Aug 2024 16:19:01 +0200 Subject: [PATCH 2/2] Update tests for new mistral client --- tests/unit/llms/test_mistral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py index 5bb233748..89f8e9649 100644 --- a/tests/unit/llms/test_mistral.py +++ b/tests/unit/llms/test_mistral.py @@ -97,7 +97,9 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: mocked_completion = Mock( choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] ) - llm._aclient.chat = AsyncMock(return_value=mocked_completion) + llm._aclient.chat = Mock( + complete_async=AsyncMock(return_value=mocked_completion) + ) nest_asyncio.apply()