-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
6 changed files
with
143 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import logging | ||
|
||
from pydantic import InstanceOf | ||
|
||
from .base_model import BaseModel | ||
from .query_wrapper import QueryWrapper | ||
from .timestamps import TimestampMixin | ||
from .session_manager import SessionManager, get_engine, get_session, init | ||
|
||
# TODO need a way to specify the session generator | ||
# TODO need a way to specify the session generator | ||
logger = logging.getLogger(__name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,62 @@ | ||
""" | ||
Class to make managing sessions with SQL Model easy | ||
Class to make managing sessions with SQL Model easy. Also provides a common entrypoint to make it easy to mutate the | ||
database environment when testing. | ||
""" | ||
|
||
import typing as t | ||
|
||
from decouple import config | ||
from sqlalchemy import Engine | ||
from sqlmodel import Session, create_engine | ||
|
||
|
||
class SessionManager: | ||
_instance: t.ClassVar[t.Optional["SessionManager"]] = None | ||
|
||
session_connection: str | ||
|
||
@classmethod | ||
def get_instance(cls, database_url: str | None = None) -> "SessionManager": | ||
if cls._instance is None: | ||
assert ( | ||
database_url is not None | ||
), "Database URL required for first initialization" | ||
cls._instance = cls(database_url) | ||
|
||
return cls._instance | ||
|
||
def __init__(self, database_url: str): | ||
self._database_url = database_url | ||
self._engine = None | ||
self.session_connection = None | ||
|
||
# TODO why is this type not reimported? | ||
def get_engine(self) -> Engine: | ||
if not self._engine: | ||
self._engine = create_engine( | ||
self._database_url, | ||
# echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False), | ||
echo=True, | ||
echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False), | ||
# https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic | ||
pool_pre_ping=True, | ||
# some implementations include `future=True` but it's not required anymore | ||
) | ||
|
||
return self._engine | ||
|
||
def get_session(self): | ||
if self.session_connection: | ||
return Session(bind=self.session_connection) | ||
|
||
return Session(self.get_engine()) | ||
|
||
|
||
import os | ||
def init(database_url: str): | ||
return SessionManager.get_instance(database_url) | ||
|
||
# TODO need a way to specify the session generator | ||
manager = SessionManager(os.environ["TEST_DATABASE_URL"]) | ||
get_engine = manager.get_engine | ||
get_session = manager.get_session | ||
|
||
from sqlmodel.sql.expression import SelectOfScalar | ||
def get_engine(): | ||
return SessionManager.get_instance().get_engine() | ||
|
||
|
||
def compile_sql(target: SelectOfScalar): | ||
dialect = get_engine().dialect | ||
# TODO I wonder if we could store the dialect to avoid getting an engine reference | ||
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) | ||
return str(compiled) | ||
def get_session(): | ||
return SessionManager.get_instance().get_session() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from sqlalchemy import types | ||
from sqlalchemy.util import generic_repr | ||
from typeid import TypeID | ||
|
||
|
||
class TypeIDType(types.TypeDecorator): | ||
""" | ||
A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database. | ||
The prefix will not be persisted, instead the database-native UUID field will be used. | ||
At retrieval time a TypeID will be constructed based on the configured prefix and the | ||
UUID value from the database. | ||
Usage: | ||
# will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2" | ||
id = mapped_column( | ||
TypeIDType("user"), | ||
primary_key=True, | ||
default=lambda: TypeID("user") | ||
) | ||
""" | ||
|
||
impl = types.Uuid | ||
# impl = uuid.UUID | ||
cache_ok = True | ||
prefix: Optional[str] = None | ||
|
||
def __init__(self, prefix: Optional[str], *args, **kwargs): | ||
self.prefix = prefix | ||
super().__init__(*args, **kwargs) | ||
|
||
def __repr__(self) -> str: | ||
# Customize __repr__ to ensure that auto-generated code e.g. from alembic includes | ||
# the right __init__ params (otherwise by default prefix will be omitted because | ||
# uuid.__init__ does not have such an argument). | ||
# TODO this makes it so inspected code does NOT include the suffix | ||
return generic_repr( | ||
self, | ||
to_inspect=TypeID(self.prefix), | ||
) | ||
|
||
def process_bind_param(self, value, dialect): | ||
if self.prefix is None: | ||
assert value.prefix is None | ||
else: | ||
assert value.prefix == self.prefix | ||
|
||
return value.uuid | ||
|
||
def process_result_value(self, value, dialect): | ||
return TypeID.from_uuid(value, self.prefix) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from sqlmodel.sql.expression import SelectOfScalar | ||
|
||
from activemodel import get_engine | ||
|
||
|
||
def compile_sql(target: SelectOfScalar): | ||
dialect = get_engine().dialect | ||
# TODO I wonder if we could store the dialect to avoid getting an engine reference | ||
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) | ||
return str(compiled) |