From b4f887912981d73c552780d2e46ff2c62de499e2 Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Sun, 10 Nov 2024 22:22:22 -0600 Subject: [PATCH 1/4] Adding orm datasource and updating http to be more useful --- cannula/datasource/__init__.py | 7 ++ cannula/datasource/base.py | 54 ++++++++++++ cannula/datasource/http.py | 145 ++++++++++++++++++++++-------- cannula/datasource/orm.py | 157 +++++++++++++++++++++++++++++++++ docs/ref/datasources.rst | 3 + examples/http_datasource.py | 45 +++++----- tests/datasource/test_http.py | 18 ++-- tests/test_utils.py | 1 + 8 files changed, 366 insertions(+), 64 deletions(-) create mode 100644 cannula/datasource/base.py create mode 100644 cannula/datasource/orm.py diff --git a/cannula/datasource/__init__.py b/cannula/datasource/__init__.py index e69de29..a8b4c06 100644 --- a/cannula/datasource/__init__.py +++ b/cannula/datasource/__init__.py @@ -0,0 +1,7 @@ +from .base import GraphModel, cacheable, expected_fields + +__all__ = [ + "GraphModel", + "cacheable", + "expected_fields", +] diff --git a/cannula/datasource/base.py b/cannula/datasource/base.py new file mode 100644 index 0000000..00cec83 --- /dev/null +++ b/cannula/datasource/base.py @@ -0,0 +1,54 @@ +import asyncio +import dataclasses +import typing + +GraphModel = typing.TypeVar("GraphModel") + + +def cacheable(f): + """Decorator that is used to allow coroutines to be cached. + + Solves the issue of `cannot reuse already awaited coroutine` + + Example:: + + _memoized: dict[str, Awaitable] + + async def get(self, pk: str): + cache_key = f"get:{pk}" + + @cacheable + async def process_get(): + return await session.get(pk) + + if results := _memoized.get(cache_key): + return await results + + _memoized[cache_key] = process_get() + return await _memoized[cache_key] + + # These results will share the same results and not + results = await asyncio.gather(get(1), get(1), get(1)) + + """ + + def wrapped(*args, **kwargs): + r = f(*args, **kwargs) + return asyncio.ensure_future(r) + + return wrapped + + +def expected_fields(obj: typing.Any) -> set[str]: + """Extract all the fields that are on the object. + + This is used when constructing a new instance from a datasource. + """ + if dataclasses.is_dataclass(obj): + return {field.name for field in dataclasses.fields(obj)} + elif hasattr(obj, "model_fields"): + return {obj.model_fields.keys()} + + raise ValueError( + "Invalid model for 'GraphModel' must be a dataclass or pydantic model" + ) diff --git a/cannula/datasource/http.py b/cannula/datasource/http.py index cbd13b1..8490169 100644 --- a/cannula/datasource/http.py +++ b/cannula/datasource/http.py @@ -14,21 +14,16 @@ import asyncio import logging -import types import typing import httpx -LOG = logging.getLogger("cannula.datasource.http") - +from cannula.datasource import GraphModel, cacheable, expected_fields -# solves the issue of `cannot reuse already awaited coroutine` -def cacheable(f): - def wrapped(*args, **kwargs): - r = f(*args, **kwargs) - return asyncio.ensure_future(r) +LOG = logging.getLogger("cannula.datasource.http") - return wrapped +AnyDict = typing.Dict[typing.Any, typing.Any] +Response = typing.Union[typing.List[AnyDict], AnyDict] class Request(typing.NamedTuple): @@ -38,20 +33,79 @@ class Request(typing.NamedTuple): headers: typing.Dict = {} -class HTTPDataSource: +class HTTPDataSource(typing.Generic[GraphModel]): """ HTTP Data Source This is modeled after the apollo http datasource. It uses httpx to preform - async requests to any remote service you wish to query. + async requests to any remote service you wish to query. All GET and HEAD + requests will be memoized so that they are only performed once per + graph resolution. Properties: + * `graph_model`: This is the object type your schema is expecting to respond with. * `base_url`: Optional base_url to apply to all requests * `timeout`: Default timeout in seconds for requests (5 seconds) - * `resource_name`: Optional name to use for `__typename` in responses. + + Example:: + + @dataclass(kw_only=True) + class User(UserTypeBase): + id: UUID + name: str + + class UserAPI( + HTTPDataSource[User], + graph_model=User, + base_url="https://auth.com", + ): + + async def get_user(self, id) -> User: + response = await self.get(f"/users/{id}") + return self.model_from_response(response) + + You can then add this to your context to make it available to your resolvers. It is + best practice to setup a client for all your http datasources to share in order to + handle auth and use the built in connection pool. First add to your context object:: + + class Context(cannula.Context): + + def __init__(self, client: httpx.AsyncClient) -> None: + self.userAPI = UserAPI(client=client) + self.groupAPI = GroupAPI(client=client) + + Next in your graph handler function create a httpx client to use:: + + @api.post('/graph') + async def graph( + graph_call: Annotated[ + GraphQLExec, + Depends(GraphQLDepends(cannula_app)), + ], + request: Request, + ) -> ExecutionResponse: + # Grab the authorization header and create the client + authorization = request.headers.get('authorization') + headers = {'authorization': authorization} + + async with httpx.AsyncClient(headers=headers) as client: + context = Context(client) + return await graph_call(context=context) + + Finally you can now use this datasource in your resolver functions like so:: + + async def resolve_person( + # Using this type hint for the ResolveInfo will make it so that + # we can inspect the `info` object in our editors and find the `user_api` + info: cannula.ResolveInfo[Context], + id: uuid.UUID, + ) -> UserType | None: + return await info.context.user_api.get_user(id) """ + _graph_model: type[GraphModel] + _expected_fields: set[str] # The base url of this resource base_url: typing.Optional[str] = None # A mapping of requests using the cache_key_for_request. Multiple resolvers @@ -62,31 +116,32 @@ class HTTPDataSource: # Timeout for an individual request in seconds. timeout: int = 5 - # Resource name for the type that this datasource returns by default this - # will use the class name of the datasource. - resource_name: typing.Optional[str] = None + def __init_subclass__( + cls, + graph_model: type[GraphModel], + base_url: typing.Optional[str] = None, + timeout: int = 5, + ) -> None: + cls._graph_model = graph_model + cls._expected_fields = expected_fields(graph_model) + cls.base_url = base_url + cls.timeout = timeout + return super().__init_subclass__() def __init__( self, - request: typing.Any, client: typing.Optional[httpx.AsyncClient] = None, ): self.client = client or httpx.AsyncClient() # close the client if this instance opened it self._should_close_client = client is None - self.request = request self.memoized_requests = {} - self.assert_has_resource_name() def __del__(self): if self._should_close_client: - LOG.debug(f"Closing httpx session for {self.resource_name}") + LOG.debug(f"Closing httpx session for {self.__class__.__name__}") asyncio.ensure_future(self.client.aclose()) - def assert_has_resource_name(self) -> None: - if self.resource_name is None: - self.resource_name = self.__class__.__name__ - def will_send_request(self, request: Request) -> Request: """Hook for subclasses to modify the request before it is sent. @@ -120,13 +175,11 @@ def did_receive_error(self, error: Exception, request: Request): """Handle errors from the remote resource""" raise error - def convert_to_object(self, json_obj): - json_obj.update({"__typename": self.resource_name}) - return types.SimpleNamespace(**json_obj) - async def did_receive_response( - self, response: httpx.Response, request: Request - ) -> typing.Any: + self, + response: httpx.Response, + request: Request, + ) -> Response: """Hook to alter the response from the server. example:: @@ -138,16 +191,16 @@ async def did_receive_response( return Widget(**response.json()) """ response.raise_for_status() - return response.json(object_hook=self.convert_to_object) + return response.json() - async def get(self, path: str) -> typing.Any: + async def get(self, path: str) -> Response: """Preform a GET request :param path: path of the request """ return await self.fetch("GET", path) - async def post(self, path: str, body: typing.Any) -> typing.Any: + async def post(self, path: str, body: typing.Any) -> Response: """Preform a POST request :param path: path of the request @@ -155,7 +208,7 @@ async def post(self, path: str, body: typing.Any) -> typing.Any: """ return await self.fetch("POST", path, body) - async def patch(self, path: str, body: typing.Any) -> typing.Any: + async def patch(self, path: str, body: typing.Any) -> Response: """Preform a PATCH request :param path: path of the request @@ -163,7 +216,7 @@ async def patch(self, path: str, body: typing.Any) -> typing.Any: """ return await self.fetch("PATCH", path, body) - async def put(self, path: str, body: typing.Any) -> typing.Any: + async def put(self, path: str, body: typing.Any) -> Response: """Preform a PUT request :param path: path of the request @@ -171,16 +224,14 @@ async def put(self, path: str, body: typing.Any) -> typing.Any: """ return await self.fetch("PUT", path, body) - async def delete(self, path: str) -> typing.Any: + async def delete(self, path: str) -> Response: """Preform a DELETE request :param path: path of the request """ return await self.fetch("DELETE", path) - async def fetch( - self, method: str, path: str, body: typing.Any = None - ) -> typing.Any: + async def fetch(self, method: str, path: str, body: typing.Any = None) -> Response: url = self.get_request_url(path) request = Request(url, method, body) @@ -190,7 +241,7 @@ async def fetch( cache_key = self.cache_key_for_request(request) @cacheable - async def process_request(): + async def process_request() -> Response: try: response = await self.client.request( request.method, @@ -217,3 +268,19 @@ async def process_request(): else: self.memoized_requests.pop(cache_key, None) return await process_request() + + def model_from_response(self, response: AnyDict, **kwargs) -> GraphModel: + model_kwargs = response.copy() + model_kwargs.update(kwargs) + cleaned_kwargs = { + key: value + for key, value in model_kwargs.items() + if key in self._expected_fields + } + obj = self._graph_model(**cleaned_kwargs) + return obj + + def model_list_from_response( + self, response: Response, **kwargs + ) -> typing.List[GraphModel]: + return list(map(self.model_from_response, response)) diff --git a/cannula/datasource/orm.py b/cannula/datasource/orm.py new file mode 100644 index 0000000..fc6f9a7 --- /dev/null +++ b/cannula/datasource/orm.py @@ -0,0 +1,157 @@ +import logging + +import uuid +from typing import ( + Any, + Awaitable, + Dict, + Generic, + TypeVar, +) + +from sqlalchemy import BinaryExpression, ColumnExpressionArgument, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import DeclarativeBase + + +from cannula.datasource import GraphModel, cacheable, expected_fields + + +DBModel = TypeVar("DBModel", bound=DeclarativeBase) + +LOG = logging.getLogger(__name__) + + +class DatabaseRepository(Generic[DBModel, GraphModel]): + """Repository for performing database queries.""" + + _memoized_get: Dict[str, Awaitable[DBModel | None]] + _memoized_list: Dict[str, Awaitable[list[DBModel]]] + _db_model: type[DBModel] + _graph_model: type[GraphModel] + _expected_fields: set[str] + + def __init_subclass__( + cls, *, db_model: type[DBModel], graph_model: type[GraphModel] + ) -> None: + cls._db_model = db_model + cls._graph_model = graph_model + cls._expected_fields = expected_fields(graph_model) + return super().__init_subclass__() + + def __init__( + self, + session_maker: async_sessionmaker[AsyncSession], + readonly_session_maker: async_sessionmaker[AsyncSession] | None = None, + ) -> None: + self.session_maker = session_maker + self.readonly_session_maker = readonly_session_maker or session_maker + self._memoized_get = {} + self._memoized_list = {} + + def from_db(self, db_obj: DBModel, **kwargs) -> GraphModel: + model_kwargs = db_obj.__dict__.copy() + model_kwargs.update(kwargs) + cleaned_kwargs = { + key: value + for key, value in model_kwargs.items() + if key in self._expected_fields + } + obj = self._graph_model(**cleaned_kwargs) + return obj + + async def add(self, **data: Any) -> GraphModel: + async with self.session_maker() as session: + instance = self._db_model(**data) + session.add(instance) + await session.commit() + await session.refresh(instance) + return self.from_db(instance) + + async def get_by_pk(self, pk: uuid.UUID) -> DBModel | None: + cache_key = f"get:{pk}" + + @cacheable + async def process_get(): + async with self.readonly_session_maker() as session: + return await session.get(self._db_model, pk) + + if results := self._memoized_get.get(cache_key): + LOG.error(f"Found cached query for {cache_key}") + return await results + + self._memoized_get[cache_key] = process_get() + return await self._memoized_get[cache_key] + + async def get_by_query( + self, *expressions: BinaryExpression | ColumnExpressionArgument + ) -> DBModel | None: + query = select(self._db_model).where(*expressions) + + # Get the query as a string with bound values + cache_key = str(query.compile(compile_kwargs={"literal_binds": True})) + + @cacheable + async def process_get(): + async with self.readonly_session_maker() as session: + results = await session.scalars(query) + return results.one_or_none() + + if results := self._memoized_get.get(cache_key): + LOG.error(f"Found cached query for {cache_key}") + return await results + + self._memoized_get[cache_key] = process_get() + return await self._memoized_get[cache_key] + + async def get_model(self, pk: uuid.UUID) -> GraphModel | None: + if db_obj := await self.get_by_pk(pk): + return self.from_db(db_obj) + + async def get_model_by_query( + self, *expressions: BinaryExpression | ColumnExpressionArgument + ) -> GraphModel | None: + if db_obj := await self.get_by_query(*expressions): + return self.from_db(db_obj) + + async def filter( + self, + *expressions: BinaryExpression | ColumnExpressionArgument, + limit: int = 100, + offset: int = 0, + ) -> list[DBModel]: + query = select(self._db_model).limit(limit).offset(offset) + if expressions: + query = query.where(*expressions) + + # Get the query as a string with bound values + cache_key = str(query.compile(compile_kwargs={"literal_binds": True})) + + @cacheable + async def process_filter(): + async with self.readonly_session_maker() as session: + # If we don't convert this to a list only the first + # coroutine that awaits this will be able to read the data. + return list(await session.scalars(query)) + + if results := self._memoized_list.get(cache_key): + LOG.error(cache_key) + LOG.error(f"\nfound cached results for {self.__class__.__name__}\n") + return await results + + LOG.error(f"Caching data for {self.__class__.__name__}") + self._memoized_list[cache_key] = process_filter() + return await self._memoized_list[cache_key] + + async def get_models( + self, + *expressions: BinaryExpression | ColumnExpressionArgument, + limit: int = 100, + offset: int = 0, + ) -> list[GraphModel]: + return list( + map( + self.from_db, + await self.filter(*expressions, limit=limit, offset=offset), + ) + ) diff --git a/docs/ref/datasources.rst b/docs/ref/datasources.rst index edb35ca..84ab24b 100644 --- a/docs/ref/datasources.rst +++ b/docs/ref/datasources.rst @@ -24,3 +24,6 @@ that for longer you'll need to implement that yourself. .. automodule:: cannula.datasource.http :members: + +.. automodule:: cannula.datasource.orm + :members: \ No newline at end of file diff --git a/examples/http_datasource.py b/examples/http_datasource.py index 3bbea29..d0908e0 100644 --- a/examples/http_datasource.py +++ b/examples/http_datasource.py @@ -1,5 +1,6 @@ import logging import typing +from dataclasses import dataclass import cannula import fastapi @@ -19,10 +20,6 @@ async def get_widgets(): return [{"name": "hammer", "type": "tool"}] -# Create a httpx client that responds with the 'remote_app' -client = httpx.AsyncClient(transport=httpx.ASGITransport(app=remote_app)) - - SCHEMA = cannula.gql( """ type Widget { @@ -38,32 +35,37 @@ async def get_widgets(): ) +@dataclass +class Widget: + name: str + type: str + + # Our actual datasource object -class WidgetDatasource(http.HTTPDataSource): - # set our base url to work with the demo fastapi app - base_url = "http://localhost" +class WidgetDatasource( + http.HTTPDataSource[Widget], graph_model=Widget, base_url="http://localhost" +): - async def get_widgets(self): - return await self.get("/widgets") + async def get_widgets(self) -> list[Widget]: + response = await self.get("/widgets") + return self.model_list_from_response(response) # Create a custom context and add the datasource class CustomContext(Context): - widget_datasource: WidgetDatasource - def handle_request(self, request: typing.Any) -> typing.Any: - # Initialize the datasource using the request and - # set the client to use the demo client app - self.widget_datasource = WidgetDatasource(request, client=client) - return request + def __init__(self, client) -> None: + self.widget_datasource = WidgetDatasource(client=client) -api = cannula.CannulaAPI(schema=SCHEMA, context=CustomContext) +async def list_widgets(info: ResolveInfo[CustomContext]): + return await info.context.widget_datasource.get_widgets() -@api.query("widgets") -async def list_widgets(parent, info: ResolveInfo[CustomContext]): - return await info.context.widget_datasource.get_widgets() +api = cannula.CannulaAPI( + schema=SCHEMA, + root_value={"widgets": list_widgets}, +) async def main(): @@ -83,7 +85,10 @@ async def main(): """ ) - results = await api.call(query) + # Create a httpx client that responds with the 'remote_app' and add to context + client = httpx.AsyncClient(transport=httpx.ASGITransport(app=remote_app)) + + results = await api.call(query, context=CustomContext(client)) print(results.data, results.errors) return results.data diff --git a/tests/datasource/test_http.py b/tests/datasource/test_http.py index 96c63c0..e89f206 100644 --- a/tests/datasource/test_http.py +++ b/tests/datasource/test_http.py @@ -1,9 +1,15 @@ +from dataclasses import dataclass import httpx import fastapi from cannula.datasource import http +@dataclass +class Widgety: + some: str + + class MockDB: # This is just a object that we can use in assertions # that it was called the specific number of times @@ -21,15 +27,17 @@ async def widget(): async def test_http_datasource(mocker): - class Widget(http.HTTPDataSource): - base_url = "http://localhost/" + class Widget( + http.HTTPDataSource[Widgety], graph_model=Widgety, base_url="http://localhost" + ): - async def get_widgets(self): - return await self.get("widgets") + async def get_widgets(self) -> list[Widgety]: + response = await self.get("widgets") + return list(map(self.model_from_response, response)) get_widget_spy = mocker.spy(mockDB, "get_widgets") mocked_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=fake_app)) - widget = Widget(request=mocker.Mock(), client=mocked_client) + widget = Widget(client=mocked_client) results_one = await widget.get_widgets() results_two = await widget.get_widgets() diff --git a/tests/test_utils.py b/tests/test_utils.py index 153ee4b..e1e6163 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,4 @@ +from __future__ import annotations from cannula import utils From 6fb136e66dc9a2373ffcacad7b852572a8cfee4a Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Mon, 11 Nov 2024 08:40:50 -0600 Subject: [PATCH 2/4] Adding tests for orm --- cannula/datasource/orm.py | 22 ++++----- pyproject.toml | 1 + tests/datasource/test_orm.py | 95 ++++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 tests/datasource/test_orm.py diff --git a/cannula/datasource/orm.py b/cannula/datasource/orm.py index fc6f9a7..c5948be 100644 --- a/cannula/datasource/orm.py +++ b/cannula/datasource/orm.py @@ -1,12 +1,13 @@ import logging -import uuid from typing import ( Any, Awaitable, Dict, Generic, TypeVar, + Tuple, + Union, ) from sqlalchemy import BinaryExpression, ColumnExpressionArgument, select @@ -18,6 +19,7 @@ DBModel = TypeVar("DBModel", bound=DeclarativeBase) +_PKIdentityArgument = Union[Any, Tuple[Any, ...]] LOG = logging.getLogger(__name__) @@ -68,16 +70,16 @@ async def add(self, **data: Any) -> GraphModel: await session.refresh(instance) return self.from_db(instance) - async def get_by_pk(self, pk: uuid.UUID) -> DBModel | None: + async def get_by_pk(self, pk: _PKIdentityArgument) -> DBModel | None: cache_key = f"get:{pk}" @cacheable - async def process_get(): + async def process_get() -> DBModel | None: async with self.readonly_session_maker() as session: return await session.get(self._db_model, pk) if results := self._memoized_get.get(cache_key): - LOG.error(f"Found cached query for {cache_key}") + LOG.debug(f"Found cached query for {cache_key}") return await results self._memoized_get[cache_key] = process_get() @@ -92,19 +94,19 @@ async def get_by_query( cache_key = str(query.compile(compile_kwargs={"literal_binds": True})) @cacheable - async def process_get(): + async def process_get() -> DBModel | None: async with self.readonly_session_maker() as session: results = await session.scalars(query) return results.one_or_none() if results := self._memoized_get.get(cache_key): - LOG.error(f"Found cached query for {cache_key}") + LOG.debug(f"Found cached query for {self.__class__.__name__}") return await results self._memoized_get[cache_key] = process_get() return await self._memoized_get[cache_key] - async def get_model(self, pk: uuid.UUID) -> GraphModel | None: + async def get_model(self, pk: _PKIdentityArgument) -> GraphModel | None: if db_obj := await self.get_by_pk(pk): return self.from_db(db_obj) @@ -128,18 +130,16 @@ async def filter( cache_key = str(query.compile(compile_kwargs={"literal_binds": True})) @cacheable - async def process_filter(): + async def process_filter() -> list[DBModel]: async with self.readonly_session_maker() as session: # If we don't convert this to a list only the first # coroutine that awaits this will be able to read the data. return list(await session.scalars(query)) if results := self._memoized_list.get(cache_key): - LOG.error(cache_key) - LOG.error(f"\nfound cached results for {self.__class__.__name__}\n") + LOG.debug(f"Found cached results for {self.__class__.__name__}") return await results - LOG.error(f"Caching data for {self.__class__.__name__}") self._memoized_list[cache_key] = process_filter() return await self._memoized_list[cache_key] diff --git a/pyproject.toml b/pyproject.toml index 428e530..8605fc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ test = [ "Sphinx==8.0.2", "sphinx-autodoc-typehints", "sqlalchemy==2.0.36", + "aiosqlite==0.19.0", "pydata-sphinx-theme", "twine==5.1.1", "types-python-dateutil", diff --git a/tests/datasource/test_orm.py b/tests/datasource/test_orm.py new file mode 100644 index 0000000..48ac9a7 --- /dev/null +++ b/tests/datasource/test_orm.py @@ -0,0 +1,95 @@ +import dataclasses +import pytest +from sqlalchemy.ext.asyncio import AsyncAttrs, create_async_engine, async_sessionmaker +from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped + +from cannula.datasource.orm import DatabaseRepository + +database_uri = "sqlite+aiosqlite:///:memory:" +engine = create_async_engine(database_uri, echo=True) +session = async_sessionmaker(engine, expire_on_commit=False) + + +class Base(AsyncAttrs, DeclarativeBase): + pass + + +class DBUser(Base): + __tablename__ = "users" + id: Mapped[int] = mapped_column(primary_key=True) + email: Mapped[str] = mapped_column(unique=True) + name: Mapped[str] + password: Mapped[str] + + +@dataclasses.dataclass +class User: + id: int + email: str | None + name: str | None + password: str | None + + +async def create_tables() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def drop_tables() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +class UserRepository( + DatabaseRepository[DBUser, User], + db_model=DBUser, + graph_model=User, +): + pass + + +@pytest.fixture(autouse=True) +async def db_session(): + await create_tables() + yield + await drop_tables() + + +async def test_orm_defaults(mocker): + mock_logger = mocker.patch("cannula.datasource.orm.LOG") + users = UserRepository(session) + new_user = await users.add(id=1, name="test", email="u@c.com", password="secret") + assert new_user.id == 1 + + all_users = await users.get_models() + assert len(all_users) == 1 + assert all_users[0].name == "test" + + specific_user = await users.get_model(1) + assert specific_user is not None + assert specific_user.name == "test" + + specific_user_again = await users.get_model(1) + assert specific_user_again is not None + mock_logger.debug.assert_called_with("Found cached query for get:1") + mock_logger.reset() + + not_found = await users.get_model(2) + assert not_found is None + + query_user = await users.get_model_by_query(DBUser.email == "u@c.com") + assert query_user is not None + assert query_user.password == "secret" + + query_user_again = await users.get_model_by_query(DBUser.email == "u@c.com") + assert query_user_again is not None + mock_logger.debug.assert_called_with("Found cached query for UserRepository") + mock_logger.reset() + + filter_users = await users.get_models(DBUser.name == "test") + assert len(filter_users) == 1 + + filter_users_again = await users.get_models(DBUser.name == "test") + assert len(filter_users_again) == 1 + mock_logger.debug.assert_called_with("Found cached results for UserRepository") + mock_logger.reset() From 38bccd701b5e7352683425d84cc94d5856e62857 Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Mon, 11 Nov 2024 08:42:48 -0600 Subject: [PATCH 3/4] fixing types --- cannula/datasource/orm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cannula/datasource/orm.py b/cannula/datasource/orm.py index c5948be..a39ab56 100644 --- a/cannula/datasource/orm.py +++ b/cannula/datasource/orm.py @@ -109,12 +109,14 @@ async def process_get() -> DBModel | None: async def get_model(self, pk: _PKIdentityArgument) -> GraphModel | None: if db_obj := await self.get_by_pk(pk): return self.from_db(db_obj) + return None async def get_model_by_query( self, *expressions: BinaryExpression | ColumnExpressionArgument ) -> GraphModel | None: if db_obj := await self.get_by_query(*expressions): return self.from_db(db_obj) + return None async def filter( self, From b0500ef1de95a52658363be2be70a3684821c970 Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Mon, 11 Nov 2024 08:53:52 -0600 Subject: [PATCH 4/4] Adding coverage --- cannula/datasource/base.py | 2 -- cannula/datasource/http.py | 2 +- tests/datasource/test_orm.py | 20 ++++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/cannula/datasource/base.py b/cannula/datasource/base.py index 00cec83..34e287b 100644 --- a/cannula/datasource/base.py +++ b/cannula/datasource/base.py @@ -46,8 +46,6 @@ def expected_fields(obj: typing.Any) -> set[str]: """ if dataclasses.is_dataclass(obj): return {field.name for field in dataclasses.fields(obj)} - elif hasattr(obj, "model_fields"): - return {obj.model_fields.keys()} raise ValueError( "Invalid model for 'GraphModel' must be a dataclass or pydantic model" diff --git a/cannula/datasource/http.py b/cannula/datasource/http.py index 8490169..3529541 100644 --- a/cannula/datasource/http.py +++ b/cannula/datasource/http.py @@ -137,7 +137,7 @@ def __init__( self._should_close_client = client is None self.memoized_requests = {} - def __del__(self): + def __del__(self): # pragma: no cover if self._should_close_client: LOG.debug(f"Closing httpx session for {self.__class__.__name__}") asyncio.ensure_future(self.client.aclose()) diff --git a/tests/datasource/test_orm.py b/tests/datasource/test_orm.py index 48ac9a7..7f9eb74 100644 --- a/tests/datasource/test_orm.py +++ b/tests/datasource/test_orm.py @@ -86,6 +86,9 @@ async def test_orm_defaults(mocker): mock_logger.debug.assert_called_with("Found cached query for UserRepository") mock_logger.reset() + query_not_found = await users.get_model_by_query(DBUser.email == "not-found") + assert query_not_found is None + filter_users = await users.get_models(DBUser.name == "test") assert len(filter_users) == 1 @@ -93,3 +96,20 @@ async def test_orm_defaults(mocker): assert len(filter_users_again) == 1 mock_logger.debug.assert_called_with("Found cached results for UserRepository") mock_logger.reset() + + +async def test_invalid_graph_model(): + class NotCorrect: + pass + + with pytest.raises( + ValueError, + match="Invalid model for 'GraphModel' must be a dataclass or pydantic model", + ): + + class InvalidRepository( + DatabaseRepository[DBUser, NotCorrect], + db_model=DBUser, + graph_model=NotCorrect, + ): + pass