From 17fa5e3ce891064231707bf30413b38b89bd6d7f Mon Sep 17 00:00:00 2001 From: Michael Merickel Date: Fri, 30 Sep 2016 18:55:25 -0500 Subject: [PATCH] add a callback hook to set_default_csrf_options for disabling checks per-request fixes #2596 --- pyramid/config/security.py | 19 +++++++++++++-- pyramid/interfaces.py | 1 + pyramid/tests/test_config/test_security.py | 6 ++++- pyramid/tests/test_viewderivers.py | 28 ++++++++++++++++++++++ pyramid/viewderivers.py | 7 +++++- 5 files changed, 57 insertions(+), 4 deletions(-) diff --git a/pyramid/config/security.py b/pyramid/config/security.py index e387eade94..02732c0426 100644 --- a/pyramid/config/security.py +++ b/pyramid/config/security.py @@ -169,6 +169,7 @@ def set_default_csrf_options( token='csrf_token', header='X-CSRF-Token', safe_methods=('GET', 'HEAD', 'OPTIONS', 'TRACE'), + callback=None, ): """ Set the default CSRF options used by subsequent view registrations. @@ -192,8 +193,20 @@ def set_default_csrf_options( never be automatically checked for CSRF tokens. Default: ``('GET', 'HEAD', 'OPTIONS', TRACE')``. + If ``callback`` is set, it must be a callable accepting ``(request)`` + and returning ``True`` if the request should be checked for a valid + CSRF token. This callback allows an application to support + alternate authentication methods that do not rely on cookies which + are not subject to CSRF attacks. For example, if a request is + authenticated using the ``Authorization`` header instead of a cookie, + this may return ``False`` for that request so that clients do not + need to send the ``X-CSRF-Token` header. The callback is only tested + for non-safe methods as defined by ``safe_methods``. + """ - options = DefaultCSRFOptions(require_csrf, token, header, safe_methods) + options = DefaultCSRFOptions( + require_csrf, token, header, safe_methods, callback, + ) def register(): self.registry.registerUtility(options, IDefaultCSRFOptions) intr = self.introspectable('default csrf view options', @@ -204,13 +217,15 @@ def register(): intr['token'] = token intr['header'] = header intr['safe_methods'] = as_sorted_tuple(safe_methods) + intr['callback'] = callback self.action(IDefaultCSRFOptions, register, order=PHASE1_CONFIG, introspectables=(intr,)) @implementer(IDefaultCSRFOptions) class DefaultCSRFOptions(object): - def __init__(self, require_csrf, token, header, safe_methods): + def __init__(self, require_csrf, token, header, safe_methods, callback): self.require_csrf = require_csrf self.token = token self.header = header self.safe_methods = frozenset(safe_methods) + self.callback = callback diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py index 114f802aac..c1ddea63fd 100644 --- a/pyramid/interfaces.py +++ b/pyramid/interfaces.py @@ -925,6 +925,7 @@ class IDefaultCSRFOptions(Interface): token = Attribute('The key to be matched in the body of the request.') header = Attribute('The header to be matched with the CSRF token.') safe_methods = Attribute('A set of safe methods that skip CSRF checks.') + callback = Attribute('A callback to disable CSRF checks per-request.') class ISessionFactory(Interface): """ An interface representing a factory which accepts a request object and diff --git a/pyramid/tests/test_config/test_security.py b/pyramid/tests/test_config/test_security.py index e461bfd4a6..5db8e21fcf 100644 --- a/pyramid/tests/test_config/test_security.py +++ b/pyramid/tests/test_config/test_security.py @@ -108,14 +108,18 @@ def test_set_default_csrf_options(self): self.assertEqual(result.header, 'X-CSRF-Token') self.assertEqual(list(sorted(result.safe_methods)), ['GET', 'HEAD', 'OPTIONS', 'TRACE']) + self.assertTrue(result.callback is None) def test_changing_set_default_csrf_options(self): from pyramid.interfaces import IDefaultCSRFOptions config = self._makeOne(autocommit=True) + def callback(request): return True config.set_default_csrf_options( - require_csrf=False, token='DUMMY', header=None, safe_methods=('PUT',)) + require_csrf=False, token='DUMMY', header=None, + safe_methods=('PUT',), callback=callback) result = config.registry.getUtility(IDefaultCSRFOptions) self.assertEqual(result.require_csrf, False) self.assertEqual(result.token, 'DUMMY') self.assertEqual(result.header, None) self.assertEqual(list(sorted(result.safe_methods)), ['PUT']) + self.assertTrue(result.callback is callback) diff --git a/pyramid/tests/test_viewderivers.py b/pyramid/tests/test_viewderivers.py index 676c6f66a7..51d0bd367f 100644 --- a/pyramid/tests/test_viewderivers.py +++ b/pyramid/tests/test_viewderivers.py @@ -1291,6 +1291,34 @@ def inner_view(request): pass view = self.config._derive_view(inner_view) self.assertRaises(BadCSRFToken, lambda: view(None, request)) + def test_csrf_view_enabled_via_callback(self): + def callback(request): + return True + from pyramid.exceptions import BadCSRFToken + def inner_view(request): pass + request = self._makeRequest() + request.scheme = "http" + request.method = 'POST' + request.session = DummySession({'csrf_token': 'foo'}) + self.config.set_default_csrf_options(require_csrf=True, callback=callback) + view = self.config._derive_view(inner_view) + self.assertRaises(BadCSRFToken, lambda: view(None, request)) + + def test_csrf_view_disabled_via_callback(self): + def callback(request): + return False + response = DummyResponse() + def inner_view(request): + return response + request = self._makeRequest() + request.scheme = "http" + request.method = 'POST' + request.session = DummySession({'csrf_token': 'foo'}) + self.config.set_default_csrf_options(require_csrf=True, callback=callback) + view = self.config._derive_view(inner_view) + result = view(None, request) + self.assertTrue(result is response) + def test_csrf_view_uses_custom_csrf_token(self): response = DummyResponse() def inner_view(request): diff --git a/pyramid/viewderivers.py b/pyramid/viewderivers.py index 513ddf0223..4eb0ce704f 100644 --- a/pyramid/viewderivers.py +++ b/pyramid/viewderivers.py @@ -481,11 +481,13 @@ def csrf_view(view, info): token = 'csrf_token' header = 'X-CSRF-Token' safe_methods = frozenset(["GET", "HEAD", "OPTIONS", "TRACE"]) + callback = None else: default_val = defaults.require_csrf token = defaults.token header = defaults.header safe_methods = defaults.safe_methods + callback = defaults.callback enabled = ( explicit_val is True or @@ -501,7 +503,10 @@ def csrf_view(view, info): wrapped_view = view if enabled: def csrf_view(context, request): - if request.method not in safe_methods: + if ( + request.method not in safe_methods and + (callback is None or callback(request)) + ): check_csrf_origin(request, raises=True) check_csrf_token(request, token, header, raises=True) return view(context, request)