diff --git a/.gitignore b/.gitignore index 8035d84..893ddbe 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ *.db .env +env/ .cache .coverage @@ -18,4 +19,6 @@ build/ dist/ _build/ -tags \ No newline at end of file +tags + +.idea/ diff --git a/README.rst b/README.rst index ebdd401..e36cd84 100644 --- a/README.rst +++ b/README.rst @@ -47,6 +47,17 @@ Now you can install the djangosaml2idp package using pip. This will also install pip install djangosaml2idp +Running the test suite +====================== +Install the dev dependencies in ``requirements-dev.txt``:: + + pip install -r requirements-dev.txt + +Run ``py.test`` from the project root:: + + py.test + + Configuration & Usage ===================== diff --git a/djangosaml2idp/error_views.py b/djangosaml2idp/error_views.py index 737418d..a50485b 100644 --- a/djangosaml2idp/error_views.py +++ b/djangosaml2idp/error_views.py @@ -9,8 +9,9 @@ class SamlIDPErrorView(TemplateView): def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) exception = kwargs.get("exception") + context.update({ - "exception_type": ".".join([exception.__module__, exception.__class__.__name__]) if exception else None, + "exception_type": exception.__class__.__name__ if exception else None, "exception_msg": str(exception) if exception else None, "extra_message": kwargs.get("extra_message"), }) diff --git a/djangosaml2idp/processors.py b/djangosaml2idp/processors.py index cc99834..b138413 100644 --- a/djangosaml2idp/processors.py +++ b/djangosaml2idp/processors.py @@ -1,9 +1,15 @@ +from django.conf import settings + + class BaseProcessor: """ Processor class is used to determine if a user has access to a client service of this IDP and to construct the identity dictionary which is sent to the SP """ - def has_access(self, user): + def __init__(self, entity_id): + self._entity_id = entity_id + + def has_access(self, request): """ Check if this user is allowed to use this IDP """ return True @@ -13,9 +19,15 @@ def enable_multifactor(self, user): """ return False - def create_identity(self, user, sp_mapping): + def get_user_id(self, user): + user_field = getattr(settings, 'SAML_IDP_DJANGO_USERNAME_FIELD', None) or \ + getattr(user, 'USERNAME_FIELD', 'username') + return str(getattr(user, user_field)) + + def create_identity(self, user, sp_mapping, **extra_config): """ Generate an identity dictionary of the user based on the given mapping of desired user attributes by the SP """ + return { out_attr: getattr(user, user_attr) for user_attr, out_attr in sp_mapping.items() diff --git a/djangosaml2idp/views.py b/djangosaml2idp/views.py index 8307201..5eb3493 100644 --- a/djangosaml2idp/views.py +++ b/djangosaml2idp/views.py @@ -15,7 +15,7 @@ from django.views.decorators.cache import never_cache from django.views.decorators.csrf import csrf_exempt from django.views.decorators.http import require_http_methods -from saml2 import BINDING_HTTP_POST +from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT from saml2.authn_context import PASSWORD, AuthnBroker, authn_context_class_ref from saml2.config import IdPConfig from saml2.ident import NameID @@ -42,7 +42,15 @@ def sso_entry(request): """ Entrypoint view for SSO. Gathers the parameters from the HTTP request, stores them in the session and redirects the requester to the login_process view. """ - passed_data = request.POST if request.method == 'POST' else request.GET + if request.method == 'POST': + passed_data = request.POST + binding = BINDING_HTTP_POST + else: + passed_data = request.GET + binding = BINDING_HTTP_REDIRECT + + request.session['Binding'] = binding + try: request.session['SAMLRequest'] = passed_data['SAMLRequest'] except (KeyError, MultiValueDictKeyError) as e: @@ -74,22 +82,23 @@ def dispatch(self, request, *args, **kwargs): return self.handle_error(request, exception=e) return super(IdPHandlerViewMixin, self).dispatch(request, *args, **kwargs) - def get_processor(self, sp_config): + def get_processor(self, entity_id, sp_config): """ "Instantiate user-specified processor or fallback to all-access base processor """ processor_string = sp_config.get('processor', None) if processor_string: try: - return import_string(processor_string)() + return import_string(processor_string)(entity_id) except Exception as e: logger.error("Failed to instantiate processor: {} - {}".format(processor_string, e), exc_info=True) - return BaseProcessor + raise + return BaseProcessor(entity_id) def get_identity(self, processor, user, sp_config): """ Create Identity dict (using SP-specific mapping) """ sp_mapping = sp_config.get('attribute_mapping', {'username': 'username'}) - return processor.create_identity(user, sp_mapping) + return processor.create_identity(user, sp_mapping, **sp_config.get('extra_config', {})) @method_decorator(never_cache, name='dispatch') @@ -99,9 +108,11 @@ class LoginProcessView(LoginRequiredMixin, IdPHandlerViewMixin, View): """ def get(self, request, *args, **kwargs): + binding = request.session.get('Binding', BINDING_HTTP_POST) + # Parse incoming request try: - req_info = self.IDP.parse_authn_request(request.session['SAMLRequest'], BINDING_HTTP_POST) + req_info = self.IDP.parse_authn_request(request.session['SAMLRequest'], binding) except Exception as excp: return self.handle_error(request, exception=excp) # TODO this is taken from example, but no idea how this works or whats it does. Check SAML2 specification? @@ -129,10 +140,10 @@ def get(self, request, *args, **kwargs): except Exception: return self.handle_error(request, exception=ImproperlyConfigured("No config for SP %s defined in SAML_IDP_SPCONFIG" % resp_args['sp_entity_id']), status=400) - processor = self.get_processor(sp_config) + processor = self.get_processor(resp_args['sp_entity_id'], sp_config) # Check if user has access to the service of this SP - if not processor.has_access(request.user): + if not processor.has_access(request): return self.handle_error(request, exception=PermissionDenied("You do not have access to this resource"), status=403) identity = self.get_identity(processor, request.user, sp_config) @@ -142,11 +153,13 @@ def get(self, request, *args, **kwargs): AUTHN_BROKER = AuthnBroker() AUTHN_BROKER.add(authn_context_class_ref(req_authn_context), "") + user_id = processor.get_user_id(request.user) + # Construct SamlResponse message try: authn_resp = self.IDP.create_authn_response( - identity=identity, userid=request.user.username, - name_id=NameID(format=resp_args['name_id_policy'].format, sp_name_qualifier=resp_args['sp_entity_id'], text=request.user.username), + identity=identity, userid=user_id, + name_id=NameID(format=resp_args['name_id_policy'].format, sp_name_qualifier=resp_args['sp_entity_id'], text=user_id), authn=AUTHN_BROKER.get_authn_by_accr(req_authn_context), sign_response=self.IDP.config.getattr("sign_response", "idp") or False, sign_assertion=self.IDP.config.getattr("sign_assertion", "idp") or False, @@ -201,10 +214,10 @@ def get(self, request, *args, **kwargs): service="assertion_consumer_service", entity_id=sp_entity_id) - processor = self.get_processor(sp_config) + processor = self.get_processor(sp_entity_id, sp_config) # Check if user has access to the service of this SP - if not processor.has_access(request.user): + if not processor.has_access(request): return self.handle_error(request, exception=PermissionDenied("You do not have access to this resource"), status=403) identity = self.get_identity(processor, request.user, sp_config) @@ -213,10 +226,12 @@ def get(self, request, *args, **kwargs): AUTHN_BROKER = AuthnBroker() AUTHN_BROKER.add(authn_context_class_ref(req_authn_context), "") + user_id = processor.get_user_id(request.user) + # Construct SamlResponse messages try: name_id_formats = self.IDP.config.getattr("name_id_format", "idp") or [NAMEID_FORMAT_UNSPECIFIED] - name_id = NameID(format=name_id_formats[0], text=request.user.username) + name_id = NameID(format=name_id_formats[0], text=user_id) authn = AUTHN_BROKER.get_authn_by_accr(req_authn_context) sign_response = self.IDP.config.getattr("sign_response", "idp") or False sign_assertion = self.IDP.config.getattr("sign_assertion", "idp") or False @@ -225,7 +240,7 @@ def get(self, request, *args, **kwargs): in_response_to="IdP_Initiated_Login", destination=destination, sp_entity_id=sp_entity_id, - userid=request.user.username, + userid=user_id, name_id=name_id, authn=authn, sign_response=sign_response, diff --git a/pytest.ini b/pytest.ini index 7dffe55..9b6f225 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,8 @@ skip = build/*,dist/*,docs/*,*/manage.py [pep8] ignore = C0111,C0301,E122,E127,E128,E131,E501,E502,E722,E731,W605 + +[pytest] +django_find_project = false +norecursedirs = env +DJANGO_SETTINGS_MODULE=tests.settings diff --git a/requirements-dev.txt b/requirements-dev.txt index d29b092..9abf3a9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,2 +1,7 @@ django>2.0,<2.1 pysaml2>=4.4.0 + +pytest==4.0.1 +pytest-django==3.4.4 +pytest-pythonpath==0.7.3 +python-dateutil==2.7.5 diff --git a/runtests.py b/runtests.py deleted file mode 100644 index 0c2ef30..0000000 --- a/runtests.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python -import os -import sys - -import django -from django.conf import settings -from django.test.utils import get_runner - -if __name__ == "__main__": - os.environ['DJANGO_SETTINGS_MODULE'] = 'tests.settings' - django.setup() - TestRunner = get_runner(settings) - test_runner = TestRunner() - failures = test_runner.run_tests(["tests"]) - sys.exit(bool(failures)) diff --git a/tests/settings.py b/tests/settings.py index 8d8fe7f..e48bcc3 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -6,12 +6,12 @@ DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. - 'NAME': './idptest.sqlite', # Or path to database file if using sqlite3. - 'USER': '', # Not used with sqlite3. - 'PASSWORD': '', # Not used with sqlite3. - 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. - 'PORT': '', # Set to empty string for default. Not used with sqlite3. + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': './idptest.sqlite', + 'USER': '', + 'PASSWORD': '', + 'HOST': '', + 'PORT': '', } } @@ -42,3 +42,11 @@ ) ROOT_URLCONF = 'tests.urls' + +# --- + +SAML_IDP_SPCONFIG = { + 'test_idp_1': { + 'attribute_mapping': {} + } +} diff --git a/tests/test_processor.py b/tests/test_processor.py new file mode 100644 index 0000000..5b440df --- /dev/null +++ b/tests/test_processor.py @@ -0,0 +1,29 @@ +import pytest + +from django.contrib.auth import get_user_model + +from djangosaml2idp.processors import BaseProcessor + + +User = get_user_model() + + +class TestBaseProcessor: + + def test_extract_user_id_configure_by_user_class(self): + + user = User() + user.USERNAME_FIELD = 'email' + user.email = 'test_email' + + assert BaseProcessor('entity-id').get_user_id(user) == 'test_email' + + def test_extract_user_id_configure_by_settings(self, settings): + """Should use `settings.SAML_IDP_DJANGO_USERNAME_FIELD` to determine the user id field""" + + settings.SAML_IDP_DJANGO_USERNAME_FIELD = 'first_name' + + user = User() + user.first_name = 'test_first_name' + + assert BaseProcessor('entity-id').get_user_id(user) == 'test_first_name' diff --git a/tests/test_views.py b/tests/test_views.py new file mode 100644 index 0000000..9170495 --- /dev/null +++ b/tests/test_views.py @@ -0,0 +1,54 @@ +import pytest + +from django.contrib.auth import get_user_model + +from djangosaml2idp.views import IdPHandlerViewMixin + +from djangosaml2idp.processors import BaseProcessor + + +User = get_user_model() + + +class CustomProcessor(BaseProcessor): + pass + + +class TestIdPHandlerViewMixin: + def test_get_identity_provides_extra_config(self): + obj = IdPHandlerViewMixin() + + def test_get_processor_errors_if_processor_cannot_be_loaded(self): + sp_config = { + 'processor': 'this.does.not.exist' + } + + with pytest.raises(Exception): + IdPHandlerViewMixin().get_processor('entity_id', sp_config) + + def test_get_processor_defaults_to_base_processor(self): + sp_config = { + } + + assert isinstance(IdPHandlerViewMixin().get_processor('entity_id', sp_config), BaseProcessor) + + def test_get_processor_loads_custom_processor(self): + sp_config = { + 'processor': 'tests.test_views.CustomProcessor' + } + + assert isinstance(IdPHandlerViewMixin().get_processor('entity_id', sp_config), CustomProcessor) + + +class TestIdpInitiatedFlow: + pass + + +class TestMetadata: + pass + + +class LoginFlow: + def test_requires_authentication(self): + """test redriect to settings.LOGIN_VIEW""" + pass