-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
326 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from importlib import import_module | ||
|
||
from django.db import IntegrityError, models | ||
from django.db.models.fields.related import lazy_related_operation | ||
|
||
|
||
class EmbeddedModelField(models.Field): | ||
"""Field that stores a model instance.""" | ||
|
||
def __init__(self, embedded_model=None, *args, **kwargs): | ||
""" | ||
`embedded_model` is the model class of the instance that will be | ||
stored. Like other relational fields, it may also be passed as a | ||
string. | ||
""" | ||
self.embedded_model = embedded_model | ||
super().__init__(*args, **kwargs) | ||
|
||
def deconstruct(self): | ||
name, path, args, kwargs = super().deconstruct() | ||
if path.startswith("django_mongodb.fields.embedded_model"): | ||
path = path.replace("django_mongodb.fields.embedded_model", "django_mongodb.fields") | ||
if self.embedded_model: | ||
kwargs["embedded_model"] = self.embedded_model | ||
return name, path, args, kwargs | ||
|
||
def get_internal_type(self): | ||
return "EmbeddedModelField" | ||
|
||
def _set_model(self, model): | ||
""" | ||
Resolve embedded model class once the field knows the model it belongs | ||
to. | ||
If the model argument passed to __init__() was a string, resolve that | ||
string to the corresponding model class, similar to relation fields. | ||
However, we need to know our own model to generate a valid key | ||
for the embedded model class lookup and EmbeddedModelFields are | ||
not contributed_to_class if used in iterable fields. Thus the | ||
collection field sets this field's "model" attribute in its | ||
contribute_to_class(). | ||
""" | ||
self._model = model | ||
if model is not None and isinstance(self.embedded_model, str): | ||
|
||
def _resolve_lookup(_, resolved_model): | ||
self.embedded_model = resolved_model | ||
|
||
lazy_related_operation(_resolve_lookup, model, self.embedded_model) | ||
|
||
model = property(lambda self: self._model, _set_model) | ||
|
||
def stored_model(self, column_values): | ||
""" | ||
Return the fixed embedded_model this field was initialized | ||
with (typed embedding) or tries to determine the model from | ||
_module / _model keys stored together with column_values | ||
(untyped embedding). | ||
Give precedence to the field's definition model, as silently using a | ||
differing serialized one could hide some data integrity problems. | ||
Note that a single untyped EmbeddedModelField may process | ||
instances of different models (especially when used as a type | ||
of a collection field). | ||
""" | ||
module = column_values.pop("_module", None) | ||
model = column_values.pop("_model", None) | ||
if self.embedded_model is not None: | ||
return self.embedded_model | ||
if module is not None: | ||
return getattr(import_module(module), model) | ||
raise IntegrityError( | ||
"Untyped EmbeddedModelField trying to load data without serialized model class info." | ||
) | ||
|
||
def from_db_value(self, value, expression, connection): | ||
return self.to_python(value) | ||
|
||
def to_python(self, value): | ||
""" | ||
Passes embedded model fields' values through embedded fields | ||
to_python methods and reinstiatates the embedded instance. | ||
We expect to receive a field.attname => value dict together | ||
with a model class from back-end database deconversion (which | ||
needs to know fields of the model beforehand). | ||
""" | ||
# Either the model class has already been determined during | ||
# deconverting values from the database or we've got a dict | ||
# from a deserializer that may contain model class info. | ||
if isinstance(value, tuple): | ||
embedded_model, attribute_values = value | ||
elif isinstance(value, dict): | ||
embedded_model = self.stored_model(value) | ||
attribute_values = value | ||
else: | ||
return value | ||
# Create the model instance. | ||
instance = embedded_model( | ||
**{ | ||
# Pass values through respective fields' to_python(), leaving | ||
# fields for which no value is specified uninitialized. | ||
field.attname: field.to_python(attribute_values[field.attname]) | ||
for field in embedded_model._meta.fields | ||
if field.attname in attribute_values | ||
} | ||
) | ||
instance._state.adding = False | ||
return instance | ||
|
||
def get_db_prep_save(self, embedded_instance, connection): | ||
""" | ||
Apply pre_save() and get_db_prep_save() of embedded instance | ||
fields and passes a field => value mapping down to database | ||
type conversions. | ||
The embedded instance will be saved as a column => value dict | ||
in the end (possibly augmented with info about instance's model | ||
for untyped embedding), but because we need to apply database | ||
type conversions on embedded instance fields' values and for | ||
these we need to know fields those values come from, we need to | ||
entrust the database layer with creating the dict. | ||
""" | ||
if embedded_instance is None: | ||
return None | ||
# The field's value should be an instance of the model given in | ||
# its declaration or at least of some model. | ||
embedded_model = self.embedded_model or models.Model | ||
if not isinstance(embedded_instance, embedded_model): | ||
raise TypeError( | ||
f"Expected instance of type {embedded_model!r}, not {type(embedded_instance)!r}." | ||
) | ||
# Apply pre_save() and get_db_prep_save() of embedded instance | ||
# fields, create the field => value mapping to be passed to | ||
# storage preprocessing. | ||
field_values = {} | ||
add = embedded_instance._state.adding | ||
for field in embedded_instance._meta.fields: | ||
value = field.get_db_prep_save( | ||
field.pre_save(embedded_instance, add), connection=connection | ||
) | ||
# Exclude unset primary keys (e.g. {'id': None}). | ||
if field.primary_key and value is None: | ||
continue | ||
field_values[field.attname] = value | ||
if self.embedded_model is None: | ||
# Untyped fields must store model info alongside values. | ||
field_values.update( | ||
( | ||
("_module", embedded_instance.__class__.__module__), | ||
("_model", embedded_instance.__class__.__name__), | ||
) | ||
) | ||
# This instance will exist in the database soon. | ||
# TODO.XXX: Ensure that this doesn't cause race conditions. | ||
embedded_instance._state.adding = False | ||
return field_values | ||
|
||
def validate(self, value, model_instance): | ||
super().validate(value, model_instance) | ||
if self.embedded_model is None: | ||
return | ||
for field in self.embedded_model._meta.fields: | ||
attname = field.attname | ||
field.validate(getattr(value, attname), model_instance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from django.db import models | ||
|
||
from django_mongodb.fields import EmbeddedModelField | ||
|
||
|
||
class Target(models.Model): | ||
index = models.IntegerField() | ||
|
||
|
||
class DecimalModel(models.Model): | ||
decimal = models.DecimalField(max_digits=9, decimal_places=2) | ||
|
||
|
||
class DecimalKey(models.Model): | ||
decimal = models.DecimalField(max_digits=9, decimal_places=2, primary_key=True) | ||
|
||
|
||
class DecimalParent(models.Model): | ||
child = models.ForeignKey(DecimalKey, models.CASCADE) | ||
|
||
|
||
class EmbeddedModelFieldModel(models.Model): | ||
simple = EmbeddedModelField("EmbeddedModel", null=True, blank=True) | ||
untyped = EmbeddedModelField(null=True, blank=True) | ||
decimal_parent = EmbeddedModelField(DecimalParent, null=True, blank=True) | ||
|
||
|
||
class EmbeddedModel(models.Model): | ||
some_relation = models.ForeignKey(Target, models.CASCADE, null=True, blank=True) | ||
someint = models.IntegerField(db_column="custom_column") | ||
auto_now = models.DateTimeField(auto_now=True) | ||
auto_now_add = models.DateTimeField(auto_now_add=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import time | ||
from decimal import Decimal | ||
|
||
from django.core.exceptions import ValidationError | ||
from django.db import models | ||
from django.test import SimpleTestCase, TestCase | ||
|
||
from django_mongodb.fields import EmbeddedModelField | ||
|
||
from .models import ( | ||
DecimalKey, | ||
DecimalParent, | ||
EmbeddedModel, | ||
EmbeddedModelFieldModel, | ||
Target, | ||
) | ||
|
||
|
||
class MethodTests(SimpleTestCase): | ||
def test_deconstruct(self): | ||
field = EmbeddedModelField() | ||
name, path, args, kwargs = field.deconstruct() | ||
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField") | ||
self.assertEqual(args, []) | ||
self.assertEqual(kwargs, {}) | ||
|
||
def test_deconstruct_with_model(self): | ||
field = EmbeddedModelField("EmbeddedModel", null=True) | ||
name, path, args, kwargs = field.deconstruct() | ||
self.assertEqual(path, "django_mongodb.fields.EmbeddedModelField") | ||
self.assertEqual(args, []) | ||
self.assertEqual(kwargs, {"embedded_model": "EmbeddedModel", "null": True}) | ||
|
||
def test_validate(self): | ||
instance = EmbeddedModelFieldModel(simple=EmbeddedModel(someint=None)) | ||
# This isn't quite right because "someint" is the field that's non-null. | ||
msg = "{'simple': ['This field cannot be null.']}" | ||
with self.assertRaisesMessage(ValidationError, msg): | ||
instance.full_clean() | ||
|
||
|
||
class QueryingTests(TestCase): | ||
def assertEqualDatetime(self, d1, d2): | ||
"""Compares d1 and d2, ignoring microseconds.""" | ||
self.assertEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) | ||
|
||
def assertNotEqualDatetime(self, d1, d2): | ||
self.assertNotEqual(d1.replace(microsecond=0), d2.replace(microsecond=0)) | ||
|
||
def test_save_load(self): | ||
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5")) | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
self.assertIsInstance(instance.simple, EmbeddedModel) | ||
# Make sure get_prep_value is called. | ||
self.assertEqual(instance.simple.someint, 5) | ||
# Primary keys should not be populated... | ||
self.assertEqual(instance.simple.id, None) | ||
# ... unless set explicitly. | ||
instance.simple.id = instance.id | ||
instance.save() | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
self.assertEqual(instance.simple.id, instance.id) | ||
|
||
def test_save_load_untyped(self): | ||
EmbeddedModelFieldModel.objects.create(simple=EmbeddedModel(someint="5")) | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
self.assertIsInstance(instance.simple, EmbeddedModel) | ||
# Make sure get_prep_value is called. | ||
self.assertEqual(instance.simple.someint, 5) | ||
# Primary keys should not be populated... | ||
self.assertEqual(instance.simple.id, None) | ||
# ... unless set explicitly. | ||
instance.simple.id = instance.id | ||
instance.save() | ||
instance = EmbeddedModelFieldModel.objects.get() | ||
self.assertEqual(instance.simple.id, instance.id) | ||
|
||
def _test_pre_save(self, instance, get_field): | ||
# Field.pre_save() is called on embedded model fields. | ||
|
||
instance.save() | ||
auto_now = get_field(instance).auto_now | ||
auto_now_add = get_field(instance).auto_now_add | ||
self.assertNotEqual(auto_now, None) | ||
self.assertNotEqual(auto_now_add, None) | ||
|
||
time.sleep(1) # FIXME | ||
instance.save() | ||
self.assertNotEqualDatetime(get_field(instance).auto_now, get_field(instance).auto_now_add) | ||
|
||
instance = EmbeddedModelFieldModel.objects.get() | ||
instance.save() | ||
# auto_now_add shouldn't have changed now, but auto_now should. | ||
self.assertEqualDatetime(get_field(instance).auto_now_add, auto_now_add) | ||
self.assertGreater(get_field(instance).auto_now, auto_now) | ||
|
||
def test_pre_save(self): | ||
obj = EmbeddedModelFieldModel(simple=EmbeddedModel()) | ||
self._test_pre_save(obj, lambda instance: instance.simple) | ||
|
||
def test_pre_save_untyped(self): | ||
obj = EmbeddedModelFieldModel(untyped=EmbeddedModel()) | ||
self._test_pre_save(obj, lambda instance: instance.untyped) | ||
|
||
def test_error_messages(self): | ||
for model_kwargs, expected in ( | ||
({"simple": 42}, EmbeddedModel), | ||
({"untyped": 42}, models.Model), | ||
): | ||
msg = "Expected instance of type %r" % expected | ||
with self.assertRaisesMessage(TypeError, msg): | ||
EmbeddedModelFieldModel(**model_kwargs).save() | ||
|
||
def test_foreign_key_in_embedded_object(self): | ||
simple = EmbeddedModel(some_relation=Target.objects.create(index=1)) | ||
obj = EmbeddedModelFieldModel.objects.create(simple=simple) | ||
simple = EmbeddedModelFieldModel.objects.get().simple | ||
self.assertNotIn("some_relation", simple.__dict__) | ||
self.assertIsInstance(simple.__dict__["some_relation_id"], type(obj.id)) | ||
self.assertIsInstance(simple.some_relation, Target) | ||
|
||
def test_embedded_field_with_foreign_conversion(self): | ||
decimal = DecimalKey.objects.create(decimal=Decimal("1.5")) | ||
decimal_parent = DecimalParent.objects.create(child=decimal) | ||
EmbeddedModelFieldModel.objects.create(decimal_parent=decimal_parent) |