Skip to content

Commit

Permalink
add Enum support in type hints #492
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Sep 3, 2021
1 parent 0ed0b80 commit 2ee41ae
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
5 changes: 5 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,11 @@ def resolve_type_hint(hint):
if all(type(args[0]) is type(choice) for choice in args):
schema.update(build_basic_type(type(args[0])))
return schema
elif inspect.isclass(hint) and issubclass(hint, Choices):
return {
'enum': [item.value for item in hint],
**build_basic_type([t for t in hint.__mro__ if is_basic_type(t)][0])
}
elif hasattr(typing, 'TypedDict') and isinstance(hint, typing._TypedDictMeta):
return build_object_type(
properties={
Expand Down
13 changes: 13 additions & 0 deletions tests/test_plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime

import pytest
from django import __version__ as DJANGO_VERSION
from django.conf.urls import include
from django.db import models
from django.urls import re_path
Expand Down Expand Up @@ -135,6 +136,18 @@ class NamedTupleB(typing.NamedTuple):
)
]

if DJANGO_VERSION > '3':
from django.db.models.enums import TextChoices # only available in Django>3

class LanguageChoices(TextChoices):
EN = 'en'
DE = 'de'

TYPE_HINT_TEST_PARAMS.append((
LanguageChoices,
{'enum': ['en', 'de'], 'type': 'string'}
))

if sys.version_info >= (3, 7):
TYPE_HINT_TEST_PARAMS.append((
typing.Iterable[collections.namedtuple("NamedTupleA", "a, b")], # noqa
Expand Down
34 changes: 31 additions & 3 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import typing
from enum import Enum
from unittest import mock

import pytest
from django import __version__ as DJANGO_VERSION
from rest_framework import generics, mixins, serializers, viewsets
from rest_framework.decorators import action
from rest_framework.views import APIView

try:
from django.db.models.enums import Choices
from django.db.models.enums import TextChoices
except ImportError:
Choices = object # type: ignore # django < 3.0 handling
TextChoices = object # type: ignore # django < 3.0 handling

from drf_spectacular.plumbing import list_hash, load_enum_name_overrides
from drf_spectacular.utils import OpenApiParameter, extend_schema
Expand All @@ -35,7 +37,7 @@ class LanguageEnum(Enum):
EN = 'en'


class LanguageChoices(Choices):
class LanguageChoices(TextChoices):
EN = 'en'


Expand Down Expand Up @@ -190,3 +192,29 @@ def test_enum_override_variations(no_warnings):
def test_enum_override_loading_fail(capsys):
load_enum_name_overrides()
assert 'unable to load choice override for LanguageEnum' in capsys.readouterr().err


@pytest.mark.skipif(DJANGO_VERSION < '3', reason='Not available before Django 3.0')
def test_textchoice_annotation(no_warnings):
class QualityChoices(TextChoices):
GOOD = 'GOOD'
BAD = 'BAD'

class XSerializer(serializers.Serializer):
quality_levels = serializers.SerializerMethodField()

def get_quality_levels(self, obj) -> typing.List[QualityChoices]:
return [QualityChoices.GOOD, QualityChoices.BAD] # pragma: no cover

class XAPIView(APIView):
@extend_schema(responses=XSerializer)
def get(self, request):
pass # pragma: no cover

schema = generate_schema('x', view=XAPIView)
assert 'QualityLevelsEnum' in schema['components']['schemas']
assert schema['components']['schemas']['X']['properties']['quality_levels'] == {
'type': 'array',
'items': {'$ref': '#/components/schemas/QualityLevelsEnum'},
'readOnly': True
}

0 comments on commit 2ee41ae

Please sign in to comment.