Skip to content

Commit

Permalink
Adding support for use_pydantic (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmyers authored Dec 8, 2024
1 parent 1b5e1cb commit b9c5a50
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*.egg-info
build
dist
.mypy_cache
.*_cache
venv
.coverage
coverage.json
Expand Down
8 changes: 0 additions & 8 deletions cannula/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
from .api import CannulaAPI
from .codegen import render_file
from .context import Context, ResolveInfo
from .errors import format_errors
from .schema import build_and_extend_schema, concat_documents, load_schema
from .types import Argument, Directive, Field, FieldType, ObjectType
from .utils import gql

__all__ = [
"CannulaAPI",
"Argument",
"Context",
"Directive",
"Field",
"FieldType",
"ObjectType",
"ResolveInfo",
"format_errors",
"build_and_extend_schema",
"concat_documents",
"gql",
"load_schema",
"render_file",
]

__VERSION__ = "0.18.0"
16 changes: 14 additions & 2 deletions cannula/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tomli

import cannula
from cannula.codegen import render_file
from cannula.scalars import ScalarInterface

# create the top-level parser for global options
Expand Down Expand Up @@ -58,6 +59,12 @@
action="append",
dest="scalars",
)
codegen_parser.add_argument(
"--use-pydantic",
"--use_pydantic",
action="store_true",
help="Use Pydantic models for generated classes.",
)


def load_config(config) -> dict:
Expand Down Expand Up @@ -85,16 +92,19 @@ def resolve_scalars(scalars: list[str]) -> list[ScalarInterface]:
return _scalars


def run_codegen(dry_run: bool, schema: str, dest: str, scalars: list[str] | None):
def run_codegen(
dry_run: bool, schema: str, dest: str, scalars: list[str] | None, use_pydantic: bool
):
source = pathlib.Path(schema)
documents = cannula.load_schema(source)
destination = pathlib.Path(dest)
_scalars = resolve_scalars(scalars or [])
cannula.render_file(
render_file(
type_defs=documents,
dest=destination,
dry_run=dry_run,
scalars=_scalars,
use_pydantic=use_pydantic,
)


Expand All @@ -115,9 +125,11 @@ def main():
schema = codegen_config.get("schema", options.schema)
dest = codegen_config.get("dest", options.dest)
scalars = codegen_config.get("scalars", options.scalars)
use_pydantic = codegen_config.get("use_pydantic", options.use_pydantic)
run_codegen(
dry_run=options.dry_run,
schema=schema,
dest=dest,
scalars=scalars,
use_pydantic=use_pydantic,
)
14 changes: 14 additions & 0 deletions cannula/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .codegen import parse_schema, render_code, render_file, render_object
from .types import Argument, Directive, Field, FieldType, ObjectType

__all__ = [
"Argument",
"Directive",
"Field",
"FieldType",
"ObjectType",
"parse_schema",
"render_code",
"render_file",
"render_object",
]
117 changes: 81 additions & 36 deletions cannula/codegen.py → cannula/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import ast
import collections
import contextvars
import logging
import pathlib
import pprint
Expand All @@ -18,25 +19,27 @@

from cannula.format import format_code
from cannula.schema import concat_documents
from cannula.types import Argument, Directive, Field, FieldType, ObjectType
from cannula.codegen.types import Argument, Directive, Field, FieldType, ObjectType

LOG = logging.getLogger(__name__)

TYPES = {
_TYPES = {
"Boolean": "bool",
"Float": "float",
"ID": "str",
"Int": "int",
"String": "str",
}
TYPES = contextvars.ContextVar("types", default=_TYPES)

VALUE_FUNCS = {
"boolean_value": lambda value: value in ["true", "True"],
"int_value": lambda value: int(value),
"float_value": lambda value: float(value),
}

IMPORTS: typing.DefaultDict[str, set[str]] = collections.defaultdict(set[str])
IMPORTS.update(
_IMPORTS: typing.DefaultDict[str, set[str]] = collections.defaultdict(set[str])
_IMPORTS.update(
{
"__future__": set(["annotations"]),
"abc": set(["ABC", "abstractmethod"]),
Expand Down Expand Up @@ -64,12 +67,26 @@
ELLIPSIS = ast.Expr(value=ast.Constant(value=Ellipsis))


def add_custom_scalar_handlers(scalars: list[ScalarInterface]) -> None:
def add_custom_scalar_types(
scalars: list[ScalarInterface],
) -> dict[str, str]:
_types = _TYPES.copy()
for scalar in scalars:
_types[scalar.name] = scalar.input_module.klass
_types[f"{scalar.name}InputType"] = scalar.output_module.klass

return _types


def add_custom_scalar_imports(
scalars: list[ScalarInterface],
) -> dict[str, set[str]]:
_imports = _IMPORTS.copy()
for scalar in scalars:
TYPES[scalar.name] = scalar.input_module.klass
TYPES[f"{scalar.name}InputType"] = scalar.output_module.klass
IMPORTS[scalar.input_module.module].add(scalar.input_module.klass)
IMPORTS[scalar.output_module.module].add(scalar.output_module.klass)
_imports[scalar.input_module.module].add(scalar.input_module.klass)
_imports[scalar.output_module.module].add(scalar.output_module.klass)

return _imports


def ast_for_import_from(module: str, names: set[str]) -> ast.ImportFrom:
Expand Down Expand Up @@ -118,7 +135,7 @@ def ast_for_argument(arg: Argument) -> ast.arg:
"""
Create an AST node for a function argument.
"""
LOG.debug(f"AST for arg: {arg.__dict__}")
# LOG.debug(f"AST for arg: {arg.__dict__}")
arg_type = arg.type if arg.required else f"Optional[{arg.type}]"
return ast.arg(arg=arg.name, annotation=ast.Name(id=arg_type, ctx=ast.Load()))

Expand Down Expand Up @@ -264,7 +281,7 @@ def parse_default(obj: typing.Dict[str, typing.Any]) -> typing.Any:
if not default_value:
return None

LOG.debug(f"Default Value: {default_value}")
# LOG.debug(f"Default Value: {default_value}")
return parse_value(default_value)


Expand Down Expand Up @@ -298,6 +315,7 @@ def parse_args(obj: dict) -> list[Argument]:
def parse_type(type_obj: dict) -> FieldType:
required = False
value = None
_types = TYPES.get()

if type_obj["kind"] == "non_null_type":
required = True
Expand All @@ -310,8 +328,8 @@ def parse_type(type_obj: dict) -> FieldType:
if type_obj["kind"] == "named_type":
name = type_obj["name"]
value = name["value"]
if value in TYPES:
value = TYPES[value]
if value in _types:
value = _types[value]
else:
value = f"{value}Type"

Expand All @@ -328,7 +346,7 @@ def parse_directives(field: typing.Dict[str, typing.Any]) -> list[Directive]:


def parse_field(field: typing.Dict[str, typing.Any], parent: str) -> Field:
LOG.debug("Field: %s", pprint.pformat(field))
# LOG.debug("Field: %s", pprint.pformat(field))
name = parse_name(field)
field_type = parse_type(field["type"])
default = parse_default(field)
Expand All @@ -355,9 +373,13 @@ def parse_node(node: Node):
raw_fields = details.get("fields", [])

description = parse_description(details)
# Check if this has Union types
raw_types = details.get("types", [])
types = [parse_type(t) for t in raw_types]
directives: typing.Dict[str, list[Directive]] = {}
# Check for name in defined scalar types
_types = TYPES.get()
defined_scalar_type = name in _types

fields: list[Field] = []
for field in raw_fields:
Expand All @@ -373,31 +395,44 @@ def parse_node(node: Node):
directives=directives,
description=description,
types=types,
defined_scalar_type=defined_scalar_type,
)


def parse_schema(
type_defs: typing.Iterable[typing.Union[str, DocumentNode]]
type_defs: typing.Iterable[typing.Union[str, DocumentNode]],
_types: dict[str, str] = _TYPES,
) -> typing.Dict[str, ObjectType]:
document = concat_documents(type_defs)
types: typing.Dict[str, ObjectType] = {}

# First we need to pull out the input types since the default
# names are different than normal object types.
for definition in document.definitions:
node = parse_node(definition)
if node.kind == "input_object_type_definition":
TYPES[node.name] = f"{node.name}Input"
if definition.kind == "input_object_type_definition":
details = definition.to_dict()
name = parse_name(details)
_types[name] = f"{name}Input"

# Set the contextvar for types so the parse functions can access it.
token = TYPES.set(_types)

for definition in document.definitions:
node = parse_node(definition)
if node.name in types:
types[node.name].fields.extend(node.fields)
types[node.name].directives.update(node.directives)
types[node.name].description = node.description
# Descriptions are only allowed on the type and not extensions
# however we might get the extention first. So only update
# if we have a description. Skipping coverage on this as it
# is really hard to test this scenario.
if node.description: # pragma: no cover
types[node.name].description = node.description
else:
types[node.name] = node

LOG.debug("TYPES: %s", pprint.pformat(types))
TYPES.reset(token)
return types


Expand Down Expand Up @@ -426,7 +461,9 @@ def ast_for_operation_field(field: Field) -> ast.AnnAssign:
return ast_for_annotation_assignment(field.name, annotation=field_type)


def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]:
def render_object(
obj: ObjectType, use_pydantic: bool = False
) -> list[ast.ClassDef | ast.Assign]:
non_computed: list[Field] = []
computed: list[Field] = []
for field in obj.fields:
Expand All @@ -447,6 +484,18 @@ def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]:
else:
constants = [type_def]

if use_pydantic:
return [
ast.ClassDef(
name=type_name,
body=[*constants, *klass_fields, *computed_fields],
bases=[ast_for_name("BaseModel")],
keywords=[],
decorator_list=[],
# type_params=[],
),
]

return [
ast.ClassDef(
name=type_name,
Expand All @@ -462,15 +511,6 @@ def render_object(obj: ObjectType) -> list[ast.ClassDef | ast.Assign]:
],
# type_params=[],
),
# TODO(rmyers): add option for pydantic
# ast.ClassDef(
# name=type_name,
# body=[*constants, *klass_fields, *computed_fields],
# bases=[ast_for_name("BaseModel")],
# keywords=[],
# decorator_list=[],
# type_params=[],
# ),
]


Expand Down Expand Up @@ -543,11 +583,13 @@ def ast_for_root_type(fields: list[Field]) -> ast.ClassDef:
def render_code(
type_defs: typing.Iterable[typing.Union[str, DocumentNode]],
scalars: list[ScalarInterface] = [],
use_pydantic: bool = False,
) -> str:
# first setup custom scalars so the parsed schema includes them
add_custom_scalar_handlers(scalars)
_types = add_custom_scalar_types(scalars)
_imports = add_custom_scalar_imports(scalars)

parsed = parse_schema(type_defs)
parsed = parse_schema(type_defs, _types)

object_types: list[ObjectType] = []
interface_types: list[ObjectType] = []
Expand Down Expand Up @@ -575,15 +617,15 @@ def render_code(

root = ast.Module(body=[], type_ignores=[])

module_imports = list(IMPORTS.keys())
module_imports = list(_imports.keys())
module_imports.sort()
for module in module_imports:
if module == "builtins":
continue
root.body.append(ast_for_import_from(module=module, names=IMPORTS[module]))
root.body.append(ast_for_import_from(module=module, names=_imports[module]))

for obj in scalar_types:
if obj.name in TYPES:
if obj.defined_scalar_type:
continue
root.body.append(ast_for_assign(f"{obj.name}Type", ast_for_name("Any")))

Expand All @@ -595,7 +637,7 @@ def render_code(

object_types.sort(key=lambda o: o.name)
for obj in object_types:
root.body.extend(render_object(obj))
root.body.extend(render_object(obj, use_pydantic))

for obj in union_types:
root.body.extend(render_union(obj))
Expand All @@ -613,9 +655,12 @@ def render_file(
type_defs: typing.Iterable[typing.Union[str, DocumentNode]],
dest: pathlib.Path,
scalars: list[ScalarInterface] = [],
use_pydantic: bool = False,
dry_run: bool = False,
) -> None:
formatted_code = render_code(type_defs=type_defs, scalars=scalars)
formatted_code = render_code(
type_defs=type_defs, scalars=scalars, use_pydantic=use_pydantic
)

if dry_run:
LOG.info(f"DRY_RUN would produce: \n{formatted_code}")
Expand Down
1 change: 1 addition & 0 deletions cannula/types.py → cannula/codegen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ObjectType:
types: typing.List[FieldType]
directives: typing.Dict[str, typing.List[Directive]]
description: typing.Optional[str] = None
defined_scalar_type: bool = False


@dataclasses.dataclass
Expand Down
Loading

0 comments on commit b9c5a50

Please sign in to comment.