diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1c642559621..613bd325a6d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -356,6 +356,10 @@ def __init__(self, *, read_only=False, write_only=False, messages.update(error_messages or {}) self.error_messages = messages + # Allow generic typing checking for fields. + def __class_getitem__(cls, *args, **kwargs): + return cls + def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda443..1673033214a 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,10 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # Allow generic typing checking for generic views. + def __class_getitem__(cls, *args, **kwargs): + return cls + def get_queryset(self): """ Get the list of items for this view. diff --git a/rest_framework/request.py b/rest_framework/request.py index 194be5f6d4d..93109226d98 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -186,6 +186,10 @@ def __repr__(self): self.method, self.get_full_path()) + # Allow generic typing checking for requests. + def __class_getitem__(cls, *args, **kwargs): + return cls + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/response.py b/rest_framework/response.py index 49542373478..6e756544c69 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -46,6 +46,10 @@ def __init__(self, data=None, status=None, for name, value in headers.items(): self[name] = value + # Allow generic typing checking for responses. + def __class_getitem__(cls, *args, **kwargs): + return cls + @property def rendered_content(self): renderer = getattr(self, 'accepted_renderer', None) diff --git a/tests/test_fields.py b/tests/test_fields.py index 512f3f78953..56e2a45bad6 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2,6 +2,7 @@ import math import os import re +import sys import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal @@ -625,6 +626,15 @@ def test_parent_binding(self): assert field.root is parent +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_field_is_subscriptable(self): + assert serializers.Field is serializers.Field["foo"] + + # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_generics.py b/tests/test_generics.py index 78dc5afb64f..9990389c947 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.db import models from django.http import Http404 @@ -698,3 +700,26 @@ def list(self, request): serializer = response.serializer assert serializer.context is context + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_genericview_is_subscriptable(self): + assert generics.GenericAPIView is generics.GenericAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_listview_is_subscriptable(self): + assert generics.ListAPIView is generics.ListAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_instanceview_is_subscriptable(self): + assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"] diff --git a/tests/test_request.py b/tests/test_request.py index 8c18aea9e67..e37aa7dda14 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -3,6 +3,7 @@ """ import copy import os.path +import sys import tempfile import pytest @@ -352,3 +353,12 @@ class TestDeepcopy(TestCase): def test_deepcopy_works(self): request = Request(factory.get('/', secure=False)) copy.deepcopy(request) + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_request_is_subscriptable(self): + assert Request is Request["foo"] diff --git a/tests/test_response.py b/tests/test_response.py index 0d5528dc9a0..cab19a1eb8c 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,3 +1,6 @@ +import sys + +import pytest from django.test import TestCase, override_settings from django.urls import include, path, re_path @@ -283,3 +286,12 @@ def test_form_has_label_and_help_text(self): self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_response_is_subscriptable(self): + assert Response is Response["foo"]