Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral vision support #1956

Merged
merged 12 commits into from
Feb 3, 2025
7 changes: 7 additions & 0 deletions cookbook/models/mistral/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
*.jpg
*.jpeg
*.png
*.mp3
*.wav
*.mp4
*.mp3
36 changes: 36 additions & 0 deletions cookbook/models/mistral/image_bytes_input_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import requests
from agno.agent import Agent
from agno.media import Image
from agno.models.mistral.mistral import MistralChat
from agno.tools.duckduckgo import DuckDuckGoTools

agent = Agent(
model=MistralChat(id="pixtral-12b-2409"),
show_tool_calls=True,
markdown=True,
)

image_url = (
"https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"
)


def fetch_image_bytes(url: str) -> bytes:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.9",
}
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.content


image_bytes_from_url = fetch_image_bytes(image_url)

agent.print_response(
"Tell me about this image.",
images=[
Image(content=image_bytes_from_url),
],
)
21 changes: 21 additions & 0 deletions cookbook/models/mistral/image_compare_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from agno.agent import Agent
from agno.media import Image
from agno.models.mistral.mistral import MistralChat

agent = Agent(
model=MistralChat(id="pixtral-12b-2409"),
markdown=True,
)

agent.print_response(
"what are the differences between two images?",
images=[
Image(
url="https://tripfixers.com/wp-content/uploads/2019/11/eiffel-tower-with-snow.jpeg"
),
Image(
url="https://assets.visitorscoverage.com/production/wp-content/uploads/2024/04/AdobeStock_626542468-min-1024x683.jpeg"
),
],
stream=True,
)
25 changes: 25 additions & 0 deletions cookbook/models/mistral/image_file_input_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

from agno.agent import Agent
from agno.media import Image
from agno.models.mistral.mistral import MistralChat
from agno.tools.duckduckgo import DuckDuckGoTools

agent = Agent(
model=MistralChat(id="pixtral-12b-2409"),
tools=[
DuckDuckGoTools()
], # pixtral-12b-2409 is not so great at tool calls, but it might work.
show_tool_calls=True,
markdown=True,
)

image_path = Path(__file__).parent.joinpath("sample.jpeg")

agent.print_response(
"Tell me about this image and give me the latest news about it from duckduckgo.",
images=[
Image(filepath=image_path),
],
stream=True,
)
32 changes: 32 additions & 0 deletions cookbook/models/mistral/image_ocr_with_structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List

from agno.agent import Agent
from agno.media import Image
from agno.models.mistral.mistral import MistralChat
from pydantic import BaseModel


class GroceryItem(BaseModel):
item_name: str
price: float


class GroceryListElements(BaseModel):
bill_number: str
items: List[GroceryItem]
total_price: float


agent = Agent(
model=MistralChat(id="pixtral-12b-2409"),
instructions=[
"Extract the text elements described by the user from the picture",
],
response_model=GroceryListElements,
markdown=True,
)

agent.print_response(
"From this restaurant bill, extract the bill number, item names and associated prices, and total price and return it as a string in a Json object",
images=[Image(url="https://i.imghippo.com/files/kgXi81726851246.jpg")],
)
19 changes: 19 additions & 0 deletions cookbook/models/mistral/image_transcribe_document_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
This agent transcribes an old written document from an image.
"""

from agno.agent import Agent
from agno.media import Image
from agno.models.mistral.mistral import MistralChat

agent = Agent(
model=MistralChat(id="pixtral-12b-2409"),
markdown=True,
)

agent.print_response(
"Transcribe this document.",
images=[
Image(url="https://ciir.cs.umass.edu/irdemo/hw-demo/page_example.jpg"),
],
)
11 changes: 9 additions & 2 deletions cookbook/tools/exa_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
from agno.tools.exa import ExaTools

agent = Agent(
tools=[ExaTools(include_domains=["cnbc.com", "reuters.com", "bloomberg.com"], show_results=True)],
tools=[
ExaTools(
include_domains=["cnbc.com", "reuters.com", "bloomberg.com"],
show_results=True,
)
],
show_tool_calls=True,
)

agent.print_response("Search for AAPL news", markdown=True)

agent.print_response("What is the paper at https://arxiv.org/pdf/2307.06435 about?", markdown=True)
agent.print_response(
"What is the paper at https://arxiv.org/pdf/2307.06435 about?", markdown=True
)

agent.print_response(
"Find me similar papers to https://arxiv.org/pdf/2307.06435 and provide a summary of what they contain",
Expand Down
113 changes: 57 additions & 56 deletions libs/agno/agno/models/anthropic/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MessageData:
tool_ids: List[str] = field(default_factory=list)


def format_image_for_message(image: Image) -> Optional[Dict[str, Any]]:
def _format_image_for_message(image: Image) -> Optional[Dict[str, Any]]:
"""
Add an image to a message by converting it to base64 encoded format.
"""
Expand All @@ -51,7 +51,7 @@ def format_image_for_message(image: Image) -> Optional[Dict[str, Any]]:
elif image.filepath is not None:
from pathlib import Path

path = Path(image.filepath)
path = Path(image.filepath) if isinstance(image.filepath, str) else image.filepath
if path.exists() and path.is_file():
with open(image.filepath, "rb") as f:
content_bytes = f.read()
Expand Down Expand Up @@ -91,6 +91,57 @@ def format_image_for_message(image: Image) -> Optional[Dict[str, Any]]:
return None


def _format_messages(messages: List[Message]) -> Tuple[List[Dict[str, str]], str]:
"""
Process the list of messages and separate them into API messages and system messages.

Args:
messages (List[Message]): The list of messages to process.

Returns:
Tuple[List[Dict[str, str]], str]: A tuple containing the list of API messages and the concatenated system messages.
"""
chat_messages: List[Dict[str, str]] = []
system_messages: List[str] = []

for idx, message in enumerate(messages):
content = message.content or ""
if message.role == "system" or (message.role != "user" and idx in [0, 1]):
if content is not None:
system_messages.append(content) # type: ignore
continue
elif message.role == "user":
if isinstance(content, str):
content = [{"type": "text", "text": content}]

if message.images is not None:
for image in message.images:
image_content = _format_image_for_message(image)
if image_content:
content.append(image_content)

# Handle tool calls from history
elif message.role == "assistant" and isinstance(message.content, str) and message.tool_calls:
if message.content:
content = [TextBlock(text=message.content, type="text")]
else:
content = []
for tool_call in message.tool_calls:
content.append(
ToolUseBlock(
id=tool_call["id"],
input=json.loads(tool_call["function"]["arguments"])
if "arguments" in tool_call["function"]
else {},
name=tool_call["function"]["name"],
type="tool_use",
)
)

chat_messages.append({"role": message.role, "content": content}) # type: ignore
return chat_messages, " ".join(system_messages)


@dataclass
class Claude(Model):
"""
Expand Down Expand Up @@ -177,56 +228,6 @@ def request_kwargs(self) -> Dict[str, Any]:
_request_params.update(self.request_params)
return _request_params

def format_messages(self, messages: List[Message]) -> Tuple[List[Dict[str, str]], str]:
"""
Process the list of messages and separate them into API messages and system messages.

Args:
messages (List[Message]): The list of messages to process.

Returns:
Tuple[List[Dict[str, str]], str]: A tuple containing the list of API messages and the concatenated system messages.
"""
chat_messages: List[Dict[str, str]] = []
system_messages: List[str] = []

for idx, message in enumerate(messages):
content = message.content or ""
if message.role == "system" or (message.role != "user" and idx in [0, 1]):
if content is not None:
system_messages.append(content) # type: ignore
continue
elif message.role == "user":
if isinstance(content, str):
content = [{"type": "text", "text": content}]

if message.images is not None:
for image in message.images:
image_content = format_image_for_message(image)
if image_content:
content.append(image_content)

# Handle tool calls from history
elif message.role == "assistant" and isinstance(message.content, str) and message.tool_calls:
if message.content:
content = [TextBlock(text=message.content, type="text")]
else:
content = []
for tool_call in message.tool_calls:
content.append(
ToolUseBlock(
id=tool_call["id"],
input=json.loads(tool_call["function"]["arguments"])
if "arguments" in tool_call["function"]
else {},
name=tool_call["function"]["name"],
type="tool_use",
)
)

chat_messages.append({"role": message.role, "content": content}) # type: ignore
return chat_messages, " ".join(system_messages)

def prepare_request_kwargs(self, system_message: str) -> Dict[str, Any]:
"""
Prepare the request keyword arguments for the API call.
Expand Down Expand Up @@ -297,7 +298,7 @@ def invoke(self, messages: List[Message]) -> AnthropicMessage:
Returns:
AnthropicMessage: The response from the model.
"""
chat_messages, system_message = self.format_messages(messages)
chat_messages, system_message = _format_messages(messages)
dirkbrnd marked this conversation as resolved.
Show resolved Hide resolved
request_kwargs = self.prepare_request_kwargs(system_message)

return self.get_client().messages.create(
Expand All @@ -316,7 +317,7 @@ def invoke_stream(self, messages: List[Message]) -> Any:
Returns:
Any: The streamed response from the model.
"""
chat_messages, system_message = self.format_messages(messages)
chat_messages, system_message = _format_messages(messages)
dirkbrnd marked this conversation as resolved.
Show resolved Hide resolved
request_kwargs = self.prepare_request_kwargs(system_message)

return self.get_client().messages.stream(
Expand Down Expand Up @@ -673,7 +674,7 @@ async def ainvoke(self, messages: List[Message]) -> AnthropicMessage:
Returns:
AnthropicMessage: The response from the model.
"""
chat_messages, system_message = self.format_messages(messages)
chat_messages, system_message = _format_messages(messages)
request_kwargs = self.prepare_request_kwargs(system_message)

return await self.get_async_client().messages.create(
Expand All @@ -692,7 +693,7 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
Returns:
Any: The streamed response from the model.
"""
chat_messages, system_message = self.format_messages(messages)
chat_messages, system_message = _format_messages(messages)
request_kwargs = self.prepare_request_kwargs(system_message)

return self.get_async_client().messages.stream(
Expand Down
Loading