Skip to content

Commit

Permalink
Add timezones (#25)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <[email protected]>
  • Loading branch information
dmontagu and samuelcolvin authored Oct 30, 2024
1 parent 7f9e9df commit 0b626f8
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 66 deletions.
16 changes: 10 additions & 6 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timezone
from typing import Annotated, Any, Literal, Union

import pydantic
Expand All @@ -11,6 +11,10 @@
from . import _pydantic


def _now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


@dataclass
class SystemPrompt:
content: str
Expand All @@ -20,7 +24,7 @@ class SystemPrompt:
@dataclass
class UserPrompt:
content: str
timestamp: datetime = field(default_factory=datetime.now)
timestamp: datetime = field(default_factory=_now_utc)
role: Literal['user'] = 'user'


Expand All @@ -32,7 +36,7 @@ class ToolReturn:
tool_name: str
content: str | dict[str, Any]
tool_id: str | None = None
timestamp: datetime = field(default_factory=datetime.now)
timestamp: datetime = field(default_factory=_now_utc)
role: Literal['tool-return'] = 'tool-return'

def model_response_str(self) -> str:
Expand All @@ -54,7 +58,7 @@ class RetryPrompt:
content: list[pydantic_core.ErrorDetails] | str
tool_name: str | None = None
tool_id: str | None = None
timestamp: datetime = field(default_factory=datetime.now)
timestamp: datetime = field(default_factory=_now_utc)
role: Literal['retry-prompt'] = 'retry-prompt'

def model_response(self) -> str:
Expand All @@ -68,7 +72,7 @@ def model_response(self) -> str:
@dataclass
class LLMResponse:
content: str
timestamp: datetime = field(default_factory=datetime.now)
timestamp: datetime = field(default_factory=_now_utc)
role: Literal['llm-response'] = 'llm-response'


Expand Down Expand Up @@ -102,7 +106,7 @@ def from_object(cls, tool_name: str, args_object: dict[str, Any]) -> ToolCall:
@dataclass
class LLMToolCalls:
calls: list[ToolCall]
timestamp: datetime = field(default_factory=datetime.now)
timestamp: datetime = field(default_factory=_now_utc)
role: Literal['llm-tool-calls'] = 'llm-tool-calls'


Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timezone
from typing import Literal

from httpx import AsyncClient as AsyncHTTPClient
Expand Down Expand Up @@ -91,7 +91,7 @@ async def request(self, messages: list[Message]) -> tuple[LLMMessage, shared.Cos
@staticmethod
def process_response(response: chat.ChatCompletion) -> LLMMessage:
choice = response.choices[0]
timestamp = datetime.fromtimestamp(response.created)
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
if choice.message.tool_calls is not None:
return LLMToolCalls(
[ToolCall.from_json(c.function.name, c.function.arguments, c.id) for c in choice.message.tool_calls],
Expand Down
23 changes: 13 additions & 10 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timezone

import httpx
import pytest
Expand Down Expand Up @@ -361,8 +362,8 @@ async def test_request_simple_success(get_gemini_client: GetGeminiClient):
assert result.response == 'Hello world'
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
LLMResponse(content='Hello world', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
assert result.cost == snapshot(Cost(request_tokens=1, response_tokens=2, total_tokens=3))
Expand All @@ -382,15 +383,15 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
assert result.response == [1, 2, 123]
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[
ToolCall(
tool_name='final_result',
args=ArgsObject(args_object={'response': [1, 2, 123]}),
)
],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
),
]
)
Expand Down Expand Up @@ -426,28 +427,30 @@ async def get_location(loc_name: str) -> str:
assert result.message_history == snapshot(
[
SystemPrompt(content='this is the system prompt'),
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[
ToolCall(
tool_name='get_location',
args=ArgsObject(args_object={'loc_name': 'San Fransisco'}),
)
],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
),
RetryPrompt(
tool_name='get_location', content='Wrong location, please try again', timestamp=IsNow(tz=timezone.utc)
),
RetryPrompt(tool_name='get_location', content='Wrong location, please try again', timestamp=IsNow()),
LLMToolCalls(
calls=[
ToolCall(
tool_name='get_location',
args=ArgsObject(args_object={'loc_name': 'London'}),
)
],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
),
ToolReturn(tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow()),
LLMResponse(content='final response', timestamp=IsNow()),
ToolReturn(tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)
assert result.cost == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9))
Expand Down
52 changes: 31 additions & 21 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import re
from dataclasses import asdict
from datetime import timezone

import pydantic_core
import pytest
from dirty_equals import IsStr
from inline_snapshot import snapshot
from pydantic import BaseModel

Expand Down Expand Up @@ -38,12 +41,12 @@ def test_simple():
[
UserPrompt(
content='Hello',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='user',
),
LLMResponse(
content="content='Hello' role='user' message_count=1",
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-response',
),
]
Expand All @@ -55,22 +58,22 @@ def test_simple():
[
UserPrompt(
content='Hello',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='user',
),
LLMResponse(
content="content='Hello' role='user' message_count=1",
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-response',
),
UserPrompt(
content='World',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='user',
),
LLMResponse(
content="content='World' role='user' message_count=3",
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-response',
),
]
Expand Down Expand Up @@ -128,16 +131,19 @@ def test_weather():
[
UserPrompt(
content='London',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='user',
),
LLMToolCalls(
calls=[ToolCall.from_json('get_location', '{"location_description": "London"}')],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-tool-calls',
),
ToolReturn(
tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(), role='tool-return'
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
timestamp=IsNow(tz=timezone.utc),
role='tool-return',
),
LLMToolCalls(
calls=[
Expand All @@ -146,18 +152,18 @@ def test_weather():
'{"lat": 51, "lng": 0}',
)
],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-tool-calls',
),
ToolReturn(
tool_name='get_weather',
content='Raining',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='tool-return',
),
LLMResponse(
content='Raining in London',
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
role='llm-response',
),
]
Expand Down Expand Up @@ -198,12 +204,14 @@ def get_var_args(ctx: CallContext[int], *args: int):
def test_var_args():
result = var_args_agent.run_sync('{"function": "get_var_args", "arguments": {"args": [1, 2, 3]}}')
response_data = json.loads(result.response)
# Can't parse ISO timestamps with trailing 'Z' in older versions of python:
response_data['timestamp'] = re.sub('Z$', '+00:00', response_data['timestamp'])
assert response_data == snapshot(
{
'tool_name': 'get_var_args',
'content': '{"args": [1, 2, 3]}',
'tool_id': None,
'timestamp': IsNow(iso_string=True),
'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc),
'role': 'tool-return',
}
)
Expand Down Expand Up @@ -317,7 +325,7 @@ def test_call_all():
assert result.message_history == snapshot(
[
SystemPrompt(content='foobar'),
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[
ToolCall.from_object('foo', {'x': 0}),
Expand All @@ -326,14 +334,16 @@ def test_call_all():
ToolCall.from_object('qux', {'x': 0}),
ToolCall.from_object('quz', {'x': 'a'}),
],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
),
ToolReturn(tool_name='foo', content='1', timestamp=IsNow(tz=timezone.utc)),
ToolReturn(tool_name='bar', content='2', timestamp=IsNow(tz=timezone.utc)),
ToolReturn(tool_name='baz', content='3', timestamp=IsNow(tz=timezone.utc)),
ToolReturn(tool_name='qux', content='4', timestamp=IsNow(tz=timezone.utc)),
ToolReturn(tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(
content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow(tz=timezone.utc)
),
ToolReturn(tool_name='foo', content='1', timestamp=IsNow()),
ToolReturn(tool_name='bar', content='2', timestamp=IsNow()),
ToolReturn(tool_name='baz', content='3', timestamp=IsNow()),
ToolReturn(tool_name='qux', content='4', timestamp=IsNow()),
ToolReturn(tool_name='quz', content='a', timestamp=IsNow()),
LLMResponse(content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow()),
]
)

Expand Down
13 changes: 7 additions & 6 deletions tests/models/test_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations as _annotations

from datetime import timezone
from typing import Annotated, Any, Literal

import pytest
Expand Down Expand Up @@ -83,15 +84,15 @@ async def my_ret(x: int) -> str:
assert result.response == snapshot('{"my_ret":"1"}')
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(
calls=[ToolCall.from_object('my_ret', {'x': 0})],
timestamp=IsNow(),
timestamp=IsNow(tz=timezone.utc),
),
RetryPrompt(tool_name='my_ret', content='First call failed', timestamp=IsNow()),
LLMToolCalls(calls=[ToolCall.from_object('my_ret', {'x': 0})], timestamp=IsNow()),
ToolReturn(tool_name='my_ret', content='1', timestamp=IsNow()),
LLMResponse(content='{"my_ret":"1"}', timestamp=IsNow()),
RetryPrompt(tool_name='my_ret', content='First call failed', timestamp=IsNow(tz=timezone.utc)),
LLMToolCalls(calls=[ToolCall.from_object('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)),
ToolReturn(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc)),
LLMResponse(content='{"my_ret":"1"}', timestamp=IsNow(tz=timezone.utc)),
]
)

Expand Down
21 changes: 13 additions & 8 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def test_request_structured_response():
assert result.response == [1, 2, 123]
assert result.message_history == snapshot(
[
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)),
LLMToolCalls(
calls=[
ToolCall(
Expand All @@ -124,7 +124,7 @@ async def test_request_structured_response():
tool_id='123',
)
],
timestamp=datetime.datetime(2024, 1, 1),
timestamp=datetime.datetime(2024, 1, 1, tzinfo=datetime.timezone.utc),
),
]
)
Expand Down Expand Up @@ -188,7 +188,7 @@ async def get_location(loc_name: str) -> str:
assert result.message_history == snapshot(
[
SystemPrompt(content='this is the system prompt'),
UserPrompt(content='Hello', timestamp=IsNow()),
UserPrompt(content='Hello', timestamp=IsNow(tz=datetime.timezone.utc)),
LLMToolCalls(
calls=[
ToolCall(
Expand All @@ -197,10 +197,13 @@ async def get_location(loc_name: str) -> str:
tool_id='1',
)
],
timestamp=datetime.datetime(2024, 1, 1, 0, 0),
timestamp=datetime.datetime(2024, 1, 1, 0, 0, tzinfo=datetime.timezone.utc),
),
RetryPrompt(
tool_name='get_location', content='Wrong location, please try again', tool_id='1', timestamp=IsNow()
tool_name='get_location',
content='Wrong location, please try again',
tool_id='1',
timestamp=IsNow(tz=datetime.timezone.utc),
),
LLMToolCalls(
calls=[
Expand All @@ -210,15 +213,17 @@ async def get_location(loc_name: str) -> str:
tool_id='2',
)
],
timestamp=datetime.datetime(2024, 1, 1, 0, 0),
timestamp=datetime.datetime(2024, 1, 1, 0, 0, tzinfo=datetime.timezone.utc),
),
ToolReturn(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_id='2',
timestamp=IsNow(),
timestamp=IsNow(tz=datetime.timezone.utc),
),
LLMResponse(
content='final response', timestamp=datetime.datetime(2024, 1, 1, 0, 0, tzinfo=datetime.timezone.utc)
),
LLMResponse(content='final response', timestamp=datetime.datetime(2024, 1, 1, 0, 0)),
]
)
assert result.cost == snapshot(
Expand Down
Loading

0 comments on commit 0b626f8

Please sign in to comment.