From f4715407b8d117ac3c285be9577974eec37f99d7 Mon Sep 17 00:00:00 2001 From: Cristhian Zanforlin Lousa Date: Wed, 12 Feb 2025 11:10:23 -0300 Subject: [PATCH] fix: add config options and error handling to MistralAI component (#6131) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ (mistral.py): Add new input parameters to MistralAIModelComponent for better customization and control over the Mistral model configuration ♻️ (mistral.py): Refactor build_model method to improve readability and maintainability by using try-except block for error handling and updating parameter names for better clarity * [autofix.ci] apply automated fixes * ♻️ (mistral.py): refactor MistralAIModelComponent class to improve code readability by formatting the IntInput and BoolInput sections for better organization and clarity. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../langflow/components/models/mistral.py | 96 ++++++++++++------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/src/backend/base/langflow/components/models/mistral.py b/src/backend/base/langflow/components/models/mistral.py index d5fb967f0b18..ed66ab529263 100644 --- a/src/backend/base/langflow/components/models/mistral.py +++ b/src/backend/base/langflow/components/models/mistral.py @@ -46,41 +46,69 @@ class MistralAIModelComponent(LCModelComponent): display_name="Mistral API Key", info="The Mistral API Key to use for the Mistral model.", advanced=False, + required=True, + value="MISTRAL_API_KEY", + ), + FloatInput( + name="temperature", + display_name="Temperature", + advanced=False, + value=0.5, + ), + IntInput( + name="max_retries", + display_name="Max Retries", + advanced=True, + value=5, + ), + IntInput( + name="timeout", + display_name="Timeout", + advanced=True, + value=60, + ), + IntInput( + name="max_concurrent_requests", + display_name="Max Concurrent Requests", + advanced=True, + value=3, + ), + FloatInput( + name="top_p", + display_name="Top P", + advanced=True, + value=1, + ), + IntInput( + name="random_seed", + display_name="Random Seed", + value=1, + advanced=True, + ), + BoolInput( + name="safe_mode", + display_name="Safe Mode", + advanced=True, + value=False, ), - FloatInput(name="temperature", display_name="Temperature", advanced=False, value=0.5), - IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5), - IntInput(name="timeout", display_name="Timeout", advanced=True, value=60), - IntInput(name="max_concurrent_requests", display_name="Max Concurrent Requests", advanced=True, value=3), - FloatInput(name="top_p", display_name="Top P", advanced=True, value=1), - IntInput(name="random_seed", display_name="Random Seed", value=1, advanced=True), - BoolInput(name="safe_mode", display_name="Safe Mode", advanced=True), ] def build_model(self) -> LanguageModel: # type: ignore[type-var] - mistral_api_key = self.api_key - temperature = self.temperature - model_name = self.model_name - max_tokens = self.max_tokens - mistral_api_base = self.mistral_api_base or "https://api.mistral.ai/v1" - max_retries = self.max_retries - timeout = self.timeout - max_concurrent_requests = self.max_concurrent_requests - top_p = self.top_p - random_seed = self.random_seed - safe_mode = self.safe_mode - - api_key = SecretStr(mistral_api_key).get_secret_value() if mistral_api_key else None - - return ChatMistralAI( - max_tokens=max_tokens or None, - model_name=model_name, - endpoint=mistral_api_base, - api_key=api_key, - temperature=temperature, - max_retries=max_retries, - timeout=timeout, - max_concurrent_requests=max_concurrent_requests, - top_p=top_p, - random_seed=random_seed, - safe_mode=safe_mode, - ) + try: + return ChatMistralAI( + model_name=self.model_name, + mistral_api_key=SecretStr(self.api_key).get_secret_value() if self.api_key else None, + endpoint=self.mistral_api_base or "https://api.mistral.ai/v1", + max_tokens=self.max_tokens or None, + temperature=self.temperature, + max_retries=self.max_retries, + timeout=self.timeout, + max_concurrent_requests=self.max_concurrent_requests, + top_p=self.top_p, + random_seed=self.random_seed, + safe_mode=self.safe_mode, + streaming=self.stream, + ) + except Exception as e: + msg = "Could not connect to MistralAI API." + raise ValueError(msg) from e