Skip to content

Commit

Permalink
feat: enhance query handling and testing in ORM
Browse files Browse the repository at this point in the history
- Introduced `get_session` from `session_manager` for improved session management.
- Added utility function `compile_sql` to streamline SQL compilation.
- Refined imports and streamlined select query calls with alias `sm`.
- Enhanced `QueryWrapper` by adding a docstring for `one()` method.
- Expanded ORM test coverage with `test_basic_query`, refined `test_all_and_count`, and adjusted assertion methods for robustness.
- Added a new attribute `something` to the `ExampleRecord` class for more flexible data representation.

Generated-by: aiautocommit
  • Loading branch information
iloveitaly committed Jan 14, 2025
1 parent 3181cca commit 363c0dd
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
10 changes: 7 additions & 3 deletions activemodel/query_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import sqlmodel
import sqlmodel as sm

from .session_manager import get_session
from .utils import compile_sql


class QueryWrapper[T]:
Expand All @@ -12,9 +15,9 @@ def __init__(self, cls, *args) -> None:

if args:
# very naive, let's assume the args are specific select statements
self.target = sqlmodel.sql.select(*args).select_from(cls)
self.target = sm.select(*args).select_from(cls)
else:
self.target = sql.select(cls)
self.target = sm.select(cls)

# TODO the .exec results should be handled in one shot

Expand All @@ -23,6 +26,7 @@ def first(self):
return session.exec(self.target).first()

def one(self):
"requires exactly one result in the dataset"
with get_session() as session:
return session.exec(self.target).one()

Expand Down
2 changes: 1 addition & 1 deletion activemodel/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlmodel.sql.expression import SelectOfScalar

from activemodel import get_engine
from .session_manager import get_engine, get_session


def compile_sql(target: SelectOfScalar):
Expand Down
35 changes: 22 additions & 13 deletions test/orm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Test core ORM functions
"""

from test.utils import temporary_tables

from activemodel import BaseModel
from activemodel.mixins.timestamps import TimestampsMixin
from activemodel.mixins.typeid import TypeIDMixin
Expand All @@ -14,24 +12,35 @@
class ExampleRecord(
BaseModel, TimestampsMixin, TypeIDMixin(EXAMPLE_TABLE_PREFIX), table=True
):
pass
something: str | None


def test_all_and_count(create_and_wipe_database):
records_to_create = 10

# create 10 example records
for i in range(records_to_create):
ExampleRecord().save()

def test_list():
with temporary_tables():
# create 10 example records
for i in range(10):
ExampleRecord().save()
assert ExampleRecord.count() == records_to_create

assert ExampleRecord.count() == 10
all_records = list(ExampleRecord.all())
assert len(all_records) == records_to_create

all_records = list(ExampleRecord.all())
assert len(all_records) == 10
assert ExampleRecord.count() == records_to_create

record = all_records[0]
assert isinstance(record, ExampleRecord)
record = all_records[0]
assert isinstance(record, ExampleRecord)


def test_foreign_key():
field = ExampleRecord.foreign_key()
assert field.sa_type.prefix == EXAMPLE_TABLE_PREFIX


def test_basic_query(create_and_wipe_database):
example = ExampleRecord(something="hi").save()
query = ExampleRecord.select().where(ExampleRecord.something == "hi")

query_as_str = str(query)
result = query.first()

0 comments on commit 363c0dd

Please sign in to comment.