Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(pydantic v1): exclude specific properties when rich printing #1751

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import os
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, Optional, cast
from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast
from datetime import date, datetime
from typing_extensions import (
Unpack,
Literal,
ClassVar,
Protocol,
Required,
Sequence,
ParamSpec,
TypedDict,
TypeGuard,
Expand Down Expand Up @@ -72,6 +73,8 @@

P = ParamSpec("P")

ReprArgs = Sequence[Tuple[Optional[str], Any]]


@runtime_checkable
class _ConfigProtocol(Protocol):
Expand All @@ -94,6 +97,11 @@ def model_fields_set(self) -> set[str]:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore

@override
def __repr_args__(self) -> ReprArgs:
# we don't want these attributes to be included when something like `rich.print` is used
return [arg for arg in super().__repr_args__() if arg[0] not in {"_request_id", "__exclude_fields__"}]

if TYPE_CHECKING:
_request_id: Optional[str] = None
"""The ID of the request, returned via the X-Request-ID header. Useful for debugging requests and reporting issues to OpenAI.
Expand Down
11 changes: 3 additions & 8 deletions tests/lib/chat/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import io
import inspect
from typing import Any, Iterable
from typing_extensions import TypeAlias

import rich
import pytest
import pydantic

from ...utils import rich_print_str

ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"


Expand All @@ -26,12 +26,7 @@ def __repr_args__(self: pydantic.BaseModel) -> ReprArgs:
with monkeypatch.context() as m:
m.setattr(pydantic.BaseModel, "__repr_args__", __repr_args__)

buf = io.StringIO()

console = rich.console.Console(file=buf, width=120)
console.print(obj)

string = buf.getvalue()
string = rich_print_str(obj)

# we remove all `fn_name.<locals>.` occurences
# so that we can share the same snapshots between
Expand Down
4 changes: 4 additions & 0 deletions tests/test_legacy_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from openai._base_client import FinalRequestOptions
from openai._legacy_response import LegacyAPIResponse

from .utils import rich_print_str


class PydanticModel(pydantic.BaseModel): ...

Expand Down Expand Up @@ -85,6 +87,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
assert "_request_id" not in rich_print_str(obj)
assert "__exclude_fields__" not in rich_print_str(obj)


def test_response_parse_annotated_type(client: OpenAI) -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from openai._streaming import Stream
from openai._base_client import FinalRequestOptions

from .utils import rich_print_str


class ConcreteBaseAPIResponse(APIResponse[bytes]): ...

Expand Down Expand Up @@ -175,6 +177,8 @@ def test_response_basemodel_request_id(client: OpenAI) -> None:
assert obj.foo == "hello!"
assert obj.bar == 2
assert obj.to_dict() == {"foo": "hello!", "bar": 2}
assert "_request_id" not in rich_print_str(obj)
assert "__exclude_fields__" not in rich_print_str(obj)


@pytest.mark.asyncio
Expand Down
13 changes: 13 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import os
import inspect
import traceback
Expand All @@ -8,6 +9,8 @@
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type

import rich

from openai._types import Omit, NoneType
from openai._utils import (
is_dict,
Expand Down Expand Up @@ -138,6 +141,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
assert_type(inner_type, entry) # type: ignore


def rich_print_str(obj: object) -> str:
"""Like `rich.print()` but returns the string instead"""
buf = io.StringIO()

console = rich.console.Console(file=buf, width=120)
console.print(obj)

return buf.getvalue()


@contextlib.contextmanager
def update_env(**new_env: str | Omit) -> Iterator[None]:
old = os.environ.copy()
Expand Down