Skip to content

Commit

Permalink
feat(weave): Patch send_message in Google GenAI SDK (#2746)
Browse files Browse the repository at this point in the history
* add: patching for google.generativeai.ChatSession.send_message

* update: tests

* add: patching for send_message_async

* add: test_send_message_async

* add: valuerror

* fix: gemini stream

* fix: bug in gemini_on_finish

* add: comments in tests

* update: tests

* remove: unnecessary files

* fix: bug in gemini_on_finish

* update: tests

* update: tests

* remove: test.py

* fix: lint

* fix: lint

* add: skips for google ai studio tests
  • Loading branch information
soumik12345 authored Nov 26, 2024
1 parent b89825e commit 92a60bb
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 7 deletions.
159 changes: 154 additions & 5 deletions tests/integrations/google_ai_studio/google_ai_studio_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os

import pytest
from pydantic import BaseModel

from weave.integrations.integration_utilities import op_name_from_ref


class Recipe(BaseModel):
recipe_name: str
ingredients: list[str]


# NOTE: These asserts are slightly more relaxed than other integrations because we can't yet save
# the output with vcrpy. When VCR.py supports GRPC, we should add recordings for these tests!
# NOTE: We have retries because these tests are not deterministic (they use the live Gemini APIs),
Expand Down Expand Up @@ -47,6 +53,31 @@ def assert_correct_summary(summary: dict, trace_name: str):
assert summary["weave"]["latency_ms"] > 0


def is_part_presence_in_content_parts(parts: list[dict], part_type: str) -> bool:
for part in parts:
if part_type in part:
return True
return False


def assert_code_execution(output: dict):
assert is_part_presence_in_content_parts(
output["candidates"][0]["content"]["parts"], "text"
)
assert is_part_presence_in_content_parts(
output["candidates"][0]["content"]["parts"], "executable_code"
)
assert is_part_presence_in_content_parts(
output["candidates"][0]["content"]["parts"], "code_execution_result"
)
assert output["candidates"][0]["content"]["role"] == "model"
assert isinstance(output["usage_metadata"], dict)
assert isinstance(output["usage_metadata"]["prompt_token_count"], int)
assert isinstance(output["usage_metadata"]["candidates_token_count"], int)
assert isinstance(output["usage_metadata"]["total_token_count"], int)
assert isinstance(output["usage_metadata"]["cached_content_token_count"], int)


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
Expand All @@ -57,7 +88,7 @@ def test_content_generation(client):

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
model.generate_content("Explain how AI works in simple terms")
model.generate_content("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1
Expand All @@ -82,9 +113,7 @@ def test_content_generation_stream(client):

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")
response = model.generate_content(
"Explain how AI works in simple terms", stream=True
)
response = model.generate_content("What is the capital of France?", stream=True)
chunks = [chunk.text for chunk in response]
assert len(chunks) > 1

Expand Down Expand Up @@ -113,7 +142,7 @@ async def test_content_generation_async(client):
genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel("gemini-1.5-flash")

_ = await model.generate_content_async("Explain how AI works in simple terms")
_ = await model.generate_content_async("What is the capital of France?")

calls = list(client.calls())
assert len(calls) == 1
Expand All @@ -125,3 +154,123 @@ async def test_content_generation_async(client):
assert call.output is not None
assert_correct_output_shape(call.output)
assert_correct_summary(call.summary, trace_name)


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_send_message(client):
import google.generativeai as genai

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel(model_name="gemini-1.5-pro", tools="code_execution")
chat = model.start_chat()
chat.send_message(
"What is the sum of the first 50 prime numbers? "
"Generate and run code for the calculation, and make sure you get all 50."
)

calls = list(client.calls())
# `send_message` is using `GenerativeModel.generate_content under the hood
# which we're already patching. Hence, we have 2 calls here.
assert len(calls) == 2

call = calls[0]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.ChatSession.send_message"
assert call.output is not None
output = call.output
assert_code_execution(output)

call = calls[1]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
output = call.output
assert_code_execution(output)


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.skip_clickhouse_client
def test_send_message_stream(client):
import google.generativeai as genai

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel(model_name="gemini-1.5-pro", tools="code_execution")
chat = model.start_chat()
response = chat.send_message(
(
"What is the sum of the first 50 prime numbers? "
"Generate and run code for the calculation, and make sure you get all 50."
),
stream=True,
)
chunks = [r.text for r in response]
assert len(chunks) > 1

calls = list(client.calls())
# `send_message` is using `GenerativeModel.generate_content under the hood
# which we're already patching. Hence, we have 2 calls here.
assert len(calls) == 2

call = calls[0]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.ChatSession.send_message"
assert call.output is not None
output = call.output
assert_code_execution(output)

call = calls[1]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
output = call.output
assert_code_execution(output)


@pytest.mark.skip(
reason="This test depends on a non-deterministic external service provider"
)
@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.asyncio
@pytest.mark.skip_clickhouse_client
async def test_send_message_async(client):
import google.generativeai as genai

genai.configure(api_key=os.getenv("GOOGLE_GENAI_KEY"))
model = genai.GenerativeModel(model_name="gemini-1.5-pro", tools="code_execution")
chat = model.start_chat()
await chat.send_message_async(
"What is the sum of the first 50 prime numbers? "
"Generate and run code for the calculation, and make sure you get all 50."
)

calls = list(client.calls())
# `send_message` is using `GenerativeModel.generate_content under the hood
# which we're already patching. Hence, we have 2 calls here.
assert len(calls) == 2

call = calls[0]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.ChatSession.send_message"
assert call.output is not None
output = call.output
assert_code_execution(output)

call = calls[1]
assert call.started_at < call.ended_at
trace_name = op_name_from_ref(call.op_name)
assert trace_name == "google.generativeai.GenerativeModel.generate_content"
assert call.output is not None
output = call.output
assert_code_execution(output)
49 changes: 47 additions & 2 deletions weave/integrations/google_ai_studio/google_ai_studio_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,35 @@ def gemini_accumulator(

for i, value_candidate in enumerate(value.candidates):
for j, value_part in enumerate(value_candidate.content.parts):
acc.candidates[i].content.parts[j].text += value_part.text
if len(value_part.text) > 0:
acc.candidates[i].content.parts[j].text += value_part.text
elif len(value_part.executable_code.code) > 0:
if len(acc.candidates[i].content.parts[j].executable_code.code) == 0:
acc.candidates[i].content.parts.append(value_part)
else:
acc.candidates[i].content.parts[
j
].executable_code.code += value_part.executable_code.code
acc.candidates[i].content.parts[
j
].executable_code.language = value_part.executable_code.language
elif len(value_part.code_execution_result.output) > 0:
if (
len(acc.candidates[i].content.parts[j].code_execution_result.output)
== 0
):
acc.candidates[i].content.parts.append(value_part)
else:
acc.candidates[i].content.parts[
j
].code_execution_result.output += (
value_part.code_execution_result.output
)
acc.candidates[i].content.parts[
j
].code_execution_result.status = (
value_part.code_execution_result.status
)

acc.usage_metadata.prompt_token_count += value.usage_metadata.prompt_token_count
acc.usage_metadata.candidates_token_count += (
Expand All @@ -39,7 +67,12 @@ def gemini_accumulator(
def gemini_on_finish(
call: Call, output: Any, exception: Optional[BaseException]
) -> None:
original_model_name = call.inputs["self"]["model_name"]
if "model_name" in call.inputs["self"]:
original_model_name = call.inputs["self"]["model_name"]
elif "model" in call.inputs["self"]:
original_model_name = call.inputs["self"]["model"]["model_name"]
else:
raise ValueError("Unknown model type")
model_name = original_model_name.split("/")[-1]
usage = {model_name: {"requests": 1}}
summary_update = {"usage": usage}
Expand Down Expand Up @@ -110,5 +143,17 @@ async def _async_wrapper(*args: Any, **kwargs: Any) -> Any:
name="google.generativeai.GenerativeModel.generate_content_async"
),
),
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"ChatSession.send_message",
gemini_wrapper_sync(name="google.generativeai.ChatSession.send_message"),
),
SymbolPatcher(
lambda: importlib.import_module("google.generativeai.generative_models"),
"ChatSession.send_message_async",
gemini_wrapper_async(
name="google.generativeai.ChatSession.send_message_async"
),
),
]
)

0 comments on commit 92a60bb

Please sign in to comment.