Skip to content

Commit

Permalink
fix: add config options and error handling to MistralAI component (#6131
Browse files Browse the repository at this point in the history
)

* ✨ (mistral.py): Add new input parameters to MistralAIModelComponent for better customization and control over the Mistral model configuration
♻️ (mistral.py): Refactor build_model method to improve readability and maintainability by using try-except block for error handling and updating parameter names for better clarity

* [autofix.ci] apply automated fixes

* ♻️ (mistral.py): refactor MistralAIModelComponent class to improve code readability by formatting the IntInput and BoolInput sections for better organization and clarity.

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
Cristhianzl and autofix-ci[bot] authored Feb 12, 2025
1 parent f7db8ee commit f471540
Showing 1 changed file with 62 additions and 34 deletions.
96 changes: 62 additions & 34 deletions src/backend/base/langflow/components/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,69 @@ class MistralAIModelComponent(LCModelComponent):
display_name="Mistral API Key",
info="The Mistral API Key to use for the Mistral model.",
advanced=False,
required=True,
value="MISTRAL_API_KEY",
),
FloatInput(
name="temperature",
display_name="Temperature",
advanced=False,
value=0.5,
),
IntInput(
name="max_retries",
display_name="Max Retries",
advanced=True,
value=5,
),
IntInput(
name="timeout",
display_name="Timeout",
advanced=True,
value=60,
),
IntInput(
name="max_concurrent_requests",
display_name="Max Concurrent Requests",
advanced=True,
value=3,
),
FloatInput(
name="top_p",
display_name="Top P",
advanced=True,
value=1,
),
IntInput(
name="random_seed",
display_name="Random Seed",
value=1,
advanced=True,
),
BoolInput(
name="safe_mode",
display_name="Safe Mode",
advanced=True,
value=False,
),
FloatInput(name="temperature", display_name="Temperature", advanced=False, value=0.5),
IntInput(name="max_retries", display_name="Max Retries", advanced=True, value=5),
IntInput(name="timeout", display_name="Timeout", advanced=True, value=60),
IntInput(name="max_concurrent_requests", display_name="Max Concurrent Requests", advanced=True, value=3),
FloatInput(name="top_p", display_name="Top P", advanced=True, value=1),
IntInput(name="random_seed", display_name="Random Seed", value=1, advanced=True),
BoolInput(name="safe_mode", display_name="Safe Mode", advanced=True),
]

def build_model(self) -> LanguageModel: # type: ignore[type-var]
mistral_api_key = self.api_key
temperature = self.temperature
model_name = self.model_name
max_tokens = self.max_tokens
mistral_api_base = self.mistral_api_base or "https://api.mistral.ai/v1"
max_retries = self.max_retries
timeout = self.timeout
max_concurrent_requests = self.max_concurrent_requests
top_p = self.top_p
random_seed = self.random_seed
safe_mode = self.safe_mode

api_key = SecretStr(mistral_api_key).get_secret_value() if mistral_api_key else None

return ChatMistralAI(
max_tokens=max_tokens or None,
model_name=model_name,
endpoint=mistral_api_base,
api_key=api_key,
temperature=temperature,
max_retries=max_retries,
timeout=timeout,
max_concurrent_requests=max_concurrent_requests,
top_p=top_p,
random_seed=random_seed,
safe_mode=safe_mode,
)
try:
return ChatMistralAI(
model_name=self.model_name,
mistral_api_key=SecretStr(self.api_key).get_secret_value() if self.api_key else None,
endpoint=self.mistral_api_base or "https://api.mistral.ai/v1",
max_tokens=self.max_tokens or None,
temperature=self.temperature,
max_retries=self.max_retries,
timeout=self.timeout,
max_concurrent_requests=self.max_concurrent_requests,
top_p=self.top_p,
random_seed=self.random_seed,
safe_mode=self.safe_mode,
streaming=self.stream,
)
except Exception as e:
msg = "Could not connect to MistralAI API."
raise ValueError(msg) from e

0 comments on commit f471540

Please sign in to comment.