diff --git a/docs/docs/integrations/chat/mistralai.ipynb b/docs/docs/integrations/chat/mistralai.ipynb index 12faf385a399e..106d51a700a5c 100644 --- a/docs/docs/integrations/chat/mistralai.ipynb +++ b/docs/docs/integrations/chat/mistralai.ipynb @@ -48,7 +48,7 @@ "source": [ "import getpass\n", "\n", - "mistral_api_key = getpass.getpass()" + "api_key = getpass.getpass()" ] }, { @@ -81,8 +81,8 @@ }, "outputs": [], "source": [ - "# If mistral_api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n", - "chat = ChatMistralAI(mistral_api_key=mistral_api_key)" + "# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.\n", + "chat = ChatMistralAI(api_key=api_key)" ] }, { diff --git a/docs/docs/integrations/text_embedding/mistralai.ipynb b/docs/docs/integrations/text_embedding/mistralai.ipynb index 55b15875bbd70..e8e89b5ede587 100644 --- a/docs/docs/integrations/text_embedding/mistralai.ipynb +++ b/docs/docs/integrations/text_embedding/mistralai.ipynb @@ -45,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "embedding = MistralAIEmbeddings(mistral_api_key=\"your-api-key\")" + "embedding = MistralAIEmbeddings(api_key=\"your-api-key\")" ] }, { diff --git a/docs/docs/modules/model_io/chat/quick_start.ipynb b/docs/docs/modules/model_io/chat/quick_start.ipynb index a70f62192ecbb..5b827fa44f8fc 100644 --- a/docs/docs/modules/model_io/chat/quick_start.ipynb +++ b/docs/docs/modules/model_io/chat/quick_start.ipynb @@ -54,8 +54,8 @@ " Dict[str, Any]: """Get the default parameters for calling the API.""" diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index e58f7d3692ea8..4a6c28b753d57 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -29,15 +29,16 @@ class MistralAIEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain_mistralai import MistralAIEmbeddings + mistral = MistralAIEmbeddings( model="mistral-embed", - mistral_api_key="my-api-key" + api_key="my-api-key" ) """ client: httpx.Client = Field(default=None) #: :meta private: async_client: httpx.AsyncClient = Field(default=None) #: :meta private: - mistral_api_key: Optional[SecretStr] = None + mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") endpoint: str = "https://api.mistral.ai/v1/" max_retries: int = 5 timeout: int = 120 @@ -49,6 +50,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): class Config: extra = Extra.forbid arbitrary_types_allowed = True + allow_population_by_field_name = True @root_validator() def validate_environment(cls, values: Dict) -> Dict: diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index f7aa3a749ab30..2ee2565e54676 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test MistralAI Chat API wrapper.""" import os -from typing import Any, AsyncGenerator, Dict, Generator +from typing import Any, AsyncGenerator, Dict, Generator, cast from unittest.mock import patch import pytest @@ -13,6 +13,7 @@ HumanMessage, SystemMessage, ) +from langchain_core.pydantic_v1 import SecretStr from langchain_mistralai.chat_models import ( # type: ignore[import] ChatMistralAI, @@ -31,7 +32,11 @@ def test_mistralai_initialization() -> None: """Test ChatMistralAI initialization.""" # Verify that ChatMistralAI can be initialized using a secret key provided # as a parameter rather than an environment variable. - ChatMistralAI(model="test", mistral_api_key="test") + for model in [ + ChatMistralAI(model="test", mistral_api_key="test"), + ChatMistralAI(model="test", api_key="test"), + ]: + assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test" @pytest.mark.parametrize( diff --git a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py index 14055af4ed7d5..d1599fce375e1 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_embeddings.py +++ b/libs/partners/mistralai/tests/unit_tests/test_embeddings.py @@ -1,4 +1,7 @@ import os +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr from langchain_mistralai import MistralAIEmbeddings @@ -6,5 +9,9 @@ def test_mistral_init() -> None: - embeddings = MistralAIEmbeddings() - assert embeddings.model == "mistral-embed" + for model in [ + MistralAIEmbeddings(model="mistral-embed", mistral_api_key="test"), + MistralAIEmbeddings(model="mistral-embed", api_key="test"), + ]: + assert model.model == "mistral-embed" + assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"