Skip to content

Commit

Permalink
Allow generic requets, responses, fields, views
Browse files Browse the repository at this point in the history
Allow Request, Response, Field, and GenericAPIView to be subscriptable.
This allows the classes to be made generic for type checking.

This is especially useful since monkey patching DRF can be problematic
as seen in this [issue][1].

[1]: typeddjango/djangorestframework-stubs#299
  • Loading branch information
jalaziz committed Jan 7, 2023
1 parent 89d6ce7 commit 581fa59
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 0 deletions.
4 changes: 4 additions & 0 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,10 @@ def __init__(self, *, read_only=False, write_only=False,
messages.update(error_messages or {})
self.error_messages = messages

# Allow generic typing checking for fields.
def __class_getitem__(cls, *args, **kwargs):
return cls

def bind(self, field_name, parent):
"""
Initializes the field name and parent for the field instance.
Expand Down
4 changes: 4 additions & 0 deletions rest_framework/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class GenericAPIView(views.APIView):
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS

# Allow generic typing checking for generic views.
def __class_getitem__(cls, *args, **kwargs):
return cls

def get_queryset(self):
"""
Get the list of items for this view.
Expand Down
4 changes: 4 additions & 0 deletions rest_framework/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ def __repr__(self):
self.method,
self.get_full_path())

# Allow generic typing checking for requests.
def __class_getitem__(cls, *args, **kwargs):
return cls

def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()

Expand Down
4 changes: 4 additions & 0 deletions rest_framework/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, data=None, status=None,
for name, value in headers.items():
self[name] = value

# Allow generic typing checking for responses.
def __class_getitem__(cls, *args, **kwargs):
return cls

@property
def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import re
import sys
import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal

Expand Down Expand Up @@ -625,6 +626,15 @@ def test_parent_binding(self):
assert field.root is parent


class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_field_is_subscriptable(self):
assert serializers.Field is serializers.Field["foo"]


# Tests for field input and output values.
# ----------------------------------------

Expand Down
25 changes: 25 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import pytest
from django.db import models
from django.http import Http404
Expand Down Expand Up @@ -698,3 +700,26 @@ def list(self, request):
serializer = response.serializer

assert serializer.context is context


class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_genericview_is_subscriptable(self):
assert generics.GenericAPIView is generics.GenericAPIView["foo"]

@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_listview_is_subscriptable(self):
assert generics.ListAPIView is generics.ListAPIView["foo"]

@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_instanceview_is_subscriptable(self):
assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"]
10 changes: 10 additions & 0 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import copy
import os.path
import sys
import tempfile

import pytest
Expand Down Expand Up @@ -352,3 +353,12 @@ class TestDeepcopy(TestCase):
def test_deepcopy_works(self):
request = Request(factory.get('/', secure=False))
copy.deepcopy(request)


class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_request_is_subscriptable(self):
assert Request is Request["foo"]
12 changes: 12 additions & 0 deletions tests/test_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import sys

import pytest
from django.test import TestCase, override_settings
from django.urls import include, path, re_path

Expand Down Expand Up @@ -283,3 +286,12 @@ def test_form_has_label_and_help_text(self):
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')


class TestTyping(TestCase):
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_response_is_subscriptable(self):
assert Response is Response["foo"]

0 comments on commit 581fa59

Please sign in to comment.