Skip to content

Commit

Permalink
Support docstring_format and require_parameter_descriptions on to…
Browse files Browse the repository at this point in the history
…ols (#643)
  • Loading branch information
sydney-runkle authored Jan 9, 2025
1 parent 5510485 commit b10ab78
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 22 deletions.
13 changes: 10 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_griffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
import re
from contextlib import contextmanager
from inspect import Signature
from typing import Any, Callable, Literal, cast
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

from griffe import Docstring, DocstringSectionKind, Object as GriffeObject

if TYPE_CHECKING:
from .tools import DocstringFormat

DocstringStyle = Literal['google', 'numpy', 'sphinx']


def doc_descriptions(
func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
func: Callable[..., Any],
sig: Signature,
*,
docstring_format: DocstringFormat,
) -> tuple[str, dict[str, str]]:
"""Extract the function description and parameter descriptions from a function's docstring.
Expand All @@ -26,7 +32,8 @@ def doc_descriptions(
# see https://github.com/mkdocstrings/griffe/issues/293
parent = cast(GriffeObject, sig)

docstring = Docstring(doc, lineno=1, parser=style or _infer_docstring_style(doc), parent=parent)
docstring_style = _infer_docstring_style(doc) if docstring_format == 'auto' else docstring_format
docstring = Docstring(doc, lineno=1, parser=docstring_style, parent=parent)
with _disable_griffe_logging():
sections = docstring.parse()

Expand Down
19 changes: 16 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._utils import check_object_json_schema, is_model_like

if TYPE_CHECKING:
from .tools import ObjectJsonSchema
from .tools import DocstringFormat, ObjectJsonSchema


__all__ = ('function_schema',)
Expand All @@ -38,12 +38,19 @@ class FunctionSchema(TypedDict):
var_positional_field: str | None


def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
def function_schema( # noqa: C901
function: Callable[..., Any],
takes_ctx: bool,
docstring_format: DocstringFormat,
require_parameter_descriptions: bool,
) -> FunctionSchema:
"""Build a Pydantic validator and JSON schema from a tool function.
Args:
function: The function to build a validator and JSON schema for.
takes_ctx: Whether the function takes a `RunContext` first argument.
docstring_format: The docstring format to use.
require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
Returns:
A `FunctionSchema` instance.
Expand All @@ -62,7 +69,13 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc
var_positional_field: str | None = None
errors: list[str] = []
decorators = _decorators.DecoratorInfos()
description, field_descriptions = doc_descriptions(function, sig)

description, field_descriptions = doc_descriptions(function, sig, docstring_format=docstring_format)

if require_parameter_descriptions:
if len(field_descriptions) != len(sig.parameters):
missing_params = set(sig.parameters) - set(field_descriptions)
errors.append(f'Missing parameter descriptions for {", ".join(missing_params)}')

for index, (name, p) in enumerate(sig.parameters.items()):
if p.annotation is sig.empty:
Expand Down
36 changes: 31 additions & 5 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .settings import ModelSettings, merge_model_settings
from .tools import (
AgentDeps,
DocstringFormat,
RunContext,
Tool,
ToolDefinition,
Expand Down Expand Up @@ -774,6 +775,8 @@ def tool(
*,
retries: int | None = None,
prepare: ToolPrepareFunc[AgentDeps] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...

def tool(
Expand All @@ -783,6 +786,8 @@ def tool(
*,
retries: int | None = None,
prepare: ToolPrepareFunc[AgentDeps] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
) -> Any:
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
Expand Down Expand Up @@ -820,20 +825,23 @@ async def spam(ctx: RunContext[str], y: float) -> float:
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
tool from a given step. This is useful if you want to customise a tool at call time,
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
"""
if func is None:

def tool_decorator(
func_: ToolFuncContext[AgentDeps, ToolParams],
) -> ToolFuncContext[AgentDeps, ToolParams]:
# noinspection PyTypeChecker
self._register_function(func_, True, retries, prepare)
self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
return func_

return tool_decorator
else:
# noinspection PyTypeChecker
self._register_function(func, True, retries, prepare)
self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
return func

@overload
Expand All @@ -846,6 +854,8 @@ def tool_plain(
*,
retries: int | None = None,
prepare: ToolPrepareFunc[AgentDeps] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...

def tool_plain(
Expand All @@ -855,6 +865,8 @@ def tool_plain(
*,
retries: int | None = None,
prepare: ToolPrepareFunc[AgentDeps] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
) -> Any:
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
Expand Down Expand Up @@ -892,17 +904,22 @@ async def spam(ctx: RunContext[str]) -> float:
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
tool from a given step. This is useful if you want to customise a tool at call time,
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
"""
if func is None:

def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
# noinspection PyTypeChecker
self._register_function(func_, False, retries, prepare)
self._register_function(
func_, False, retries, prepare, docstring_format, require_parameter_descriptions
)
return func_

return tool_decorator
else:
self._register_function(func, False, retries, prepare)
self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
return func

def _register_function(
Expand All @@ -911,10 +928,19 @@ def _register_function(
takes_ctx: bool,
retries: int | None,
prepare: ToolPrepareFunc[AgentDeps] | None,
docstring_format: DocstringFormat,
require_parameter_descriptions: bool,
) -> None:
"""Private utility to register a function as a tool."""
retries_ = retries if retries is not None else self._default_retries
tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
tool = Tool(
func,
takes_ctx=takes_ctx,
max_retries=retries_,
prepare=prepare,
docstring_format=docstring_format,
require_parameter_descriptions=require_parameter_descriptions,
)
self._register_tool(tool)

def _register_tool(self, tool: Tool[AgentDeps]) -> None:
Expand Down
23 changes: 21 additions & 2 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
Expand All @@ -18,6 +18,7 @@

__all__ = (
'AgentDeps',
'DocstringFormat',
'RunContext',
'SystemPromptFunc',
'ToolFuncContext',
Expand Down Expand Up @@ -127,6 +128,15 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
Usage `ToolPrepareFunc[AgentDeps]`.
"""

DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
"""Supported docstring formats.
* `'google'` — [Google-style](https://google.github.io/styleguide/pyguide.html#381-docstrings) docstrings.
* `'numpy'` — [Numpy-style](https://numpydoc.readthedocs.io/en/latest/format.html) docstrings.
* `'sphinx'` — [Sphinx-style](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format) docstrings.
* `'auto'` — Automatically infer the format based on the structure of the docstring.
"""

A = TypeVar('A')


Expand All @@ -140,6 +150,8 @@ class Tool(Generic[AgentDeps]):
name: str
description: str
prepare: ToolPrepareFunc[AgentDeps] | None
docstring_format: DocstringFormat
require_parameter_descriptions: bool
_is_async: bool = field(init=False)
_single_arg_name: str | None = field(init=False)
_positional_fields: list[str] = field(init=False)
Expand All @@ -157,6 +169,8 @@ def __init__(
name: str | None = None,
description: str | None = None,
prepare: ToolPrepareFunc[AgentDeps] | None = None,
docstring_format: DocstringFormat = 'auto',
require_parameter_descriptions: bool = False,
):
"""Create a new tool instance.
Expand Down Expand Up @@ -203,17 +217,22 @@ async def prep_my_tool(
prepare: custom method to prepare the tool definition for each step, return `None` to omit this
tool from a given step. This is useful if you want to customise a tool at call time,
or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
"""
if takes_ctx is None:
takes_ctx = _pydantic.takes_ctx(function)

f = _pydantic.function_schema(function, takes_ctx)
f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
self.function = function
self.takes_ctx = takes_ctx
self.max_retries = max_retries
self.name = name or function.__name__
self.description = description or f['description']
self.prepare = prepare
self.docstring_format = docstring_format
self.require_parameter_descriptions = require_parameter_descriptions
self._is_async = inspect.iscoroutinefunction(self.function)
self._single_arg_name = f['single_arg_name']
self._positional_fields = f['positional_fields']
Expand Down
43 changes: 34 additions & 9 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import Annotated, Any, Callable, Union
from typing import Annotated, Any, Callable, Literal, Union

import pydantic_core
import pytest
Expand Down Expand Up @@ -78,9 +78,10 @@ async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> Mod
return ModelResponse.from_text(pydantic_core.to_json(r).decode())


def test_docstring_google(set_event_loop: None):
@pytest.mark.parametrize('docstring_format', ['google', 'auto'])
def test_docstring_google(set_event_loop: None, docstring_format: Literal['google', 'auto']):
agent = Agent(FunctionModel(get_json_schema))
agent.tool_plain(google_style_docstring)
agent.tool_plain(docstring_format=docstring_format)(google_style_docstring)

result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
Expand Down Expand Up @@ -116,9 +117,10 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
return str(foo)


def test_docstring_sphinx(set_event_loop: None):
@pytest.mark.parametrize('docstring_format', ['sphinx', 'auto'])
def test_docstring_sphinx(set_event_loop: None, docstring_format: Literal['sphinx', 'auto']):
agent = Agent(FunctionModel(get_json_schema))
agent.tool_plain(sphinx_style_docstring)
agent.tool_plain(docstring_format=docstring_format)(sphinx_style_docstring)

result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
Expand Down Expand Up @@ -150,9 +152,10 @@ def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover
return f'{foo} {bar}'


def test_docstring_numpy(set_event_loop: None):
@pytest.mark.parametrize('docstring_format', ['numpy', 'auto'])
def test_docstring_numpy(set_event_loop: None, docstring_format: Literal['numpy', 'auto']):
agent = Agent(FunctionModel(get_json_schema))
agent.tool_plain(numpy_style_docstring)
agent.tool_plain(docstring_format=docstring_format)(numpy_style_docstring)

result = agent.run_sync('Hello')
json_schema = json.loads(result.data)
Expand Down Expand Up @@ -209,9 +212,10 @@ async def google_style_docstring_no_body(
# fmt: on


def test_docstring_google_no_body(set_event_loop: None):
@pytest.mark.parametrize('docstring_format', ['google', 'auto'])
def test_docstring_google_no_body(set_event_loop: None, docstring_format: Literal['google', 'auto']):
agent = Agent(FunctionModel(get_json_schema))
agent.tool_plain(google_style_docstring_no_body)
agent.tool_plain(docstring_format=docstring_format)(google_style_docstring_no_body)

result = agent.run_sync('')
json_schema = json.loads(result.data)
Expand Down Expand Up @@ -566,3 +570,24 @@ def test_suppress_griffe_logging(set_event_loop: None, caplog: LogCaptureFixture
# Without suppressing griffe logging, we get:
# assert caplog.messages == snapshot(['<module>:4: No type or annotation for returned value 1'])
assert caplog.messages == snapshot([])


async def missing_parameter_descriptions_docstring(foo: int, bar: str) -> str: # pragma: no cover
"""Describes function ops, but missing parameter descriptions."""
return f'{foo} {bar}'


def test_enforce_parameter_descriptions() -> None:
agent = Agent(FunctionModel(get_json_schema))

with pytest.raises(UserError) as exc_info:
agent.tool_plain(require_parameter_descriptions=True)(missing_parameter_descriptions_docstring)

error_reason = exc_info.value.args[0]
error_parts = [
'Error generating schema for missing_parameter_descriptions_docstring',
'Missing parameter descriptions for ',
'foo',
'bar',
]
assert all(err_part in error_reason for err_part in error_parts)

0 comments on commit b10ab78

Please sign in to comment.