diff --git a/cookbook/models/anthropic/async/basic.py b/cookbook/models/anthropic/async_basic.py similarity index 100% rename from cookbook/models/anthropic/async/basic.py rename to cookbook/models/anthropic/async_basic.py diff --git a/cookbook/models/anthropic/async/basic_stream.py b/cookbook/models/anthropic/async_basic_stream.py similarity index 100% rename from cookbook/models/anthropic/async/basic_stream.py rename to cookbook/models/anthropic/async_basic_stream.py diff --git a/cookbook/models/anthropic/async/tool_use.py b/cookbook/models/anthropic/async_tool_use.py similarity index 100% rename from cookbook/models/anthropic/async/tool_use.py rename to cookbook/models/anthropic/async_tool_use.py diff --git a/cookbook/models/groq/async/__init__.py b/cookbook/models/groq/async/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/cookbook/models/groq/async/basic.py b/cookbook/models/groq/async_basic.py similarity index 100% rename from cookbook/models/groq/async/basic.py rename to cookbook/models/groq/async_basic.py diff --git a/cookbook/models/groq/async/basic_stream.py b/cookbook/models/groq/async_basic_stream.py similarity index 66% rename from cookbook/models/groq/async/basic_stream.py rename to cookbook/models/groq/async_basic_stream.py index 45571ed93..af39178f6 100644 --- a/cookbook/models/groq/async/basic_stream.py +++ b/cookbook/models/groq/async_basic_stream.py @@ -3,12 +3,12 @@ from agno.agent import Agent from agno.models.groq import Groq -assistant = Agent( +agent = Agent( model=Groq(id="llama-3.3-70b-versatile"), description="You help people with their health and fitness goals.", instructions=["Recipes should be under 5 ingredients"], ) -# -*- Print a response to the cli +# -*- Print a response to the terminal asyncio.run( - assistant.aprint_response("Share a breakfast recipe.", markdown=True, stream=True) + agent.aprint_response("Share a breakfast recipe.", markdown=True, stream=True) ) diff --git a/cookbook/models/groq/async_tool_use.py b/cookbook/models/groq/async_tool_use.py new file mode 100644 index 000000000..886ac1481 --- /dev/null +++ b/cookbook/models/groq/async_tool_use.py @@ -0,0 +1,28 @@ +"""Please install dependencies using: +pip install openai duckduckgo-search newspaper4k lxml_html_clean agno +""" +import asyncio + +from agno.agent import Agent +from agno.models.groq import Groq +from agno.tools.duckduckgo import DuckDuckGoTools +from agno.tools.newspaper4k import Newspaper4kTools + +agent = Agent( + model=Groq(id="llama-3.3-70b-versatile"), + tools=[DuckDuckGoTools(), Newspaper4kTools()], + description="You are a senior NYT researcher writing an article on a topic.", + instructions=[ + "For a given topic, search for the top 5 links.", + "Then read each URL and extract the article text, if a URL isn't available, ignore it.", + "Analyse and prepare an NYT worthy article based on the information.", + ], + markdown=True, + show_tool_calls=True, + add_datetime_to_instructions=True, +) + +# -*- Print a response to the cli +asyncio.run( + agent.aprint_response("Simulation theory", stream=True) +) diff --git a/cookbook/models/groq/web_search.py b/cookbook/models/groq/web_search.py deleted file mode 100644 index 01dddd748..000000000 --- a/cookbook/models/groq/web_search.py +++ /dev/null @@ -1,15 +0,0 @@ -from agno.agent import Agent -from agno.models.groq import Groq -from agno.tools.duckduckgo import DuckDuckGoTools - -# Initialize the agent with the Groq model and tools for DuckDuckGo -agent = Agent( - model=Groq(id="llama-3.3-70b-versatile"), - description="You are an enthusiastic news reporter with a flair for storytelling!", - tools=[DuckDuckGoTools()], # Add DuckDuckGo tool to search the web - show_tool_calls=True, # Shows tool calls in the response, set to False to hide - markdown=True, # Format responses in markdown -) - -# Prompt the agent to fetch a breaking news story from New York -agent.print_response("Tell me about a breaking news story from New York.", stream=True) diff --git a/cookbook/models/ollama/README.md b/cookbook/models/ollama/README.md index a66150232..908f09e5f 100644 --- a/cookbook/models/ollama/README.md +++ b/cookbook/models/ollama/README.md @@ -42,7 +42,7 @@ python cookbook/models/ollama/basic.py - DuckDuckGo Search ```shell -python cookbook/models/ollama/web_search.py +python cookbook/models/ollama/tool_use.py ``` ### 6. Run Agent that returns structured output @@ -63,7 +63,7 @@ python cookbook/models/ollama/storage.py python cookbook/models/ollama/knowledge.py ``` -### 9. Run Agent that uses memory +### 9. Run Agent that uses memory ```shell python cookbook/models/ollama/memory.py diff --git a/libs/agno/agno/agent/agent.py b/libs/agno/agno/agent/agent.py index 5f4344055..f4a03bf87 100644 --- a/libs/agno/agno/agent/agent.py +++ b/libs/agno/agno/agent/agent.py @@ -867,7 +867,7 @@ def run( **kwargs, ) return next(resp) - except AgentRunException as e: + except Exception as e: logger.warning(f"Attempt {attempt + 1}/{num_attempts} failed: {str(e)}") if isinstance(e, StopAgentRun): raise e @@ -1267,7 +1267,7 @@ async def arun( **kwargs, ) return await resp.__anext__() - except AgentRunException as e: + except Exception as e: logger.warning(f"Attempt {attempt + 1}/{num_attempts} failed: {str(e)}") if isinstance(e, StopAgentRun): raise e diff --git a/libs/agno/agno/document/reader/csv_reader.py b/libs/agno/agno/document/reader/csv_reader.py index bd2dcc8fd..afdf4aa52 100644 --- a/libs/agno/agno/document/reader/csv_reader.py +++ b/libs/agno/agno/document/reader/csv_reader.py @@ -1,6 +1,7 @@ import csv import io import os +from time import sleep from pathlib import Path from typing import IO, Any, List, Union from urllib.parse import urlparse @@ -63,7 +64,18 @@ def read(self, url: str) -> List[Document]: raise ImportError("`httpx` not installed") logger.info(f"Reading: {url}") - response = httpx.get(url) + # Retry the request up to 3 times with exponential backoff + for attempt in range(3): + try: + response = httpx.get(url) + break + except httpx.RequestError as e: + if attempt == 2: # Last attempt + logger.error(f"Failed to fetch CSV after 3 attempts: {e}") + raise + wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds + logger.warning(f"Request failed, retrying in {wait_time} seconds...") + sleep(wait_time) try: response.raise_for_status() diff --git a/libs/agno/agno/document/reader/pdf_reader.py b/libs/agno/agno/document/reader/pdf_reader.py index af8c5cb4d..229e2bbc5 100644 --- a/libs/agno/agno/document/reader/pdf_reader.py +++ b/libs/agno/agno/document/reader/pdf_reader.py @@ -1,4 +1,5 @@ from pathlib import Path +from time import sleep from typing import IO, Any, List, Union from agno.document.base import Document @@ -59,7 +60,18 @@ def read(self, url: str) -> List[Document]: raise ImportError("`httpx` not installed. Please install it via `pip install httpx`.") logger.info(f"Reading: {url}") - response = httpx.get(url) + # Retry the request up to 3 times with exponential backoff + for attempt in range(3): + try: + response = httpx.get(url) + break + except httpx.RequestError as e: + if attempt == 2: # Last attempt + logger.error(f"Failed to fetch PDF after 3 attempts: {e}") + raise + wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds + logger.warning(f"Request failed, retrying in {wait_time} seconds...") + sleep(wait_time) try: response.raise_for_status() diff --git a/libs/agno/agno/document/reader/url_reader.py b/libs/agno/agno/document/reader/url_reader.py index f8d02eb13..a4a32fcb6 100644 --- a/libs/agno/agno/document/reader/url_reader.py +++ b/libs/agno/agno/document/reader/url_reader.py @@ -1,5 +1,6 @@ from typing import List from urllib.parse import urlparse +from time import sleep from agno.document.base import Document from agno.document.reader.base import Reader @@ -19,7 +20,18 @@ def read(self, url: str) -> List[Document]: raise ImportError("`httpx` not installed. Please install it via `pip install httpx`.") logger.info(f"Reading: {url}") - response = httpx.get(url) + # Retry the request up to 3 times with exponential backoff + for attempt in range(3): + try: + response = httpx.get(url) + break + except httpx.RequestError as e: + if attempt == 2: # Last attempt + logger.error(f"Failed to fetch PDF after 3 attempts: {e}") + raise + wait_time = 2 ** attempt # Exponential backoff: 1, 2, 4 seconds + logger.warning(f"Request failed, retrying in {wait_time} seconds...") + sleep(wait_time) try: logger.debug(f"Status: {response.status_code}") diff --git a/libs/agno/agno/models/anthropic/claude.py b/libs/agno/agno/models/anthropic/claude.py index 6ba49562a..faa62704d 100644 --- a/libs/agno/agno/models/anthropic/claude.py +++ b/libs/agno/agno/models/anthropic/claude.py @@ -12,6 +12,8 @@ try: from anthropic import Anthropic as AnthropicClient from anthropic import AsyncAnthropic as AsyncAnthropicClient + from anthropic import APIConnectionError, RateLimitError, APIStatusError + from anthropic.types import ( ContentBlockDeltaEvent, MessageStopEvent, @@ -290,15 +292,33 @@ def invoke(self, messages: List[Message]) -> AnthropicMessage: Returns: AnthropicMessage: The response from the model. - """ - chat_messages, system_message = _format_messages(messages) - request_kwargs = self._prepare_request_kwargs(system_message) - return self.get_client().messages.create( - model=self.id, - messages=chat_messages, # type: ignore - **request_kwargs, - ) + Raises: + APIConnectionError: If there are network connectivity issues + RateLimitError: If the API rate limit is exceeded + APIStatusError: For other API-related errors + """ + try: + chat_messages, system_message = _format_messages(messages) + request_kwargs = self._prepare_request_kwargs(system_message) + + return self.get_client().messages.create( + model=self.id, + messages=chat_messages, # type: ignore + **request_kwargs, + ) + except APIConnectionError as e: + logger.error(f"Connection error while calling Claude API: {str(e)}") + raise + except RateLimitError as e: + logger.warning(f"Rate limit exceeded: {str(e)}") + raise + except APIStatusError as e: + logger.error(f"Claude API error (status {e.status_code}): {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error calling Claude API: {str(e)}") + raise def invoke_stream(self, messages: List[Message]) -> Any: """ @@ -313,11 +333,24 @@ def invoke_stream(self, messages: List[Message]) -> Any: chat_messages, system_message = _format_messages(messages) request_kwargs = self._prepare_request_kwargs(system_message) - return self.get_client().messages.stream( - model=self.id, - messages=chat_messages, # type: ignore - **request_kwargs, - ).__enter__() + try: + return self.get_client().messages.stream( + model=self.id, + messages=chat_messages, # type: ignore + **request_kwargs, + ).__enter__() + except APIConnectionError as e: + logger.error(f"Connection error while calling Claude API: {str(e)}") + raise + except RateLimitError as e: + logger.warning(f"Rate limit exceeded: {str(e)}") + raise + except APIStatusError as e: + logger.error(f"Claude API error (status {e.status_code}): {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error calling Claude API: {str(e)}") + raise async def ainvoke(self, messages: List[Message]) -> AnthropicMessage: """ @@ -328,15 +361,33 @@ async def ainvoke(self, messages: List[Message]) -> AnthropicMessage: Returns: AnthropicMessage: The response from the model. - """ - chat_messages, system_message = _format_messages(messages) - request_kwargs = self._prepare_request_kwargs(system_message) - return await self.get_async_client().messages.create( - model=self.id, - messages=chat_messages, # type: ignore - **request_kwargs, - ) + Raises: + APIConnectionError: If there are network connectivity issues + RateLimitError: If the API rate limit is exceeded + APIStatusError: For other API-related errors + """ + try: + chat_messages, system_message = _format_messages(messages) + request_kwargs = self._prepare_request_kwargs(system_message) + + return await self.get_async_client().messages.create( + model=self.id, + messages=chat_messages, # type: ignore + **request_kwargs, + ) + except APIConnectionError as e: + logger.error(f"Connection error while calling Claude API: {str(e)}") + raise + except RateLimitError as e: + logger.warning(f"Rate limit exceeded: {str(e)}") + raise + except APIStatusError as e: + logger.error(f"Claude API error (status {e.status_code}): {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error calling Claude API: {str(e)}") + raise async def ainvoke_stream(self, messages: List[Message]) -> Any: """ @@ -348,15 +399,27 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: Returns: Any: The streamed response from the model. """ - chat_messages, system_message = _format_messages(messages) - request_kwargs = self._prepare_request_kwargs(system_message) - - return await self.get_async_client().messages.create( - model=self.id, - messages=chat_messages, # type: ignore - stream=True, - **request_kwargs, - ) + try: + chat_messages, system_message = _format_messages(messages) + request_kwargs = self._prepare_request_kwargs(system_message) + + return await self.get_async_client().messages.stream( + model=self.id, + messages=chat_messages, # type: ignore + **request_kwargs, + ).__aenter__() + except APIConnectionError as e: + logger.error(f"Connection error while calling Claude API: {str(e)}") + raise + except RateLimitError as e: + logger.warning(f"Rate limit exceeded: {str(e)}") + raise + except APIStatusError as e: + logger.error(f"Claude API error (status {e.status_code}): {str(e)}") + raise + except Exception as e: + logger.error(f"Unexpected error calling Claude API: {str(e)}") + raise # Overwrite the default from the base model def format_function_call_results( diff --git a/libs/agno/agno/models/aws/bedrock.py b/libs/agno/agno/models/aws/bedrock.py index 645f9f770..f88f62708 100644 --- a/libs/agno/agno/models/aws/bedrock.py +++ b/libs/agno/agno/models/aws/bedrock.py @@ -1,26 +1,22 @@ -import json +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional from agno.aws.api_client import AwsApiClient # type: ignore -from agno.models.base import Model, StreamData +from agno.models.base import Model from agno.models.message import Message -from agno.models.response import ModelResponse, ModelResponseEvent +from agno.models.response import ModelProviderResponse from agno.utils.log import logger -from agno.utils.timer import Timer -from agno.utils.tools import ( - get_function_call_for_tool_call, -) try: from boto3 import session # noqa: F401 except ImportError: - logger.error("`boto3` not installed") + logger.error("`boto3` not installed. Please install it via `pip install boto3`.") raise @dataclass -class AwsBedrock(Model): +class AwsBedrock(Model, ABC): """ AWS Bedrock model. @@ -28,18 +24,15 @@ class AwsBedrock(Model): aws_region (Optional[str]): The AWS region to use. aws_profile (Optional[str]): The AWS profile to use. aws_client (Optional[AwsApiClient]): The AWS client to use. - _bedrock_client (Optional[Any]): The Bedrock client to use. - _bedrock_runtime_client (Optional[Any]): The Bedrock runtime client to use. """ aws_region: Optional[str] = None aws_profile: Optional[str] = None aws_client: Optional[AwsApiClient] = None - _bedrock_client: Optional[Any] = None _bedrock_runtime_client: Optional[Any] = None - def get_aws_region(self) -> Optional[str]: + def _get_aws_region(self) -> Optional[str]: # Priority 1: Use aws_region from model if self.aws_region is not None: return self.aws_region @@ -54,7 +47,7 @@ def get_aws_region(self) -> Optional[str]: self.aws_region = aws_region_env return self.aws_region - def get_aws_profile(self) -> Optional[str]: + def _get_aws_profile(self) -> Optional[str]: # Priority 1: Use aws_region from resource if self.aws_profile is not None: return self.aws_profile @@ -69,479 +62,70 @@ def get_aws_profile(self) -> Optional[str]: self.aws_profile = aws_profile_env return self.aws_profile - def get_aws_client(self) -> AwsApiClient: + def _get_aws_client(self) -> AwsApiClient: if self.aws_client is not None: return self.aws_client - self.aws_client = AwsApiClient(aws_region=self.get_aws_region(), aws_profile=self.get_aws_profile()) + self.aws_client = AwsApiClient(aws_region=self._get_aws_region(), aws_profile=self._get_aws_profile()) return self.aws_client - @property - def bedrock_runtime_client(self): + def get_client(self): if self._bedrock_runtime_client is not None: return self._bedrock_runtime_client - boto3_session: session = self.get_aws_client().boto3_session + boto3_session: session = self._get_aws_client().boto3_session self._bedrock_runtime_client = boto3_session.client(service_name="bedrock-runtime") return self._bedrock_runtime_client - def invoke(self, body: Dict[str, Any]) -> Dict[str, Any]: + def invoke(self, messages: List[Message]) -> Dict[str, Any]: """ Invoke the Bedrock API. Args: - body (Dict[str, Any]): The request body. + messages (List[Message]): The messages to include in the request. Returns: Dict[str, Any]: The response from the Bedrock API. """ - return self.bedrock_runtime_client.converse(**body) + body = self.format_messages(messages) + try: + return self.get_client().converse(**body) + except Exception as e: + logger.error(f"Unexpected error calling Bedrock API: {str(e)}") + raise - def invoke_stream(self, body: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + def invoke_stream(self, messages: List[Message]) -> Iterator[Dict[str, Any]]: """ Invoke the Bedrock API with streaming. Args: - body (Dict[str, Any]): The request body. + messages (List[Message]): The messages to include in the request. Returns: Iterator[Dict[str, Any]]: The streamed response. """ - response = self.bedrock_runtime_client.converse_stream(**body) + body = self.format_messages(messages) + response = self.get_client().converse_stream(**body) stream = response.get("stream") if stream: for event in stream: yield event - def create_assistant_message(self, request_body: Dict[str, Any]) -> Message: - raise NotImplementedError("Please use a subclass of AwsBedrock") - - def get_request_body(self, messages: List[Message]) -> Dict[str, Any]: + @abstractmethod + def format_messages(self, messages: List[Message]) -> Dict[str, Any]: raise NotImplementedError("Please use a subclass of AwsBedrock") - def parse_response_message(self, response: Dict[str, Any]) -> Dict[str, Any]: + @abstractmethod + def parse_model_provider_response(self, response: Dict[str, Any]) -> ModelProviderResponse: raise NotImplementedError("Please use a subclass of AwsBedrock") - def _create_tool_calls( - self, stop_reason: str, parsed_response: Dict[str, Any] - ) -> Tuple[List[str], List[Dict[str, Any]]]: - tool_ids: List[str] = [] - tool_calls: List[Dict[str, Any]] = [] - - if stop_reason == "tool_use": - tool_requests = parsed_response.get("tool_requests") - if tool_requests: - for tool in tool_requests: - if "toolUse" in tool: - tool_use = tool["toolUse"] - tool_id = tool_use["toolUseId"] - tool_name = tool_use["name"] - tool_args = tool_use["input"] - - tool_ids.append(tool_id) - tool_calls.append( - { - "type": "function", - "function": { - "name": tool_name, - "arguments": json.dumps(tool_args), - }, - } - ) - - return tool_ids, tool_calls - - def _handle_tool_calls( - self, assistant_message: Message, messages: List[Message], model_response: ModelResponse, tool_ids - ) -> Optional[ModelResponse]: - """ - Handle tool calls in the assistant message. - - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - model_response (ModelResponse): The model response. - - Returns: - Optional[ModelResponse]: The model response after handling tool calls. - """ - # -*- Parse and run function call - if assistant_message.tool_calls is not None: - if model_response.tool_calls is None: - model_response.tool_calls = [] - - # Remove the tool call from the response content - model_response.content = "" - tool_role: str = "tool" - function_calls_to_run: List[Any] = [] - function_call_results: List[Message] = [] - for tool_call in assistant_message.tool_calls: - _tool_call_id = tool_call.get("id") - _function_call = get_function_call_for_tool_call(tool_call, self._functions) - if _function_call is None: - messages.append( - Message( - role="tool", - tool_call_id=_tool_call_id, - content="Could not find function to call.", - ) - ) - continue - if _function_call.error is not None: - messages.append( - Message( - role="tool", - tool_call_id=_tool_call_id, - content=_function_call.error, - ) - ) - continue - function_calls_to_run.append(_function_call) - - if self.show_tool_calls: - model_response.content += "\nRunning:" - for _f in function_calls_to_run: - model_response.content += f"\n - {_f.get_call_str()}" - model_response.content += "\n\n" - - for function_call_response in self.run_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results, tool_role=tool_role - ): - if ( - function_call_response.event == ModelResponseEvent.tool_call_completed.value - and function_call_response.tool_calls is not None - ): - model_response.tool_calls.extend(function_call_response.tool_calls) - - if len(function_call_results) > 0: - fc_responses: List = [] - - for _fc_message_index, _fc_message in enumerate(function_call_results): - tool_result = { - "toolUseId": tool_ids[_fc_message_index], - "content": [{"json": json.dumps(_fc_message.content)}], - } - tool_result_message = {"role": "user", "content": json.dumps([{"toolResult": tool_result}])} - fc_responses.append(tool_result_message) - - logger.debug(f"Tool call responses: {fc_responses}") - messages.append(Message(role="user", content=json.dumps(fc_responses))) - - return model_response - return None - - def _update_metrics(self, assistant_message, parsed_response, response_timer): - """ - Update usage metrics in assistant_message and self.metrics based on the parsed_response. - - Args: - assistant_message: The assistant's message object where individual metrics are stored. - parsed_response: The parsed response containing usage metrics. - response_timer: Timer object that has the elapsed time of the response. - """ - # Add response time to metrics - assistant_message.metrics["time"] = response_timer.elapsed - if "response_times" not in self.metrics: - self.metrics["response_times"] = [] - self.metrics["response_times"].append(response_timer.elapsed) - - # Add token usage to metrics - usage = parsed_response.get("usage", {}) - prompt_tokens = usage.get("inputTokens") - completion_tokens = usage.get("outputTokens") - total_tokens = usage.get("totalTokens") - - if prompt_tokens is not None: - assistant_message.metrics["prompt_tokens"] = prompt_tokens - self.metrics["prompt_tokens"] = self.metrics.get("prompt_tokens", 0) + prompt_tokens - - if completion_tokens is not None: - assistant_message.metrics["completion_tokens"] = completion_tokens - self.metrics["completion_tokens"] = self.metrics.get("completion_tokens", 0) + completion_tokens - - if total_tokens is not None: - assistant_message.metrics["total_tokens"] = total_tokens - self.metrics["total_tokens"] = self.metrics.get("total_tokens", 0) + total_tokens - - def response(self, messages: List[Message]) -> ModelResponse: - """ - Generate a response from the Bedrock API. - - Args: - messages (List[Message]): The messages to include in the request. - - Returns: - ModelResponse: The response from the Bedrock API. - """ - logger.debug("---------- Bedrock Response Start ----------") - self._log_messages(messages) - model_response = ModelResponse() - - # Invoke the Bedrock API - response_timer = Timer() - response_timer.start() - body = self.get_request_body(messages) - response: Dict[str, Any] = self.invoke(body=body) - response_timer.stop() - - # Parse response - parsed_response = self.parse_response_message(response) - logger.debug(f"Parsed response: {parsed_response}") - stop_reason = parsed_response["stop_reason"] - - # Create assistant message - assistant_message = self.create_assistant_message(parsed_response) - - # Update usage metrics using the new function - self._update_metrics(assistant_message, parsed_response, response_timer) - - # Add assistant message to messages - messages.append(assistant_message) - assistant_message.log() - - # Create tool calls if needed - tool_ids, tool_calls = self._create_tool_calls(stop_reason, parsed_response) - - # Handle tool calls - if stop_reason == "tool_use" and tool_calls: - assistant_message.content = parsed_response["tool_requests"][0]["text"] - assistant_message.tool_calls = tool_calls - - # Run tool calls - if self._handle_tool_calls(assistant_message, messages, model_response, tool_ids): - response_after_tool_calls = self.response(messages=messages) - if response_after_tool_calls.content is not None: - if model_response.content is None: - model_response.content = "" - model_response.content += response_after_tool_calls.content - return model_response - - # Add assistant message content to model response - if assistant_message.content is not None: - model_response.content = assistant_message.get_content_string() - - logger.debug("---------- AWS Response End ----------") - return model_response - - def _handle_stream_tool_calls(self, assistant_message: Message, messages: List[Message], tool_ids: List[str]): - """ - Handle tool calls in the assistant message. - - Args: - assistant_message (Message): The assistant message. - messages (List[Message]): The list of messages. - tool_ids (List[str]): The list of tool IDs. - """ - tool_role: str = "tool" - function_calls_to_run: List[Any] = [] - function_call_results: List[Message] = [] - for tool_call in assistant_message.tool_calls or []: - _tool_call_id = tool_call.get("id") - _function_call = get_function_call_for_tool_call(tool_call, self._functions) - if _function_call is None: - messages.append( - Message( - role="tool", - tool_call_id=_tool_call_id, - content="Could not find function to call.", - ) - ) - continue - if _function_call.error is not None: - messages.append( - Message( - role="tool", - tool_call_id=_tool_call_id, - content=_function_call.error, - ) - ) - continue - function_calls_to_run.append(_function_call) - - if self.show_tool_calls: - yield ModelResponse(content="\nRunning:") - for _f in function_calls_to_run: - yield ModelResponse(content=f"\n - {_f.get_call_str()}") - yield ModelResponse(content="\n\n") - - for _ in self.run_function_calls( - function_calls=function_calls_to_run, function_call_results=function_call_results, tool_role=tool_role - ): - pass - - if len(function_call_results) > 0: - fc_responses: List = [] - - for _fc_message_index, _fc_message in enumerate(function_call_results): - tool_result = { - "toolUseId": tool_ids[_fc_message_index], - "content": [{"json": json.dumps(_fc_message.content)}], - } - tool_result_message = {"role": "user", "content": json.dumps([{"toolResult": tool_result}])} - fc_responses.append(tool_result_message) - - logger.debug(f"Tool call responses: {fc_responses}") - messages.append(Message(role="user", content=json.dumps(fc_responses))) - - def _update_stream_metrics(self, stream_data: StreamData, assistant_message: Message): - """ - Update the metrics for the streaming response. - - Args: - stream_data (StreamData): The streaming data - assistant_message (Message): The assistant message. - """ - assistant_message.metrics["time"] = stream_data.response_timer.elapsed - if stream_data.time_to_first_token is not None: - assistant_message.metrics["time_to_first_token"] = stream_data.time_to_first_token - - if "response_times" not in self.metrics: - self.metrics["response_times"] = [] - self.metrics["response_times"].append(stream_data.response_timer.elapsed) - if stream_data.time_to_first_token is not None: - if "time_to_first_token" not in self.metrics: - self.metrics["time_to_first_token"] = [] - self.metrics["time_to_first_token"].append(stream_data.time_to_first_token) - if stream_data.completion_tokens > 0: - if "tokens_per_second" not in self.metrics: - self.metrics["tokens_per_second"] = [] - self.metrics["tokens_per_second"].append( - f"{stream_data.completion_tokens / stream_data.response_timer.elapsed:.4f}" - ) - - assistant_message.metrics["prompt_tokens"] = stream_data.response_prompt_tokens - assistant_message.metrics["input_tokens"] = stream_data.response_prompt_tokens - self.metrics["prompt_tokens"] = self.metrics.get("prompt_tokens", 0) + stream_data.response_prompt_tokens - self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + stream_data.response_prompt_tokens - - assistant_message.metrics["completion_tokens"] = stream_data.response_completion_tokens - assistant_message.metrics["output_tokens"] = stream_data.response_completion_tokens - self.metrics["completion_tokens"] = ( - self.metrics.get("completion_tokens", 0) + stream_data.response_completion_tokens - ) - self.metrics["output_tokens"] = self.metrics.get("output_tokens", 0) + stream_data.response_completion_tokens - - assistant_message.metrics["total_tokens"] = stream_data.response_total_tokens - self.metrics["total_tokens"] = self.metrics.get("total_tokens", 0) + stream_data.response_total_tokens - - def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: - """ - Stream the response from the Bedrock API. - - Args: - messages (List[Message]): The messages to include in the request. - - Returns: - Iterator[str]: The streamed response. - """ - logger.debug("---------- Bedrock Response Start ----------") - self._log_messages(messages) - - stream_data: StreamData = StreamData() - stream_data.response_timer.start() - - tool_use: Dict[str, Any] = {} - tool_ids: List[str] = [] - tool_calls: List[Dict[str, Any]] = [] - stop_reason: Optional[str] = None - content: List[Dict[str, Any]] = [] - - request_body = self.get_request_body(messages) - response = self.invoke_stream(body=request_body) - - # Process the streaming response - for chunk in response: - if "contentBlockStart" in chunk: - tool = chunk["contentBlockStart"]["start"].get("toolUse") - if tool: - tool_use["toolUseId"] = tool["toolUseId"] - tool_use["name"] = tool["name"] - - elif "contentBlockDelta" in chunk: - delta = chunk["contentBlockDelta"]["delta"] - if "toolUse" in delta: - if "input" not in tool_use: - tool_use["input"] = "" - tool_use["input"] += delta["toolUse"]["input"] - elif "text" in delta: - stream_data.response_content += delta["text"] - stream_data.completion_tokens += 1 - if stream_data.completion_tokens == 1: - stream_data.time_to_first_token = stream_data.response_timer.elapsed - logger.debug(f"Time to first token: {stream_data.time_to_first_token:.4f}s") - yield ModelResponse(content=delta["text"]) # Yield text content as it's received - - elif "contentBlockStop" in chunk: - if "input" in tool_use: - # Finish collecting tool use input - try: - tool_use["input"] = json.loads(tool_use["input"]) - except json.JSONDecodeError as e: - logger.error(f"Failed to parse tool input as JSON: {e}") - tool_use["input"] = {} - content.append({"toolUse": tool_use}) - tool_ids.append(tool_use["toolUseId"]) - # Prepare the tool call - tool_call = { - "type": "function", - "function": { - "name": tool_use["name"], - "arguments": json.dumps(tool_use["input"]), - }, - } - tool_calls.append(tool_call) - tool_use = {} - else: - # Finish collecting text content - content.append({"text": stream_data.response_content}) - - elif "messageStop" in chunk: - stop_reason = chunk["messageStop"]["stopReason"] - logger.debug(f"Stop reason: {stop_reason}") - - elif "metadata" in chunk: - metadata = chunk["metadata"] - if "usage" in metadata: - stream_data.response_prompt_tokens = metadata["usage"]["inputTokens"] - stream_data.response_total_tokens = metadata["usage"]["totalTokens"] - stream_data.completion_tokens = metadata["usage"]["outputTokens"] - - stream_data.response_timer.stop() - - # Create assistant message - if stream_data.response_content != "": - assistant_message = Message(role="assistant", content=stream_data.response_content, tool_calls=tool_calls) - - if stream_data.completion_tokens > 0: - logger.debug( - f"Time per output token: {stream_data.response_timer.elapsed / stream_data.completion_tokens:.4f}s" - ) - logger.debug( - f"Throughput: {stream_data.completion_tokens / stream_data.response_timer.elapsed:.4f} tokens/s" - ) - - # Update metrics - self._update_stream_metrics(stream_data, assistant_message) - - # Add assistant message to messages - messages.append(assistant_message) - assistant_message.log() - - # Handle tool calls if any - if tool_calls: - yield from self._handle_stream_tool_calls(assistant_message, messages, tool_ids) - yield from self.response_stream(messages=messages) - - logger.debug("---------- Bedrock Response End ----------") - async def ainvoke(self, *args, **kwargs) -> Any: raise NotImplementedError(f"Async not supported on {self.name}.") async def ainvoke_stream(self, *args, **kwargs) -> Any: raise NotImplementedError(f"Async not supported on {self.name}.") - async def aresponse(self, messages: List[Message]) -> ModelResponse: - raise NotImplementedError(f"Async not supported on {self.name}.") - - async def aresponse_stream(self, messages: List[Message]) -> ModelResponse: - raise NotImplementedError(f"Async not supported on {self.name}.") + def parse_model_provider_response_stream( + self, response: Any + ) -> Iterator[ModelProviderResponse]: + pass \ No newline at end of file diff --git a/libs/agno/agno/models/aws/claude.py b/libs/agno/agno/models/aws/claude.py index 86677396c..a31de6f2d 100644 --- a/libs/agno/agno/models/aws/claude.py +++ b/libs/agno/agno/models/aws/claude.py @@ -1,9 +1,18 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional +import json +from typing import Any, Dict, List, Optional, Iterator from agno.models.aws.bedrock import AwsBedrock +from agno.models.base import MessageData from agno.models.message import Message +from agno.models.response import ModelProviderResponse, ModelResponse +from agno.utils.log import logger +@dataclass +class BedrockResponseUsage: + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 @dataclass class Claude(AwsBedrock): @@ -67,13 +76,10 @@ def api_kwargs(self) -> Dict[str, Any]: _request_params.update(self.request_params) return _request_params - def get_tools(self) -> Optional[Dict[str, Any]]: + def _format_tools(self) -> Optional[Dict[str, Any]]: """ Refactors the tools in a format accepted by the Bedrock API. """ - if not self._functions: - return None - tools = [] for f_name, function in self._functions.items(): properties = {} @@ -106,9 +112,9 @@ def get_tools(self) -> Optional[Dict[str, Any]]: return {"tools": tools} - def get_request_body(self, messages: List[Message]) -> Dict[str, Any]: + def format_messages(self, messages: List[Message]) -> Dict[str, Any]: """ - Get the request body for the Bedrock API. + Create the request body for the Bedrock API. Args: messages (List[Message]): The messages to include in the request. @@ -145,13 +151,13 @@ def get_request_body(self, messages: List[Message]) -> Dict[str, Any]: if inference_config: request_body["inferenceConfig"] = inference_config # type: ignore - if self.tools: - tools = self.get_tools() + if self._functions: + tools = self._format_tools() request_body["toolConfig"] = tools # type: ignore return request_body - def parse_response_message(self, response: Dict[str, Any]) -> Dict[str, Any]: + def parse_model_provider_response(self, response: Dict[str, Any]) -> ModelProviderResponse: """ Parse the response from the Bedrock API. @@ -159,70 +165,168 @@ def parse_response_message(self, response: Dict[str, Any]) -> Dict[str, Any]: response (Dict[str, Any]): The response from the Bedrock API. Returns: - Dict[str, Any]: The parsed response. + ModelProviderResponse: The parsed response. """ - res = {} + provider_response = ModelProviderResponse() + + # Extract message from output if "output" in response and "message" in response["output"]: message = response["output"]["message"] - role = message.get("role") - content = message.get("content", []) - - # Extract text content if it's a list of dictionaries - if isinstance(content, list) and content and isinstance(content[0], dict): - content = [item.get("text", "") for item in content if "text" in item] - content = "\n".join(content) # Join multiple text items if present - - res = { - "content": content, - "usage": { - "inputTokens": response.get("usage", {}).get("inputTokens"), - "outputTokens": response.get("usage", {}).get("outputTokens"), - "totalTokens": response.get("usage", {}).get("totalTokens"), - }, - "metrics": {"latencyMs": response.get("metrics", {}).get("latencyMs")}, - "role": role, - } + # Add role + if "role" in message: + provider_response.role = message["role"] + + # Extract and join text content from content list + if "content" in message: + content = message["content"] + if isinstance(content, list) and content: + text_content = [item.get("text", "") for item in content if "text" in item] + provider_response.content = "\n".join(text_content) + + # Add usage metrics + if "usage" in response: + # This ensures that the usage can be parsed upstream + provider_response.response_usage = BedrockResponseUsage( + input_tokens=response.get("usage", {}).get("inputTokens", 0), + output_tokens=response.get("usage", {}).get("outputTokens", 0), + total_tokens=response.get("usage", {}).get("totalTokens", 0), + ) + + # If we have a stop reason, it works a bit differently stop_reason = None if "stopReason" in response: stop_reason = response["stopReason"] - res["stop_reason"] = stop_reason if stop_reason else None - res["tool_requests"] = None - - if stop_reason == "tool_use": + if stop_reason and stop_reason == "tool_use": tool_requests = response["output"]["message"]["content"] - res["tool_requests"] = tool_requests - - return res - def create_assistant_message(self, parsed_response: Dict[str, Any]) -> Message: + tool_ids = [] + tool_calls = [] + if tool_requests: + for tool in tool_requests: + if "toolUse" in tool: + tool_use = tool["toolUse"] + tool_id = tool_use["toolUseId"] + tool_name = tool_use["name"] + tool_args = tool_use["input"] + + tool_ids.append(tool_id) + tool_calls.append( + { + "type": "function", + "function": { + "name": tool_name, + "arguments": json.dumps(tool_args), + }, + } + ) + if tool_calls: + provider_response.tool_calls = tool_calls + if tool_requests: + provider_response.content = tool_requests[0]["text"] + provider_response.extra["tool_ids"] = tool_ids + + return provider_response + + # Override the base class method + def format_function_call_results(self, messages: List[Message], function_call_results: List[Message], tool_ids: List[str]) -> None: """ - Create an assistant message from the parsed response. - - Args: - parsed_response (Dict[str, Any]): The parsed response from the Bedrock API. - - Returns: - Message: The assistant message. + Format function call results. """ + if len(function_call_results) > 0: + fc_responses: List = [] - return Message( - role=parsed_response["role"], - content=parsed_response["content"], - metrics=parsed_response["metrics"], - ) + for _fc_message_index, _fc_message in enumerate(function_call_results): + tool_result = { + "toolUseId": tool_ids[_fc_message_index], + "content": [{"json": json.dumps(_fc_message.content)}], + } + tool_result_message = {"role": "user", "content": json.dumps([{"toolResult": tool_result}])} + fc_responses.append(tool_result_message) - def parse_response_delta(self, response: Dict[str, Any]) -> Optional[str]: - """ - Parse the response delta from the Bedrock API. + logger.debug(f"Tool call responses: {fc_responses}") + messages.append(Message(role="user", content=json.dumps(fc_responses))) - Args: - response (Dict[str, Any]): The response from the Bedrock API. - Returns: - Optional[str]: The response delta. + # Override the base class method + def process_response_stream(self, messages: List[Message], assistant_message: Message, stream_data: MessageData) -> Iterator[ModelResponse]: + """ + Process the streaming response from the Bedrock API. """ - if "delta" in response: - return response.get("delta", {}).get("text") - return response.get("completion") + + tool_use: Dict[str, Any] = {} + tool_ids: List[str] = [] + tool_calls: List[Dict[str, Any]] = [] + content: List[Dict[str, Any]] = [] + + # Process the streaming response + for chunk in self.invoke_stream(messages=messages): + if "contentBlockStart" in chunk: + tool = chunk["contentBlockStart"]["start"].get("toolUse") + if tool: + tool_use["toolUseId"] = tool["toolUseId"] + tool_use["name"] = tool["name"] + + elif "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "toolUse" in delta: + if "input" not in tool_use: + tool_use["input"] = "" + tool_use["input"] += delta["toolUse"]["input"] + elif "text" in delta: + # Update metrics + assistant_message.metrics.completion_tokens += 1 + if not assistant_message.metrics.time_to_first_token: + assistant_message.metrics.set_time_to_first_token() + + # Update provider response content + stream_data.response_content += delta["text"] + yield ModelResponse(content=delta["text"]) # Yield text content as it's received + + elif "contentBlockStop" in chunk: + if "input" in tool_use: + # Finish collecting tool use input + try: + tool_use["input"] = json.loads(tool_use["input"]) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse tool input as JSON: {e}") + tool_use["input"] = {} + content.append({"toolUse": tool_use}) + tool_ids.append(tool_use["toolUseId"]) + # Prepare the tool call + tool_call = { + "id": tool_use["toolUseId"], + "type": "function", + "function": { + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + tool_calls.append(tool_call) + # Reset tool use + tool_use = {} + else: + # Finish collecting text content + content.append({"text": stream_data.response_content}) + + elif "metadata" in chunk: + metadata = chunk["metadata"] + if "usage" in metadata: + response_usage = BedrockResponseUsage( + input_tokens=metadata["usage"]["inputTokens"], + output_tokens=metadata["usage"]["outputTokens"], + total_tokens=metadata["usage"]["totalTokens"], + ) + + # Update metrics + self.add_usage_metrics_to_assistant_message( + assistant_message=assistant_message, + response_usage=response_usage + ) + + if tool_ids: + stream_data.extra["tool_ids"] = tool_ids + + if tool_calls: + stream_data.response_tool_calls = tool_calls diff --git a/libs/agno/agno/models/base.py b/libs/agno/agno/models/base.py index b1d629f14..8f65211e2 100644 --- a/libs/agno/agno/models/base.py +++ b/libs/agno/agno/models/base.py @@ -596,9 +596,6 @@ async def ahandle_tool_calls( self.format_function_call_results(messages=messages, function_call_results=function_call_results, **kwargs) - if len(function_call_results) > 0: - messages.extend(function_call_results) - return model_response return None @@ -802,145 +799,6 @@ async def ahandle_post_tool_call_messages_stream(self, messages: List[Message]) async for model_response in self.aresponse_stream(messages=messages): # type: ignore yield model_response - def _process_image_url(self, image_url: str) -> Dict[str, Any]: - """Process image (base64 or URL).""" - - if image_url.startswith("data:image") or image_url.startswith(("http://", "https://")): - return {"type": "image_url", "image_url": {"url": image_url}} - else: - raise ValueError("Image URL must start with 'data:image' or 'http(s)://'.") - - def _process_image_path(self, image_path: Union[Path, str]) -> Dict[str, Any]: - """Process image ( file path).""" - # Process local file image - import base64 - import mimetypes - - path = image_path if isinstance(image_path, Path) else Path(image_path) - if not path.exists(): - raise FileNotFoundError(f"Image file not found: {image_path}") - - mime_type = mimetypes.guess_type(image_path)[0] or "image/jpeg" - with open(path, "rb") as image_file: - base64_image = base64.b64encode(image_file.read()).decode("utf-8") - image_url = f"data:{mime_type};base64,{base64_image}" - return {"type": "image_url", "image_url": {"url": image_url}} - - def _process_bytes_image(self, image: bytes) -> Dict[str, Any]: - """Process bytes image data.""" - import base64 - - base64_image = base64.b64encode(image).decode("utf-8") - image_url = f"data:image/jpeg;base64,{base64_image}" - return {"type": "image_url", "image_url": {"url": image_url}} - - def _process_image(self, image: Image) -> Optional[Dict[str, Any]]: - """Process an image based on the format.""" - - if image.url is not None: - image_payload = self._process_image_url(image.url) - - elif image.filepath is not None: - image_payload = self._process_image_path(image.filepath) - - elif image.content is not None: - image_payload = self._process_bytes_image(image.content) - - else: - logger.warning(f"Unsupported image type: {type(image)}") - return None - - if image.detail: - image_payload["image_url"]["detail"] = image.detail - - return image_payload - - def add_images_to_message(self, message: Message, images: Sequence[Image]) -> Message: - """ - Add images to a message for the model. By default, we use the OpenAI image format but other Models - can override this method to use a different image format. - - Args: - message: The message for the Model - images: Sequence of images in various formats: - - str: base64 encoded image, URL, or file path - - Dict: pre-formatted image data - - bytes: raw image data - - Returns: - Message content with images added in the format expected by the model - """ - # If no images are provided, return the message as is - if len(images) == 0: - return message - - # Ignore non-string message content - # because we assume that the images/audio are already added to the message - if not isinstance(message.content, str): - return message - - # Create a default message content with text - message_content_with_image: List[Dict[str, Any]] = [{"type": "text", "text": message.content}] - - # Add images to the message content - for image in images: - try: - image_data = self._process_image(image) - if image_data: - message_content_with_image.append(image_data) - except Exception as e: - logger.error(f"Failed to process image: {str(e)}") - continue - - # Update the message content with the images - message.content = message_content_with_image - return message - - @staticmethod - def add_audio_to_message(message: Message, audio: Sequence[Audio]) -> Message: - """ - Add audio to a message for the model. By default, we use the OpenAI audio format but other Models - can override this method to use a different audio format. - - Args: - message: The message for the Model - audio: Pre-formatted audio data like { - "content": encoded_string, - "format": "wav" - } - - Returns: - Message content with audio added in the format expected by the model - """ - if len(audio) == 0: - return message - - # Create a default message content with text - message_content_with_audio: List[Dict[str, Any]] = [{"type": "text", "text": message.content}] - - for audio_snippet in audio: - # This means the audio is raw data - if audio_snippet.content: - import base64 - - encoded_string = base64.b64encode(audio_snippet.content).decode("utf-8") - - # Create a message with audio - message_content_with_audio.append( - { - "type": "input_audio", - "input_audio": { - "data": encoded_string, - "format": audio_snippet.format, - }, - }, - ) - - # Update the message content with the audio - message.content = message_content_with_audio - message.audio = None # The message should not have an audio component after this - - return message def get_system_message_for_model(self) -> Optional[str]: return self.system_prompt @@ -1200,25 +1058,10 @@ async def aresponse(self, messages: List[Message]) -> ModelResponse: logger.debug(f"---------- {self.get_provider()} Async Response End ----------") return model_response - def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: + def process_response_stream(self, messages: List[Message], assistant_message: Message, stream_data: MessageData) -> Iterator[ModelResponse]: """ - Generate a streaming response from the model. - - Args: - messages: List of messages in the conversation - - Returns: - Iterator[ModelResponse]: Iterator of model responses + Process a streaming response from the model. """ - logger.debug(f"---------- {self.get_provider()} Response Stream Start ----------") - self._log_messages(messages) - stream_data: MessageData = MessageData() - - # Create assistant message - assistant_message = Message(role=self.assistant_message_role) - - # Generate response - assistant_message.metrics.start_timer() for response in self.invoke_stream(messages=messages): # Parse provider response for provider_response in self.parse_model_provider_response_stream(response): @@ -1254,6 +1097,26 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: response_usage=provider_response.response_usage ) + def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: + """ + Generate a streaming response from the model. + + Args: + messages: List of messages in the conversation + + Returns: + Iterator[ModelResponse]: Iterator of model responses + """ + logger.debug(f"---------- {self.get_provider()} Response Stream Start ----------") + self._log_messages(messages) + stream_data: MessageData = MessageData() + + # Create assistant message + assistant_message = Message(role=self.assistant_message_role) + + # Generate response + assistant_message.metrics.start_timer() + yield from self.process_response_stream(messages=messages, assistant_message=assistant_message, stream_data=stream_data) assistant_message.metrics.stop_timer() # Add response content and audio to assistant message @@ -1286,25 +1149,10 @@ def response_stream(self, messages: List[Message]) -> Iterator[ModelResponse]: logger.debug(f"---------- {self.get_provider()} Response Stream End ----------") - async def aresponse_stream(self, messages: List[Message]) -> Any: + async def aprocess_response_stream(self, messages: List[Message], assistant_message: Message, stream_data: MessageData) -> AsyncIterator[ModelResponse]: """ - Generate an asynchronous streaming response from the model. - - Args: - messages: List of messages in the conversation - - Returns: - Any: Async iterator of model responses + Process a streaming response from the model. """ - logger.debug(f"---------- {self.get_provider()} Async Response Stream Start ----------") - self._log_messages(messages) - stream_data = MessageData() - - # Create assistant message - assistant_message = Message(role=self.assistant_message_role) - - # Generate response - assistant_message.metrics.start_timer() async for response in await self.ainvoke_stream(messages=messages): # Parse provider response for provider_response in self.parse_model_provider_response_stream(response): @@ -1340,6 +1188,27 @@ async def aresponse_stream(self, messages: List[Message]) -> Any: response_usage=provider_response.response_usage ) + async def aresponse_stream(self, messages: List[Message]) -> Any: + """ + Generate an asynchronous streaming response from the model. + + Args: + messages: List of messages in the conversation + + Returns: + Any: Async iterator of model responses + """ + logger.debug(f"---------- {self.get_provider()} Async Response Stream Start ----------") + self._log_messages(messages) + stream_data = MessageData() + + # Create assistant message + assistant_message = Message(role=self.assistant_message_role) + + # Generate response + assistant_message.metrics.start_timer() + async for response in self.aprocess_response_stream(messages=messages, assistant_message=assistant_message, stream_data=stream_data): + yield response assistant_message.metrics.stop_timer() # Add response content and audio to assistant message @@ -1351,9 +1220,9 @@ async def aresponse_stream(self, messages: List[Message]) -> Any: # Add tool calls to assistant message if stream_data.response_tool_calls is not None and len(stream_data.response_tool_calls) > 0: - _tool_calls = self.build_tool_calls(stream_data.response_tool_calls) - if len(_tool_calls) > 0: - assistant_message.tool_calls = _tool_calls + parsed_tool_calls = self.parse_tool_calls(stream_data.response_tool_calls) + if len(parsed_tool_calls) > 0: + assistant_message.tool_calls = parsed_tool_calls # Add assistant message to messages messages.append(assistant_message) diff --git a/libs/agno/agno/models/groq/groq.py b/libs/agno/agno/models/groq/groq.py index 61731587f..98b19e619 100644 --- a/libs/agno/agno/models/groq/groq.py +++ b/libs/agno/agno/models/groq/groq.py @@ -8,6 +8,7 @@ from agno.models.message import Message from agno.models.response import ModelProviderResponse from agno.utils.log import logger +from agno.utils.openai import add_images_to_message try: from groq import AsyncGroq as AsyncGroqClient @@ -18,6 +19,23 @@ raise ImportError("`groq` not installed. Please install using `pip install groq`") +def format_message(message: Message) -> Dict[str, Any]: + """ + Format a message into the format expected by Groq. + + Args: + message (Message): The message to format. + + Returns: + Dict[str, Any]: The formatted message. + """ + if message.role == "user": + if message.images is not None: + message = add_images_to_message(message=message, images=message.images) + + return message.to_dict() + + @dataclass class Groq(Model): """ @@ -193,22 +211,6 @@ def to_dict(self) -> Dict[str, Any]: cleaned_dict = {k: v for k, v in model_dict.items() if v is not None} return cleaned_dict - def format_message(self, message: Message) -> Dict[str, Any]: - """ - Format a message into the format expected by Groq. - - Args: - message (Message): The message to format. - - Returns: - Dict[str, Any]: The formatted message. - """ - if message.role == "user": - if message.images is not None: - message = self.add_images_to_message(message=message, images=message.images) - - return message.to_dict() - def invoke(self, messages: List[Message]) -> ChatCompletion: """ Send a chat completion request to the Groq API. @@ -221,7 +223,7 @@ def invoke(self, messages: List[Message]) -> ChatCompletion: """ return self.get_client().chat.completions.create( model=self.id, - messages=[self.format_message(m) for m in messages], # type: ignore + messages=[format_message(m) for m in messages], # type: ignore **self.request_kwargs, ) @@ -237,7 +239,7 @@ async def ainvoke(self, messages: List[Message]) -> ChatCompletion: """ return await self.get_async_client().chat.completions.create( model=self.id, - messages=[self.format_message(m) for m in messages], # type: ignore + messages=[format_message(m) for m in messages], # type: ignore **self.request_kwargs, ) @@ -253,7 +255,7 @@ def invoke_stream(self, messages: List[Message]) -> Iterator[ChatCompletionChunk """ return self.get_client().chat.completions.create( model=self.id, - messages=[self.format_message(m) for m in messages], # type: ignore + messages=[format_message(m) for m in messages], # type: ignore stream=True, **self.request_kwargs, ) @@ -270,7 +272,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any: """ return await self.get_async_client().chat.completions.create( model=self.id, - messages=[self.format_message(m) for m in messages], # type: ignore + messages=[format_message(m) for m in messages], # type: ignore stream=True, **self.request_kwargs, ) @@ -316,7 +318,7 @@ def parse_tool_calls(tool_calls_data: List[ChoiceDeltaToolCall]) -> List[Dict[st if _tool_call_type: tool_call_entry["type"] = _tool_call_type return tool_calls - + def parse_model_provider_response( self, response: ChatCompletion ) -> ModelProviderResponse: diff --git a/libs/agno/agno/models/openai/chat.py b/libs/agno/agno/models/openai/chat.py index 0711d8afa..e469c6615 100644 --- a/libs/agno/agno/models/openai/chat.py +++ b/libs/agno/agno/models/openai/chat.py @@ -10,6 +10,7 @@ from agno.models.message import Message from agno.models.response import ModelProviderResponse from agno.utils.log import logger +from agno.utils.openai import add_images_to_message, add_audio_to_message try: from openai import AsyncOpenAI as AsyncOpenAIClient @@ -228,6 +229,7 @@ def to_dict(self) -> Dict[str, Any]: cleaned_dict = {k: v for k, v in model_dict.items() if v is not None} return cleaned_dict + def format_message(self, message: Message) -> Dict[str, Any]: """ Format a message into the format expected by OpenAI. @@ -240,10 +242,10 @@ def format_message(self, message: Message) -> Dict[str, Any]: """ if message.role == "user": if message.images is not None: - message = self.add_images_to_message(message=message, images=message.images) + message = add_images_to_message(message=message, images=message.images) if message.audio is not None: - message = self.add_audio_to_message(message=message, audio=message.audio) + message = add_audio_to_message(message=message, audio=message.audio) if message.videos is not None: logger.warning("Video input is currently unsupported.") diff --git a/libs/agno/agno/utils/openai.py b/libs/agno/agno/utils/openai.py new file mode 100644 index 000000000..c1f8e370d --- /dev/null +++ b/libs/agno/agno/utils/openai.py @@ -0,0 +1,151 @@ +from pathlib import Path +from typing import Sequence, Any, Dict, Optional, List, Union + +from agno.media import Image, Audio +from agno.models.message import Message +from agno.utils.log import logger + + +def add_audio_to_message(message: Message, audio: Sequence[Audio]) -> Message: + """ + Add audio to a message for the model. By default, we use the OpenAI audio format but other Models + can override this method to use a different audio format. + + Args: + message: The message for the Model + audio: Pre-formatted audio data like { + "content": encoded_string, + "format": "wav" + } + + Returns: + Message content with audio added in the format expected by the model + """ + if len(audio) == 0: + return message + + # Create a default message content with text + message_content_with_audio: List[Dict[str, Any]] = [{"type": "text", "text": message.content}] + + for audio_snippet in audio: + # This means the audio is raw data + if audio_snippet.content: + import base64 + + encoded_string = base64.b64encode(audio_snippet.content).decode("utf-8") + + # Create a message with audio + message_content_with_audio.append( + { + "type": "input_audio", + "input_audio": { + "data": encoded_string, + "format": audio_snippet.format, + }, + }, + ) + + # Update the message content with the audio + message.content = message_content_with_audio + message.audio = None # The message should not have an audio component after this + + return message + + +def _process_bytes_image(image: bytes) -> Dict[str, Any]: + """Process bytes image data.""" + import base64 + + base64_image = base64.b64encode(image).decode("utf-8") + image_url = f"data:image/jpeg;base64,{base64_image}" + return {"type": "image_url", "image_url": {"url": image_url}} + + +def _process_image_path(image_path: Union[Path, str]) -> Dict[str, Any]: + """Process image ( file path).""" + # Process local file image + import base64 + import mimetypes + + path = image_path if isinstance(image_path, Path) else Path(image_path) + if not path.exists(): + raise FileNotFoundError(f"Image file not found: {image_path}") + + mime_type = mimetypes.guess_type(image_path)[0] or "image/jpeg" + with open(path, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode("utf-8") + image_url = f"data:{mime_type};base64,{base64_image}" + return {"type": "image_url", "image_url": {"url": image_url}} + + +def _process_image_url(image_url: str) -> Dict[str, Any]: + """Process image (base64 or URL).""" + + if image_url.startswith("data:image") or image_url.startswith(("http://", "https://")): + return {"type": "image_url", "image_url": {"url": image_url}} + else: + raise ValueError("Image URL must start with 'data:image' or 'http(s)://'.") + + +def _process_image(image: Image) -> Optional[Dict[str, Any]]: + """Process an image based on the format.""" + + if image.url is not None: + image_payload = _process_image_url(image.url) + + elif image.filepath is not None: + image_payload = _process_image_path(image.filepath) + + elif image.content is not None: + image_payload = _process_bytes_image(image.content) + + else: + logger.warning(f"Unsupported image type: {type(image)}") + return None + + if image.detail: + image_payload["image_url"]["detail"] = image.detail + + return image_payload + + +def add_images_to_message(message: Message, images: Sequence[Image]) -> Message: + """ + Add images to a message for the model. By default, we use the OpenAI image format but other Models + can override this method to use a different image format. + + Args: + message: The message for the Model + images: Sequence of images in various formats: + - str: base64 encoded image, URL, or file path + - Dict: pre-formatted image data + - bytes: raw image data + + Returns: + Message content with images added in the format expected by the model + """ + # If no images are provided, return the message as is + if len(images) == 0: + return message + + # Ignore non-string message content + # because we assume that the images/audio are already added to the message + if not isinstance(message.content, str): + return message + + # Create a default message content with text + message_content_with_image: List[Dict[str, Any]] = [{"type": "text", "text": message.content}] + + # Add images to the message content + for image in images: + try: + image_data = _process_image(image) + if image_data: + message_content_with_image.append(image_data) + except Exception as e: + logger.error(f"Failed to process image: {str(e)}") + continue + + # Update the message content with the images + message.content = message_content_with_image + return message