Skip to content

Commit

Permalink
prevent endless loop in extensions when augmenting schema #426
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Jun 15, 2021
1 parent 475e411 commit a33d55d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 4 additions & 3 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ def get_name(self) -> Optional[str]:

def map_serializer(self, auto_schema: 'AutoSchema', direction):
""" override for customized serializer mapping """
return auto_schema._map_basic_serializer(self.target_class, direction)
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)


class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializerFieldExtension']):
"""
Extension for replacing an insufficient or specifying an unknown SerializerField schema.
To augment the default schema, you can get what `drf-spectacular` would generate with
``auto_schema._map_serializer_field(self.target, direction)``. Beware that this may
still emit warnings, in which case manual construction is advisable.
``auto_schema._map_serializer_field(self.target, direction, bypass_extensions=True)``.
and edit the returned schema at your discretion. Beware that this may still emit
warnings, in which case manual construction is advisable.
``map_serializer_field()`` is expected to return a valid `OpenAPI schema object
<https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#schema-object>`_.
Expand Down
8 changes: 4 additions & 4 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def _map_model_field(self, model_field, direction):
)
return build_basic_type(OpenApiTypes.STR)

def _map_serializer_field(self, field, direction):
def _map_serializer_field(self, field, direction, bypass_extensions=False):
meta = self._get_serializer_field_meta(field)

if has_override(field, 'field'):
Expand Down Expand Up @@ -476,7 +476,7 @@ def _map_serializer_field(self, field, direction):
return append_meta(schema, meta)

serializer_field_extension = OpenApiSerializerFieldExtension.get_match(field)
if serializer_field_extension:
if serializer_field_extension and not bypass_extensions:
schema = serializer_field_extension.map_serializer_field(self, direction)
if serializer_field_extension.get_name():
component = ResolvedComponent(
Expand Down Expand Up @@ -695,11 +695,11 @@ def _map_min_max(self, field, content):
if field.min_value:
content['minimum'] = field.min_value

def _map_serializer(self, serializer, direction):
def _map_serializer(self, serializer, direction, bypass_extensions=False):
serializer = force_instance(serializer)
serializer_extension = OpenApiSerializerExtension.get_match(serializer)

if serializer_extension:
if serializer_extension and not bypass_extensions:
schema = serializer_extension.map_serializer(self, direction)
else:
schema = self._map_basic_serializer(serializer, direction)
Expand Down

0 comments on commit a33d55d

Please sign in to comment.