Skip to content

Commit

Permalink
Adding initial support for custom scalars (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmyers authored Oct 2, 2024
1 parent cefb2a6 commit 82ea96c
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 139 deletions.
44 changes: 20 additions & 24 deletions cannula/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def render_function_args_ast(
return pos_args_ast, kwonly_args_ast, defaults


def render_computed_field_ast(field: Field) -> ast.FunctionDef:
def render_computed_field_ast(field: Field) -> ast.AsyncFunctionDef:
"""
Render a computed field as an AST node for a function definition.
"""
Expand All @@ -136,7 +136,6 @@ def render_computed_field_ast(field: Field) -> ast.FunctionDef:
ast.arg("info", annotation=ast_for_name("cannula.ResolveInfo")),
*pos_args,
]
value = field.value if field.required else f"Optional[{field.value}]"
args_node = ast.arguments(
args=[*args],
vararg=None,
Expand All @@ -146,18 +145,18 @@ def render_computed_field_ast(field: Field) -> ast.FunctionDef:
kwarg=None,
defaults=[],
)
func_node = ast.FunctionDef(
func_node = ast.AsyncFunctionDef(
name=field.name,
args=args_node,
body=[ast.Pass()], # Placeholder for the function body
decorator_list=[ast.Name(id="abc.abstractmethod", ctx=ast.Load())],
returns=ast.Name(id=f"Awaitable[{value}]", ctx=ast.Load()),
returns=ast.Name(id=field.type, ctx=ast.Load()),
lineno=None, # type: ignore
)
return func_node


def render_operation_field_ast(field: Field) -> ast.FunctionDef:
def render_operation_field_ast(field: Field) -> ast.AsyncFunctionDef:
"""
Render a computed field as an AST node for a function definition.
"""
Expand All @@ -177,12 +176,12 @@ def render_operation_field_ast(field: Field) -> ast.FunctionDef:
kwarg=None,
defaults=[],
)
func_node = ast.FunctionDef(
func_node = ast.AsyncFunctionDef(
name="__call__",
args=args_node,
body=[ELLIPSIS], # Placeholder for the function body
decorator_list=[],
returns=ast.Name(id=f"Awaitable[{field.value}]", ctx=ast.Load()),
returns=ast.Name(id=field.value, ctx=ast.Load()),
lineno=None, # type: ignore
)
return func_node
Expand Down Expand Up @@ -343,8 +342,7 @@ def parse_schema(


def ast_for_class_field(field: Field) -> ast.AnnAssign:
field_type_str = field.value if field.required else f"Optional[{field.value}]"
field_type = ast_for_name(field_type_str)
field_type = ast_for_name(field.type)

# Handle the defaults properly. When the field is required we don't want to
# set a default value of `None`. But when it is optional we need to properly
Expand All @@ -359,16 +357,12 @@ def ast_for_class_field(field: Field) -> ast.AnnAssign:


def ast_for_dict_field(field: Field) -> ast.AnnAssign:
field_type_str = field.value if field.required else f"NotRequired[{field.value}]"
field_type = ast_for_name(field_type_str)
field_type = ast_for_name(field.type)
return ast_for_annotation_assignment(field.name, annotation=field_type)


def ast_for_operation_field(field: Field) -> ast.AnnAssign:
field_type_str = (
field.func_name if field.required else f"NotRequired[{field.func_name}]"
)
field_type = ast_for_name(field_type_str)
field_type = ast_for_name(field.operation_type)
return ast_for_annotation_assignment(field.name, annotation=field_type)


Expand Down Expand Up @@ -407,7 +401,7 @@ def render_object(obj: ObjectType) -> typing.List[ast.ClassDef | ast.Assign]:
name=dict_name,
body=[*dict_fields],
bases=[ast_for_name("TypedDict")],
keywords=[],
keywords=[ast.keyword(arg="total", value=ast_for_constant(False))],
decorator_list=[],
),
ast_for_assign(
Expand All @@ -428,13 +422,7 @@ def render_interface(obj: ObjectType) -> typing.List[ast.ClassDef | ast.Assign]:
body=[type_def, *klass_fields],
bases=[ast_for_name("Protocol")],
keywords=[],
decorator_list=[
ast.Call(
func=ast_for_name("dataclass"),
args=[],
keywords=[ast_for_keyword("kw_only", True)],
)
],
decorator_list=[],
)
]

Expand Down Expand Up @@ -466,7 +454,7 @@ def ast_for_root_type(fields: typing.List[Field]) -> ast.ClassDef:
name="RootType",
body=[*dict_fields],
bases=[ast_for_name("TypedDict")],
keywords=[],
keywords=[ast.keyword(arg="total", value=ast_for_constant(False))],
decorator_list=[],
)

Expand All @@ -481,11 +469,14 @@ def render_file(
object_types: typing.List[ObjectType] = []
interface_types: typing.List[ObjectType] = []
union_types: typing.List[ObjectType] = []
scalar_types: typing.List[ObjectType] = []
operation_fields: typing.List[Field] = []
for obj in parsed.values():
if obj.name in ["Query", "Mutation", "Subscription"]:
for field in obj.fields:
operation_fields.append(field)
elif obj.kind == "scalar_type_definition":
scalar_types.append(obj)
elif obj.kind in [
"object_type_definition",
"input_object_type_definition",
Expand All @@ -508,6 +499,7 @@ def render_file(
ast_for_import_from(
"typing",
[
"Any",
"Awaitable",
"List",
"Optional",
Expand All @@ -522,6 +514,10 @@ def render_file(
ast_for_import_from("typing_extensions", ["NotRequired", "TypedDict"])
)

for obj in scalar_types:
# TODO(rmyers): add support for specific types
root.body.append(ast_for_assign(f"{obj.name}Type", ast_for_name("Any")))

for obj in interface_types:
root.body.extend(render_interface(obj))

Expand Down
8 changes: 8 additions & 0 deletions cannula/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ class Field:
default: typing.Any = None
required: bool = False

@property
def type(self) -> str:
return self.value if self.required else f"Optional[{self.value}]"

@property
def operation_type(self) -> str:
return self.func_name if self.required else f"Optional[{self.func_name}]"

@property
def is_computed(self) -> bool:
has_args = bool(self.args)
Expand Down
93 changes: 30 additions & 63 deletions examples/extension/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
import cannula
from abc import ABC
from dataclasses import dataclass
from typing import Awaitable, List, Optional, Protocol, Union
from typing_extensions import NotRequired, TypedDict
from typing import Any, List, Optional, Protocol, Union
from typing_extensions import TypedDict

DatetimeType = Any


class GenericType(Protocol):
__typename = "Generic"
name: Optional[str] = None


@dataclass(kw_only=True)
Expand All @@ -14,92 +21,52 @@ class BookTypeBase(ABC):
author: Optional[str] = None

@abc.abstractmethod
def movies(self, info: cannula.ResolveInfo) -> Awaitable[Optional[List[MovieType]]]:
async def movies(
self, info: cannula.ResolveInfo, *, limit: Optional[int] = 100
) -> Optional[List[MovieType]]:
pass


class BookTypeDict(TypedDict):
movies: NotRequired[List[MovieType]]
name: NotRequired[str]
author: NotRequired[str]
class BookTypeDict(TypedDict, total=False):
movies: Optional[List[MovieType]]
name: Optional[str]
author: Optional[str]


BookType = Union[BookTypeBase, BookTypeDict]


@dataclass(kw_only=True)
class GenericThingTypeBase(ABC):
__typename = "GenericThing"
name: Optional[str] = None


class GenericThingTypeDict(TypedDict):
name: NotRequired[str]


GenericThingType = Union[GenericThingTypeBase, GenericThingTypeDict]


@dataclass(kw_only=True)
class MovieTypeBase(ABC):
__typename = "Movie"
name: Optional[str] = None
director: Optional[str] = None
book: Optional[BookType] = None
views: Optional[int] = None
created: Optional[DatetimeType] = None


class MovieTypeDict(TypedDict):
name: NotRequired[str]
director: NotRequired[str]
book: NotRequired[BookType]
views: NotRequired[int]
class MovieTypeDict(TypedDict, total=False):
name: Optional[str]
director: Optional[str]
book: Optional[BookType]
views: Optional[int]
created: Optional[DatetimeType]


MovieType = Union[MovieTypeBase, MovieTypeDict]


@dataclass(kw_only=True)
class MovieInputTypeBase(ABC):
__typename = "MovieInput"
name: Optional[str] = None
director: Optional[str] = None
limit: Optional[int] = 100


class MovieInputTypeDict(TypedDict):
name: NotRequired[str]
director: NotRequired[str]
limit: NotRequired[int]


MovieInputType = Union[MovieInputTypeBase, MovieInputTypeDict]


class booksQuery(Protocol):
def __call__(self, info: cannula.ResolveInfo) -> Awaitable[List[BookType]]: ...


class createMovieMutation(Protocol):
def __call__(
self, info: cannula.ResolveInfo, *, input: Optional[MovieInputType] = None
) -> Awaitable[MovieType]: ...


class genericQuery(Protocol):
def __call__(
self, info: cannula.ResolveInfo
) -> Awaitable[List[GenericThingType]]: ...
async def __call__(self, info: cannula.ResolveInfo) -> List[BookType]: ...


class moviesQuery(Protocol):
def __call__(
class mediaQuery(Protocol):
async def __call__(
self, info: cannula.ResolveInfo, *, limit: Optional[int] = 100
) -> Awaitable[List[MovieType]]: ...
) -> List[GenericType]: ...


class RootType(TypedDict):
books: NotRequired[booksQuery]
createMovie: NotRequired[createMovieMutation]
generic: NotRequired[genericQuery]
movies: NotRequired[moviesQuery]
class RootType(TypedDict, total=False):
books: Optional[booksQuery]
media: Optional[mediaQuery]
40 changes: 36 additions & 4 deletions examples/extension/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

import cannula
import cannula.middleware
from ._generated import BookType, BookTypeBase, MovieType, RootType
from ._generated import (
BookType,
BookTypeBase,
GenericType,
MovieType,
MovieTypeBase,
RootType,
)

BASE_DIR = pathlib.Path(__file__).parent

Expand All @@ -14,16 +21,30 @@


class Book(BookTypeBase):
async def movies(self, info: cannula.ResolveInfo) -> list[MovieType]:
LOG.info(f"{self.name}")
async def movies(
self, info: cannula.ResolveInfo, *, limit: int | None = 100
) -> list[MovieType] | None:
return [{"name": "Lost the Movie", "director": "Ted"}]


class Movie(MovieTypeBase):
pass


async def get_books(info: cannula.ResolveInfo) -> list[BookType]:
return [Book(name="Lost", author="Frank")]


root_value: RootType = {"books": get_books}
async def get_media(
info: cannula.ResolveInfo, limit: int | None = 100
) -> list[GenericType]:
return [
Book(name="the Best Movies", author="Jane"),
Movie(name="the Best Books", director="Sally"),
]


root_value: RootType = {"books": get_books, "media": get_media}

api = cannula.API[RootType](
root_value=root_value,
Expand All @@ -44,6 +65,16 @@ async def get_books(info: cannula.ResolveInfo) -> list[BookType]:
director
}
}
media {
__typename
name
... on Book {
author
}
... on Movie {
director
}
}
}
"""
)
Expand All @@ -52,3 +83,4 @@ async def get_books(info: cannula.ResolveInfo) -> list[BookType]:
if __name__ == "__main__":
results = api.call_sync(QUERY, None)
pprint.pprint(results.data)
pprint.pprint(results.errors)
7 changes: 3 additions & 4 deletions examples/extension/schema/base.graphql
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
type GenericThing {
scalar Datetime

interface Generic {
name: String
}

type Query {
generic: [GenericThing]
}
2 changes: 1 addition & 1 deletion examples/extension/schema/books.graphql
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
type Book {
type Book implements Generic {
name: String
author: String
}
Expand Down
Loading

0 comments on commit 82ea96c

Please sign in to comment.