Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mistrallm #904

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
mistralai = ["mistralai >= 0.1.0"]
mistralai = ["mistralai >= 1.0.0"]
ollama = ["ollama >= 0.1.7"]
openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40"]
Expand Down
15 changes: 8 additions & 7 deletions src/distilabel/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

if TYPE_CHECKING:
from mistralai.async_client import MistralAsyncClient
from mistralai import Mistral


_MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY"
Expand All @@ -50,7 +50,7 @@ class MistralLLM(AsyncLLM):
`InstructorStructuredOutputType` from `distilabel.steps.tasks.structured_outputs.instructor`.
_api_key_env_var: the name of the environment variable to use for the API key. It is meant to
be used internally.
_aclient: the `MistralAsyncClient` to use for the Mistral API. It is meant to be used internally.
_aclient: the `Mistral` to use for the Mistral API. It is meant to be used internally.
Set in the `load` method.

Runtime parameters:
Expand Down Expand Up @@ -126,14 +126,14 @@ class User(BaseModel):
_num_generations_param_supported = False

_api_key_env_var: str = PrivateAttr(_MISTRALAI_API_KEY_ENV_VAR_NAME)
_aclient: Optional["MistralAsyncClient"] = PrivateAttr(...)
_aclient: Optional["Mistral"] = PrivateAttr(...)

def load(self) -> None:
"""Loads the `MistralAsyncClient` client to benefit from async requests."""
"""Loads the `Mistral` client to benefit from async requests."""
super().load()

try:
from mistralai.async_client import MistralAsyncClient
from mistralai import Mistral
except ImportError as ie:
raise ImportError(
"MistralAI Python client is not installed. Please install it using"
Expand All @@ -146,7 +146,7 @@ def load(self) -> None:
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
)

self._aclient = MistralAsyncClient(
self._aclient = Mistral(
api_key=self.api_key.get_secret_value(),
endpoint=self.endpoint,
max_retries=self.max_retries, # type: ignore
Expand Down Expand Up @@ -218,7 +218,8 @@ async def agenerate( # type: ignore
# We need to check instructor and see if we can create a PR.
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
else:
completion = await self._aclient.chat(**kwargs) # type: ignore
# completion = await self._aclient.chat(**kwargs) # type: ignore
completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore

if structured_output:
generations.append(completion.model_dump_json())
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/llms/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
mocked_completion = Mock(
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
)
llm._aclient.chat = AsyncMock(return_value=mocked_completion)
llm._aclient.chat = Mock(
complete_async=AsyncMock(return_value=mocked_completion)
)

nest_asyncio.apply()

Expand Down
Loading