Skip to content

Commit

Permalink
Merge pull request philipn#80 from rpkilby/prevent-transform-recursion
Browse files Browse the repository at this point in the history
Fix philipn#79, add transform recursion prevention
  • Loading branch information
Ryan P Kilby committed Apr 4, 2016
2 parents 2b79a7c + 32276c4 commit 885b3a5
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Unreleased
----------

* Fixes #79, enabling compatibility with ``django.contrib.postgres``
* Adds basic infinite recursion prevention for chainable transforms

v0.8.0
------

Expand Down
46 changes: 42 additions & 4 deletions rest_framework_filters/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from collections import OrderedDict

import django
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Expression
from django.db.models.lookups import Transform
from django.utils import six

Expand All @@ -13,11 +15,47 @@ def lookups_for_field(model_field):
lookups = []

for expr, lookup in six.iteritems(class_lookups(model_field)):
if issubclass(lookup, Transform) and django.VERSION >= (1, 9):
transform = lookup(Expression(model_field))
lookups += [
LOOKUP_SEP.join([expr, sub_expr]) for sub_expr
in lookups_for_transform(transform)
]

else:
lookups.append(expr)

return lookups


def lookups_for_transform(transform):
"""
Generates a list of subsequent lookup expressions for a transform.
Note:
Infinite transform recursion is only prevented when the subsequent and
passed in transforms are the same class. For example, the ``Unaccent``
transform from ``django.contrib.postgres``.
There is no cycle detection across multiple transforms. For example,
``a__b__a__b`` would continue to recurse. However, this is not currently
a problem (no builtin transforms exhibit this behavior).
"""
lookups = []

for expr, lookup in six.iteritems(class_lookups(transform.output_field)):
if issubclass(lookup, Transform):

# type match indicates recursion.
if type(transform) == lookup:
continue

sub_transform = lookup(transform)
lookups += [
LOOKUP_SEP.join([expr, transform]) for transform
in lookups_for_field(lookup(model_field).output_field)
LOOKUP_SEP.join([expr, sub_expr]) for sub_expr
in lookups_for_transform(sub_transform)
]

else:
lookups.append(expr)

Expand All @@ -28,12 +66,12 @@ def class_lookups(model_field):
"""
Get a compiled set of class_lookups for a model field.
"""
field_class = model_field.__class__
field_class = type(model_field)
class_lookups = OrderedDict()

# traverse MRO in reverse, as this puts standard
# lookups before subclass transforms/lookups
for cls in field_class.mro()[::-1]:
for cls in reversed(field_class.mro()):
if hasattr(cls, 'class_lookups'):
class_lookups.update(getattr(cls, 'class_lookups'))

Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_transformed_field(self):
self.assertIn('date__year__exact', lookups)


@unittest.skipIf(django.VERSION < (1, 9), "version does not support transformed lookup expressions")
class LookupsForTransformTests(TestCase):
def test_recursion_prevention(self):
model_field = Person._meta.get_field('name')
lookups = utils.lookups_for_field(model_field)

self.assertIn('unaccent__exact', lookups)
self.assertNotIn('unaccent__unaccent__exact', lookups)


class ClassLookupsTests(TestCase):
def test_standard_field(self):
model_field = Person._meta.get_field('name')
Expand Down
2 changes: 2 additions & 0 deletions tests/testapp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

default_app_config = 'tests.testapp.apps.TestappConfig'
13 changes: 13 additions & 0 deletions tests/testapp/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

from django.apps import AppConfig
from django.db.models import CharField, TextField

from .lookups import Unaccent


class TestappConfig(AppConfig):
name = 'tests.testapp'

def ready(self):
CharField.register_lookup(Unaccent)
TextField.register_lookup(Unaccent)
9 changes: 9 additions & 0 deletions tests/testapp/lookups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.db.models import Transform


# This is a copy of the `Unaccent` transform from `django.contrib.postgres`.
# This is necessary as the postgres app requires psycopg2 to be installed.
class Unaccent(Transform):
bilateral = True
lookup_name = 'unaccent'
function = 'UNACCENT'

0 comments on commit 885b3a5

Please sign in to comment.