Skip to content

Commit

Permalink
Adding support for root_value (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmyers authored Jan 20, 2024
1 parent 04e563c commit 43bd2c9
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 47 deletions.
13 changes: 12 additions & 1 deletion cannula/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

LOG = logging.getLogger(__name__)

RootType = typing.TypeVar("RootType", dict, typing.Mapping, covariant=True)


class ParseResults(typing.NamedTuple):
document_ast: DocumentNode
Expand Down Expand Up @@ -172,7 +174,7 @@ def decorator(function):
return decorator


class API:
class API(typing.Generic[RootType]):
"""
Your entry point into the fun filled world of graphql. Just dive right in::
Expand All @@ -191,22 +193,29 @@ def hello(who):
:param schema: GraphQL Schema for this resolver. This can either be a str or `pathlib.Path` object.
:param context: Context class to hold shared state, added to GraphQLResolveInfo object.
:param middleware: List of middleware to enable.
:param root_value: Mapping of operation names to resolver functions.
:param kwargs: Any extra kwargs passed directly to graphql.execute function.
"""

_schema: typing.Union[str, DocumentNode, pathlib.Path]
_resolvers: typing.List[Resolver]
_root_value: typing.Optional[RootType]
_kwargs: typing.Dict[str, typing.Any]

def __init__(
self,
schema: typing.Union[str, DocumentNode, pathlib.Path],
context: typing.Optional[typing.Any] = None,
middleware: typing.List[typing.Any] = [],
root_value: typing.Optional[RootType] = None,
**kwargs,
):
self._context = context or Context
self._resolvers = []
self._schema = schema
self.middleware = middleware
self._root_value = root_value
self._kwargs = kwargs

def query(self, field_name: typing.Optional[str] = None) -> typing.Any:
"""Query Resolver
Expand Down Expand Up @@ -404,6 +413,8 @@ async def call(
context_value=context,
variable_values=variables,
middleware=self.middleware,
root_value=self._root_value,
**self._kwargs,
)
if inspect.isawaitable(result):
return await result
Expand Down
58 changes: 40 additions & 18 deletions cannula/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def parse_field(field: typing.Dict[str, typing.Any], parent: str) -> Field:
default = parse_default(field)
directives = parse_directives(field)
args = parse_args(field)
func_name = f"{parent}__{name}"
func_name = f"{name}{parent}"

return Field(
name=name,
Expand Down Expand Up @@ -186,13 +186,23 @@ def parse_schema(
{field.name}: typing.Optional[{field.value}] = {field.default!r}
"""

object_dict_fields = """\
{field.name}: {field_type}
"""


object_template = """\
@dataclasses.dataclass
class {obj.name}Type:
class {obj.name}TypeBase(abc.ABC):
__typename = "{obj.name}"
{rendered_fields}"""
{rendered_base_fields}
class {obj.name}TypeDict(typing.TypedDict):
{rendered_dict_fields}
{obj.name}Type = typing.Union[{obj.name}TypeBase, {obj.name}TypeDict]
"""


function_args_template = """\
Expand All @@ -203,7 +213,6 @@ class {obj.name}Type:
class {field.func_name}(typing.Protocol):
def __call__(
self,
root: typing.Any,
info: cannula.ResolveInfo,
{rendered_args}) -> typing.Awaitable[{field.value}]:
...
Expand All @@ -215,15 +224,16 @@ def __call__(


operation_template = """\
class {obj.name}Type(typing.TypedDict):
class RootType(typing.TypedDict):
{rendered_fields}"""


base_template = """\
from __future__ import annotations
import typing
import abc
import dataclasses
import typing
from typing_extensions import NotRequired
Expand All @@ -239,11 +249,18 @@ def render_field(field: Field) -> str:
return optional_field_template.format(field=field)


def render_dict_field(field: Field) -> str:
field_type = field.value if field.required else f"NotRequired[{field.value}]"
return object_dict_fields.format(field=field, field_type=field_type)


def render_object(obj: ObjectType) -> str:
rendered_fields = "".join([render_field(f) for f in obj.fields])
rendered_base_fields = "".join([render_field(f) for f in obj.fields])
rendered_dict_fields = "".join([render_dict_field(f) for f in obj.fields])
return object_template.format(
obj=obj,
rendered_fields=rendered_fields,
rendered_base_fields=rendered_base_fields,
rendered_dict_fields=rendered_dict_fields,
)


Expand All @@ -270,10 +287,9 @@ def render_operation_field(field: Field) -> str:
return operation_field_template.format(field=field)


def render_operation(obj: ObjectType) -> str:
rendered_fields = "".join([render_operation_field(f) for f in obj.fields])
def render_operation(fields: typing.List[Field]) -> str:
rendered_fields = "".join([render_operation_field(f) for f in fields])
return operation_template.format(
obj=obj,
rendered_fields=rendered_fields,
)

Expand All @@ -285,18 +301,24 @@ def render_file(
) -> None:
parsed = parse_schema(type_defs)

objects: typing.List[str] = []
operations: typing.List[str] = []
functions: typing.List[str] = []
object_types: typing.List[ObjectType] = []
operation_fields: typing.List[Field] = []
for obj in parsed.values():
if obj.name in ["Query", "Mutation", "Subscription"]:
operations.append(render_operation(obj))
for field in obj.fields:
functions.append(render_function(field))
operation_fields.append(field)
else:
objects.append(render_object(obj))
object_types.append(obj)

object_types.sort(key=lambda o: o.name)
objects = [render_object(obj) for obj in object_types]

operation_fields.sort(key=lambda f: f.name)
operations = render_operation(operation_fields)
functions = [render_function(field) for field in operation_fields]

rendered_items = "\n\n".join(objects + functions + operations)
rendered_objects = "\n\n".join(objects + functions)
rendered_items = f"{rendered_objects}\n\n{operations}"
content = base_template.format(
rendered_items=rendered_items,
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ classifiers = [
]
dependencies = [
"graphql-core==3.2.3",
"typing-extensions==4.9.0",
]

[project.optional-dependencies]
Expand Down
77 changes: 49 additions & 28 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,78 +52,99 @@
expected_output = """\
from __future__ import annotations
import typing
import abc
import dataclasses
import typing
from typing_extensions import NotRequired
import cannula
@dataclasses.dataclass
class SenderType:
__typename = "Sender"
class EmailSearchTypeBase(abc.ABC):
__typename = "EmailSearch"
name: typing.Optional[str] = None
email: str
limit: typing.Optional[int] = 100
other: typing.Optional[str] = 'blah'
include: typing.Optional[bool] = False
class EmailSearchTypeDict(typing.TypedDict):
email: str
limit: NotRequired[int]
other: NotRequired[str]
include: NotRequired[bool]
EmailSearchType = typing.Union[EmailSearchTypeBase, EmailSearchTypeDict]
@dataclasses.dataclass
class MessageType:
class MessageTypeBase(abc.ABC):
__typename = "Message"
text: typing.Optional[str] = None
sender: typing.Optional[SenderType] = None
class MessageTypeDict(typing.TypedDict):
text: NotRequired[str]
sender: NotRequired[SenderType]
MessageType = typing.Union[MessageTypeBase, MessageTypeDict]
@dataclasses.dataclass
class EmailSearchType:
__typename = "EmailSearch"
class SenderTypeBase(abc.ABC):
__typename = "Sender"
name: typing.Optional[str] = None
email: str
limit: typing.Optional[int] = 100
other: typing.Optional[str] = 'blah'
include: typing.Optional[bool] = False
class Query__messages(typing.Protocol):
def __call__(
self,
root: typing.Any,
info: cannula.ResolveInfo,
limit: int,
) -> typing.Awaitable[typing.List[MessageType]]:
...
class SenderTypeDict(typing.TypedDict):
name: NotRequired[str]
email: str
SenderType = typing.Union[SenderTypeBase, SenderTypeDict]
class Query__get_sender_by_email(typing.Protocol):
class get_sender_by_emailQuery(typing.Protocol):
def __call__(
self,
root: typing.Any,
info: cannula.ResolveInfo,
input: typing.Optional[EmailSearchType] = None,
) -> typing.Awaitable[SenderType]:
...
class Mutation__message(typing.Protocol):
class messageMutation(typing.Protocol):
def __call__(
self,
root: typing.Any,
info: cannula.ResolveInfo,
text: str,
sender: str,
) -> typing.Awaitable[MessageType]:
...
class QueryType(typing.TypedDict):
messages: NotRequired[Query__messages]
get_sender_by_email: NotRequired[Query__get_sender_by_email]
class messagesQuery(typing.Protocol):
def __call__(
self,
info: cannula.ResolveInfo,
limit: int,
) -> typing.Awaitable[typing.List[MessageType]]:
...
class MutationType(typing.TypedDict):
message: NotRequired[Mutation__message]
class RootType(typing.TypedDict):
get_sender_by_email: NotRequired[get_sender_by_emailQuery]
message: NotRequired[messageMutation]
messages: NotRequired[messagesQuery]
"""


Expand All @@ -147,7 +168,7 @@ async def test_parse_schema_dict():
),
],
args=[],
func_name="Test__name",
func_name="nameTest",
default=None,
required=False,
)
Expand Down

0 comments on commit 43bd2c9

Please sign in to comment.