diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f900d2f..f834a90 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,9 @@ Unreleased ---------- +* Fixes #79, enabling compatibility with ``django.contrib.postgres`` +* Adds basic infinite recursion prevention for chainable transforms + v0.8.0 ------ diff --git a/rest_framework_filters/utils.py b/rest_framework_filters/utils.py index 627e5b9..3052154 100644 --- a/rest_framework_filters/utils.py +++ b/rest_framework_filters/utils.py @@ -1,7 +1,9 @@ from collections import OrderedDict +import django from django.db.models.constants import LOOKUP_SEP +from django.db.models.expressions import Expression from django.db.models.lookups import Transform from django.utils import six @@ -13,11 +15,47 @@ def lookups_for_field(model_field): lookups = [] for expr, lookup in six.iteritems(class_lookups(model_field)): + if issubclass(lookup, Transform) and django.VERSION >= (1, 9): + transform = lookup(Expression(model_field)) + lookups += [ + LOOKUP_SEP.join([expr, sub_expr]) for sub_expr + in lookups_for_transform(transform) + ] + + else: + lookups.append(expr) + + return lookups + + +def lookups_for_transform(transform): + """ + Generates a list of subsequent lookup expressions for a transform. + + Note: + Infinite transform recursion is only prevented when the subsequent and + passed in transforms are the same class. For example, the ``Unaccent`` + transform from ``django.contrib.postgres``. + There is no cycle detection across multiple transforms. For example, + ``a__b__a__b`` would continue to recurse. However, this is not currently + a problem (no builtin transforms exhibit this behavior). + + """ + lookups = [] + + for expr, lookup in six.iteritems(class_lookups(transform.output_field)): if issubclass(lookup, Transform): + + # type match indicates recursion. + if type(transform) == lookup: + continue + + sub_transform = lookup(transform) lookups += [ - LOOKUP_SEP.join([expr, transform]) for transform - in lookups_for_field(lookup(model_field).output_field) + LOOKUP_SEP.join([expr, sub_expr]) for sub_expr + in lookups_for_transform(sub_transform) ] + else: lookups.append(expr) @@ -28,12 +66,12 @@ def class_lookups(model_field): """ Get a compiled set of class_lookups for a model field. """ - field_class = model_field.__class__ + field_class = type(model_field) class_lookups = OrderedDict() # traverse MRO in reverse, as this puts standard # lookups before subclass transforms/lookups - for cls in field_class.mro()[::-1]: + for cls in reversed(field_class.mro()): if hasattr(cls, 'class_lookups'): class_lookups.update(getattr(cls, 'class_lookups')) diff --git a/tests/test_utils.py b/tests/test_utils.py index ab14744..50b78db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,6 +30,16 @@ def test_transformed_field(self): self.assertIn('date__year__exact', lookups) +@unittest.skipIf(django.VERSION < (1, 9), "version does not support transformed lookup expressions") +class LookupsForTransformTests(TestCase): + def test_recursion_prevention(self): + model_field = Person._meta.get_field('name') + lookups = utils.lookups_for_field(model_field) + + self.assertIn('unaccent__exact', lookups) + self.assertNotIn('unaccent__unaccent__exact', lookups) + + class ClassLookupsTests(TestCase): def test_standard_field(self): model_field = Person._meta.get_field('name') diff --git a/tests/testapp/__init__.py b/tests/testapp/__init__.py index e69de29..d6c0860 100644 --- a/tests/testapp/__init__.py +++ b/tests/testapp/__init__.py @@ -0,0 +1,2 @@ + +default_app_config = 'tests.testapp.apps.TestappConfig' diff --git a/tests/testapp/apps.py b/tests/testapp/apps.py new file mode 100644 index 0000000..eafad9f --- /dev/null +++ b/tests/testapp/apps.py @@ -0,0 +1,13 @@ + +from django.apps import AppConfig +from django.db.models import CharField, TextField + +from .lookups import Unaccent + + +class TestappConfig(AppConfig): + name = 'tests.testapp' + + def ready(self): + CharField.register_lookup(Unaccent) + TextField.register_lookup(Unaccent) diff --git a/tests/testapp/lookups.py b/tests/testapp/lookups.py new file mode 100644 index 0000000..235c17c --- /dev/null +++ b/tests/testapp/lookups.py @@ -0,0 +1,9 @@ +from django.db.models import Transform + + +# This is a copy of the `Unaccent` transform from `django.contrib.postgres`. +# This is necessary as the postgres app requires psycopg2 to be installed. +class Unaccent(Transform): + bilateral = True + lookup_name = 'unaccent' + function = 'UNACCENT'