Skip to content

Commit

Permalink
add overlap lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Jan 7, 2025
1 parent 1e393c4 commit 039af42
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 0 deletions.
3 changes: 3 additions & 0 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# contains with Exists() doesn't work:
# https://github.com/mongodb-labs/django-mongodb/issues/204
"model_fields_.test_arrayfield.QueryingTests.test_contains_subquery",
# overlap with values() returns no results:
# https://github.com/mongodb-labs/django-mongodb/issues/209
"model_fields_.test_arrayfield.QueryingTests.test_overlap_values",
# icontains doesn't work on ArrayField:
# Unsupported conversion from array to string in $convert
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
Expand Down
12 changes: 12 additions & 0 deletions django_mongodb/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ class ArrayExact(ArrayRHSMixin, Exact):
pass


@ArrayField.register_lookup
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
lookup_name = "overlap"

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
return {
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
}


@ArrayField.register_lookup
class ArrayLenTransform(Transform):
lookup_name = "len"
Expand Down
20 changes: 20 additions & 0 deletions docs/source/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,26 @@ data. It uses the ``$setIntersection`` operator. For example:
>>> Post.objects.filter(tags__contains=["django", "thoughts"])
<QuerySet [<Post: First post>]>
.. fieldlookup:: arrayfield.overlap

``overlap``
~~~~~~~~~~~

Returns objects where the data shares any results with the values passed. It
uses the ``$setIntersection`` operator. For example:

.. code-block:: pycon
>>> Post.objects.create(name="First post", tags=["thoughts", "django"])
>>> Post.objects.create(name="Second post", tags=["thoughts", "tutorial"])
>>> Post.objects.create(name="Third post", tags=["tutorial", "django"])
>>> Post.objects.filter(tags__overlap=["thoughts"])
<QuerySet [<Post: First post>, <Post: Second post>]>
>>> Post.objects.filter(tags__overlap=["thoughts", "tutorial"])
<QuerySet [<Post: First post>, <Post: Second post>, <Post: Third post>]>
.. fieldlookup:: arrayfield.len

``len``
Expand Down
39 changes: 39 additions & 0 deletions tests/model_fields_/test_arrayfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django.core.management import call_command
from django.db import IntegrityError, connection, models
from django.db.models.expressions import Exists, OuterRef, Value
from django.db.models.functions import Upper
from django.test import (
SimpleTestCase,
TestCase,
Expand Down Expand Up @@ -369,6 +370,38 @@ def test_icontains(self):
def test_contains_charfield(self):
self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), [])

def test_overlap_charfield(self):
self.assertSequenceEqual(CharArrayModel.objects.filter(field__overlap=["text"]), [])

def test_overlap_charfield_including_expression(self):
obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"])
obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"])
CharArrayModel.objects.create(field=["lower text", "text"])
self.assertSequenceEqual(
CharArrayModel.objects.filter(
field__overlap=[
Upper(Value("text")),
"other",
]
),
[obj_1, obj_2],
)

def test_overlap_values(self):
qs = NullableIntegerArrayModel.objects.filter(order__lt=3)
self.assertCountEqual(
NullableIntegerArrayModel.objects.filter(
field__overlap=qs.values_list("field"),
),
self.objs[:3],
)
self.assertCountEqual(
NullableIntegerArrayModel.objects.filter(
field__overlap=qs.values("field"),
),
self.objs[:3],
)

def test_index(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3]
Expand All @@ -389,6 +422,12 @@ def test_index_used_on_nested_data(self):
NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance]
)

def test_overlap(self):
self.assertSequenceEqual(
NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]),
self.objs[0:3],
)

def test_index_annotation(self):
qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1"))
self.assertCountEqual(
Expand Down

0 comments on commit 039af42

Please sign in to comment.