Skip to content

Commit

Permalink
fix: support lots of UUID inputs
Browse files Browse the repository at this point in the history
Generated-by: aiautocommit
  • Loading branch information
iloveitaly committed Jan 6, 2025
1 parent 78af758 commit 783a9fe
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 11 deletions.
17 changes: 13 additions & 4 deletions activemodel/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import typing as t
from uuid import UUID

import pydash
from typeid import TypeID
Expand Down Expand Up @@ -150,7 +151,7 @@ def foreign_key(cls):
"""

return Field(
# TODO id field is hard coded
# TODO id field is hard coded, should pick the PK field in case it's different
sa_type=cls.model_fields["id"].sa_column.type, # type: ignore
foreign_key=f"{cls.__tablename__}.id",
nullable=False,
Expand Down Expand Up @@ -248,13 +249,21 @@ def get(cls, *args: t.Any, **kwargs: t.Any):
Gets a single record from the database. Pass an PK ID or a kwarg to filter by.
"""

# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
id_field_name = "id"

# special case for getting by ID
if len(args) == 1 and isinstance(args[0], int):
# TODO id is hardcoded, not good! Need to dynamically pick the best uid field
kwargs["id"] = args[0]
kwargs[id_field_name] = args[0]
args = []
elif len(args) == 1 and isinstance(args[0], TypeID):
kwargs["id"] = args[0]
kwargs[id_field_name] = args[0]
args = []
elif len(args) == 1 and isinstance(args[0], str):
kwargs[id_field_name] = args[0]
args = []
elif len(args) == 1 and isinstance(args[0], UUID):
kwargs[id_field_name] = args[0]
args = []

statement = select(cls).filter(*args).filter_by(**kwargs)
Expand Down
3 changes: 2 additions & 1 deletion activemodel/mixins/typeid.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import uuid

from sqlmodel import Column, Field
from typeid import TypeID

from activemodel.types.typeid import TypeIDType
from sqlmodel import Column, Field

# global list of prefixes to ensure uniqueness
_prefixes = []


Expand Down
45 changes: 39 additions & 6 deletions activemodel/types/typeid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

from typing import Optional
from uuid import UUID

from typeid import TypeID

from sqlalchemy import types
from sqlalchemy.util import generic_repr
from typeid import TypeID


class TypeIDType(types.TypeDecorator):
Expand Down Expand Up @@ -45,12 +47,43 @@ def __repr__(self) -> str:
)

def process_bind_param(self, value, dialect):
if self.prefix is None:
assert value.prefix is None
else:
assert value.prefix == self.prefix
"""
This is run when a search query is built or ...
"""

if isinstance(value, UUID):
# then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002')
return value

if isinstance(value, str) and value.startswith(self.prefix + "_"):
# then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2'
value = TypeID.from_string(value)

return value.uuid
if isinstance(value, str):
# no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle
# ex: '01942886-7afc-7129-8f57-db09137ed002'
return UUID(value)

if isinstance(value, TypeID):
# TODO in what case could this None prefix ever occur?
if self.prefix is None:
assert value.prefix is None
else:
assert value.prefix == self.prefix

return value.uuid

raise ValueError("Unexpected input type")

def process_result_value(self, value, dialect):
return TypeID.from_uuid(value, self.prefix)

# def coerce_compared_value(self, op, value):
# """
# This method is called when SQLAlchemy needs to compare a column to a value.
# By returning self, we indicate that this type can handle TypeID instances.
# """
# if isinstance(value, TypeID):
# return self

# return super().coerce_compared_value(op, value)
45 changes: 45 additions & 0 deletions test/typeid_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import pytest
from typeid import TypeID

from test.utils import temporary_tables

from activemodel import BaseModel
from activemodel.mixins import TypeIDMixin


Expand All @@ -13,3 +17,44 @@ def test_enforces_unique_prefixes():
def test_no_empty_prefixes_test():
with pytest.raises(AssertionError):
TypeIDMixin("")


TYPEID_PREFIX = "myid"


class ExampleWithId(BaseModel, TypeIDMixin(TYPEID_PREFIX), table=True):
pass


# the UIDs stored in the DB are NOT the same as the
def test_get_through_prefixed_uid():
type_uid = TypeID(prefix=TYPEID_PREFIX)

with temporary_tables():
record = ExampleWithId.get(type_uid)
assert record is None


def test_get_through_prefixed_uid_as_str():
type_uid = TypeID(prefix=TYPEID_PREFIX)

with temporary_tables():
record = ExampleWithId.get(str(type_uid))
assert record is None


def test_get_through_plain_uid_as_str():
type_uid = TypeID(prefix=TYPEID_PREFIX)

with temporary_tables():
# pass uid as string. Ex: '01942886-7afc-7129-8f57-db09137ed002'
record = ExampleWithId.get(str(type_uid.uuid))
assert record is None


def test_get_through_plain_uid():
type_uid = TypeID(prefix=TYPEID_PREFIX)

with temporary_tables():
record = ExampleWithId.get(type_uid.uuid)
assert record is None

0 comments on commit 783a9fe

Please sign in to comment.