Skip to content

Commit

Permalink
feat: lots of stuff :)
Browse files Browse the repository at this point in the history
Generated-by: aiautocommit
  • Loading branch information
iloveitaly committed Nov 26, 2024
1 parent 9a55ff4 commit 0510d20
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 52 deletions.
10 changes: 6 additions & 4 deletions activemodel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging

from pydantic import InstanceOf

from .base_model import BaseModel
from .query_wrapper import QueryWrapper
from .timestamps import TimestampMixin
from .session_manager import SessionManager, get_engine, get_session, init

# TODO need a way to specify the session generator
# TODO need a way to specify the session generator
logger = logging.getLogger(__name__)
58 changes: 36 additions & 22 deletions activemodel/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@

import pydash
import sqlalchemy as sa
import sqlmodel as sm
from sqlalchemy.orm import declared_attr
from sqlmodel import Session, SQLModel
from sqlmodel import Session, SQLModel, select
from typeid import TypeID

from .query_wrapper import QueryWrapper
from .session_manager import get_session


class BaseModel(SQLModel):
"""
Base model class to inherit from so we can hate python less
https://github.com/woofz/sqlmodel-basecrud/blob/main/sqlmodel_basecrud/basecrud.py
{before,after} hooks are modeled after Rails.
"""

# TODO implement actually calling these hooks
Expand All @@ -24,16 +29,22 @@ def before_delete(self):
def after_delete(self):
pass

def before_save(self):
def before_update(self):
pass

def after_save(self):
def after_update(self):
pass

def before_update(self):
def before_create(self):
pass

def after_update(self):
def after_create(self):
pass

def before_save(self):
pass

def after_save(self):
pass

@declared_attr
Expand All @@ -55,44 +66,45 @@ def select(cls, *args):
return QueryWrapper[cls](cls, *args)

def save(self):
old_session = Session.object_session(self)
with get_session() as session:
if old_session:
if old_session := Session.object_session(self):
# I was running into an issue where the object was already
# associated with a session, but the session had been closed,
# to get around this, you need to remove it from the old one,
# then add it to the new one (below)
old_session.expunge(self)

self.before_update()
self.before_save()

# breakpoint()
# self.before_save()

session.add(self)
session.commit()
session.refresh(self)

self.after_update()
# self.after_save()
self.after_update()
self.after_save()
# self.after_create()

return self
return self

# except IntegrityError:
# log.quiet(f"{self} already exists in the database.")
# session.rollback()
# except IntegrityError:
# log.quiet(f"{self} already exists in the database.")
# session.rollback()

# TODO shouldn't this be handled by pydantic?
def json(self, **kwargs):
return json.dumps(self.dict(), default=str, **kwargs)

# TODO should move this to the wrapper
@classmethod
def count(cls):
def count(cls) -> int:
"""
Returns the number of records in the database.
"""
# TODO should move this to the wrapper
with get_session() as session:
query = sql.select(sql.func.count()).select_from(cls)
return session.exec(query).one()
return get_session().exec(sm.select(sm.func.count()).select_from(cls)).one()

# TODO what's super dangerous here is you pass a kwarg which does not map to a specific
# field it will result in `True`, which will return all records, and not give you any typing
Expand All @@ -109,15 +121,17 @@ def get(cls, *args: sa.BinaryExpression, **kwargs: t.Any):
# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
kwargs["id"] = args[0]
args = []
elif len(args) == 1 and isinstance(args[0], TypeID):
kwargs["id"] = args[0]
args = []

statement = sql.select(cls).filter(*args).filter_by(**kwargs)
with get_session() as session:
return session.exec(statement).first()
statement = select(cls).filter(*args).filter_by(**kwargs)
return get_session().exec(statement).first()

@classmethod
def all(cls):
with get_session() as session:
results = session.exec(sql.select(cls))
results = session.exec(sa.sql.select(cls))

# TODO do we need this or can we just return results?
for result in results:
Expand Down
16 changes: 4 additions & 12 deletions activemodel/query_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
from typing import Generic, TypeVar
import sqlmodel

from sqlmodel.sql.expression import SelectOfScalar

WrappedModelType = TypeVar("WrappedModelType")


def compile_sql(target: SelectOfScalar):
return str(target.compile(get_engine().connect()))


class QueryWrapper(Generic[WrappedModelType]):
class QueryWrapper[T]:
"""
Make it easy to run queries off of a model
"""
Expand All @@ -20,7 +12,7 @@ def __init__(self, cls, *args) -> None:

if args:
# very naive, let's assume the args are specific select statements
self.target = sql.select(*args).select_from(cls)
self.target = sqlmodel.sql.select(*args).select_from(cls)
else:
self.target = sql.select(cls)

Expand Down Expand Up @@ -75,7 +67,7 @@ def wrapper(*args, **kwargs):

def sql(self):
"""
Output the raw SQL of the query
Output the raw SQL of the query for debugging
"""

return compile_sql(self.target)
Expand Down
45 changes: 31 additions & 14 deletions activemodel/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,62 @@
"""
Class to make managing sessions with SQL Model easy
Class to make managing sessions with SQL Model easy. Also provides a common entrypoint to make it easy to mutate the
database environment when testing.
"""

import typing as t

from decouple import config
from sqlalchemy import Engine
from sqlmodel import Session, create_engine


class SessionManager:
_instance: t.ClassVar[t.Optional["SessionManager"]] = None

session_connection: str

@classmethod
def get_instance(cls, database_url: str | None = None) -> "SessionManager":
if cls._instance is None:
assert (
database_url is not None
), "Database URL required for first initialization"
cls._instance = cls(database_url)

return cls._instance

def __init__(self, database_url: str):
self._database_url = database_url
self._engine = None
self.session_connection = None

# TODO why is this type not reimported?
def get_engine(self) -> Engine:
if not self._engine:
self._engine = create_engine(
self._database_url,
# echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
echo=True,
echo=config("ACTIVEMODEL_LOG_SQL", cast=bool, default=False),
# https://docs.sqlalchemy.org/en/20/core/pooling.html#disconnect-handling-pessimistic
pool_pre_ping=True,
# some implementations include `future=True` but it's not required anymore
)

return self._engine

def get_session(self):
if self.session_connection:
return Session(bind=self.session_connection)

return Session(self.get_engine())


import os
def init(database_url: str):
return SessionManager.get_instance(database_url)

# TODO need a way to specify the session generator
manager = SessionManager(os.environ["TEST_DATABASE_URL"])
get_engine = manager.get_engine
get_session = manager.get_session

from sqlmodel.sql.expression import SelectOfScalar
def get_engine():
return SessionManager.get_instance().get_engine()


def compile_sql(target: SelectOfScalar):
dialect = get_engine().dialect
# TODO I wonder if we could store the dialect to avoid getting an engine reference
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
return str(compiled)
def get_session():
return SessionManager.get_instance().get_session()
56 changes: 56 additions & 0 deletions activemodel/types/typeid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Lifted from: https://github.com/akhundMurad/typeid-python/blob/main/examples/sqlalchemy.py
"""

from typing import Optional

from sqlalchemy import types
from sqlalchemy.util import generic_repr
from typeid import TypeID


class TypeIDType(types.TypeDecorator):
"""
A SQLAlchemy TypeDecorator that allows storing TypeIDs in the database.
The prefix will not be persisted, instead the database-native UUID field will be used.
At retrieval time a TypeID will be constructed based on the configured prefix and the
UUID value from the database.
Usage:
# will result in TypeIDs such as "user_01h45ytscbebyvny4gc8cr8ma2"
id = mapped_column(
TypeIDType("user"),
primary_key=True,
default=lambda: TypeID("user")
)
"""

impl = types.Uuid
# impl = uuid.UUID
cache_ok = True
prefix: Optional[str] = None

def __init__(self, prefix: Optional[str], *args, **kwargs):
self.prefix = prefix
super().__init__(*args, **kwargs)

def __repr__(self) -> str:
# Customize __repr__ to ensure that auto-generated code e.g. from alembic includes
# the right __init__ params (otherwise by default prefix will be omitted because
# uuid.__init__ does not have such an argument).
# TODO this makes it so inspected code does NOT include the suffix
return generic_repr(
self,
to_inspect=TypeID(self.prefix),
)

def process_bind_param(self, value, dialect):
if self.prefix is None:
assert value.prefix is None
else:
assert value.prefix == self.prefix

return value.uuid

def process_result_value(self, value, dialect):
return TypeID.from_uuid(value, self.prefix)
10 changes: 10 additions & 0 deletions activemodel/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from sqlmodel.sql.expression import SelectOfScalar

from activemodel import get_engine


def compile_sql(target: SelectOfScalar):
dialect = get_engine().dialect
# TODO I wonder if we could store the dialect to avoid getting an engine reference
compiled = target.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
return str(compiled)

0 comments on commit 0510d20

Please sign in to comment.