Skip to content

Commit

Permalink
handle 1:1 and "manually" remove items from reverse FKs since set i…
Browse files Browse the repository at this point in the history
…s bugged
  • Loading branch information
claytondaley committed Feb 7, 2020
1 parent 12151c8 commit 164f1a9
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions drf_writable_nested/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

from django.contrib.contenttypes.fields import GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import transaction
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction, router
from django.db.models import ProtectedError, FieldDoesNotExist, OneToOneRel
from django.db.models.fields.related import ForeignObjectRel, ManyToManyField
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty
from rest_framework.relations import ManyRelatedField
from rest_framework.serializers import BaseSerializer
from rest_framework.serializers import BaseSerializer, ListSerializer
from rest_framework.validators import UniqueValidator, UniqueTogetherValidator

# permit writable nested serializers
Expand Down Expand Up @@ -557,6 +558,8 @@ def save(self, **kwargs):

def _save_direct_relations(self, kwargs):
"""Save direct relations so FKs exist when committing the base instance"""
if self._validated_data is None and kwargs == {}:
return # delete-only
for field_name, field in self.fields.items():
if self.field_types[field_name] != self.TYPE_DIRECT:
continue
Expand All @@ -580,24 +583,47 @@ def _save_reverse_relations(self, instance, kwargs):
for field_name, field in self.fields.items():
if self.field_types[field_name] != self.TYPE_REVERSE:
continue
if self._validated_data is None and kwargs == {}:
return # delete-only
if self._validated_data.get(field.source, empty) == empty and kwargs.get(field_name, empty) == empty:
continue # nothing to save
# inject the instance into reverse relations so the <parent>_id ForeignKey field is valid when saved
related_field = self._get_model_field(field.source).field
print("{} populating reverse field {}".format(self.__class__.__name__, related_field.name))
model_field = self._get_model_field(field.source)
print("{} populating reverse field {}".format(self.__class__.__name__, model_field.field.name))
if isinstance(field, serializers.ListSerializer):
# reverse FK, inject the instance into reverse relations so the <parent>_id FK field is valid when saved
for obj in field._validated_data:
obj[related_field.name] = instance
obj[model_field.field.name] = instance
elif isinstance(field, serializers.ModelSerializer):
if field._validated_data is None:
field._validated_data = {} # delete situation, but need a place to put FK
field._validated_data[related_field.name] = instance
# 1:1
if self._validated_data[field.source] is None:
# indicates that we should delete 1:1 relation (if it exists)
try:
getattr(instance, field.source).delete()
continue
except ObjectDoesNotExist:
pass
else:
field._validated_data[model_field.field.name] = instance
else:
raise Exception("unexpected serializer type")
# no tests fail if we do not cache this value in _validated_data, but it's consistent with forward relations
# create/update (as appropriate) related objects
self._validated_data[field.source] = field.save(**kwargs.get(field_name, {}))
print("{}._validated_data[{}] to reverse {}".format(self.__class__.__name__, field_name, self._validated_data[field.source]))

# eliminate related objects that weren't in the request
if isinstance(field, ListSerializer):
# due to a bug in Django, calling `set` on a non-nullable reverse relation will only `add`
if model_field.field.null:
getattr(instance, field.source).set(self._validated_data[field.source])
else:
# models should be attached when saved so we only need to delete
obj_field = getattr(instance, field.source)
db = router.db_for_write(obj_field.model, instance=instance)
old_objs = set(obj_field.using(db).all())
for obj in old_objs:
if obj not in self._validated_data[field.source]:
obj.delete()


class FocalSaveMixin(FieldLookupMixin):
"""Provides a framework for extracting the values needed to get or create the focal object."""
Expand Down Expand Up @@ -696,6 +722,7 @@ def save(self, **kwargs):

new_values = []

# TODO: we don't actually delete absent reverse FKs
for item in self._validated_data:
# integrate save kwargs
self.child._validated_data = item
Expand Down

0 comments on commit 164f1a9

Please sign in to comment.