From 82ea96cf6e50e2c3f91c75eb9e8b653a9d0b1595 Mon Sep 17 00:00:00 2001 From: Robert Myers Date: Tue, 1 Oct 2024 23:20:02 -0500 Subject: [PATCH] Adding initial support for custom scalars (#61) --- cannula/codegen.py | 44 +++++------ cannula/types.py | 8 ++ examples/extension/_generated.py | 93 ++++++++---------------- examples/extension/main.py | 40 +++++++++- examples/extension/schema/base.graphql | 7 +- examples/extension/schema/books.graphql | 2 +- examples/extension/schema/movies.graphql | 17 +---- tests/test_codegen.py | 66 +++++++++-------- tests/test_examples.py | 14 +++- 9 files changed, 152 insertions(+), 139 deletions(-) diff --git a/cannula/codegen.py b/cannula/codegen.py index 071e3e1..abee986 100644 --- a/cannula/codegen.py +++ b/cannula/codegen.py @@ -126,7 +126,7 @@ def render_function_args_ast( return pos_args_ast, kwonly_args_ast, defaults -def render_computed_field_ast(field: Field) -> ast.FunctionDef: +def render_computed_field_ast(field: Field) -> ast.AsyncFunctionDef: """ Render a computed field as an AST node for a function definition. """ @@ -136,7 +136,6 @@ def render_computed_field_ast(field: Field) -> ast.FunctionDef: ast.arg("info", annotation=ast_for_name("cannula.ResolveInfo")), *pos_args, ] - value = field.value if field.required else f"Optional[{field.value}]" args_node = ast.arguments( args=[*args], vararg=None, @@ -146,18 +145,18 @@ def render_computed_field_ast(field: Field) -> ast.FunctionDef: kwarg=None, defaults=[], ) - func_node = ast.FunctionDef( + func_node = ast.AsyncFunctionDef( name=field.name, args=args_node, body=[ast.Pass()], # Placeholder for the function body decorator_list=[ast.Name(id="abc.abstractmethod", ctx=ast.Load())], - returns=ast.Name(id=f"Awaitable[{value}]", ctx=ast.Load()), + returns=ast.Name(id=field.type, ctx=ast.Load()), lineno=None, # type: ignore ) return func_node -def render_operation_field_ast(field: Field) -> ast.FunctionDef: +def render_operation_field_ast(field: Field) -> ast.AsyncFunctionDef: """ Render a computed field as an AST node for a function definition. """ @@ -177,12 +176,12 @@ def render_operation_field_ast(field: Field) -> ast.FunctionDef: kwarg=None, defaults=[], ) - func_node = ast.FunctionDef( + func_node = ast.AsyncFunctionDef( name="__call__", args=args_node, body=[ELLIPSIS], # Placeholder for the function body decorator_list=[], - returns=ast.Name(id=f"Awaitable[{field.value}]", ctx=ast.Load()), + returns=ast.Name(id=field.value, ctx=ast.Load()), lineno=None, # type: ignore ) return func_node @@ -343,8 +342,7 @@ def parse_schema( def ast_for_class_field(field: Field) -> ast.AnnAssign: - field_type_str = field.value if field.required else f"Optional[{field.value}]" - field_type = ast_for_name(field_type_str) + field_type = ast_for_name(field.type) # Handle the defaults properly. When the field is required we don't want to # set a default value of `None`. But when it is optional we need to properly @@ -359,16 +357,12 @@ def ast_for_class_field(field: Field) -> ast.AnnAssign: def ast_for_dict_field(field: Field) -> ast.AnnAssign: - field_type_str = field.value if field.required else f"NotRequired[{field.value}]" - field_type = ast_for_name(field_type_str) + field_type = ast_for_name(field.type) return ast_for_annotation_assignment(field.name, annotation=field_type) def ast_for_operation_field(field: Field) -> ast.AnnAssign: - field_type_str = ( - field.func_name if field.required else f"NotRequired[{field.func_name}]" - ) - field_type = ast_for_name(field_type_str) + field_type = ast_for_name(field.operation_type) return ast_for_annotation_assignment(field.name, annotation=field_type) @@ -407,7 +401,7 @@ def render_object(obj: ObjectType) -> typing.List[ast.ClassDef | ast.Assign]: name=dict_name, body=[*dict_fields], bases=[ast_for_name("TypedDict")], - keywords=[], + keywords=[ast.keyword(arg="total", value=ast_for_constant(False))], decorator_list=[], ), ast_for_assign( @@ -428,13 +422,7 @@ def render_interface(obj: ObjectType) -> typing.List[ast.ClassDef | ast.Assign]: body=[type_def, *klass_fields], bases=[ast_for_name("Protocol")], keywords=[], - decorator_list=[ - ast.Call( - func=ast_for_name("dataclass"), - args=[], - keywords=[ast_for_keyword("kw_only", True)], - ) - ], + decorator_list=[], ) ] @@ -466,7 +454,7 @@ def ast_for_root_type(fields: typing.List[Field]) -> ast.ClassDef: name="RootType", body=[*dict_fields], bases=[ast_for_name("TypedDict")], - keywords=[], + keywords=[ast.keyword(arg="total", value=ast_for_constant(False))], decorator_list=[], ) @@ -481,11 +469,14 @@ def render_file( object_types: typing.List[ObjectType] = [] interface_types: typing.List[ObjectType] = [] union_types: typing.List[ObjectType] = [] + scalar_types: typing.List[ObjectType] = [] operation_fields: typing.List[Field] = [] for obj in parsed.values(): if obj.name in ["Query", "Mutation", "Subscription"]: for field in obj.fields: operation_fields.append(field) + elif obj.kind == "scalar_type_definition": + scalar_types.append(obj) elif obj.kind in [ "object_type_definition", "input_object_type_definition", @@ -508,6 +499,7 @@ def render_file( ast_for_import_from( "typing", [ + "Any", "Awaitable", "List", "Optional", @@ -522,6 +514,10 @@ def render_file( ast_for_import_from("typing_extensions", ["NotRequired", "TypedDict"]) ) + for obj in scalar_types: + # TODO(rmyers): add support for specific types + root.body.append(ast_for_assign(f"{obj.name}Type", ast_for_name("Any"))) + for obj in interface_types: root.body.extend(render_interface(obj)) diff --git a/cannula/types.py b/cannula/types.py index a7a582c..0200015 100644 --- a/cannula/types.py +++ b/cannula/types.py @@ -28,6 +28,14 @@ class Field: default: typing.Any = None required: bool = False + @property + def type(self) -> str: + return self.value if self.required else f"Optional[{self.value}]" + + @property + def operation_type(self) -> str: + return self.func_name if self.required else f"Optional[{self.func_name}]" + @property def is_computed(self) -> bool: has_args = bool(self.args) diff --git a/examples/extension/_generated.py b/examples/extension/_generated.py index 2dea1a5..26304dd 100644 --- a/examples/extension/_generated.py +++ b/examples/extension/_generated.py @@ -3,8 +3,15 @@ import cannula from abc import ABC from dataclasses import dataclass -from typing import Awaitable, List, Optional, Protocol, Union -from typing_extensions import NotRequired, TypedDict +from typing import Any, List, Optional, Protocol, Union +from typing_extensions import TypedDict + +DatetimeType = Any + + +class GenericType(Protocol): + __typename = "Generic" + name: Optional[str] = None @dataclass(kw_only=True) @@ -14,32 +21,21 @@ class BookTypeBase(ABC): author: Optional[str] = None @abc.abstractmethod - def movies(self, info: cannula.ResolveInfo) -> Awaitable[Optional[List[MovieType]]]: + async def movies( + self, info: cannula.ResolveInfo, *, limit: Optional[int] = 100 + ) -> Optional[List[MovieType]]: pass -class BookTypeDict(TypedDict): - movies: NotRequired[List[MovieType]] - name: NotRequired[str] - author: NotRequired[str] +class BookTypeDict(TypedDict, total=False): + movies: Optional[List[MovieType]] + name: Optional[str] + author: Optional[str] BookType = Union[BookTypeBase, BookTypeDict] -@dataclass(kw_only=True) -class GenericThingTypeBase(ABC): - __typename = "GenericThing" - name: Optional[str] = None - - -class GenericThingTypeDict(TypedDict): - name: NotRequired[str] - - -GenericThingType = Union[GenericThingTypeBase, GenericThingTypeDict] - - @dataclass(kw_only=True) class MovieTypeBase(ABC): __typename = "Movie" @@ -47,59 +43,30 @@ class MovieTypeBase(ABC): director: Optional[str] = None book: Optional[BookType] = None views: Optional[int] = None + created: Optional[DatetimeType] = None -class MovieTypeDict(TypedDict): - name: NotRequired[str] - director: NotRequired[str] - book: NotRequired[BookType] - views: NotRequired[int] +class MovieTypeDict(TypedDict, total=False): + name: Optional[str] + director: Optional[str] + book: Optional[BookType] + views: Optional[int] + created: Optional[DatetimeType] MovieType = Union[MovieTypeBase, MovieTypeDict] -@dataclass(kw_only=True) -class MovieInputTypeBase(ABC): - __typename = "MovieInput" - name: Optional[str] = None - director: Optional[str] = None - limit: Optional[int] = 100 - - -class MovieInputTypeDict(TypedDict): - name: NotRequired[str] - director: NotRequired[str] - limit: NotRequired[int] - - -MovieInputType = Union[MovieInputTypeBase, MovieInputTypeDict] - - class booksQuery(Protocol): - def __call__(self, info: cannula.ResolveInfo) -> Awaitable[List[BookType]]: ... - - -class createMovieMutation(Protocol): - def __call__( - self, info: cannula.ResolveInfo, *, input: Optional[MovieInputType] = None - ) -> Awaitable[MovieType]: ... - - -class genericQuery(Protocol): - def __call__( - self, info: cannula.ResolveInfo - ) -> Awaitable[List[GenericThingType]]: ... + async def __call__(self, info: cannula.ResolveInfo) -> List[BookType]: ... -class moviesQuery(Protocol): - def __call__( +class mediaQuery(Protocol): + async def __call__( self, info: cannula.ResolveInfo, *, limit: Optional[int] = 100 - ) -> Awaitable[List[MovieType]]: ... + ) -> List[GenericType]: ... -class RootType(TypedDict): - books: NotRequired[booksQuery] - createMovie: NotRequired[createMovieMutation] - generic: NotRequired[genericQuery] - movies: NotRequired[moviesQuery] +class RootType(TypedDict, total=False): + books: Optional[booksQuery] + media: Optional[mediaQuery] diff --git a/examples/extension/main.py b/examples/extension/main.py index f45e40c..411efbd 100644 --- a/examples/extension/main.py +++ b/examples/extension/main.py @@ -4,7 +4,14 @@ import cannula import cannula.middleware -from ._generated import BookType, BookTypeBase, MovieType, RootType +from ._generated import ( + BookType, + BookTypeBase, + GenericType, + MovieType, + MovieTypeBase, + RootType, +) BASE_DIR = pathlib.Path(__file__).parent @@ -14,16 +21,30 @@ class Book(BookTypeBase): - async def movies(self, info: cannula.ResolveInfo) -> list[MovieType]: - LOG.info(f"{self.name}") + async def movies( + self, info: cannula.ResolveInfo, *, limit: int | None = 100 + ) -> list[MovieType] | None: return [{"name": "Lost the Movie", "director": "Ted"}] +class Movie(MovieTypeBase): + pass + + async def get_books(info: cannula.ResolveInfo) -> list[BookType]: return [Book(name="Lost", author="Frank")] -root_value: RootType = {"books": get_books} +async def get_media( + info: cannula.ResolveInfo, limit: int | None = 100 +) -> list[GenericType]: + return [ + Book(name="the Best Movies", author="Jane"), + Movie(name="the Best Books", director="Sally"), + ] + + +root_value: RootType = {"books": get_books, "media": get_media} api = cannula.API[RootType]( root_value=root_value, @@ -44,6 +65,16 @@ async def get_books(info: cannula.ResolveInfo) -> list[BookType]: director } } + media { + __typename + name + ... on Book { + author + } + ... on Movie { + director + } + } } """ ) @@ -52,3 +83,4 @@ async def get_books(info: cannula.ResolveInfo) -> list[BookType]: if __name__ == "__main__": results = api.call_sync(QUERY, None) pprint.pprint(results.data) + pprint.pprint(results.errors) diff --git a/examples/extension/schema/base.graphql b/examples/extension/schema/base.graphql index 518cacf..eaa3a6e 100644 --- a/examples/extension/schema/base.graphql +++ b/examples/extension/schema/base.graphql @@ -1,7 +1,6 @@ -type GenericThing { +scalar Datetime + +interface Generic { name: String } -type Query { - generic: [GenericThing] -} \ No newline at end of file diff --git a/examples/extension/schema/books.graphql b/examples/extension/schema/books.graphql index 62d79a8..afc8643 100644 --- a/examples/extension/schema/books.graphql +++ b/examples/extension/schema/books.graphql @@ -1,4 +1,4 @@ -type Book { +type Book implements Generic { name: String author: String } diff --git a/examples/extension/schema/movies.graphql b/examples/extension/schema/movies.graphql index d9d7252..00149a6 100644 --- a/examples/extension/schema/movies.graphql +++ b/examples/extension/schema/movies.graphql @@ -1,24 +1,15 @@ -type Movie { +type Movie implements Generic { name: String director: String book: Book views: Int + created: Datetime } extend type Book { - movies: [Movie] @computed -} - -extend type Query { movies(limit: Int = 100): [Movie] } -input MovieInput { - name: String - director: String - limit: Int = 100 -} - -extend type Mutation { - createMovie(input: MovieInput): Movie +extend type Query { + media(limit: Int = 100): [Generic] } \ No newline at end of file diff --git a/tests/test_codegen.py b/tests/test_codegen.py index ffb939b..4014728 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -55,8 +55,8 @@ import cannula from abc import ABC from dataclasses import dataclass -from typing import Awaitable, List, Optional, Protocol, Union -from typing_extensions import NotRequired, TypedDict +from typing import List, Optional, Protocol, Union +from typing_extensions import TypedDict @dataclass(kw_only=True) @@ -68,11 +68,11 @@ class EmailSearchTypeBase(ABC): include: Optional[bool] = False -class EmailSearchTypeDict(TypedDict): +class EmailSearchTypeDict(TypedDict, total=False): email: str - limit: NotRequired[int] - other: NotRequired[str] - include: NotRequired[bool] + limit: Optional[int] + other: Optional[str] + include: Optional[bool] EmailSearchType = Union[EmailSearchTypeBase, EmailSearchTypeDict] @@ -85,9 +85,9 @@ class MessageTypeBase(ABC): sender: Optional[SenderType] = None -class MessageTypeDict(TypedDict): - text: NotRequired[str] - sender: NotRequired[SenderType] +class MessageTypeDict(TypedDict, total=False): + text: Optional[str] + sender: Optional[SenderType] MessageType = Union[MessageTypeBase, MessageTypeDict] @@ -100,8 +100,8 @@ class SenderTypeBase(ABC): email: str -class SenderTypeDict(TypedDict): - name: NotRequired[str] +class SenderTypeDict(TypedDict, total=False): + name: Optional[str] email: str @@ -109,30 +109,31 @@ class SenderTypeDict(TypedDict): class get_sender_by_emailQuery(Protocol): - def __call__( + async def __call__( self, info: cannula.ResolveInfo, *, input: Optional[EmailSearchType] = None - ) -> Awaitable[SenderType]: ... + ) -> SenderType: ... class messageMutation(Protocol): - def __call__( + async def __call__( self, info: cannula.ResolveInfo, text: str, sender: str - ) -> Awaitable[MessageType]: ... + ) -> MessageType: ... class messagesQuery(Protocol): - def __call__( + async def __call__( self, info: cannula.ResolveInfo, limit: int - ) -> Awaitable[List[MessageType]]: ... + ) -> List[MessageType]: ... -class RootType(TypedDict): - get_sender_by_email: NotRequired[get_sender_by_emailQuery] - message: NotRequired[messageMutation] - messages: NotRequired[messagesQuery] +class RootType(TypedDict, total=False): + get_sender_by_email: Optional[get_sender_by_emailQuery] + message: Optional[messageMutation] + messages: Optional[messagesQuery] """ schema_interface = """\ +scalar Datetime interface Persona { id: ID! } @@ -143,6 +144,7 @@ class RootType(TypedDict): type Admin implements Persona { id: ID! + created: Datetime } union Person = User | Admin @@ -152,11 +154,12 @@ class RootType(TypedDict): from __future__ import annotations from abc import ABC from dataclasses import dataclass -from typing import Protocol, Union +from typing import Any, Optional, Protocol, Union from typing_extensions import TypedDict +DatetimeType = Any + -@dataclass(kw_only=True) class PersonaType(Protocol): __typename = "Persona" id: str @@ -166,10 +169,12 @@ class PersonaType(Protocol): class AdminTypeBase(ABC): __typename = "Admin" id: str + created: Optional[DatetimeType] = None -class AdminTypeDict(TypedDict): +class AdminTypeDict(TypedDict, total=False): id: str + created: Optional[DatetimeType] AdminType = Union[AdminTypeBase, AdminTypeDict] @@ -181,7 +186,7 @@ class UserTypeBase(ABC): id: str -class UserTypeDict(TypedDict): +class UserTypeDict(TypedDict, total=False): id: str @@ -251,7 +256,7 @@ async def test_render_file(dry_run: bool, schema: list[str], expected: str): targets=[ Name(id='__typename', ctx=Load())], value=Constant(value='Test')), - FunctionDef( + AsyncFunctionDef( name='name', args=arguments( posonlyargs=[], @@ -267,7 +272,7 @@ async def test_render_file(dry_run: bool, schema: list[str], expected: str): Pass()], decorator_list=[ Name(id='abc.abstractmethod', ctx=Load())], - returns=Name(id='Awaitable[Optional[str]]', ctx=Load()))], + returns=Name(id='Optional[str]', ctx=Load()))], decorator_list=[ Call( func=Name(id='dataclass', ctx=Load()), @@ -280,11 +285,14 @@ async def test_render_file(dry_run: bool, schema: list[str], expected: str): name='TestTypeDict', bases=[ Name(id='TypedDict', ctx=Load())], - keywords=[], + keywords=[ + keyword( + arg='total', + value=Constant(value=False))], body=[ AnnAssign( target=Name(id='name', ctx=Store()), - annotation=Name(id='NotRequired[str]', ctx=Load()), + annotation=Name(id='Optional[str]', ctx=Load()), simple=1)], decorator_list=[]), Assign( diff --git a/tests/test_examples.py b/tests/test_examples.py index 43ab1fe..60da258 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -16,7 +16,19 @@ def test_extension_works_properly_from_multiple_file(): "movies": [{"director": "Ted", "name": "Lost the Movie"}], "name": "Lost", } - ] + ], + "media": [ + { + "__typename": "Book", + "author": "Jane", + "name": "the Best Movies", + }, + { + "__typename": "Movie", + "director": "Sally", + "name": "the Best Books", + }, + ], }