Skip to content

Commit

Permalink
Handle 'lookup_field' containing relationships for path parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
spookylukey committed Sep 9, 2021
1 parent 1407059 commit 378e79e
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 8 deletions.
16 changes: 8 additions & 8 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from drf_spectacular.plumbing import (
ComponentRegistry, ResolvedComponent, UnableToProceedError, append_meta, build_array_type,
build_basic_type, build_choice_field, build_examples_list, build_media_type_object,
build_object_type, build_parameter_type, error, follow_field_source, force_instance, get_doc,
get_type_hints, get_view_model, is_basic_type, is_field, is_list_serializer,
is_patched_serializer, is_serializer, is_trivial_string_variation, resolve_regex_path_parameter,
resolve_type_hint, safe_ref, warn,
build_object_type, build_parameter_type, error, follow_field_source, follow_model_field_lookup,
force_instance, get_doc, get_type_hints, get_view_model, is_basic_type, is_field,
is_list_serializer, is_patched_serializer, is_serializer, is_trivial_string_variation,
resolve_regex_path_parameter, resolve_type_hint, safe_ref, warn,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes, build_generic_type
Expand Down Expand Up @@ -342,22 +342,22 @@ def _resolve_path_parameters(self, variables):
schema = resolved_parameter['schema']
elif get_view_model(self.view) is None:
warn(
f'could not derive type of path parameter "{variable}" because because it '
f'could not derive type of path parameter "{variable}" because it '
f'is untyped and obtaining queryset from the viewset failed. '
f'Consider adding a type to the path (e.g. <int:{variable}>) or annotating '
f'the parameter type with @extend_schema. Defaulting to "string".'
)
else:
try:
model = get_view_model(self.view)
model_field = model._meta.get_field(variable)
model_field = follow_model_field_lookup(model, variable)
schema = self._map_model_field(model_field, direction=None)
if 'description' not in schema and model_field.primary_key:
description = get_pk_description(model, model_field)
except django_exceptions.FieldDoesNotExist:
except django_exceptions.FieldError:
warn(
f'could not derive type of path parameter "{variable}" because '
f'model "{model}" did contain no such field. Consider annotating '
f'model "{model.__module__}.{model.__name__}" contained no such field. Consider annotating '
f'parameter with @extend_schema. Defaulting to "string".'
)

Expand Down
13 changes: 13 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
ForwardManyToOneDescriptor, ManyToManyDescriptor, ReverseManyToOneDescriptor,
ReverseOneToOneDescriptor,
)
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.db.models.sql.query import Query
from django.urls.resolvers import ( # type: ignore
_PATH_PARAMETER_COMPONENT_RE, RegexPattern, Resolver404, RoutePattern, URLPattern, URLResolver,
get_resolver,
Expand Down Expand Up @@ -460,6 +462,17 @@ def dummy_property(obj) -> str:
return dummy_property


def follow_model_field_lookup(model, lookup):
"""
Follow a model lookup `foreignkey__foreignkey__field` in the same
way that Django QuerySet.filter() does, returning the final models.Field.
"""
query = Query(model)
lookup_splitted = lookup.split(LOOKUP_SEP)
_, field, _, _ = query.names_to_path(lookup_splitted, query.get_meta())
return field


def alpha_operation_sorter(endpoint):
""" sort endpoints first alphanumerically by path, then by method order """
path, path_regex, method, callback = endpoint
Expand Down
135 changes: 135 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,141 @@ class XViewset(viewsets.ModelViewSet):
assert '/x/{related_field}/{id}/' in schema['paths']


def test_path_parameter_with_relationships(no_warnings):
class PathParamParent(models.Model):
pass

class PathParamChild(models.Model):
parent = models.ForeignKey(PathParamParent, on_delete=models.CASCADE)

class PathParamGrandChild(models.Model):
parent = models.ForeignKey(PathParamChild, on_delete=models.CASCADE)

class PathParamChildSerializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = PathParamChild

class XViewset1(viewsets.ModelViewSet):
serializer_class = PathParamChildSerializer
queryset = PathParamChild.objects.none()
lookup_field = 'id'

class XViewset2(viewsets.ModelViewSet):
serializer_class = PathParamChildSerializer
queryset = PathParamChild.objects.none()
lookup_field = 'parent'

class XViewset3(viewsets.ModelViewSet):
serializer_class = PathParamChildSerializer
queryset = PathParamChild.objects.none()
lookup_field = 'parent__id' # Functionally the same as above

class PathParamGrandChildSerializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = PathParamGrandChild

class XViewset4(viewsets.ModelViewSet):
serializer_class = PathParamGrandChildSerializer
queryset = PathParamGrandChild.objects.none()
lookup_field = 'parent__parent'

class XViewset5(viewsets.ModelViewSet):
serializer_class = PathParamGrandChildSerializer
queryset = PathParamGrandChild.objects.none()
lookup_field = 'parent__parent__id'

router = routers.SimpleRouter()
router.register('child_by_id', XViewset1)
router.register('child_by_parent_id', XViewset2)
router.register('child_by_parent_id_alt', XViewset3)
router.register('grand_child_by_grand_parent_id', XViewset4)
router.register('grand_child_by_grand_parent_id_alt', XViewset5)

schema = generate_schema(None, patterns=router.urls)

# Basic cases:
assert schema['paths']['/child_by_id/{id}/']['get']['parameters'][0] == {
'description': 'A unique integer value identifying this path param child.',
'in': 'path',
'name': 'id',
'schema': {'type': 'integer'},
'required': True
}
assert schema['paths']['/child_by_parent_id/{parent}/']['get']['parameters'][0] == {
'in': 'path',
'name': 'parent',
'schema': {'type': 'integer'},
'required': True
}

# Can we traverse relationships?
assert schema['paths']['/grand_child_by_grand_parent_id/{parent__parent}/']['get']['parameters'][0] == {
'in': 'path',
'name': 'parent__parent',
'schema': {'type': 'integer'},
'required': True
}

# Explicit `__id` handling:
assert schema['paths']['/grand_child_by_grand_parent_id_alt/{parent__parent__id}/']['get']['parameters'][0] == {
'description': 'A unique integer value identifying this path param grand child.',
'in': 'path',
'name': 'parent__parent__id',
'schema': {'type': 'integer'},
'required': True
}
assert schema['paths']['/child_by_parent_id_alt/{parent__id}/']['get']['parameters'][0] == {
'description': 'A unique integer value identifying this path param child.',
'in': 'path',
'name': 'parent__id',
'schema': {'type': 'integer'},
'required': True
}


def test_path_parameter_with_lookups(no_warnings):
class JournalEntry(models.Model):
recorded_at = models.DateTimeField()

class JournalEntrySerializer(serializers.ModelSerializer):
class Meta:
fields = '__all__'
model = JournalEntry

class JournalEntryViewset(viewsets.ModelViewSet):
serializer_class = JournalEntrySerializer
queryset = JournalEntry.objects.none()
lookup_field = 'recorded_at__date'

class JournalEntryAltViewset(viewsets.ModelViewSet):
serializer_class = JournalEntrySerializer
queryset = JournalEntry.objects.none()
lookup_field = 'recorded_at__date'
lookup_url_kwarg = 'recorded_at'

router = routers.SimpleRouter()
router.register('journal', JournalEntryViewset)
router.register('journal_alt', JournalEntryAltViewset)

schema = generate_schema(None, patterns=router.urls)

assert schema['paths']['/journal/{recorded_at__date}/']['get']['parameters'][0] == {
'in': 'path',
'name': 'recorded_at__date',
'required': True,
'schema': {'format': 'date-time', 'type': 'string'},
}

assert schema['paths']['/journal_alt/{recorded_at}/']['get']['parameters'][0] == {
'in': 'path',
'name': 'recorded_at',
'required': True,
'schema': {'format': 'date-time', 'type': 'string'},
}


@pytest.mark.contrib('psycopg2')
def test_multiple_choice_enum(no_warnings):
from django.contrib.postgres.fields import ArrayField
Expand Down
12 changes: 12 additions & 0 deletions tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,15 @@ class XViewSet(viewsets.ModelViewSet):
generate_schema('x', XViewSet)
stderr = capsys.readouterr().err
assert 'Could not derive type for ReadOnlyField "field"' in stderr


def test_warning_missing_lookup_field_on_model_serializer(capsys):
class XViewSet(viewsets.ModelViewSet):
serializer_class = SimpleSerializer
queryset = SimpleModel.objects.all()
lookup_field = 'non_existent_field'

generate_schema('x', XViewSet)
stderr = capsys.readouterr().err
assert ('could not derive type of path parameter "non_existent_field" because model '
'"tests.models.SimpleModel" contained no such field.') in stderr

0 comments on commit 378e79e

Please sign in to comment.