diff --git a/.gitignore b/.gitignore index 659899c..f2d0bef 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ *.egg-info build dist -.mypy_cache +.*_cache venv .coverage coverage.json diff --git a/cannula/__init__.py b/cannula/__init__.py index e9112dd..ea38834 100644 --- a/cannula/__init__.py +++ b/cannula/__init__.py @@ -1,26 +1,18 @@ from .api import CannulaAPI -from .codegen import render_file from .context import Context, ResolveInfo from .errors import format_errors from .schema import build_and_extend_schema, concat_documents, load_schema -from .types import Argument, Directive, Field, FieldType, ObjectType from .utils import gql __all__ = [ "CannulaAPI", - "Argument", "Context", - "Directive", - "Field", - "FieldType", - "ObjectType", "ResolveInfo", "format_errors", "build_and_extend_schema", "concat_documents", "gql", "load_schema", - "render_file", ] __VERSION__ = "0.18.0" diff --git a/cannula/cli.py b/cannula/cli.py index 197363e..abff11a 100644 --- a/cannula/cli.py +++ b/cannula/cli.py @@ -7,6 +7,7 @@ import tomli import cannula +from cannula.codegen import render_file from cannula.scalars import ScalarInterface # create the top-level parser for global options @@ -58,6 +59,12 @@ action="append", dest="scalars", ) +codegen_parser.add_argument( + "--use-pydantic", + "--use_pydantic", + action="store_true", + help="Use Pydantic models for generated classes.", +) def load_config(config) -> dict: @@ -85,16 +92,19 @@ def resolve_scalars(scalars: list[str]) -> list[ScalarInterface]: return _scalars -def run_codegen(dry_run: bool, schema: str, dest: str, scalars: list[str] | None): +def run_codegen( + dry_run: bool, schema: str, dest: str, scalars: list[str] | None, use_pydantic: bool +): source = pathlib.Path(schema) documents = cannula.load_schema(source) destination = pathlib.Path(dest) _scalars = resolve_scalars(scalars or []) - cannula.render_file( + render_file( type_defs=documents, dest=destination, dry_run=dry_run, scalars=_scalars, + use_pydantic=use_pydantic, ) @@ -115,9 +125,11 @@ def main(): schema = codegen_config.get("schema", options.schema) dest = codegen_config.get("dest", options.dest) scalars = codegen_config.get("scalars", options.scalars) + use_pydantic = codegen_config.get("use_pydantic", options.use_pydantic) run_codegen( dry_run=options.dry_run, schema=schema, dest=dest, scalars=scalars, + use_pydantic=use_pydantic, ) diff --git a/cannula/codegen/__init__.py b/cannula/codegen/__init__.py new file mode 100644 index 0000000..48a5086 --- /dev/null +++ b/cannula/codegen/__init__.py @@ -0,0 +1,14 @@ +from .codegen import parse_schema, render_code, render_file, render_object +from .types import Argument, Directive, Field, FieldType, ObjectType + +__all__ = [ + "Argument", + "Directive", + "Field", + "FieldType", + "ObjectType", + "parse_schema", + "render_code", + "render_file", + "render_object", +] diff --git a/cannula/codegen.py b/cannula/codegen/codegen.py similarity index 84% rename from cannula/codegen.py rename to cannula/codegen/codegen.py index 68163e7..a888a0e 100644 --- a/cannula/codegen.py +++ b/cannula/codegen/codegen.py @@ -5,6 +5,7 @@ import ast import collections +import contextvars import logging import pathlib import pprint @@ -18,25 +19,27 @@ from cannula.format import format_code from cannula.schema import concat_documents -from cannula.types import Argument, Directive, Field, FieldType, ObjectType +from cannula.codegen.types import Argument, Directive, Field, FieldType, ObjectType LOG = logging.getLogger(__name__) -TYPES = { +_TYPES = { "Boolean": "bool", "Float": "float", "ID": "str", "Int": "int", "String": "str", } +TYPES = contextvars.ContextVar("types", default=_TYPES) + VALUE_FUNCS = { "boolean_value": lambda value: value in ["true", "True"], "int_value": lambda value: int(value), "float_value": lambda value: float(value), } -IMPORTS: typing.DefaultDict[str, set[str]] = collections.defaultdict(set[str]) -IMPORTS.update( +_IMPORTS: typing.DefaultDict[str, set[str]] = collections.defaultdict(set[str]) +_IMPORTS.update( { "__future__": set(["annotations"]), "abc": set(["ABC", "abstractmethod"]), @@ -64,12 +67,26 @@ ELLIPSIS = ast.Expr(value=ast.Constant(value=Ellipsis)) -def add_custom_scalar_handlers(scalars: list[ScalarInterface]) -> None: +def add_custom_scalar_types( + scalars: list[ScalarInterface], +) -> dict[str, str]: + _types = _TYPES.copy() + for scalar in scalars: + _types[scalar.name] = scalar.input_module.klass + _types[f"{scalar.name}InputType"] = scalar.output_module.klass + + return _types + + +def add_custom_scalar_imports( + scalars: list[ScalarInterface], +) -> dict[str, set[str]]: + _imports = _IMPORTS.copy() for scalar in scalars: - TYPES[scalar.name] = scalar.input_module.klass - TYPES[f"{scalar.name}InputType"] = scalar.output_module.klass - IMPORTS[scalar.input_module.module].add(scalar.input_module.klass) - IMPORTS[scalar.output_module.module].add(scalar.output_module.klass) + _imports[scalar.input_module.module].add(scalar.input_module.klass) + _imports[scalar.output_module.module].add(scalar.output_module.klass) + + return _imports def ast_for_import_from(module: str, names: set[str]) -> ast.ImportFrom: @@ -118,7 +135,7 @@ def ast_for_argument(arg: Argument) -> ast.arg: """ Create an AST node for a function argument. """ - LOG.debug(f"AST for arg: {arg.__dict__}") + # LOG.debug(f"AST for arg: {arg.__dict__}") arg_type = arg.type if arg.required else f"Optional[{arg.type}]" return ast.arg(arg=arg.name, annotation=ast.Name(id=arg_type, ctx=ast.Load())) @@ -264,7 +281,7 @@ def parse_default(obj: typing.Dict[str, typing.Any]) -> typing.Any: if not default_value: return None - LOG.debug(f"Default Value: {default_value}") + # LOG.debug(f"Default Value: {default_value}") return parse_value(default_value) @@ -298,6 +315,7 @@ def parse_args(obj: dict) -> list[Argument]: def parse_type(type_obj: dict) -> FieldType: required = False value = None + _types = TYPES.get() if type_obj["kind"] == "non_null_type": required = True @@ -310,8 +328,8 @@ def parse_type(type_obj: dict) -> FieldType: if type_obj["kind"] == "named_type": name = type_obj["name"] value = name["value"] - if value in TYPES: - value = TYPES[value] + if value in _types: + value = _types[value] else: value = f"{value}Type" @@ -328,7 +346,7 @@ def parse_directives(field: typing.Dict[str, typing.Any]) -> list[Directive]: def parse_field(field: typing.Dict[str, typing.Any], parent: str) -> Field: - LOG.debug("Field: %s", pprint.pformat(field)) + # LOG.debug("Field: %s", pprint.pformat(field)) name = parse_name(field) field_type = parse_type(field["type"]) default = parse_default(field) @@ -355,9 +373,13 @@ def parse_node(node: Node): raw_fields = details.get("fields", []) description = parse_description(details) + # Check if this has Union types raw_types = details.get("types", []) types = [parse_type(t) for t in raw_types] directives: typing.Dict[str, list[Directive]] = {} + # Check for name in defined scalar types + _types = TYPES.get() + defined_scalar_type = name in _types fields: list[Field] = [] for field in raw_fields: @@ -373,11 +395,13 @@ def parse_node(node: Node): directives=directives, description=description, types=types, + defined_scalar_type=defined_scalar_type, ) def parse_schema( - type_defs: typing.Iterable[typing.Union[str, DocumentNode]] + type_defs: typing.Iterable[typing.Union[str, DocumentNode]], + _types: dict[str, str] = _TYPES, ) -> typing.Dict[str, ObjectType]: document = concat_documents(type_defs) types: typing.Dict[str, ObjectType] = {} @@ -385,19 +409,30 @@ def parse_schema( # First we need to pull out the input types since the default # names are different than normal object types. for definition in document.definitions: - node = parse_node(definition) - if node.kind == "input_object_type_definition": - TYPES[node.name] = f"{node.name}Input" + if definition.kind == "input_object_type_definition": + details = definition.to_dict() + name = parse_name(details) + _types[name] = f"{name}Input" + + # Set the contextvar for types so the parse functions can access it. + token = TYPES.set(_types) for definition in document.definitions: node = parse_node(definition) if node.name in types: types[node.name].fields.extend(node.fields) types[node.name].directives.update(node.directives) - types[node.name].description = node.description + # Descriptions are only allowed on the type and not extensions + # however we might get the extention first. So only update + # if we have a description. Skipping coverage on this as it + # is really hard to test this scenario. + if node.description: # pragma: no cover + types[node.name].description = node.description else: types[node.name] = node + LOG.debug("TYPES: %s", pprint.pformat(types)) + TYPES.reset(token) return types @@ -426,7 +461,9 @@ def ast_for_operation_field(field: Field) -> ast.AnnAssign: return ast_for_annotation_assignment(field.name, annotation=field_type) -def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]: +def render_object( + obj: ObjectType, use_pydantic: bool = False +) -> list[ast.ClassDef | ast.Assign]: non_computed: list[Field] = [] computed: list[Field] = [] for field in obj.fields: @@ -447,6 +484,18 @@ def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]: else: constants = [type_def] + if use_pydantic: + return [ + ast.ClassDef( + name=type_name, + body=[*constants, *klass_fields, *computed_fields], + bases=[ast_for_name("BaseModel")], + keywords=[], + decorator_list=[], + # type_params=[], + ), + ] + return [ ast.ClassDef( name=type_name, @@ -462,15 +511,6 @@ def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]: ], # type_params=[], ), - # TODO(rmyers): add option for pydantic - # ast.ClassDef( - # name=type_name, - # body=[*constants, *klass_fields, *computed_fields], - # bases=[ast_for_name("BaseModel")], - # keywords=[], - # decorator_list=[], - # type_params=[], - # ), ] @@ -543,11 +583,13 @@ def ast_for_root_type(fields: list[Field]) -> ast.ClassDef: def render_code( type_defs: typing.Iterable[typing.Union[str, DocumentNode]], scalars: list[ScalarInterface] = [], + use_pydantic: bool = False, ) -> str: # first setup custom scalars so the parsed schema includes them - add_custom_scalar_handlers(scalars) + _types = add_custom_scalar_types(scalars) + _imports = add_custom_scalar_imports(scalars) - parsed = parse_schema(type_defs) + parsed = parse_schema(type_defs, _types) object_types: list[ObjectType] = [] interface_types: list[ObjectType] = [] @@ -575,15 +617,15 @@ def render_code( root = ast.Module(body=[], type_ignores=[]) - module_imports = list(IMPORTS.keys()) + module_imports = list(_imports.keys()) module_imports.sort() for module in module_imports: if module == "builtins": continue - root.body.append(ast_for_import_from(module=module, names=IMPORTS[module])) + root.body.append(ast_for_import_from(module=module, names=_imports[module])) for obj in scalar_types: - if obj.name in TYPES: + if obj.defined_scalar_type: continue root.body.append(ast_for_assign(f"{obj.name}Type", ast_for_name("Any"))) @@ -595,7 +637,7 @@ def render_code( object_types.sort(key=lambda o: o.name) for obj in object_types: - root.body.extend(render_object(obj)) + root.body.extend(render_object(obj, use_pydantic)) for obj in union_types: root.body.extend(render_union(obj)) @@ -613,9 +655,12 @@ def render_file( type_defs: typing.Iterable[typing.Union[str, DocumentNode]], dest: pathlib.Path, scalars: list[ScalarInterface] = [], + use_pydantic: bool = False, dry_run: bool = False, ) -> None: - formatted_code = render_code(type_defs=type_defs, scalars=scalars) + formatted_code = render_code( + type_defs=type_defs, scalars=scalars, use_pydantic=use_pydantic + ) if dry_run: LOG.info(f"DRY_RUN would produce: \n{formatted_code}") diff --git a/cannula/types.py b/cannula/codegen/types.py similarity index 97% rename from cannula/types.py rename to cannula/codegen/types.py index 0200015..306b72b 100644 --- a/cannula/types.py +++ b/cannula/codegen/types.py @@ -59,6 +59,7 @@ class ObjectType: types: typing.List[FieldType] directives: typing.Dict[str, typing.List[Directive]] description: typing.Optional[str] = None + defined_scalar_type: bool = False @dataclasses.dataclass diff --git a/tests/test_cli.py b/tests/test_cli.py index 6ca0203..ffe92e6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -18,7 +18,7 @@ def test_help(mocker: MockerFixture): def test_invalid_command_does_not_hang(mocker: MockerFixture): - mocker.patch("cannula.render_file") + mocker.patch("cannula.cli.render_file") mocker.patch.object(sys, "argv", ["cli", "codegen", "--invalid"]) main() @@ -26,7 +26,7 @@ def test_invalid_command_does_not_hang(mocker: MockerFixture): def test_codegen(mocker: MockerFixture): mock_schema = mocker.Mock() mocker.patch("cannula.load_schema", return_value=mock_schema) - mock_render = mocker.patch("cannula.render_file") + mock_render = mocker.patch("cannula.cli.render_file") mocker.patch.object(sys, "argv", ["cli", "codegen"]) main() mock_render.assert_called_with( @@ -34,13 +34,14 @@ def test_codegen(mocker: MockerFixture): dest=mocker.ANY, dry_run=False, scalars=[], + use_pydantic=False, ) def test_codegen_dry_run(mocker: MockerFixture): mock_schema = mocker.Mock() mocker.patch("cannula.load_schema", return_value=mock_schema) - mock_render = mocker.patch("cannula.render_file") + mock_render = mocker.patch("cannula.cli.render_file") mocker.patch.object( sys, "argv", @@ -56,6 +57,7 @@ def test_codegen_dry_run(mocker: MockerFixture): dest=mocker.ANY, scalars=[], dry_run=True, + use_pydantic=False, ) @@ -63,7 +65,7 @@ def test_codegen_scalars(mocker: MockerFixture): expected_scalars = resolve_scalars(["cannula.scalars.date.Datetime"]) mock_schema = mocker.Mock() mocker.patch("cannula.load_schema", return_value=mock_schema) - mock_render = mocker.patch("cannula.render_file") + mock_render = mocker.patch("cannula.cli.render_file") mocker.patch.object( sys, "argv", @@ -79,6 +81,7 @@ def test_codegen_scalars(mocker: MockerFixture): dest=mocker.ANY, scalars=expected_scalars, dry_run=False, + use_pydantic=False, ) diff --git a/tests/test_codegen.py b/tests/test_codegen.py index 2d46fb1..b2aee1e 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -4,8 +4,14 @@ import pytest -from cannula.codegen import parse_schema, render_file, render_object -from cannula.types import Argument, Directive, Field +from cannula.codegen import ( + Argument, + Directive, + Field, + parse_schema, + render_file, + render_object, +) from cannula.scalars import ScalarInterface from cannula.scalars.date import Datetime @@ -35,23 +41,23 @@ } ''' -EXTENTIONS = """ - extend type Sender { - email: String! - } - input EmailSearch { - "email to search" - email: String! - limit: Int = 100 - other: String = "blah" - include: Boolean = false - } - extend type Query { - get_sender_by_email(input: EmailSearch): Sender - } +EXTENTIONS = """\ +extend type Sender { + email: String! +} +input EmailSearch { + "email to search" + email: String! + limit: Int = 100 + other: String = "blah" + include: Boolean = false +} +extend type Query { + get_sender_by_email(input: EmailSearch): Sender +} """ -expected_output = """\ +expected_output = '''\ from __future__ import annotations from abc import ABC from cannula import ResolveInfo @@ -76,6 +82,14 @@ class MessageType(ABC): @dataclass(kw_only=True) class SenderType(ABC): + """ + Some sender action: + + ``` + Sender(foo) + ``` + """ + __typename = "Sender" name: Optional[str] = None email: str @@ -106,7 +120,69 @@ class RootType(TypedDict, total=False): get_sender_by_email: Optional[get_sender_by_emailQuery] message: Optional[messageMutation] messages: Optional[messagesQuery] -""" +''' + +expected_pydantic = '''\ +from __future__ import annotations +from cannula import ResolveInfo +from pydantic import BaseModel +from typing import Optional, Protocol, Sequence +from typing_extensions import TypedDict + + +class EmailSearchInput(TypedDict): + email: str + limit: int + other: str + include: bool + + +class MessageType(BaseModel): + __typename = "Message" + text: Optional[str] = None + sender: Optional[SenderType] = None + + +class SenderType(BaseModel): + """ + Some sender action: + + ``` + Sender(foo) + ``` + """ + + __typename = "Sender" + name: Optional[str] = None + email: str + + +class get_sender_by_emailQuery(Protocol): + + async def __call__( + self, info: ResolveInfo, *, input: Optional[EmailSearchInput] = None + ) -> Optional[SenderType]: ... + + +class messageMutation(Protocol): + + async def __call__( + self, info: ResolveInfo, text: str, sender: str + ) -> Optional[MessageType]: ... + + +class messagesQuery(Protocol): + + async def __call__( + self, info: ResolveInfo, limit: int + ) -> Optional[Sequence[MessageType]]: ... + + +class RootType(TypedDict, total=False): + get_sender_by_email: Optional[get_sender_by_emailQuery] + message: Optional[messageMutation] + messages: Optional[messagesQuery] +''' schema_interface = """\ scalar Datetime @@ -215,26 +291,56 @@ async def test_parse_schema_dict(): @pytest.mark.parametrize( - "dry_run, schema, scalars, expected", + "dry_run, schema, scalars, use_pydantic, expected", [ - pytest.param(True, [SCHEMA, EXTENTIONS], [], "", id="dry-run:True"), pytest.param( - False, [SCHEMA, EXTENTIONS], [], expected_output, id="dry-run:False" + True, + [SCHEMA, EXTENTIONS], + [], + False, + "", + id="dry-run:True", + ), + pytest.param( + False, + [SCHEMA, EXTENTIONS], + [], + False, + expected_output, + id="dry-run:False", ), pytest.param( - False, [schema_interface], [], expected_interface, id="interfaces" + False, + [schema_interface], + [], + False, + expected_interface, + id="interfaces", ), pytest.param( False, [schema_scalars], [Datetime], + False, expected_scalars, id="scalars", ), + pytest.param( + False, + [SCHEMA, EXTENTIONS], + [Datetime], + True, + expected_pydantic, + id="pydantic", + ), ], ) async def test_render_file( - dry_run: bool, schema: list[str], expected: str, scalars: list[ScalarInterface] + dry_run: bool, + schema: list[str], + scalars: list[ScalarInterface], + use_pydantic: bool, + expected: str, ): with tempfile.NamedTemporaryFile() as generated_file: render_file( @@ -242,6 +348,7 @@ async def test_render_file( dest=pathlib.Path(generated_file.name), dry_run=dry_run, scalars=scalars, + use_pydantic=use_pydantic, ) with open(generated_file.name, "r") as rendered: content = rendered.read()