Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests for schema generation and fix adding categories #4390

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None):

def get_schema(self, request=None):
if self.endpoints is None:
self.endpoints = self.get_api_endpoints(self.patterns)
endpoints = self.get_api_endpoints(self.patterns)
self.endpoints = self.add_categories(endpoints)

links = []
for path, method, category, action, callback in self.endpoints:
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
view = self.get_view(callback)
view.args = ()
view.kwargs = {}
view.format_kwarg = None

if request is not None:
view.request = clone_request(request, method)

try:
view.check_permissions(view.request)
except exceptions.APIException:
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_api_endpoints(self, patterns, prefix=''):
)
api_endpoints.extend(nested_endpoints)

return self.add_categories(api_endpoints)
return api_endpoints

def add_categories(self, api_endpoints):
"""
Expand All @@ -144,6 +144,15 @@ def add_categories(self, api_endpoints):
for (path, method, action, callback) in api_endpoints
]

def get_view(self, callback):
"""
Return constructed view with respect of overrided attributes by detail_route and list_route
"""
view = callback.cls()
for attr, val in getattr(callback, 'initkwargs', {}).items():
setattr(view, attr, val)
return view

def get_path(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
Expand Down Expand Up @@ -174,9 +183,10 @@ def get_allowed_methods(self, callback):
if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()]

view = self.get_view(callback)
return [
method for method in
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
view.allowed_methods if method not in ('OPTIONS', 'HEAD')
]

def get_action(self, path, method, callback):
Expand Down
75 changes: 70 additions & 5 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator
from rest_framework.test import APIClient
from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet

Expand All @@ -33,15 +33,25 @@ class AnotherSerializer(serializers.Serializer):
d = serializers.CharField(required=False)


class ForbidAll(permissions.BasePermission):
def has_permission(self, request, view):
return False


class ExampleViewSet(ModelViewSet):
pagination_class = ExamplePagination
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [filters.OrderingFilter]
serializer_class = ExampleSerializer

@detail_route(methods=['post'], serializer_class=AnotherSerializer)
@detail_route(methods=['put', 'post'],
serializer_class=AnotherSerializer)
def custom_action(self, request, pk):
return super(ExampleSerializer, self).retrieve(self, request)
return super(ExampleSerializer, self).update(self, request)

@detail_route(permission_classes=[ForbidAll])
def forbidden_action(self, request, pk):
return super(ExampleSerializer, self).update(self, request)

@list_route()
def custom_list_action(self, request):
Expand All @@ -52,6 +62,15 @@ def get_serializer(self, *args, **kwargs):
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)


class RestrictiveViewSet(ModelViewSet):
permission_classes = [ForbidAll]
serializer_class = ExampleSerializer

@detail_route(methods=['put'], permission_classes=[permissions.AllowAny])
def allowed_action(self, request):
return super(RestrictiveViewSet, self).update(self, request)


class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]

Expand All @@ -67,7 +86,14 @@ def post(self, request, *args, **kwargs):
urlpatterns = [
url(r'^', include(router.urls))
]
urlpatterns2 = [

router = DefaultRouter(schema_title='Restrictive API' if coreapi else None)
router.register('example', RestrictiveViewSet, base_name='example')
urlpatterns_restrict = [
url(r'^', include(router.urls))
]

urlpatterns_view = [
url(r'^example-view/$', ExampleView.as_view(), name='example-view')
]

Expand Down Expand Up @@ -142,6 +168,16 @@ def test_authenticated_request(self):
coreapi.Field('pk', required=True, location='path')
]
),
'custom_action': coreapi.Link(
url='/example/{pk}/custom_action/',
action='put',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('c', required=True, location='form'),
coreapi.Field('d', required=False, location='form'),
]
),
'custom_action': coreapi.Link(
url='/example/{pk}/custom_action/',
action='post',
Expand Down Expand Up @@ -189,10 +225,39 @@ def test_authenticated_request(self):
self.assertEqual(response.data, expected)


@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaForRestrictedMethods(TestCase):
def test_resctricted_methods(self):
schema_generator = SchemaGenerator(title='Restrictive API', patterns=urlpatterns_restrict)
factory = APIRequestFactory()
from rest_framework.request import Request
mock_request = factory.get('/')
schema = schema_generator.get_schema(request=Request(mock_request))
expected = coreapi.Document(
url='',
title='Restrictive API',
content={
'example': {
'allowed_action': coreapi.Link(
url='/example/{pk}/allowed_action/',
action='put',
encoding='application/json',
fields=[
coreapi.Field('pk', required=True, location='path'),
coreapi.Field('a', required=True, location='form', description='A field description'),
coreapi.Field('b', required=False, location='form')
]
),
}
}
)
self.assertEqual(schema, expected)


@unittest.skipUnless(coreapi, 'coreapi is not installed')
class TestSchemaGenerator(TestCase):
def test_view(self):
schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2)
schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns_view)
schema = schema_generator.get_schema()
expected = coreapi.Document(
url='',
Expand Down