Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding orm datasource and updating http to be more useful #83

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cannula/datasource/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import GraphModel, cacheable, expected_fields

__all__ = [
"GraphModel",
"cacheable",
"expected_fields",
]
52 changes: 52 additions & 0 deletions cannula/datasource/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import dataclasses
import typing

GraphModel = typing.TypeVar("GraphModel")


def cacheable(f):
"""Decorator that is used to allow coroutines to be cached.

Solves the issue of `cannot reuse already awaited coroutine`

Example::

_memoized: dict[str, Awaitable]

async def get(self, pk: str):
cache_key = f"get:{pk}"

@cacheable
async def process_get():
return await session.get(pk)

if results := _memoized.get(cache_key):
return await results

_memoized[cache_key] = process_get()
return await _memoized[cache_key]

# These results will share the same results and not
results = await asyncio.gather(get(1), get(1), get(1))

"""

def wrapped(*args, **kwargs):
r = f(*args, **kwargs)
return asyncio.ensure_future(r)

return wrapped


def expected_fields(obj: typing.Any) -> set[str]:
"""Extract all the fields that are on the object.

This is used when constructing a new instance from a datasource.
"""
if dataclasses.is_dataclass(obj):
return {field.name for field in dataclasses.fields(obj)}

raise ValueError(
"Invalid model for 'GraphModel' must be a dataclass or pydantic model"
)
147 changes: 107 additions & 40 deletions cannula/datasource/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,16 @@

import asyncio
import logging
import types
import typing

import httpx

LOG = logging.getLogger("cannula.datasource.http")

from cannula.datasource import GraphModel, cacheable, expected_fields

# solves the issue of `cannot reuse already awaited coroutine`
def cacheable(f):
def wrapped(*args, **kwargs):
r = f(*args, **kwargs)
return asyncio.ensure_future(r)
LOG = logging.getLogger("cannula.datasource.http")

return wrapped
AnyDict = typing.Dict[typing.Any, typing.Any]
Response = typing.Union[typing.List[AnyDict], AnyDict]


class Request(typing.NamedTuple):
Expand All @@ -38,20 +33,79 @@ class Request(typing.NamedTuple):
headers: typing.Dict = {}


class HTTPDataSource:
class HTTPDataSource(typing.Generic[GraphModel]):
"""
HTTP Data Source

This is modeled after the apollo http datasource. It uses httpx to preform
async requests to any remote service you wish to query.
async requests to any remote service you wish to query. All GET and HEAD
requests will be memoized so that they are only performed once per
graph resolution.

Properties:

* `graph_model`: This is the object type your schema is expecting to respond with.
* `base_url`: Optional base_url to apply to all requests
* `timeout`: Default timeout in seconds for requests (5 seconds)
* `resource_name`: Optional name to use for `__typename` in responses.

Example::

@dataclass(kw_only=True)
class User(UserTypeBase):
id: UUID
name: str

class UserAPI(
HTTPDataSource[User],
graph_model=User,
base_url="https://auth.com",
):

async def get_user(self, id) -> User:
response = await self.get(f"/users/{id}")
return self.model_from_response(response)

You can then add this to your context to make it available to your resolvers. It is
best practice to setup a client for all your http datasources to share in order to
handle auth and use the built in connection pool. First add to your context object::

class Context(cannula.Context):

def __init__(self, client: httpx.AsyncClient) -> None:
self.userAPI = UserAPI(client=client)
self.groupAPI = GroupAPI(client=client)

Next in your graph handler function create a httpx client to use::

@api.post('/graph')
async def graph(
graph_call: Annotated[
GraphQLExec,
Depends(GraphQLDepends(cannula_app)),
],
request: Request,
) -> ExecutionResponse:
# Grab the authorization header and create the client
authorization = request.headers.get('authorization')
headers = {'authorization': authorization}

async with httpx.AsyncClient(headers=headers) as client:
context = Context(client)
return await graph_call(context=context)

Finally you can now use this datasource in your resolver functions like so::

async def resolve_person(
# Using this type hint for the ResolveInfo will make it so that
# we can inspect the `info` object in our editors and find the `user_api`
info: cannula.ResolveInfo[Context],
id: uuid.UUID,
) -> UserType | None:
return await info.context.user_api.get_user(id)
"""

_graph_model: type[GraphModel]
_expected_fields: set[str]
# The base url of this resource
base_url: typing.Optional[str] = None
# A mapping of requests using the cache_key_for_request. Multiple resolvers
Expand All @@ -62,31 +116,32 @@ class HTTPDataSource:
# Timeout for an individual request in seconds.
timeout: int = 5

# Resource name for the type that this datasource returns by default this
# will use the class name of the datasource.
resource_name: typing.Optional[str] = None
def __init_subclass__(
cls,
graph_model: type[GraphModel],
base_url: typing.Optional[str] = None,
timeout: int = 5,
) -> None:
cls._graph_model = graph_model
cls._expected_fields = expected_fields(graph_model)
cls.base_url = base_url
cls.timeout = timeout
return super().__init_subclass__()

def __init__(
self,
request: typing.Any,
client: typing.Optional[httpx.AsyncClient] = None,
):
self.client = client or httpx.AsyncClient()
# close the client if this instance opened it
self._should_close_client = client is None
self.request = request
self.memoized_requests = {}
self.assert_has_resource_name()

def __del__(self):
def __del__(self): # pragma: no cover
if self._should_close_client:
LOG.debug(f"Closing httpx session for {self.resource_name}")
LOG.debug(f"Closing httpx session for {self.__class__.__name__}")
asyncio.ensure_future(self.client.aclose())

def assert_has_resource_name(self) -> None:
if self.resource_name is None:
self.resource_name = self.__class__.__name__

def will_send_request(self, request: Request) -> Request:
"""Hook for subclasses to modify the request before it is sent.

Expand Down Expand Up @@ -120,13 +175,11 @@ def did_receive_error(self, error: Exception, request: Request):
"""Handle errors from the remote resource"""
raise error

def convert_to_object(self, json_obj):
json_obj.update({"__typename": self.resource_name})
return types.SimpleNamespace(**json_obj)

async def did_receive_response(
self, response: httpx.Response, request: Request
) -> typing.Any:
self,
response: httpx.Response,
request: Request,
) -> Response:
"""Hook to alter the response from the server.

example::
Expand All @@ -138,49 +191,47 @@ async def did_receive_response(
return Widget(**response.json())
"""
response.raise_for_status()
return response.json(object_hook=self.convert_to_object)
return response.json()

async def get(self, path: str) -> typing.Any:
async def get(self, path: str) -> Response:
"""Preform a GET request

:param path: path of the request
"""
return await self.fetch("GET", path)

async def post(self, path: str, body: typing.Any) -> typing.Any:
async def post(self, path: str, body: typing.Any) -> Response:
"""Preform a POST request

:param path: path of the request
:param body: body of the request
"""
return await self.fetch("POST", path, body)

async def patch(self, path: str, body: typing.Any) -> typing.Any:
async def patch(self, path: str, body: typing.Any) -> Response:
"""Preform a PATCH request

:param path: path of the request
:param body: body of the request
"""
return await self.fetch("PATCH", path, body)

async def put(self, path: str, body: typing.Any) -> typing.Any:
async def put(self, path: str, body: typing.Any) -> Response:
"""Preform a PUT request

:param path: path of the request
:param body: body of the request
"""
return await self.fetch("PUT", path, body)

async def delete(self, path: str) -> typing.Any:
async def delete(self, path: str) -> Response:
"""Preform a DELETE request

:param path: path of the request
"""
return await self.fetch("DELETE", path)

async def fetch(
self, method: str, path: str, body: typing.Any = None
) -> typing.Any:
async def fetch(self, method: str, path: str, body: typing.Any = None) -> Response:
url = self.get_request_url(path)

request = Request(url, method, body)
Expand All @@ -190,7 +241,7 @@ async def fetch(
cache_key = self.cache_key_for_request(request)

@cacheable
async def process_request():
async def process_request() -> Response:
try:
response = await self.client.request(
request.method,
Expand All @@ -217,3 +268,19 @@ async def process_request():
else:
self.memoized_requests.pop(cache_key, None)
return await process_request()

def model_from_response(self, response: AnyDict, **kwargs) -> GraphModel:
model_kwargs = response.copy()
model_kwargs.update(kwargs)
cleaned_kwargs = {
key: value
for key, value in model_kwargs.items()
if key in self._expected_fields
}
obj = self._graph_model(**cleaned_kwargs)
return obj

def model_list_from_response(
self, response: Response, **kwargs
) -> typing.List[GraphModel]:
return list(map(self.model_from_response, response))
Loading