Skip to content

Commit

Permalink
refact: make GeminiMultiModal a thin wrapper around Gemini
Browse files Browse the repository at this point in the history
fix unit test

fix request_options handling

save progress

address review comment

update notebook example

fix completion docs handling
  • Loading branch information
masci committed Jan 14, 2025
1 parent 3f7e66e commit a80e2ab
Show file tree
Hide file tree
Showing 10 changed files with 425 additions and 371 deletions.
270 changes: 157 additions & 113 deletions docs/docs/examples/multi_modal/gemini.ipynb

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,14 @@ def image_to_base64(self) -> Self:
# Not base64 - encode it
self.image = base64.b64encode(self.image)

self._guess_mimetype(decoded_img)
return self

def _guess_mimetype(self, img_data: bytes) -> None:
if not self.image_mimetype:
guess = filetype.guess(decoded_img)
guess = filetype.guess(img_data)
self.image_mimetype = guess.mime if guess else None

return self

def resolve_image(self, as_base64: bool = False) -> BytesIO:
"""Resolve an image such that PIL can read it.
Expand All @@ -103,13 +105,15 @@ def resolve_image(self, as_base64: bool = False) -> BytesIO:
return BytesIO(base64.b64decode(self.image))
elif self.path is not None:
img_bytes = self.path.read_bytes()
self._guess_mimetype(img_bytes)
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
elif self.url is not None:
# load image from URL
response = requests.get(str(self.url))
img_bytes = response.content
self._guess_mimetype(img_bytes)
if as_base64:
return BytesIO(base64.b64encode(img_bytes))
return BytesIO(img_bytes)
Expand Down
11 changes: 7 additions & 4 deletions llama-index-core/llama_index/core/llms/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Sequence

from llama_index.core.base.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
Expand All @@ -12,10 +16,6 @@
llm_chat_callback,
llm_completion_callback,
)
from llama_index.core.base.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.core.llms.llm import LLM


Expand All @@ -26,6 +26,9 @@ class CustomLLM(LLM):
`_stream_complete`, and `metadata` methods.
"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
assert self.messages_to_prompt is not None
Expand Down
17 changes: 11 additions & 6 deletions llama-index-core/llama_index/core/utilities/gemini_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Global Gemini Utilities (shared between Gemini LLM and Vertex)."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Dict

from llama_index.core.base.llms.types import ChatMessage, MessageRole

ROLES_TO_GEMINI: Dict[MessageRole, MessageRole] = {
ROLES_TO_GEMINI: dict[MessageRole, MessageRole] = {
MessageRole.USER: MessageRole.USER,
MessageRole.ASSISTANT: MessageRole.MODEL,
## Gemini chat mode only has user and model roles. Put the rest in user role.
Expand All @@ -17,17 +18,21 @@
MessageRole.TOOL: MessageRole.USER,
MessageRole.FUNCTION: MessageRole.USER,
}
ROLES_FROM_GEMINI: Dict[MessageRole, MessageRole] = {
ROLES_FROM_GEMINI: dict[str, MessageRole] = {
## Gemini has user, model and function roles.
MessageRole.USER: MessageRole.USER,
MessageRole.MODEL: MessageRole.ASSISTANT,
MessageRole.FUNCTION: MessageRole.TOOL,
"user": MessageRole.USER,
"model": MessageRole.ASSISTANT,
"function": MessageRole.TOOL,
}


def merge_neighboring_same_role_messages(
messages: Sequence[ChatMessage],
) -> Sequence[ChatMessage]:
if len(messages) < 2:
# Nothing to merge
return messages

# Gemini does not support multiple messages of the same role in a row, so we merge them
merged_messages = []
i = 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Google's hosted Gemini API."""

import os
from typing import Any, Dict, Optional, Sequence
import warnings
from typing import Any, Dict, Optional, Sequence, cast

import google.generativeai as genai
from google.generativeai.types import generation_types
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseAsyncGen,
ChatResponseGen,
CompletionResponse,
CompletionResponseGen,
Expand All @@ -21,13 +24,12 @@
ROLES_FROM_GEMINI,
merge_neighboring_same_role_messages,
)
from llama_index.llms.gemini.utils import (

from .utils import (
chat_from_gemini_response,
chat_message_to_gemini,
completion_from_gemini_response,
)
import google.generativeai as genai


GEMINI_MODELS = (
"models/gemini-2.0-flash-exp",
Expand Down Expand Up @@ -86,7 +88,7 @@ class Gemini(CustomLLM):
def __init__(
self,
api_key: Optional[str] = None,
model: Optional[str] = GEMINI_MODELS[0],
model: str = GEMINI_MODELS[0],
temperature: float = DEFAULT_TEMPERATURE,
max_tokens: Optional[int] = None,
generation_config: Optional[genai.types.GenerationConfigDict] = None,
Expand Down Expand Up @@ -118,7 +120,7 @@ def __init__(
if transport:
config_params["transport"] = transport
if default_headers:
default_metadata: Sequence[Dict[str, str]] = []
default_metadata = []
for key, value in default_headers.items():
default_metadata.append((key, value))
# `default_metadata` contains (key, value) pairs that will be sent with every request.
Expand All @@ -129,7 +131,10 @@ def __init__(

base_gen_config = generation_config if generation_config else {}
# Explicitly passed args take precedence over the generation_config.
final_gen_config = {"temperature": temperature, **base_gen_config}
final_gen_config = cast(
generation_types.GenerationConfigDict,
{"temperature": temperature, **base_gen_config},
)

model_meta = genai.get_model(model)

Expand Down Expand Up @@ -187,6 +192,15 @@ def complete(
)
return completion_from_gemini_response(result)

async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
request_options = self._request_options or kwargs.pop("request_options", None)
result = await self._model.generate_content_async(
prompt, request_options=request_options, **kwargs
)
return completion_from_gemini_response(result)

def stream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
Expand All @@ -209,6 +223,18 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
)
return chat_from_gemini_response(response)

async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
request_options = self._request_options or kwargs.pop("request_options", None)
merged_messages = merge_neighboring_same_role_messages(messages)
*history, next_msg = map(chat_message_to_gemini, merged_messages)
chat = self._model.start_chat(history=history)
response = await chat.send_message_async(
next_msg, request_options=request_options, **kwargs
)
return chat_from_gemini_response(response)

@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
Expand All @@ -228,9 +254,41 @@ def gen() -> ChatResponseGen:
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback) # type: ignore
),
}
content += content_delta
yield ChatResponse(
message=ChatMessage(role=role, content=content),
delta=content_delta,
raw=raw,
)

return gen()

async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseAsyncGen:
request_options = self._request_options or kwargs.pop("request_options", None)
merged_messages = merge_neighboring_same_role_messages(messages)
*history, next_msg = map(chat_message_to_gemini, messages)
chat = self._model.start_chat(history=history)
response = await chat.send_message_async(
next_msg, stream=True, request_options=request_options, **kwargs
)

async def gen() -> ChatResponseAsyncGen:
content = ""
async for r in response:
top_candidate = r.candidates[0]
content_delta = top_candidate.content.parts[0].text
role = ROLES_FROM_GEMINI[top_candidate.content.role]
raw = {
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
**(
type(response.prompt_feedback).to_dict(response.prompt_feedback)
type(response.prompt_feedback).to_dict(response.prompt_feedback) # type: ignore
),
}
content += content_delta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import google.ai.generativelanguage as glm
import google.generativeai as genai
import PIL

from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
CompletionResponse,
ImageBlock,
TextBlock,
)
from llama_index.core.multi_modal_llms.base import ChatMessage
from llama_index.core.utilities.gemini_utils import ROLES_FROM_GEMINI, ROLES_TO_GEMINI


Expand Down Expand Up @@ -39,8 +40,8 @@ def completion_from_gemini_response(
_error_if_finished_early(top_candidate)

raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(type(response.prompt_feedback).to_dict(response.prompt_feedback)),
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
**(type(response.prompt_feedback).to_dict(response.prompt_feedback)), # type: ignore
}
if response.usage_metadata:
raw["usage_metadata"] = type(response.usage_metadata).to_dict(
Expand All @@ -59,8 +60,8 @@ def chat_from_gemini_response(
_error_if_finished_early(top_candidate)

raw = {
**(type(top_candidate).to_dict(top_candidate)),
**(type(response.prompt_feedback).to_dict(response.prompt_feedback)),
**(type(top_candidate).to_dict(top_candidate)), # type: ignore
**(type(response.prompt_feedback).to_dict(response.prompt_feedback)), # type: ignore
}
if response.usage_metadata:
raw["usage_metadata"] = type(response.usage_metadata).to_dict(
Expand All @@ -72,9 +73,22 @@ def chat_from_gemini_response(

def chat_message_to_gemini(message: ChatMessage) -> "genai.types.ContentDict":
"""Convert ChatMessages to Gemini-specific history, including ImageDocuments."""
parts = [message.content]
if images := message.additional_kwargs.get("images"):
parts += [PIL.Image.open(doc.resolve_image()) for doc in images]
parts = []
content_txt = ""
for block in message.blocks:
if isinstance(block, TextBlock):
parts.append(block.text)
elif isinstance(block, ImageBlock):
base64_bytes = block.resolve_image(as_base64=False).read()
parts.append(
{
"mime_type": block.image_mimetype,
"data": base64_bytes,
}
)
else:
msg = f"Unsupported content block type: {type(block).__name__}"
raise ValueError(msg)

return {
"role": ROLES_TO_GEMINI[message.role],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
import os

import pytest
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatMessage, ImageBlock, MessageRole
from llama_index.llms.gemini import Gemini
from llama_index.llms.gemini.utils import chat_message_to_gemini


def test_embedding_class():
names_of_base_classes = [b.__name__ for b in Gemini.__mro__]
assert BaseLLM.__name__ in names_of_base_classes


def test_chat_message_to_gemini():
msg = ChatMessage("Some content")
assert chat_message_to_gemini(msg) == {
"role": MessageRole.USER,
"parts": ["Some content"],
}

msg = ChatMessage("Some content")
msg.blocks.append(ImageBlock(image=b"foo", image_mimetype="image/png"))
assert chat_message_to_gemini(msg) == {
"role": MessageRole.USER,
"parts": ["Some content", {"data": b"foo", "mime_type": "image/png"}],
}


@pytest.mark.skipif(
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
)
def test_generate_image_prompt():
msg = ChatMessage("Tell me the brand of the car in this image:")
msg.blocks.append(
ImageBlock(
url="https://upload.wikimedia.org/wikipedia/commons/5/52/Ferrari_SP_FFX.jpg"
)
)
response = Gemini().chat(messages=[msg])
assert "ferrari" in str(response).lower()


@pytest.mark.skipif(
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
)
def test_chat_stream():
msg = ChatMessage("List three types of software testing strategies")
response = list(Gemini().stream_chat(messages=[msg]))
assert response
Loading

0 comments on commit a80e2ab

Please sign in to comment.