Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sqlmodel #564

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions gel/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from gel import blocking_client
from gel.orm.introspection import get_schema_json, GelORMWarning
from gel.orm.sqla import ModelGenerator as SQLAModGen
from gel.orm.sqlmodel import ModelGenerator as SQLModGen
from gel.orm.django.generator import ModelGenerator as DjangoModGen


Expand Down Expand Up @@ -694,6 +695,26 @@ def get_dsn_for_sqla(cls):
return dsn


class SQLModelTestCase(ORMTestCase):
@classmethod
def setupORM(cls):
gen = SQLModGen(
outdir=os.path.join(cls.tmpormdir.name, cls.MODEL_PACKAGE),
basemodule=cls.MODEL_PACKAGE,
)
gen.render_models(cls.spec)

@classmethod
def get_dsn_for_sqla(cls):
cargs = cls.get_connect_args(database=cls.get_database_name())
dsn = (
f'postgresql://{cargs["user"]}:{cargs["password"]}'
f'@{cargs["host"]}:{cargs["port"]}/{cargs["database"]}'
)

return dsn


APPS_PY = '''\
from django.apps import AppConfig
Expand Down
13 changes: 12 additions & 1 deletion gel/orm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from gel.codegen.generator import _get_conn_args
from .introspection import get_schema_json, GelORMWarning
from .sqla import ModelGenerator as SQLAModGen
from .sqlmodel import ModelGenerator as SQLModGen
from .django.generator import ModelGenerator as DjangoModGen


Expand All @@ -41,7 +42,7 @@ def error(self, message):
)
parser.add_argument(
"orm",
choices=['sqlalchemy', 'django'],
choices=['sqlalchemy', 'sqlmodel', 'django'],
help="Pick which ORM to generate models for.",
)
parser.add_argument("--dsn")
Expand Down Expand Up @@ -97,6 +98,16 @@ def generate_models(args, spec):
)
gen.render_models(spec)

case 'sqlmodel':
if args.mod is None:
parser.error('sqlmodel requires to specify --mod')

gen = SQLModGen(
outdir=args.out,
basemodule=args.mod,
)
gen.render_models(spec)

case 'django':
gen = DjangoModGen(
out=args.out,
Expand Down
8 changes: 4 additions & 4 deletions gel/orm/django/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(self, name):
def table(self):
return self.meta['db_table'].strip("'")

def get_backlink_name(self, name):
return self.backlink_renames.get(name, f'backlink_via_{name}')
def get_backlink_name(self, name, srcname):
return self.backlink_renames.get(name, f'back_to_{srcname}')


class ModelGenerator(FilePrinter):
Expand Down Expand Up @@ -140,7 +140,7 @@ def build_models(self, maps):
continue

lname = link['name']
bklink = mod.get_backlink_name(lname)
bklink = mod.get_backlink_name(lname, name)
code = self.render_link(link, bklink)
if code:
mod.links[lname] = code
Expand Down Expand Up @@ -177,7 +177,7 @@ def build_models(self, maps):
# ManyToManyField.
src = modmap[source]
tgt = modmap[target]
bkname = src.get_backlink_name(fwname)
bkname = src.get_backlink_name(fwname, source)
src.mlinks[fwname] = (
f'models.ManyToManyField('
f'{tgt.name!r}, '
Expand Down
10 changes: 5 additions & 5 deletions gel/orm/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def _process_links(types, modules):

objtype = type_map[target]
objtype['backlinks'].append({
'name': f'backlink_via_{sql_name}',
'name': f'back_to_{sql_source}',
'fwname': sql_name,
# flip cardinality and exclusivity
'cardinality': 'One' if exclusive else 'Many',
'exclusive': cardinality == 'One',
Expand Down Expand Up @@ -239,7 +240,7 @@ def _process_links(types, modules):
# Find collisions in backlink names
bk = collections.defaultdict(list)
for link in spec['backlinks']:
if link['name'].startswith('backlink_via_'):
if link['name'].startswith('back_to_'):
bk[link['name']].append(link)

for bklinks in bk.values():
Expand All @@ -249,12 +250,11 @@ def _process_links(types, modules):
for link in bklinks:
origsrc = get_sql_name(link['target']['name'])
lname = link['name']
link['name'] = f'{lname}_from_{origsrc}'
fwname = link['fwname']
link['name'] = f'follow_{fwname}_{lname}'
# Also update the original source of the link with the
# special backlink name.
source = type_map[link['target']['name']]
fwname = lname.replace('backlink_via_', '', 1)
link['fwname'] = fwname
source['backlink_renames'][fwname] = link['name']

return {
Expand Down
8 changes: 4 additions & 4 deletions gel/orm/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def render_link_object(self, spec, modules):
src = modules[mod]['object_types'][source_name]
bklink = src['backlink_renames'].get(
source_link,
f'backlink_via_{source_link}',
f'back_to_{source_name}',
)

self.write(
Expand Down Expand Up @@ -303,7 +303,7 @@ def render_type(self, spec, modules):
self.write(f'gel_type_id: Mapped[uuid.UUID] = mapped_column(')
self.indent()
self.write(
f"'__type__', Uuid(), unique=True, server_default='PLACEHOLDER')")
f"'__type__', Uuid(), server_default='PLACEHOLDER')")
self.dedent()

if spec['properties']:
Expand Down Expand Up @@ -370,7 +370,7 @@ def render_link(self, spec, mod, parent, modules):
tmod, target = get_mod_and_name(spec['target']['name'])
source = modules[mod]['object_types'][parent]
cardinality = spec['cardinality']
bklink = source['backlink_renames'].get(name, f'backlink_via_{name}')
bklink = source['backlink_renames'].get(name, f'back_to_{parent}')

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
Expand Down Expand Up @@ -437,7 +437,7 @@ def render_backlink(self, spec, mod, modules):
tmod, target = get_mod_and_name(spec['target']['name'])
cardinality = spec['cardinality']
exclusive = spec['exclusive']
bklink = spec.get('fwname', name.replace('backlink_via_', '', 1))
bklink = spec['fwname']

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
Expand Down
Loading
Loading