Skip to content

Commit

Permalink
test gemini and openai
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 21, 2024
1 parent f26c306 commit e19851e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
9 changes: 9 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_GeminiTools, # pyright: ignore[reportPrivateUsage]
_GeminiUsageMetaData, # pyright: ignore[reportPrivateUsage]
)
from pydantic_ai.shared import Cost
from tests.conftest import ClientWithHandler, IsNow, TestEnv

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -359,6 +360,13 @@ async def test_request_simple_success(get_gemini_client: GetGeminiClient):

result = await agent.run('Hello')
assert result.response == 'Hello world'
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
LLMResponse(content='Hello world', timestamp=IsNow()),
]
)
assert result.cost == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3, details={}))


async def test_request_structured_response(get_gemini_client: GetGeminiClient):
Expand Down Expand Up @@ -443,6 +451,7 @@ async def get_location(loc_name: str) -> str:
LLMResponse(content='final response', timestamp=IsNow()),
]
)
assert result.cost == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9, details={}))


async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv):
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def ret_a(x: str) -> str:
return f'{x}-a'

@agent.retriever_plain
async def ret_b(x: str) -> str:
async def ret_b(x: str) -> str: # pragma: no cover
calls.append('b')
return f'{x}-b'

Expand Down
42 changes: 38 additions & 4 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from inline_snapshot import snapshot
from openai import AsyncOpenAI
from openai.types import chat
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # pyright: ignore[reportPrivateImportUsage]
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.completion_usage import CompletionUsage, PromptTokensDetails

from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import (
Expand All @@ -23,6 +25,7 @@
UserPrompt,
)
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.shared import Cost
from tests.conftest import IsNow

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -54,13 +57,14 @@ async def chat_completions_create(self, *_args: Any, **_kwargs: Any) -> chat.Cha
return completion


def completion_message(message: ChatCompletionMessage) -> chat.ChatCompletion:
def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion:
return chat.ChatCompletion(
id='123',
choices=[Choice(finish_reason='stop', index=0, message=message)],
created=1704067200, # 2024-01-01
model='gpt-4',
object='chat.completion',
usage=usage,
)


Expand All @@ -72,6 +76,21 @@ async def test_request_simple_success():

result = await agent.run('Hello')
assert result.response == 'world'
assert result.cost == snapshot(Cost())


async def test_request_simple_usage():
c = completion_message(
ChatCompletionMessage(content='world', role='assistant'),
usage=CompletionUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3),
)
mock_client = MockOpenAI.create_mock(c)
m = OpenAIModel('gpt-4', openai_client=mock_client)
agent = Agent(m, deps=None)

result = await agent.run('Hello')
assert result.response == 'world'
assert result.cost == snapshot(Cost(request_tokens=2, response_tokens=1, total_tokens=3, details={}))


async def test_request_structured_response():
Expand Down Expand Up @@ -124,7 +143,13 @@ async def test_request_tool_call():
type='function',
)
],
)
),
usage=CompletionUsage(
completion_tokens=1,
prompt_tokens=2,
total_tokens=3,
prompt_tokens_details=PromptTokensDetails(cached_tokens=1),
),
),
completion_message(
ChatCompletionMessage(
Expand All @@ -137,7 +162,13 @@ async def test_request_tool_call():
type='function',
)
],
)
),
usage=CompletionUsage(
completion_tokens=2,
prompt_tokens=3,
total_tokens=6,
prompt_tokens_details=PromptTokensDetails(cached_tokens=2),
),
),
completion_message(ChatCompletionMessage(content='final response', role='assistant')),
]
Expand Down Expand Up @@ -190,3 +221,6 @@ async def get_location(loc_name: str) -> str:
LLMResponse(content='final response', timestamp=datetime.datetime(2024, 1, 1, 0, 0)),
]
)
assert result.cost == snapshot(
Cost(request_tokens=5, response_tokens=3, total_tokens=9, details={'cached_tokens': 3})
)
2 changes: 1 addition & 1 deletion tests/test_retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_retriever_plain_with_ctx():
with pytest.raises(UserError) as exc_info:

@agent.retriever_plain
async def invalid_retriever(ctx: CallContext[None]) -> str:
async def invalid_retriever(ctx: CallContext[None]) -> str: # pragma: no cover
return 'Hello'

assert str(exc_info.value) == snapshot(
Expand Down

0 comments on commit e19851e

Please sign in to comment.