Skip to content

Commit

Permalink
feat: enhance session management and JSON serialization
Browse files Browse the repository at this point in the history
- Introduced `global_session` for session management in `SessionManager`.
- Updated `PydanticJSONMixin` to better handle JSONB columns conversion.
- Expanded test suite with new test cases for session and serialization.
- Added computed properties to models, supporting validation and serialization.
- Ensured session state is correctly managed, especially in test scenarios.

Generated-by: aiautocommit
  • Loading branch information
iloveitaly committed Jan 14, 2025
1 parent 085c6a9 commit c74804b
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 9 deletions.
6 changes: 6 additions & 0 deletions activemodel/mixins/pydantic_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@


class PydanticJSONMixin:
"""
By default, SQLModel does not convert JSONB columns into pydantic models when they are loaded from the database.
This mixin, combined with a custom serializer, fixes that issue.
"""

@reconstructor
def init_on_load(self):
# TODO do we need to inspect sa_type
Expand Down
36 changes: 36 additions & 0 deletions activemodel/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
database environment when testing.
"""

import contextlib
import json
import typing as t

Expand Down Expand Up @@ -41,6 +42,9 @@ class SessionManager:
_instance: t.ClassVar[t.Optional["SessionManager"]] = None

session_connection: Connection | None
"optionally specify a specific session connection to use for all get_session() calls, useful for testing"

session: Session | None

@classmethod
def get_instance(cls, database_url: str | None = None) -> "SessionManager":
Expand All @@ -55,7 +59,9 @@ def get_instance(cls, database_url: str | None = None) -> "SessionManager":
def __init__(self, database_url: str):
self._database_url = database_url
self._engine = None

self.session_connection = None
self.session = None

# TODO why is this type not reimported?
def get_engine(self) -> Engine:
Expand All @@ -72,11 +78,36 @@ def get_engine(self) -> Engine:
return self._engine

def get_session(self):
if self.session:

@contextlib.contextmanager
def _fake():
assert self.session
yield self.session

return _fake()

if self.session_connection:
return Session(bind=self.session_connection)

return Session(self.get_engine())

@contextlib.contextmanager
def global_session(self):
"""
Context manager that generates a new session and sets it as the
`session_connection`, restoring it to `None` at the end.
"""

# Generate a new connection and set it as the session_connection
with self.get_session() as session:
self.session = session

try:
yield
finally:
self.session = None


def init(database_url: str):
return SessionManager.get_instance(database_url)
Expand All @@ -88,3 +119,8 @@ def get_engine():

def get_session():
return SessionManager.get_instance().get_session()


def global_session():
with SessionManager.get_instance().global_session():
yield
69 changes: 63 additions & 6 deletions test/fastapi_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,75 @@
from fastapi import FastAPI
from typing import Annotated

from test.models import ExampleWithId
import pytest
import sqlalchemy
from fastapi import Depends, FastAPI, Path, Request
from fastapi.testclient import TestClient
from starlette.testclient import TestClient

from activemodel.session_manager import global_session
from activemodel.types.typeid import TypeIDType
from test.models import AnotherExample, ExampleWithComputedProperty, ExampleWithId


def fake_app():
api_app = FastAPI() # type: ignore
api_app = FastAPI(dependencies=[Depends(global_session)])

@api_app.get("/typeid")
async def index() -> ExampleWithId:
return "hi"
return ExampleWithId().save()

@api_app.get("/computed")
async def computed():
another_example = AnotherExample(note="hello").save()
return ExampleWithComputedProperty(another_example_id=another_example.id).save()

@api_app.post("/example/{example_id}")
async def get_record(
request: Request,
example_id: Annotated[TypeIDType, Path()],
) -> ExampleWithId:
example = ExampleWithId.get(id=example_id)
assert example
return example

return api_app


def test_openapi():
def fake_client():
app = fake_app()
return app, TestClient(app)


def test_openapi_generation():
openapi = fake_app().openapi()
breakpoint()


def test_typeid_input_parsing(create_and_wipe_database):
example = ExampleWithId().save()
example_id = example.id

app, client = fake_client()

response = client.post(f"/example/{example_id}")

assert response.status_code == 200


def test_typeid_invalid_prefix_match(create_and_wipe_database):
app, client = fake_client()

# TODO we should really be able to assert against this:
# with pytest.raises(TypeIDValidationError):
# we'll need to

with pytest.raises(sqlalchemy.exc.StatementError):
response = client.post("/example/user_01h45ytscbebyvny4gc8cr8ma2")


def test_computed_property(create_and_wipe_database):
app, client = fake_client()

response = client.get("/computed")

assert response.status_code == 200
assert response.json()["special_note"] == "SPECIAL: hello"
18 changes: 16 additions & 2 deletions test/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
from pydantic import computed_field
from sqlmodel import Field, Relationship

from activemodel import BaseModel
from activemodel.mixins import TypeIDMixin
from activemodel.types.typeid import TypeIDType
from sqlmodel import Relationship

TYPEID_PREFIX = "myid"


class AnotherExample(BaseModel, TypeIDMixin("myotherid"), table=True):
pass
note: str | None = Field(nullable=True)


class ExampleWithId(BaseModel, TypeIDMixin(TYPEID_PREFIX), table=True):
another_example_id: TypeIDType = AnotherExample.foreign_key(nullable=True)
another_example: AnotherExample = Relationship()


class ExampleWithComputedProperty(
BaseModel, TypeIDMixin("example_computed"), table=True
):
another_example_id: TypeIDType = AnotherExample.foreign_key()
another_example: AnotherExample = Relationship()

@computed_field
@property
def special_note(self) -> str:
return f"SPECIAL: {self.another_example.note}"
26 changes: 25 additions & 1 deletion test/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import Field
from sqlmodel import Field, Session

from activemodel import BaseModel
from activemodel.mixins import PydanticJSONMixin, TypeIDMixin
from activemodel.session_manager import SessionManager
from test.models import AnotherExample, ExampleWithComputedProperty


class SubObject(PydanticBaseModel):
Expand Down Expand Up @@ -55,3 +57,25 @@ def test_json_serialization(create_and_wipe_database):
assert fresh_example.optional_list_field
assert isinstance(fresh_example.optional_list_field[0], SubObject)
assert isinstance(fresh_example.unstructured_field, dict)


def test_computed_serialization(create_and_wipe_database):
# count()s are a bit paranoid because I don't understand the sqlalchemy session model yet

with SessionManager.get_instance().global_session():
another_example = AnotherExample(note="test").save()

example = ExampleWithComputedProperty(
another_example_id=another_example.id,
).save()

assert ExampleWithComputedProperty.count() == 1
assert AnotherExample.count() == 1

assert Session.object_session(another_example)
assert Session.object_session(example)

example.model_dump_json()

assert ExampleWithComputedProperty.count() == 1
assert AnotherExample.count() == 1

0 comments on commit c74804b

Please sign in to comment.