Skip to content

Commit

Permalink
fix docstring inference (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 20, 2024
1 parent bd5ade0 commit 3354245
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 15 deletions.
95 changes: 86 additions & 9 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations as _annotations

import re
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, cast, get_origin

Expand Down Expand Up @@ -216,15 +217,91 @@ def _doc_descriptions(

def _infer_docstring_style(doc: str) -> DocstringStyle:
"""Simplistic docstring style inference."""
if ' Args:' in doc:
return 'google'
elif ' :param ' in doc:
return 'sphinx'
elif ' Parameters' in doc:
return 'numpy'
else:
# fallback to google style
return 'google'
for pattern, replacements, style in _docstring_style_patterns:
matches = (
re.search(pattern.format(replacement), doc, re.IGNORECASE | re.MULTILINE) for replacement in replacements
)
if any(matches):
return style
# fallback to google style
return 'google'


# See https://github.com/mkdocstrings/griffe/issues/329#issuecomment-2425017804
_docstring_style_patterns: list[tuple[str, list[str], DocstringStyle]] = [
(
r'\n[ \t]*:{0}([ \t]+\w+)*:([ \t]+.+)?\n',
[
'param',
'parameter',
'arg',
'argument',
'key',
'keyword',
'type',
'var',
'ivar',
'cvar',
'vartype',
'returns',
'return',
'rtype',
'raises',
'raise',
'except',
'exception',
],
'sphinx',
),
(
r'\n[ \t]*{0}:([ \t]+.+)?\n[ \t]+.+',
[
'args',
'arguments',
'params',
'parameters',
'keyword args',
'keyword arguments',
'other args',
'other arguments',
'other params',
'other parameters',
'raises',
'exceptions',
'returns',
'yields',
'receives',
'examples',
'attributes',
'functions',
'methods',
'classes',
'modules',
'warns',
'warnings',
],
'google',
),
(
r'\n[ \t]*{0}\n[ \t]*---+\n',
[
'deprecated',
'parameters',
'other parameters',
'returns',
'yields',
'receives',
'raises',
'warns',
'attributes',
'functions',
'methods',
'classes',
'modules',
],
'numpy',
),
]


def _is_call_ctx(annotation: Any) -> bool:
Expand Down
6 changes: 0 additions & 6 deletions tests/test_retrievers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import sys
from typing import Annotated

import pytest
Expand Down Expand Up @@ -107,10 +106,6 @@ def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
return str(foo)


IS_PY313 = sys.version_info[:2] >= (3, 13)


@pytest.mark.skipif(IS_PY313, reason='https://github.com/mkdocstrings/griffe/issues/329')
def test_docstring_sphinx():
agent = Agent(FunctionModel(get_json_schema), deps=None)
agent.retriever_plain(sphinx_style_docstring)
Expand Down Expand Up @@ -144,7 +139,6 @@ def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover
return f'{foo} {bar}'


@pytest.mark.skipif(IS_PY313, reason='https://github.com/mkdocstrings/griffe/issues/329')
def test_docstring_numpy():
agent = Agent(FunctionModel(get_json_schema), deps=None)
agent.retriever_plain(numpy_style_docstring)
Expand Down

0 comments on commit 3354245

Please sign in to comment.