Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add OAS securitySchemes and security objects #7516

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,26 @@ differentiate between request and response objects.
By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects.

#### `get_security_schemes()`

Generates the OpenAPI `securitySchemes` components based on:
- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`)
- Per-view non-default `authentication_classes`

These are generated using the authentication classes' `openapi_security_scheme()` class method. If you
extend `BaseAuthentication` with your own authentication class, you can add this class method to return
the appropriate security scheme object.

#### `get_security_requirements()`

Root-level security requirements (the top-level `security` object) are generated based on the
default authentication classes. Operation-level security requirements are generated only if the given view's
`authentication_classes` differ from the defaults.

These are generated using the authentication classes' `openapi_security_requirement()` class
method. If you extended `BaseAuthentication` with your own authentication class, you can add this
class method to return the appropriate list of security requirements objects.

### `AutoSchema.__init__()` kwargs

`AutoSchema` provides a number of `__init__()` kwargs that can be used for
Expand Down
93 changes: 93 additions & 0 deletions rest_framework/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@ def authenticate_header(self, request):
"""
pass

#: Name of openapi security scheme. Override if you want to customize it.
openapi_security_scheme_name = None

@classmethod
def openapi_security_scheme(cls):
"""
Override this to return an Open API Specification `securityScheme object
<http://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
"""
return {}
Comment on lines +55 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I would rather see this sort of hook added via a decorator (or such) rather than added in as part of the DRF API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carltongibson Could you provide an example of how that might work? I don't quite understand.


@classmethod
def openapi_security_requirement(cls, view, method):
"""
Override this to return an Open API Specification `security requirement object
<http://spec.openapis.org/oas/v3.0.3#security-requirement-object>`_

:param view: used to find view attributes used by a permission class or None for root-level
:param method: used to distinguish among method-specific permissions or None for root-level
:return:list: [security requirement objects]
"""
# At this point, none of the built-in DRF authentication classes fill in the
# requirement list: OAuth2/OIDC are the only security types that currently uses the list
# (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2.
return [{}]


class BasicAuthentication(BaseAuthentication):
"""
Expand Down Expand Up @@ -108,6 +134,22 @@ def authenticate_credentials(self, userid, password, request=None):
def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm

openapi_security_scheme_name = 'basicAuth'

@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'scheme': 'basic',
'description': 'Basic Authentication'
}
}

@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]


class SessionAuthentication(BaseAuthentication):
"""
Expand Down Expand Up @@ -147,6 +189,23 @@ def dummy_get_response(request): # pragma: no cover
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)

openapi_security_scheme_name = 'sessionAuth'

@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'apiKey',
'in': 'cookie',
'name': 'JSESSIONID',
'description': 'Session authentication'
}
}

@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]


class TokenAuthentication(BaseAuthentication):
"""
Expand Down Expand Up @@ -210,6 +269,23 @@ def authenticate_credentials(self, key):
def authenticate_header(self, request):
return self.keyword

openapi_security_scheme_name = 'tokenAuth'

@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'in': 'header',
'name': 'Authorization', # Authorization: token ...
'description': 'Token authentication'
}
}

@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]


class RemoteUserAuthentication(BaseAuthentication):
"""
Expand All @@ -230,3 +306,20 @@ def authenticate(self, request):
user = authenticate(request=request, remote_user=request.META.get(self.header))
if user and user.is_active:
return (user, None)

openapi_security_scheme_name = 'remoteUserAuth'

@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'in': 'header',
'name': 'REMOTE_USER',
'description': 'Remote User authentication'
}
}

@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]
58 changes: 58 additions & 0 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False):
"""
self._initialise_endpoints()
components_schemas = {}
security_schemes_schemas = {}
root_security_requirements = []

if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
req = auth_class.openapi_security_requirement(None, None)
if req:
root_security_requirements += req

# Iterate endpoints generating per method path operations.
paths = {}
Expand All @@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False):

operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method)

for k in components.keys():
if k not in components_schemas:
continue
Expand All @@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False):

components_schemas.update(components)

security_schemes = view.schema.get_security_schemes(path, method)
for k in security_schemes.keys():
if k not in security_schemes_schemas:
continue
if security_schemes_schemas[k] == security_schemes[k]:
continue
warnings.warn('Security scheme component "{}" has been overriden with a different '
'value.'.format(k))
security_schemes_schemas.update(security_schemes)

# Normalise path for any provided mount url.
if path.startswith('/'):
path = path[1:]
Expand All @@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False):
'schemas': components_schemas
}

if len(security_schemes_schemas) > 0:
if 'components' not in schema:
schema['components'] = {}
schema['components']['securitySchemes'] = security_schemes_schemas

if len(root_security_requirements) > 0:
schema['security'] = root_security_requirements

return schema

# View Inspectors
Expand Down Expand Up @@ -146,6 +173,9 @@ def get_operation(self, path, method):

operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method)
security = self.get_security_requirements(path, method)
if security is not None:
operation['security'] = security

parameters = []
parameters += self.get_path_parameters(path, method)
Expand Down Expand Up @@ -713,6 +743,34 @@ def get_tags(self, path, method):

return [path.split('/')[0].replace('_', '-')]

def get_security_schemes(self, path, method):
"""
Get components.schemas.securitySchemes required by this path.
returns dict of securitySchemes.
"""
schemes = {}
for auth_class in self.view.authentication_classes:
if hasattr(auth_class, 'openapi_security_scheme'):
schemes.update(auth_class.openapi_security_scheme())
return schemes

def get_security_requirements(self, path, method):
"""
Get Security Requirement Object list for this operation.
Returns a list of security requirement objects based on the view's authentication classes
unless this view's authentication classes are the same as the root-level defaults.
"""
# references the securityScheme names described above in get_security_schemes()
security = []
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
return None
for auth_class in self.view.authentication_classes:
if hasattr(auth_class, 'openapi_security_requirement'):
req = auth_class.openapi_security_requirement(self.view, method)
if req:
security += req
return security

def _get_path_parameters(self, path, method):
warnings.warn(
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
Expand Down
49 changes: 48 additions & 1 deletion tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.utils.translation import gettext_lazy as _

from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.authentication import TokenAuthentication
from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import uritemplate
from rest_framework.parsers import JSONParser, MultiPartParser
Expand Down Expand Up @@ -1235,5 +1236,51 @@ class ExampleView(generics.DestroyAPIView):
]
generator = SchemaGenerator(patterns=url_patterns)
schema = generator.get_schema(request=create_request('/'))
assert 'components' not in schema
assert 'schemas' not in schema['components']
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']

def test_default_root_security_schemes(self):
patterns = [
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
]

generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' in schema
assert {'sessionAuth': []} in schema['security']
assert {'basicAuth': []} in schema['security']
assert 'security' not in schema['paths']['/example/']['get']

@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
def test_no_default_root_security_schemes(self):
patterns = [
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
]

generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' not in schema

def test_operation_security_schemes(self):
class MyExample(views.ExampleAutoSchemaComponentName):
authentication_classes = [TokenAuthentication]

patterns = [
path('^example/?$', MyExample.as_view()),
]

generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' in schema
assert {'sessionAuth': []} in schema['security']
assert {'basicAuth': []} in schema['security']
get_operation = schema['paths']['/example/']['get']
assert 'security' in get_operation
assert {'tokenAuth': []} in get_operation['security']
assert len(get_operation['security']) == 1