From 783a9fe7f733b346cd6bf2401a9ae134ec785cf6 Mon Sep 17 00:00:00 2001 From: Michael Bianco Date: Mon, 6 Jan 2025 12:33:39 -0700 Subject: [PATCH] fix: support lots of UUID inputs Generated-by: aiautocommit --- activemodel/base_model.py | 17 ++++++++++---- activemodel/mixins/typeid.py | 3 ++- activemodel/types/typeid.py | 45 +++++++++++++++++++++++++++++++----- test/typeid_test.py | 45 ++++++++++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 11 deletions(-) diff --git a/activemodel/base_model.py b/activemodel/base_model.py index b43dfc6..e01462e 100644 --- a/activemodel/base_model.py +++ b/activemodel/base_model.py @@ -1,5 +1,6 @@ import json import typing as t +from uuid import UUID import pydash from typeid import TypeID @@ -150,7 +151,7 @@ def foreign_key(cls): """ return Field( - # TODO id field is hard coded + # TODO id field is hard coded, should pick the PK field in case it's different sa_type=cls.model_fields["id"].sa_column.type, # type: ignore foreign_key=f"{cls.__tablename__}.id", nullable=False, @@ -248,13 +249,21 @@ def get(cls, *args: t.Any, **kwargs: t.Any): Gets a single record from the database. Pass an PK ID or a kwarg to filter by. """ + # TODO id is hardcoded, not good! Need to dynamically pick the best uid field + id_field_name = "id" + # special case for getting by ID if len(args) == 1 and isinstance(args[0], int): - # TODO id is hardcoded, not good! Need to dynamically pick the best uid field - kwargs["id"] = args[0] + kwargs[id_field_name] = args[0] args = [] elif len(args) == 1 and isinstance(args[0], TypeID): - kwargs["id"] = args[0] + kwargs[id_field_name] = args[0] + args = [] + elif len(args) == 1 and isinstance(args[0], str): + kwargs[id_field_name] = args[0] + args = [] + elif len(args) == 1 and isinstance(args[0], UUID): + kwargs[id_field_name] = args[0] args = [] statement = select(cls).filter(*args).filter_by(**kwargs) diff --git a/activemodel/mixins/typeid.py b/activemodel/mixins/typeid.py index a9b69fe..79e0984 100644 --- a/activemodel/mixins/typeid.py +++ b/activemodel/mixins/typeid.py @@ -1,10 +1,11 @@ import uuid -from sqlmodel import Column, Field from typeid import TypeID from activemodel.types.typeid import TypeIDType +from sqlmodel import Column, Field +# global list of prefixes to ensure uniqueness _prefixes = [] diff --git a/activemodel/types/typeid.py b/activemodel/types/typeid.py index 7f18222..dee402a 100644 --- a/activemodel/types/typeid.py +++ b/activemodel/types/typeid.py @@ -3,10 +3,12 @@ """ from typing import Optional +from uuid import UUID + +from typeid import TypeID from sqlalchemy import types from sqlalchemy.util import generic_repr -from typeid import TypeID class TypeIDType(types.TypeDecorator): @@ -45,12 +47,43 @@ def __repr__(self) -> str: ) def process_bind_param(self, value, dialect): - if self.prefix is None: - assert value.prefix is None - else: - assert value.prefix == self.prefix + """ + This is run when a search query is built or ... + """ + + if isinstance(value, UUID): + # then it's a UUID class, such as UUID('01942886-7afc-7129-8f57-db09137ed002') + return value + + if isinstance(value, str) and value.startswith(self.prefix + "_"): + # then it's a TypeID such as 'user_01h45ytscbebyvny4gc8cr8ma2' + value = TypeID.from_string(value) - return value.uuid + if isinstance(value, str): + # no prefix, raw UUID, let's coerce it into a UUID which SQLAlchemy can handle + # ex: '01942886-7afc-7129-8f57-db09137ed002' + return UUID(value) + + if isinstance(value, TypeID): + # TODO in what case could this None prefix ever occur? + if self.prefix is None: + assert value.prefix is None + else: + assert value.prefix == self.prefix + + return value.uuid + + raise ValueError("Unexpected input type") def process_result_value(self, value, dialect): return TypeID.from_uuid(value, self.prefix) + + # def coerce_compared_value(self, op, value): + # """ + # This method is called when SQLAlchemy needs to compare a column to a value. + # By returning self, we indicate that this type can handle TypeID instances. + # """ + # if isinstance(value, TypeID): + # return self + + # return super().coerce_compared_value(op, value) diff --git a/test/typeid_test.py b/test/typeid_test.py index f5652ac..2a23cc0 100644 --- a/test/typeid_test.py +++ b/test/typeid_test.py @@ -1,5 +1,9 @@ import pytest +from typeid import TypeID +from test.utils import temporary_tables + +from activemodel import BaseModel from activemodel.mixins import TypeIDMixin @@ -13,3 +17,44 @@ def test_enforces_unique_prefixes(): def test_no_empty_prefixes_test(): with pytest.raises(AssertionError): TypeIDMixin("") + + +TYPEID_PREFIX = "myid" + + +class ExampleWithId(BaseModel, TypeIDMixin(TYPEID_PREFIX), table=True): + pass + + +# the UIDs stored in the DB are NOT the same as the +def test_get_through_prefixed_uid(): + type_uid = TypeID(prefix=TYPEID_PREFIX) + + with temporary_tables(): + record = ExampleWithId.get(type_uid) + assert record is None + + +def test_get_through_prefixed_uid_as_str(): + type_uid = TypeID(prefix=TYPEID_PREFIX) + + with temporary_tables(): + record = ExampleWithId.get(str(type_uid)) + assert record is None + + +def test_get_through_plain_uid_as_str(): + type_uid = TypeID(prefix=TYPEID_PREFIX) + + with temporary_tables(): + # pass uid as string. Ex: '01942886-7afc-7129-8f57-db09137ed002' + record = ExampleWithId.get(str(type_uid.uuid)) + assert record is None + + +def test_get_through_plain_uid(): + type_uid = TypeID(prefix=TYPEID_PREFIX) + + with temporary_tables(): + record = ExampleWithId.get(type_uid.uuid) + assert record is None