diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index b9de6745fe..937eb4d425 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -389,6 +389,26 @@ differentiate between request and response objects. By default returns `get_serializer()` but can be overridden to differentiate between request and response objects. +#### `get_security_schemes()` + +Generates the OpenAPI `securitySchemes` components based on: +- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`) +- Per-view non-default `authentication_classes` + +These are generated using the authentication classes' `openapi_security_scheme()` class method. If you +extend `BaseAuthentication` with your own authentication class, you can add this class method to return +the appropriate security scheme object. + +#### `get_security_requirements()` + +Root-level security requirements (the top-level `security` object) are generated based on the +default authentication classes. Operation-level security requirements are generated only if the given view's +`authentication_classes` differ from the defaults. + +These are generated using the authentication classes' `openapi_security_requirement()` class +method. If you extended `BaseAuthentication` with your own authentication class, you can add this +class method to return the appropriate list of security requirements objects. + ### `AutoSchema.__init__()` kwargs `AutoSchema` provides a number of `__init__()` kwargs that can be used for diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 382abf1580..ffff1f4567 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -49,6 +49,32 @@ def authenticate_header(self, request): """ pass + #: Name of openapi security scheme. Override if you want to customize it. + openapi_security_scheme_name = None + + @classmethod + def openapi_security_scheme(cls): + """ + Override this to return an Open API Specification `securityScheme object + `_ + """ + return {} + + @classmethod + def openapi_security_requirement(cls, view, method): + """ + Override this to return an Open API Specification `security requirement object + `_ + + :param view: used to find view attributes used by a permission class or None for root-level + :param method: used to distinguish among method-specific permissions or None for root-level + :return:list: [security requirement objects] + """ + # At this point, none of the built-in DRF authentication classes fill in the + # requirement list: OAuth2/OIDC are the only security types that currently uses the list + # (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2. + return [{}] + class BasicAuthentication(BaseAuthentication): """ @@ -108,6 +134,22 @@ def authenticate_credentials(self, userid, password, request=None): def authenticate_header(self, request): return 'Basic realm="%s"' % self.www_authenticate_realm + openapi_security_scheme_name = 'basicAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'scheme': 'basic', + 'description': 'Basic Authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class SessionAuthentication(BaseAuthentication): """ @@ -147,6 +189,23 @@ def dummy_get_response(request): # pragma: no cover # CSRF failed, bail with explicit error message raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + openapi_security_scheme_name = 'sessionAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'apiKey', + 'in': 'cookie', + 'name': 'JSESSIONID', + 'description': 'Session authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class TokenAuthentication(BaseAuthentication): """ @@ -210,6 +269,23 @@ def authenticate_credentials(self, key): def authenticate_header(self, request): return self.keyword + openapi_security_scheme_name = 'tokenAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'in': 'header', + 'name': 'Authorization', # Authorization: token ... + 'description': 'Token authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class RemoteUserAuthentication(BaseAuthentication): """ @@ -230,3 +306,20 @@ def authenticate(self, request): user = authenticate(request=request, remote_user=request.META.get(self.header)) if user and user.is_active: return (user, None) + + openapi_security_scheme_name = 'remoteUserAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'in': 'header', + 'name': 'REMOTE_USER', + 'description': 'Remote User authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 5e9d59f8bf..481fe4c42b 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False): """ self._initialise_endpoints() components_schemas = {} + security_schemes_schemas = {} + root_security_requirements = [] + + if api_settings.DEFAULT_AUTHENTICATION_CLASSES: + for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES: + req = auth_class.openapi_security_requirement(None, None) + if req: + root_security_requirements += req # Iterate endpoints generating per method path operations. paths = {} @@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False): operation = view.schema.get_operation(path, method) components = view.schema.get_components(path, method) + for k in components.keys(): if k not in components_schemas: continue @@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False): components_schemas.update(components) + security_schemes = view.schema.get_security_schemes(path, method) + for k in security_schemes.keys(): + if k not in security_schemes_schemas: + continue + if security_schemes_schemas[k] == security_schemes[k]: + continue + warnings.warn('Security scheme component "{}" has been overriden with a different ' + 'value.'.format(k)) + security_schemes_schemas.update(security_schemes) + # Normalise path for any provided mount url. if path.startswith('/'): path = path[1:] @@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False): 'schemas': components_schemas } + if len(security_schemes_schemas) > 0: + if 'components' not in schema: + schema['components'] = {} + schema['components']['securitySchemes'] = security_schemes_schemas + + if len(root_security_requirements) > 0: + schema['security'] = root_security_requirements + return schema # View Inspectors @@ -146,6 +173,9 @@ def get_operation(self, path, method): operation['operationId'] = self.get_operation_id(path, method) operation['description'] = self.get_description(path, method) + security = self.get_security_requirements(path, method) + if security is not None: + operation['security'] = security parameters = [] parameters += self.get_path_parameters(path, method) @@ -713,6 +743,34 @@ def get_tags(self, path, method): return [path.split('/')[0].replace('_', '-')] + def get_security_schemes(self, path, method): + """ + Get components.schemas.securitySchemes required by this path. + returns dict of securitySchemes. + """ + schemes = {} + for auth_class in self.view.authentication_classes: + if hasattr(auth_class, 'openapi_security_scheme'): + schemes.update(auth_class.openapi_security_scheme()) + return schemes + + def get_security_requirements(self, path, method): + """ + Get Security Requirement Object list for this operation. + Returns a list of security requirement objects based on the view's authentication classes + unless this view's authentication classes are the same as the root-level defaults. + """ + # references the securityScheme names described above in get_security_schemes() + security = [] + if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES: + return None + for auth_class in self.view.authentication_classes: + if hasattr(auth_class, 'openapi_security_requirement'): + req = auth_class.openapi_security_requirement(self.view, method) + if req: + security += req + return security + def _get_path_parameters(self, path, method): warnings.warn( "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. " diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index daa035a3f3..c62ae1a081 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -8,6 +8,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import filters, generics, pagination, routers, serializers +from rest_framework.authentication import TokenAuthentication from rest_framework.authtoken.views import obtain_auth_token from rest_framework.compat import uritemplate from rest_framework.parsers import JSONParser, MultiPartParser @@ -1235,5 +1236,51 @@ class ExampleView(generics.DestroyAPIView): ] generator = SchemaGenerator(patterns=url_patterns) schema = generator.get_schema(request=create_request('/')) - assert 'components' not in schema + assert 'schemas' not in schema['components'] assert 'content' not in schema['paths']['/example/']['delete']['responses']['204'] + + def test_default_root_security_schemes(self): + patterns = [ + path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' in schema + assert {'sessionAuth': []} in schema['security'] + assert {'basicAuth': []} in schema['security'] + assert 'security' not in schema['paths']['/example/']['get'] + + @override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None}) + def test_no_default_root_security_schemes(self): + patterns = [ + path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' not in schema + + def test_operation_security_schemes(self): + class MyExample(views.ExampleAutoSchemaComponentName): + authentication_classes = [TokenAuthentication] + + patterns = [ + path('^example/?$', MyExample.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' in schema + assert {'sessionAuth': []} in schema['security'] + assert {'basicAuth': []} in schema['security'] + get_operation = schema['paths']['/example/']['get'] + assert 'security' in get_operation + assert {'tokenAuth': []} in get_operation['security'] + assert len(get_operation['security']) == 1