Skip to content

Commit

Permalink
Model generation and validation of relations (#108)
Browse files Browse the repository at this point in the history
* Adding support for relationships
* Add relationship validation and adding db_type to type_info
  • Loading branch information
rmyers authored Jan 3, 2025
1 parent b154941 commit bd857c3
Show file tree
Hide file tree
Showing 9 changed files with 363 additions and 53 deletions.
1 change: 1 addition & 0 deletions cannula/codegen/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# AST contants for common values
NONE = ast.Constant(value=None)
ELLIPSIS = ast.Expr(value=ast.Constant(value=Ellipsis))
PASS = ast.Pass()


def ast_for_import_from(module: str, names: set[str]) -> ast.ImportFrom:
Expand Down
239 changes: 204 additions & 35 deletions cannula/codegen/generate_sql.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Set, Union
import ast

from graphql import GraphQLObjectType, DocumentNode
from cannula.scalars import ScalarInterface
from cannula.codegen.base import ast_for_docstring, ast_for_keyword
from cannula.codegen.base import (
PASS,
ast_for_annotation_assignment,
ast_for_assign,
ast_for_docstring,
ast_for_keyword,
ast_for_name,
ast_for_subscript,
)
from cannula.codegen.schema_analyzer import SchemaAnalyzer, TypeInfo, CodeGenerator
from cannula.schema import build_and_extend_schema
from cannula.format import format_code
from cannula.codegen.generate_types import _IMPORTS
from cannula.types import Field


class SchemaValidationError(Exception):
Expand All @@ -19,6 +28,35 @@ class SchemaValidationError(Exception):
class SQLAlchemyGenerator(CodeGenerator):
"""Generates SQLAlchemy models from GraphQL schema."""

def validate_relationship_metadata(
self, field: Field, type_info: TypeInfo[GraphQLObjectType]
) -> None:
"""Validate basic structure of relationship metadata."""
if not field.metadata.get("relation"):
return

relation_metadata = field.metadata["relation"]
if not isinstance(relation_metadata, dict):
raise SchemaValidationError(
f"Relation metadata for {type_info.name}.{field.name} must be a dictionary"
)

# Validate optional cascade value if present
if "cascade" in relation_metadata:
cascade = relation_metadata["cascade"]
if not isinstance(cascade, str):
raise SchemaValidationError(
f"Cascade option in relationship {type_info.name}.{field.name} must be a string"
)

def get_db_table_types(self) -> Set[str]:
"""Get all types that have db_table metadata."""
return {
type_info.db_type
for type_info in self.analyzer.object_types
if "db_table" in type_info.metadata
}

def validate_field_metadata(
self, field_name: str, is_required: bool, metadata: Dict[str, Any]
) -> None:
Expand All @@ -45,7 +83,7 @@ def get_primary_key_fields(
def create_column_args(
self, field_name: str, is_required: bool, metadata: Dict[str, Any]
) -> tuple[list[ast.expr], list[ast.keyword]]:
"""Create SQLAlchemy Column arguments based on field metadata."""
"""Create SQLAlchemy Column arguments for a regular column."""
args: list[ast.expr] = []
keywords: list[ast.keyword] = []

Expand All @@ -57,17 +95,31 @@ def create_column_args(
if is_primary_key:
keywords.append(ast_for_keyword("primary_key", True))

# Handle foreign key
if foreign_key := metadata.get("foreign_key"):
keywords.append(
# This does not use a constant so we cannot use ast_for_keyword
ast.keyword(
arg="foreign_key",
value=ast.Call(
func=ast_for_name("ForeignKey"),
args=[ast.Constant(value=foreign_key)],
keywords=[],
),
)
)

# Handle index
if not is_primary_key and metadata.get("index"):
keywords.append(ast_for_keyword("index", True))
keywords.append(ast_for_keyword(arg="index", value=True))

# Handle unique constraint
if not is_primary_key and metadata.get("unique"):
keywords.append(ast_for_keyword("unique", True))
keywords.append(ast_for_keyword(arg="unique", value=True))

# Handle custom column name
if db_column := metadata.get("db_column"):
keywords.append(ast_for_keyword("name", db_column))
keywords.append(ast_for_keyword(arg="name", value=db_column))

# Handle nullable based on GraphQL schema
if not is_primary_key:
Expand All @@ -76,10 +128,82 @@ def create_column_args(
nullable = (
not is_required if metadata_nullable is None else metadata_nullable
)
keywords.append(ast_for_keyword("nullable", nullable))
keywords.append(ast_for_keyword(arg="nullable", value=nullable))

return args, keywords

def get_db_type(self, type_name: str) -> str:
"""Get the SQLAlchemy type for a given GraphQL type."""
return next(
(
type_info.db_type
for type_info in self.analyzer.object_types
if type_name == type_info.py_type
),
type_name,
)

def create_relationship_args(
self, field: Field, metadata: Dict[str, Any]
) -> tuple[list[ast.expr], list[ast.keyword]]:
"""Create SQLAlchemy relationship arguments."""
args: list[ast.expr] = []
keywords: list[ast.keyword] = []

relation_metadata = metadata.get("relation", {})

# Add the related class name as first argument
relation_value = self.get_db_type(
field.field_type.of_type or field.field_type.safe_value
)
args.append(ast.Constant(value=relation_value))

# Add back_populates
if back_populates := relation_metadata.get("back_populates"):
keywords.append(ast_for_keyword(arg="back_populates", value=back_populates))

# Add cascade if specified
if cascade := relation_metadata.get("cascade"):
keywords.append(ast_for_keyword(arg="cascade", value=cascade))

return args, keywords

def create_field_definition(
self, field: Field, type_info: TypeInfo[GraphQLObjectType]
) -> ast.AnnAssign:
"""Create field definition AST node based on field type and metadata."""
# Validate relationship metadata if present
self.validate_relationship_metadata(field, type_info)

# Handle relationship fields
if field.metadata.get("relation"):
func_name = "relationship"
args, keywords = self.create_relationship_args(field, field.metadata)
# Create the Mapped[DBType] annotation
relation_value = self.get_db_type(
field.field_type.of_type or field.field_type.safe_value
)
mapped_type = ast_for_subscript(ast_for_name("Mapped"), relation_value)
else:
func_name = "mapped_column"
args, keywords = self.create_column_args(
field.name, field.field_type.required, field.metadata
)
# Create the Mapped[Type] annotation
mapped_type = ast_for_subscript(
ast_for_name("Mapped"), field.field_type.type
)

return ast_for_annotation_assignment(
target=field.name,
annotation=mapped_type,
default=ast.Call(
func=ast_for_name(func_name),
args=args,
keywords=keywords,
),
)

def create_model_class(
self, type_info: TypeInfo[GraphQLObjectType]
) -> ast.ClassDef:
Expand All @@ -105,54 +229,95 @@ def create_model_class(
# Add table name
table_name = type_info.metadata.get("db_table", type_info.name.lower())
body.append(
ast.Assign(
targets=[ast.Name(id="__tablename__", ctx=ast.Store())],
ast_for_assign(
"__tablename__",
value=ast.Constant(value=table_name),
)
)

# Add columns
# Add fields
for field in type_info.fields:
args, keywords = self.create_column_args(
field.name, field.field_type.required, field.metadata
)

# Create the Mapped[Type] annotation
mapped_type = ast.Subscript(
value=ast.Name(id="Mapped", ctx=ast.Load()),
slice=ast.Name(id=field.type, ctx=ast.Load()),
ctx=ast.Load(),
)
if field.is_computed:
continue

column_def = ast.AnnAssign(
target=ast.Name(id=field.name, ctx=ast.Store()),
annotation=mapped_type,
value=ast.Call(
func=ast.Name(id="mapped_column", ctx=ast.Load()),
args=args,
keywords=keywords,
),
simple=1,
)
body.append(column_def)
field_def = self.create_field_definition(field, type_info)
body.append(field_def)

return ast.ClassDef(
name=type_info.name,
bases=[ast.Name(id="Base", ctx=ast.Load())],
name=type_info.db_type,
bases=[ast_for_name("Base")],
keywords=[],
body=body,
decorator_list=[],
)

def validate_relationships(self) -> None:
"""Validate that relationships reference valid database tables and have proper foreign keys."""
db_tables = self.get_db_table_types()

for type_info in self.analyzer.object_types:
if "db_table" not in type_info.metadata:
continue

for field in type_info.fields:
if not field.metadata.get("relation"):
continue

# Get the related type (handle both direct and sequence relationships)
related_type = self.get_db_type(
field.field_type.of_type or field.field_type.safe_value
)
schema_type = field.field_type.of_type or field.field_type.safe_value

# Ensure the related type is also a database table
if related_type not in db_tables:
raise SchemaValidationError(
f"Relationship {type_info.name}.{field.name} references type {related_type} "
"which is not marked as a database table"
)

relation_metadata = field.metadata.get("relation", {})
if "back_populates" in relation_metadata:
# If back_populates is specified, ensure the referenced model exists
referenced_field = relation_metadata["back_populates"]
referenced_type = self.analyzer.object_types_by_name.get(
schema_type
)

if not referenced_type:
raise SchemaValidationError(
f"Relationship {type_info.name}.{field.name} references non-existent type {schema_type}"
)

# Find field by name in referenced type
referenced_fields = {f.name: f for f in referenced_type.fields}
if referenced_field not in referenced_fields:
raise SchemaValidationError(
f"Relationship {type_info.name}.{field.name} references non-existent field "
f"{referenced_field} in type {schema_type}"
)

# For many-to-one or one-to-one relationships, ensure there's a foreign key
if not field.field_type.is_list:
fk_field = next(
(f for f in type_info.fields if f.metadata.get("foreign_key")),
None,
)
if not fk_field:
raise SchemaValidationError(
f"Relationship {type_info.name}.{field.name} to {related_type} "
"requires a foreign key field"
)

def generate(self) -> str:
"""Generate SQLAlchemy models from the schema."""
# Create base class definition
body: list[ast.stmt] = [
ast.ClassDef(
name="Base",
bases=[ast.Name(id="DeclarativeBase", ctx=ast.Load())],
bases=[ast_for_name("DeclarativeBase")],
keywords=[],
body=[ast.Pass()],
body=[PASS],
decorator_list=[],
)
]
Expand All @@ -161,9 +326,13 @@ def generate(self) -> str:
for type_info in self.analyzer.object_types:
if "db_table" not in type_info.metadata:
continue

model_class = self.create_model_class(type_info)
body.append(model_class)

# Validate all relationships
self.validate_relationships()

# Create and format the complete module
module = self.create_module(body)
return format_code(module)
Expand Down
16 changes: 13 additions & 3 deletions cannula/codegen/parse_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,22 @@ def parse_graphql_type(
if is_non_null_type(type_obj):
non_null_type = cast(GraphQLNonNull, type_obj)
inner = parse_graphql_type(non_null_type.of_type, schema_types)
return FieldType(value=inner.value, required=True)
return FieldType(
value=inner.value,
required=True,
of_type=inner.of_type,
is_list=inner.is_list,
)

if is_list_type(type_obj):
list_type = cast(GraphQLList, type_obj)
inner = parse_graphql_type(list_type.of_type, schema_types)
return FieldType(value=f"Sequence[{inner.value}]", required=False)
return FieldType(
value=f"Sequence[{inner.value}]",
required=False,
of_type=inner.of_type,
is_list=True,
)

# At this point we have a named type
named_type = get_named_type(type_obj)
Expand All @@ -43,4 +53,4 @@ def parse_graphql_type(
if type_name in schema_types:
type_name = schema_types[type_name].extensions.get("py_type", type_name)

return FieldType(value=type_name, required=False)
return FieldType(value=type_name, required=False, of_type=type_name)
11 changes: 11 additions & 0 deletions cannula/codegen/schema_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class TypeInfo(Generic[T]):
fields: List[Field]
description: Optional[str] = None

@property
def is_db_type(self) -> bool:
return bool(self.metadata.get("db_table", False))

@property
def db_type(self) -> str:
return self.type_def.extensions.get("db_type", f"DB{self.name}")


class SchemaAnalyzer:
"""
Expand All @@ -109,6 +117,8 @@ def _analyze(self) -> None:
self.union_types: List[UnionType] = []
self.operation_types: List[TypeInfo[GraphQLObjectType]] = []
self.operation_fields: List[Field] = []
# Add helper to access object types by name
self.object_types_by_name: Dict[str, TypeInfo[GraphQLObjectType]] = {}

for name, type_def in self.schema.type_map.items():
is_operation = name in ("Query", "Mutation", "Subscription")
Expand Down Expand Up @@ -141,6 +151,7 @@ def _analyze(self) -> None:
self.operation_fields.sort(key=lambda o: o.name)
self.operation_types.sort(key=lambda o: o.name)
self.union_types.sort(key=lambda o: o.name)
self.object_types_by_name = {t.py_type: t for t in self.object_types}

def parse_union(self, node: GraphQLUnionType) -> UnionType:
"""Parse a GraphQL Union type into a UnionType object"""
Expand Down
Loading

0 comments on commit bd857c3

Please sign in to comment.