From 0510d20a807829db4f6b454ee2915c32ecedb323 Mon Sep 17 00:00:00 2001 From: Michael Bianco Date: Tue, 26 Nov 2024 08:01:12 -0700 Subject: [PATCH] feat: lots of stuff :) Generated-by: aiautocommit --- activemodel/__init__.py | 10 +++--- activemodel/base_model.py | 58 +++++++++++++++++++++------------- activemodel/query_wrapper.py | 16 +++------- activemodel/session_manager.py | 45 ++++++++++++++++++-------- activemodel/types/typeid.py | 56 ++++++++++++++++++++++++++++++++ activemodel/utils.py | 10 ++++++ 6 files changed, 143 insertions(+), 52 deletions(-) create mode 100644 activemodel/types/typeid.py create mode 100644 activemodel/utils.py diff --git a/activemodel/__init__.py b/activemodel/__init__.py index b408b62..20ce17f 100755 --- a/activemodel/__init__.py +++ b/activemodel/__init__.py @@ -1,6 +1,8 @@ +import logging + +from pydantic import InstanceOf + from .base_model import BaseModel -from .query_wrapper import QueryWrapper -from .timestamps import TimestampMixin +from .session_manager import SessionManager, get_engine, get_session, init -# TODO need a way to specify the session generator -# TODO need a way to specify the session generator +logger = logging.getLogger(__name__) diff --git a/activemodel/base_model.py b/activemodel/base_model.py index 01af580..e198519 100644 --- a/activemodel/base_model.py +++ b/activemodel/base_model.py @@ -3,10 +3,13 @@ import pydash import sqlalchemy as sa +import sqlmodel as sm from sqlalchemy.orm import declared_attr -from sqlmodel import Session, SQLModel +from sqlmodel import Session, SQLModel, select +from typeid import TypeID from .query_wrapper import QueryWrapper +from .session_manager import get_session class BaseModel(SQLModel): @@ -14,6 +17,8 @@ class BaseModel(SQLModel): Base model class to inherit from so we can hate python less https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py + + {before,after} hooks are modeled after Rails. """ # TODO implement actually calling these hooks @@ -24,16 +29,22 @@ def before_delete(self): def after_delete(self): pass - def before_save(self): + def before_update(self): pass - def after_save(self): + def after_update(self): pass - def before_update(self): + def before_create(self): pass - def after_update(self): + def after_create(self): + pass + + def before_save(self): + pass + + def after_save(self): pass @declared_attr @@ -55,9 +66,8 @@ def select(cls, *args): return QueryWrapper[cls](cls, *args) def save(self): - old_session = Session.object_session(self) with get_session() as session: - if old_session: + if old_session := Session.object_session(self): # I was running into an issue where the object was already # associated with a session, but the session had been closed, # to get around this, you need to remove it from the old one, @@ -65,34 +75,36 @@ def save(self): old_session.expunge(self) self.before_update() + self.before_save() + + # breakpoint() # self.before_save() session.add(self) session.commit() session.refresh(self) - self.after_update() - # self.after_save() + self.after_update() + self.after_save() + # self.after_create() - return self + return self - # except IntegrityError: - # log.quiet(f"{self} already exists in the database.") - # session.rollback() + # except IntegrityError: + # log.quiet(f"{self} already exists in the database.") + # session.rollback() # TODO shouldn't this be handled by pydantic? def json(self, **kwargs): return json.dumps(self.dict(), default=str, **kwargs) + # TODO should move this to the wrapper @classmethod - def count(cls): + def count(cls) -> int: """ Returns the number of records in the database. """ - # TODO should move this to the wrapper - with get_session() as session: - query = sql.select(sql.func.count()).select_from(cls) - return session.exec(query).one() + return get_session().exec(sm.select(sm.func.count()).select_from(cls)).one() # TODO what's super dangerous here is you pass a kwarg which does not map to a specific # field it will result in `True`, which will return all records, and not give you any typing @@ -109,15 +121,17 @@ def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any): # TODO id is hardcoded, not good! Need to dynamically pick the best uid field kwargs["id"] = args[0] args = [] + elif len(args) == 1 and isinstance(args[0], TypeID): + kwargs["id"] = args[0] + args = [] - statement = sql.select(cls).filter(*args).filter_by(**kwargs) - with get_session() as session: - return session.exec(statement).first() + statement = select(cls).filter(*args).filter_by(**kwargs) + return get_session().exec(statement).first() @classmethod def all(cls): with get_session() as session: - results = session.exec(sql.select(cls)) + results = session.exec(sa.sql.select(cls)) # TODO do we need this or can we just return results? for result in results: diff --git a/activemodel/query_wrapper.py b/activemodel/query_wrapper.py index 1399516..886b9a8 100644 --- a/activemodel/query_wrapper.py +++ b/activemodel/query_wrapper.py @@ -1,15 +1,7 @@ -from typing import Generic, TypeVar +import sqlmodel -from sqlmodel.sql.expression import SelectOfScalar -WrappedModelType = TypeVar("WrappedModelType") - - -def compile_sql(target: SelectOfScalar): - return str(target.compile(get_engine().connect())) - - -class QueryWrapper(Generic[WrappedModelType]): +class QueryWrapper[T]: """ Make it easy to run queries off of a model """ @@ -20,7 +12,7 @@ def __init__(self, cls, *args) -> None: if args: # very naive, let's assume the args are specific select statements - self.target = sql.select(*args).select_from(cls) + self.target = sqlmodel.sql.select(*args).select_from(cls) else: self.target = sql.select(cls) @@ -75,7 +67,7 @@ def wrapper(*args, **kwargs): def sql(self): """ - Output the raw SQL of the query + Output the raw SQL of the query for debugging """ return compile_sql(self.target) diff --git a/activemodel/session_manager.py b/activemodel/session_manager.py index a8c4951..7d70186 100644 --- a/activemodel/session_manager.py +++ b/activemodel/session_manager.py @@ -1,45 +1,62 @@ """ -Class to make managing sessions with SQL Model easy +Class to make managing sessions with SQL Model easy. Also provides a common entrypoint to make it easy to mutate the +database environment when testing. """ +import typing as t + +from decouple import config from sqlalchemy import Engine from sqlmodel import Session, create_engine class SessionManager: + _instance: t.ClassVar[t.Optional["SessionManager"]] = None + + session_connection: str + + @classmethod + def get_instance(cls, database_url: str | None = None) -> "SessionManager": + if cls._instance is None: + assert ( + database_url is not None + ), "Database URL required for first initialization" + cls._instance = cls(database_url) + + return cls._instance + def __init__(self, database_url: str): self._database_url = database_url self._engine = None + self.session_connection = None # TODO why is this type not reimported? def get_engine(self) -> Engine: if not self._engine: self._engine = create_engine( self._database_url, - # echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False), - echo=True, + echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False), # https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic pool_pre_ping=True, + # some implementations include `future=True` but it's not required anymore ) return self._engine def get_session(self): + if self.session_connection: + return Session(bind=self.session_connection) + return Session(self.get_engine()) -import os +def init(database_url: str): + return SessionManager.get_instance(database_url) -# TODO need a way to specify the session generator -manager = SessionManager(os.environ["TEST_DATABASE_URL"]) -get_engine = manager.get_engine -get_session = manager.get_session -from sqlmodel.sql.expression import SelectOfScalar +def get_engine(): + return SessionManager.get_instance().get_engine() -def compile_sql(target: SelectOfScalar): - dialect = get_engine().dialect - # TODO I wonder if we could store the dialect to avoid getting an engine reference - compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) - return str(compiled) +def get_session(): + return SessionManager.get_instance().get_session() diff --git a/activemodel/types/typeid.py b/activemodel/types/typeid.py new file mode 100644 index 0000000..7f18222 --- /dev/null +++ b/activemodel/types/typeid.py @@ -0,0 +1,56 @@ +""" +Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py +""" + +from typing import Optional + +from sqlalchemy import types +from sqlalchemy.util import generic_repr +from typeid import TypeID + + +class TypeIDType(types.TypeDecorator): + """ + A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database. + The prefix will not be persisted, instead the database-native UUID field will be used. + At retrieval time a TypeID will be constructed based on the configured prefix and the + UUID value from the database. + + Usage: + # will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2" + id = mapped_column( + TypeIDType("user"), + primary_key=True, + default=lambda: TypeID("user") + ) + """ + + impl = types.Uuid + # impl = uuid.UUID + cache_ok = True + prefix: Optional[str] = None + + def __init__(self, prefix: Optional[str], *args, **kwargs): + self.prefix = prefix + super().__init__(*args, **kwargs) + + def __repr__(self) -> str: + # Customize __repr__ to ensure that auto-generated code e.g. from alembic includes + # the right __init__ params (otherwise by default prefix will be omitted because + # uuid.__init__ does not have such an argument). + # TODO this makes it so inspected code does NOT include the suffix + return generic_repr( + self, + to_inspect=TypeID(self.prefix), + ) + + def process_bind_param(self, value, dialect): + if self.prefix is None: + assert value.prefix is None + else: + assert value.prefix == self.prefix + + return value.uuid + + def process_result_value(self, value, dialect): + return TypeID.from_uuid(value, self.prefix) diff --git a/activemodel/utils.py b/activemodel/utils.py new file mode 100644 index 0000000..e41f54c --- /dev/null +++ b/activemodel/utils.py @@ -0,0 +1,10 @@ +from sqlmodel.sql.expression import SelectOfScalar + +from activemodel import get_engine + + +def compile_sql(target: SelectOfScalar): + dialect = get_engine().dialect + # TODO I wonder if we could store the dialect to avoid getting an engine reference + compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + return str(compiled)