-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ArrayField #197
Open
timgraham
wants to merge
2
commits into
mongodb-labs:main
Choose a base branch
from
timgraham:arrayfield
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add ArrayField #197
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,344 @@ | ||
import json | ||
|
||
from django.contrib.postgres.validators import ArrayMaxLengthValidator | ||
from django.core import checks, exceptions | ||
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value | ||
from django.db.models.fields.mixins import CheckFieldDefaultMixin | ||
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup | ||
from django.utils.translation import gettext_lazy as _ | ||
|
||
from ..forms import SimpleArrayField | ||
from ..query_utils import process_lhs, process_rhs | ||
from ..utils import prefix_validation_error | ||
|
||
__all__ = ["ArrayField"] | ||
|
||
|
||
class AttributeSetter: | ||
def __init__(self, name, value): | ||
setattr(self, name, value) | ||
|
||
|
||
class ArrayField(CheckFieldDefaultMixin, Field): | ||
empty_strings_allowed = False | ||
default_error_messages = { | ||
"item_invalid": _("Item %(nth)s in the array did not validate:"), | ||
"nested_array_mismatch": _("Nested arrays must have the same length."), | ||
} | ||
_default_hint = ("list", "[]") | ||
|
||
def __init__(self, base_field, size=None, **kwargs): | ||
self.base_field = base_field | ||
self.size = size | ||
if self.size: | ||
self.default_validators = [ | ||
*self.default_validators, | ||
ArrayMaxLengthValidator(self.size), | ||
] | ||
# For performance, only add a from_db_value() method if the base field | ||
# implements it. | ||
if hasattr(self.base_field, "from_db_value"): | ||
self.from_db_value = self._from_db_value | ||
super().__init__(**kwargs) | ||
|
||
@property | ||
def model(self): | ||
try: | ||
return self.__dict__["model"] | ||
except KeyError: | ||
raise AttributeError( | ||
"'%s' object has no attribute 'model'" % self.__class__.__name__ | ||
) from None | ||
|
||
@model.setter | ||
def model(self, model): | ||
self.__dict__["model"] = model | ||
self.base_field.model = model | ||
|
||
@classmethod | ||
def _choices_is_value(cls, value): | ||
return isinstance(value, list | tuple) or super()._choices_is_value(value) | ||
|
||
def check(self, **kwargs): | ||
errors = super().check(**kwargs) | ||
if self.base_field.remote_field: | ||
errors.append( | ||
checks.Error( | ||
"Base field for array cannot be a related field.", | ||
obj=self, | ||
id="django_mongodb_backend.array.E002", | ||
) | ||
) | ||
else: | ||
base_checks = self.base_field.check() | ||
if base_checks: | ||
error_messages = "\n ".join( | ||
f"{base_check.msg} ({base_check.id})" | ||
for base_check in base_checks | ||
if isinstance(base_check, checks.Error) | ||
) | ||
if error_messages: | ||
errors.append( | ||
checks.Error( | ||
f"Base field for array has errors:\n {error_messages}", | ||
obj=self, | ||
id="django_mongodb_backend.array.E001", | ||
) | ||
) | ||
warning_messages = "\n ".join( | ||
f"{base_check.msg} ({base_check.id})" | ||
for base_check in base_checks | ||
if isinstance(base_check, checks.Warning) | ||
) | ||
if warning_messages: | ||
errors.append( | ||
checks.Warning( | ||
f"Base field for array has warnings:\n {warning_messages}", | ||
obj=self, | ||
id="django_mongodb_backend.array.W004", | ||
) | ||
) | ||
return errors | ||
|
||
def set_attributes_from_name(self, name): | ||
super().set_attributes_from_name(name) | ||
self.base_field.set_attributes_from_name(name) | ||
|
||
@property | ||
def description(self): | ||
return f"Array of {self.base_field.description}" | ||
|
||
def db_type(self, connection): | ||
return "array" | ||
|
||
def get_db_prep_value(self, value, connection, prepared=False): | ||
if isinstance(value, list | tuple): | ||
# Workaround for https://code.djangoproject.com/ticket/35982 | ||
# (fixed in Django 5.2). | ||
if isinstance(self.base_field, DecimalField): | ||
return [self.base_field.get_db_prep_save(i, connection) for i in value] | ||
return [self.base_field.get_db_prep_value(i, connection, prepared=False) for i in value] | ||
return value | ||
|
||
def deconstruct(self): | ||
name, path, args, kwargs = super().deconstruct() | ||
if path == "django_mongodb_backend.fields.array.ArrayField": | ||
path = "django_mongodb_backend.fields.ArrayField" | ||
kwargs.update( | ||
{ | ||
"base_field": self.base_field.clone(), | ||
"size": self.size, | ||
} | ||
) | ||
return name, path, args, kwargs | ||
|
||
def to_python(self, value): | ||
if isinstance(value, str): | ||
# Assume value is being deserialized. | ||
vals = json.loads(value) | ||
value = [self.base_field.to_python(val) for val in vals] | ||
return value | ||
|
||
def _from_db_value(self, value, expression, connection): | ||
if value is None: | ||
return value | ||
return [self.base_field.from_db_value(item, expression, connection) for item in value] | ||
|
||
def value_to_string(self, obj): | ||
values = [] | ||
vals = self.value_from_object(obj) | ||
base_field = self.base_field | ||
|
||
for val in vals: | ||
if val is None: | ||
values.append(None) | ||
else: | ||
obj = AttributeSetter(base_field.attname, val) | ||
values.append(base_field.value_to_string(obj)) | ||
return json.dumps(values) | ||
|
||
def get_transform(self, name): | ||
transform = super().get_transform(name) | ||
if transform: | ||
return transform | ||
if "_" not in name: | ||
try: | ||
index = int(name) | ||
except ValueError: | ||
pass | ||
else: | ||
return IndexTransformFactory(index, self.base_field) | ||
try: | ||
start, end = name.split("_") | ||
start = int(start) | ||
end = int(end) | ||
except ValueError: | ||
pass | ||
else: | ||
return SliceTransformFactory(start, end) | ||
|
||
def validate(self, value, model_instance): | ||
super().validate(value, model_instance) | ||
for index, part in enumerate(value): | ||
try: | ||
self.base_field.validate(part, model_instance) | ||
except exceptions.ValidationError as error: | ||
raise prefix_validation_error( | ||
error, | ||
prefix=self.error_messages["item_invalid"], | ||
code="item_invalid", | ||
params={"nth": index + 1}, | ||
) from None | ||
if isinstance(self.base_field, ArrayField) and len({len(i) for i in value}) > 1: | ||
raise exceptions.ValidationError( | ||
self.error_messages["nested_array_mismatch"], | ||
code="nested_array_mismatch", | ||
) | ||
|
||
def run_validators(self, value): | ||
super().run_validators(value) | ||
for index, part in enumerate(value): | ||
try: | ||
self.base_field.run_validators(part) | ||
except exceptions.ValidationError as error: | ||
raise prefix_validation_error( | ||
error, | ||
prefix=self.error_messages["item_invalid"], | ||
code="item_invalid", | ||
params={"nth": index + 1}, | ||
) from None | ||
|
||
def formfield(self, **kwargs): | ||
return super().formfield( | ||
**{ | ||
"form_class": SimpleArrayField, | ||
"base_field": self.base_field.formfield(), | ||
"max_length": self.size, | ||
**kwargs, | ||
} | ||
) | ||
|
||
|
||
class Array(Func): | ||
def as_mql(self, compiler, connection): | ||
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] | ||
|
||
|
||
class ArrayRHSMixin: | ||
def __init__(self, lhs, rhs): | ||
if isinstance(rhs, tuple | list): | ||
expressions = [] | ||
for value in rhs: | ||
if not hasattr(value, "resolve_expression"): | ||
field = lhs.output_field | ||
value = Value(field.base_field.get_prep_value(value)) | ||
expressions.append(value) | ||
rhs = Array(*expressions) | ||
super().__init__(lhs, rhs) | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | ||
lookup_name = "contains" | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
value = process_rhs(self, compiler, connection) | ||
return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [value, lhs_mql]}]} | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | ||
lookup_name = "contained_by" | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
value = process_rhs(self, compiler, connection) | ||
return {"$and": [{"$ne": [lhs_mql, None]}, {"$setIsSubset": [lhs_mql, value]}]} | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayExact(ArrayRHSMixin, Exact): | ||
pass | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayOverlap(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): | ||
lookup_name = "overlap" | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
value = process_rhs(self, compiler, connection) | ||
return { | ||
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] | ||
} | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayLenTransform(Transform): | ||
lookup_name = "len" | ||
output_field = IntegerField() | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$size": lhs_mql}}} | ||
|
||
|
||
@ArrayField.register_lookup | ||
class ArrayInLookup(In): | ||
def get_prep_lookup(self): | ||
values = super().get_prep_lookup() | ||
if hasattr(values, "resolve_expression"): | ||
return values | ||
# process_rhs() expects hashable values, so convert lists to tuples. | ||
prepared_values = [] | ||
for value in values: | ||
if hasattr(value, "resolve_expression"): | ||
prepared_values.append(value) | ||
else: | ||
prepared_values.append(tuple(value)) | ||
return prepared_values | ||
|
||
|
||
class IndexTransform(Transform): | ||
def __init__(self, index, base_field, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.index = index | ||
self.base_field = base_field | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
return {"$arrayElemAt": [lhs_mql, self.index]} | ||
|
||
@property | ||
def output_field(self): | ||
return self.base_field | ||
|
||
|
||
class IndexTransformFactory: | ||
def __init__(self, index, base_field): | ||
self.index = index | ||
self.base_field = base_field | ||
|
||
def __call__(self, *args, **kwargs): | ||
return IndexTransform(self.index, self.base_field, *args, **kwargs) | ||
|
||
|
||
class SliceTransform(Transform): | ||
def __init__(self, start, end, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.start = start | ||
self.end = end | ||
|
||
def as_mql(self, compiler, connection): | ||
lhs_mql = process_lhs(self, compiler, connection) | ||
return {"$slice": [lhs_mql, self.start, self.end]} | ||
|
||
|
||
class SliceTransformFactory: | ||
def __init__(self, start, end): | ||
self.start = start | ||
self.end = end | ||
|
||
def __call__(self, *args, **kwargs): | ||
return SliceTransform(self.start, self.end, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
from .fields import ObjectIdField | ||
from .fields import ObjectIdField, SimpleArrayField, SplitArrayField, SplitArrayWidget | ||
|
||
__all__ = ["ObjectIdField"] | ||
__all__ = [ | ||
"SimpleArrayField", | ||
"SplitArrayField", | ||
"SplitArrayWidget", | ||
"ObjectIdField", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debugin the test
model_fields_.test_arrayfield.QueryingTests.test_contains_subquery
I find that the value must be checked for nullability also, sometimes subqueries will pass None as a value.return {"$and": [{"$ne": [lhs_mql, None]}, {"$ne": [value, None]}, {"$setIsSubset": [value, lhs_mql]}]}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, this PR could be merge as it is. that bug could be fixed latter. Just decide to fix this future issue now, or with the corresponding issue