Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support TypeAliasType unions #26

Merged
merged 3 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@
from .shared import AgentDeps


__all__ = 'function_schema', 'LazyTypeAdapter', 'is_union'


def is_union(tp: Any) -> bool:
origin = get_origin(tp)
return _typing_extra.origin_is_union(origin)
__all__ = 'function_schema', 'LazyTypeAdapter'


class FunctionSchema(TypedDict):
Expand Down
36 changes: 27 additions & 9 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations as _annotations

import inspect
import sys
import types
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Union, cast, get_args
from typing import Any, Callable, Generic, Union, cast, get_args, get_origin

from pydantic import TypeAdapter, ValidationError
from typing_extensions import Self, TypedDict
from typing_extensions import Self, TypeAliasType, TypedDict

from . import _pydantic, _utils, messages
from . import _utils, messages
from .messages import LLMToolCalls, ToolCall
from .shared import AgentDeps, CallContext, ModelRetry, ResultData

Expand Down Expand Up @@ -106,7 +108,7 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat
)

tools: dict[str, ResultTool[ResultData]] = {}
if args := union_args(response_type):
if args := get_union_args(response_type):
for arg in args:
tool_name = union_tool_name(name, arg)
tools[tool_name] = _build_tool(arg, tool_name, True)
Expand Down Expand Up @@ -204,10 +206,11 @@ def union_arg_name(union_arg: Any) -> str:

def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
"""Extract the string type from a Union, return the remaining union or remaining type."""
if _pydantic.is_union(response_type) and any(t is str for t in get_args(response_type)):
union_args = get_union_args(response_type)
if any(t is str for t in union_args):
remain_args: list[Any] = []
includes_str = False
for arg in get_args(response_type):
for arg in union_args:
if arg is str:
includes_str = True
else:
Expand All @@ -219,9 +222,24 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
return _utils.Some(Union[tuple(remain_args)])


def union_args(response_type: Any) -> tuple[Any, ...]:
def get_union_args(tp: Any) -> tuple[Any, ...]:
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty union."""
if _pydantic.is_union(response_type):
return get_args(response_type)
if isinstance(tp, TypeAliasType):
tp = tp.__value__

origin = get_origin(tp)
if origin_is_union(origin):
return get_args(tp)
else:
return ()


if sys.version_info < (3, 10):

def origin_is_union(tp: type[Any] | None) -> bool:
return tp is Union

else:

def origin_is_union(tp: type[Any] | None) -> bool:
return tp is Union or tp is types.UnionType
27 changes: 14 additions & 13 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,47 +417,48 @@ def __init__(self, schema: _utils.ObjectJsonSchema):
self.defs = self.schema.pop('$defs', {})

def simplify(self) -> dict[str, Any]:
self._simplify(self.schema, allow_ref=True)
self._simplify(self.schema, refs_stack=())
return self.schema

def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None:
def _simplify(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
schema.pop('title', None)
schema.pop('default', None)
if ref := schema.pop('$ref', None):
if not allow_ref:
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
# noinspection PyTypeChecker
key = re.sub(r'^#/\$defs/', '', ref)
if key in refs_stack:
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
refs_stack += (key,)
schema_def = self.defs[key]
self._simplify(schema_def, allow_ref=False)
self._simplify(schema_def, refs_stack)
schema.update(schema_def)
return

if any_of := schema.get('anyOf'):
for schema in any_of:
self._simplify(schema, allow_ref)
self._simplify(schema, refs_stack)

type_ = schema.get('type')

if type_ == 'object':
self._object(schema, allow_ref)
self._object(schema, refs_stack)
elif type_ == 'array':
return self._array(schema, allow_ref)
return self._array(schema, refs_stack)

def _object(self, schema: dict[str, Any], allow_ref: bool) -> None:
def _object(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
ad_props = schema.pop('additionalProperties', None)
if ad_props:
raise shared.UserError('Additional properties in JSON Schema are not supported by Gemini')

if properties := schema.get('properties'): # pragma: no branch
for value in properties.values():
self._simplify(value, allow_ref)
self._simplify(value, refs_stack)

def _array(self, schema: dict[str, Any], allow_ref: bool) -> None:
def _array(self, schema: dict[str, Any], refs_stack: tuple[str, ...]) -> None:
if prefix_items := schema.get('prefixItems'):
# TODO I think this not is supported by Gemini, maybe we should raise an error?
for prefix_item in prefix_items:
self._simplify(prefix_item, allow_ref)
self._simplify(prefix_item, refs_stack)

if items_schema := schema.get('items'): # pragma: no branch
self._simplify(items_schema, allow_ref)
self._simplify(items_schema, refs_stack)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ filterwarnings = [

# https://coverage.readthedocs.io/en/latest/config.html#run
[tool.coverage.run]
# required to avoid warnings about files created by create_module fixture
include = ["pydantic_ai/**/*.py", "tests/**/*.py"]
branch = true

# https://coverage.readthedocs.io/en/latest/config.html#report
Expand Down
54 changes: 54 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations as _annotations

import importlib.util
import os
import re
import secrets
import sys
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, Callable

import httpx
import pytest
from _pytest.assertion.rewrite import AssertionRewritingHook
from typing_extensions import TypeAlias

__all__ = 'IsNow', 'TestEnv'
Expand Down Expand Up @@ -78,3 +85,50 @@ def create_client(handler: Callable[[httpx.Request], httpx.Response]) -> httpx.A


ClientWithHandler: TypeAlias = Callable[[Callable[[httpx.Request], httpx.Response]], httpx.AsyncClient]


# pyright: reportUnknownMemberType=false, reportUnknownArgumentType=false
@pytest.fixture
def create_module(tmp_path: Path, request: pytest.FixtureRequest) -> Callable[[str], Any]:
"""Taken from `pydantic/tests/conftest.py`, create module object, execute and return it."""

def run(
source_code: str,
rewrite_assertions: bool = True,
module_name_prefix: str | None = None,
) -> ModuleType:
"""Create module object, execute and return it.

Can be used as a decorator of the function from the source code of which the module will be constructed.

Args:
source_code: Python source code of the module
rewrite_assertions: whether to rewrite assertions in module or not
module_name_prefix: string prefix to use in the name of the module, does not affect the name of the file.

"""

# Max path length in Windows is 260. Leaving some buffer here
max_name_len = 240 - len(str(tmp_path))
# Windows does not allow these characters in paths. Linux bans slashes only.
sanitized_name = re.sub('[' + re.escape('<>:"/\\|?*') + ']', '-', request.node.name)[:max_name_len]
module_name = f'{sanitized_name}_{secrets.token_hex(5)}'
path = tmp_path / f'{module_name}.py'
path.write_text(source_code)
filename = str(path)

if module_name_prefix:
module_name = module_name_prefix + module_name

if rewrite_assertions:
loader = AssertionRewritingHook(config=request.config)
loader.mark_rewrite(module_name)
else:
loader = None

spec = importlib.util.spec_from_file_location(module_name, filename, loader=loader)
sys.modules[module_name] = module = importlib.util.module_from_spec(spec) # pyright: ignore[reportArgumentType]
spec.loader.exec_module(module) # pyright: ignore[reportOptionalMemberAccess]
return module

return run
44 changes: 33 additions & 11 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from datetime import timezone
from typing import Any, Callable, Union

Expand Down Expand Up @@ -255,23 +256,44 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
)


# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
@pytest.mark.parametrize(
'union_code',
[
pytest.param('ResultType = Union[Foo, Bar]'),
pytest.param('ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')),
pytest.param(
'ResultType: TypeAlias = Foo | Bar',
marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='Python 3.10+'),
),
pytest.param(
'type ResultType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 12), reason='3.12+')
),
],
)
def test_response_multiple_return_tools(create_module: Callable[[str], Any], union_code: str):
module_code = f'''
from pydantic import BaseModel
from typing import Union
from typing_extensions import TypeAlias

class Foo(BaseModel):
a: int
b: str


class Bar(BaseModel):
"""This is a bar model."""

b: str

{union_code}
'''

@pytest.mark.parametrize(
'input_union_callable', [lambda: Union[Foo, Bar], lambda: Foo | Bar], ids=['Union[Foo, Bar]', 'Foo | Bar']
)
def test_response_multiple_return_tools(input_union_callable: Callable[[], Any]):
try:
union = input_union_callable()
except TypeError:
raise pytest.skip('Python version does not support `|` syntax for unions')
mod = create_module(module_code)

m = TestModel()
agent: Agent[None, Union[Foo, Bar]] = Agent(m, result_type=union)
agent = Agent(m, result_type=mod.ResultType)
got_tool_call_name = 'unset'

@agent.result_validator
Expand All @@ -281,7 +303,7 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
return r

result = agent.run_sync('Hello')
assert result.response == Foo(a=0, b='a')
assert result.response == mod.Foo(a=0, b='a')
assert got_tool_call_name == snapshot('final_result_Foo')

assert m.agent_model_retrievers == snapshot({})
Expand Down Expand Up @@ -324,5 +346,5 @@ def validate_result(ctx: CallContext[None], r: Any) -> Any:
)

result = agent.run_sync('Hello', model=TestModel(seed=1))
assert result.response == Bar(b='b')
assert result.response == mod.Bar(b='b')
assert got_tool_call_name == snapshot('final_result_Bar')
Loading