diff --git a/.env.default b/.env.default index 4c379b4..0e44f3f 100644 --- a/.env.default +++ b/.env.default @@ -3,5 +3,5 @@ AWS_SECRET_ACCESS_KEY=dummy123 AWS_ENDPOINT_URL=http://localhost:8080 AWS_DEFAULT_REGION=eu-central-1 AWS_DYNAMODB_TABLE_NAME=test-db -ALLOWED_DOMAINS=.*localhost((:[0-9]*)?|\/)?,.*admin\.ch,.*bgdi\.ch +ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch STAGING=local diff --git a/.env.testing b/.env.testing index a4f99a6..113c19f 100644 --- a/.env.testing +++ b/.env.testing @@ -1,4 +1,4 @@ -ALLOWED_DOMAINS=.*\.geo\.admin\.ch,.*\.bgdi\.ch,http://localhost((:[0-9]*)?|\/)? +ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch AWS_ACCESS_KEY_ID=testing AWS_SECRET_ACCESS_KEY=testing AWS_SECURITY_TOKEN=testing diff --git a/app/__init__.py b/app/__init__.py index f8ade0e..f259b94 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -12,8 +12,8 @@ from app.helpers.utils import get_redirect_param from app.helpers.utils import get_registered_method +from app.helpers.utils import is_domain_allowed from app.helpers.utils import make_error_msg -from app.settings import ALLOWED_DOMAINS_PATTERN from app.settings import CACHE_CONTROL from app.settings import CACHE_CONTROL_4XX @@ -25,10 +25,6 @@ app.config.from_mapping({"TRAP_HTTP_EXCEPTIONS": True}) -def is_domain_allowed(domain): - return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None - - @app.before_request # Add quick log of the routes used to all request. # Important: this should be the first before_request method, to ensure diff --git a/app/helpers/utils.py b/app/helpers/utils.py index 5e7a1a5..382d39e 100644 --- a/app/helpers/utils.py +++ b/app/helpers/utils.py @@ -113,8 +113,12 @@ def get_url(): f"The url given as parameter was too long. (limit is 2046 " f"characters, {len(url)} given)" ) - if not re.fullmatch(ALLOWED_DOMAINS_PATTERN, urlparse(url).netloc): - logger.error('URL(%s) given as a parameter is not allowed', url) + if not is_domain_allowed(url): + logger.error( + 'URL(%s) given as a parameter is not allowed, test pattern %s', + url, + ALLOWED_DOMAINS_PATTERN + ) abort(400, 'URL given as a parameter is not allowed.') return url @@ -132,3 +136,12 @@ def strtobool(value) -> bool: if value in ('n', 'no', 'f', 'false', 'off', '0'): return False raise ValueError(f"invalid truth value \'{value}\'") + + +def is_domain_allowed(url): + """Check if the url contain a domain that is allowed + """ + domain = urlparse(url).hostname + if domain: + return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None + return False diff --git a/tests/unit_tests/base.py b/tests/unit_tests/base.py index 795370c..0d5293c 100644 --- a/tests/unit_tests/base.py +++ b/tests/unit_tests/base.py @@ -1,6 +1,7 @@ import logging import re import unittest +from urllib.parse import urlparse import boto3 @@ -83,18 +84,19 @@ def setUp(self): def tearDown(self): self.table.delete() - def assertCors( - self, - response, - expected_allowed_methods, - origin_pattern=ALLOWED_DOMAINS_PATTERN - ): # pylint: disable=invalid-name + def assertCors(self, response, expected_allowed_methods, all_origin=False): # pylint: disable=invalid-name self.assertIn('Access-Control-Allow-Origin', response.headers) - self.assertIsNotNone( - re.fullmatch(origin_pattern, response.headers['Access-Control-Allow-Origin']), - msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}" - f" doesn't match {origin_pattern}" - ) + if all_origin: + self.assertEqual(response.headers['Access-Control-Allow-Origin'], '*') + else: + allow_origin_domain = urlparse(response.headers['Access-Control-Allow-Origin']).hostname + self.assertIsNotNone( + re.fullmatch( + ALLOWED_DOMAINS_PATTERN, allow_origin_domain if allow_origin_domain else '' + ), + msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}" + f" doesn't match {ALLOWED_DOMAINS_PATTERN}" + ) self.assertIn('Access-Control-Allow-Methods', response.headers) self.assertListEqual( sorted(expected_allowed_methods), diff --git a/tests/unit_tests/test_routes.py b/tests/unit_tests/test_routes.py index e3a21e2..d2b4722 100644 --- a/tests/unit_tests/test_routes.py +++ b/tests/unit_tests/test_routes.py @@ -18,7 +18,7 @@ class TestRoutes(BaseShortlinkTestCase): def test_checker_ok(self): # checker - response = self.app.get(url_for('checker'), headers={"Origin": "map.geo.admin.ch"}) + response = self.app.get(url_for('checker'), headers={"Origin": "https://map.geo.admin.ch"}) self.assertEqual(response.status_code, 200) self.assertNotIn('Cache-Control', response.headers) self.assertEqual(response.content_type, "application/json; charset=utf-8") @@ -27,7 +27,9 @@ def test_checker_ok(self): def test_create_shortlink_ok(self): url = "https://map.geo.admin.ch/#/map?lang=en¢er=2647850.83,1120124.2&z=1.812&bgLayer=ch.swisstopo.pixelkarte-farbe&top" # pylint: disable=line-too-long response = self.app.post( - url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"} + url_for('create_shortlink'), + json={"url": url}, + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 201) self.assertCors(response, ['POST', 'OPTIONS']) @@ -49,7 +51,9 @@ def test_create_shortlink_ok(self): ) # Check that second call returns 200 and the same short url response = self.app.post( - url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"} + url_for('create_shortlink'), + json={"url": url}, + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 200) self.assertCors(response, ['POST', 'OPTIONS']) @@ -59,7 +63,7 @@ def test_create_shortlink_ok(self): def test_create_shortlink_no_json(self): response = self.app.post( - url_for('create_shortlink'), headers={"Origin": "map.geo.admin.ch"} + url_for('create_shortlink'), headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(415, response.status_code) self.assertCors(response, ['POST', 'OPTIONS']) @@ -77,7 +81,7 @@ def test_create_shortlink_no_json(self): def test_create_shortlink_no_url(self): response = self.app.post( - url_for('create_shortlink'), json={}, headers={"Origin": "map.geo.admin.ch"} + url_for('create_shortlink'), json={}, headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(400, response.status_code) self.assertCors(response, ['POST', 'OPTIONS']) @@ -97,7 +101,7 @@ def test_create_shortlink_no_hostname(self): response = self.app.post( url_for('create_shortlink'), json={"url": f"{wrong_url}"}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 400) self.assertCors(response, ['POST', 'OPTIONS']) @@ -116,7 +120,7 @@ def test_create_shortlink_non_allowed_hostname(self): response = self.app.post( url_for('create_shortlink'), json={"url": "https://non-allowed.hostname.ch/test"}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 400) self.assertCors(response, ['POST', 'OPTIONS']) @@ -135,7 +139,7 @@ def test_create_shortlink_non_allowed_hostname_containing_admin_address(self): response = self.app.post( url_for('create_shortlink'), json={"url": "https://map.geo.admin.ch.non-allowed.hostname.ch/test"}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 400) self.assertCors(response, ['POST', 'OPTIONS']) @@ -156,7 +160,7 @@ def test_create_shortlink_url_too_long(self): url_for('create_shortlink'), json={"url": url}, content_type="application/json", - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 400) self.assertCors(response, ['POST', 'OPTIONS']) @@ -178,7 +182,7 @@ def test_redirect_shortlink_ok(self): for short_id, url in self.uuid_to_url_dict.items(): response = self.app.get(url_for('get_shortlink', shortlink_id=short_id)) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True) self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=', response.headers['Cache-Control']) self.assertEqual(response.content_type, "text/html; charset=utf-8") @@ -192,7 +196,7 @@ def test_redirect_shortlink_ok_with_query(self): headers={"Origin": "www.example.com"} ) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True) self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=', response.headers['Cache-Control']) self.assertEqual(response.content_type, "text/html; charset=utf-8") @@ -204,7 +208,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self): url_for('get_shortlink', shortlink_id=short_id), query_string={'redirect': 'banana'}, content_type="text/html", - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) expected_json = { 'success': False, @@ -226,7 +230,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self): def test_redirect_shortlink_url_not_found(self): response = self.app.get( url_for('get_shortlink', shortlink_id='nonexistent'), - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) expected_json = { 'success': False, @@ -235,7 +239,7 @@ def test_redirect_shortlink_url_not_found(self): } } self.assertEqual(response.status_code, 404) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True) self.assertIn('Cache-Control', response.headers) self.assertIn('max-age=3600', response.headers['Cache-Control']) self.assertIn('application/json', response.content_type) @@ -246,7 +250,7 @@ def test_fetch_full_url_from_shortlink_ok(self): response = self.app.get( url_for('get_shortlink', shortlink_id=short_id), query_string={'redirect': 'false'}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 200) self.assertCors(response, ['GET', 'HEAD', 'OPTIONS']) @@ -262,7 +266,7 @@ def test_fetch_full_url_from_shortlink_ok_explicit_parameter(self): response = self.app.get( url_for('get_shortlink', shortlink_id=short_id), query_string={'redirect': 'false'}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 200) self.assertCors(response, ['GET', 'HEAD', 'OPTIONS']) @@ -277,7 +281,7 @@ def test_fetch_full_url_from_shortlink_url_not_found(self): response = self.app.get( url_for('get_shortlink', shortlink_id='nonexistent'), query_string={'redirect': 'false'}, - headers={"Origin": "map.geo.admin.ch"} + headers={"Origin": "https://map.geo.admin.ch"} ) self.assertEqual(response.status_code, 404) self.assertCors(response, ['GET', 'HEAD', 'OPTIONS']) @@ -325,12 +329,12 @@ def test_create_shortlink_origin_not_allowed(self, headers): ) @params( - {'Origin': 'map.geo.admin.ch'}, + {'Origin': 'https://map.geo.admin.ch'}, { - 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + 'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' }, { - 'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' + 'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' }, { 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site' @@ -389,19 +393,19 @@ def test_get_shortlink_redirect_origin_allowed(self, headers): headers=headers ) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True) response = self.app.get(url_for('get_shortlink', shortlink_id=short_id), headers=headers) self.assertEqual(response.status_code, 301) - self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$") + self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True) @params( - {'Origin': 'map.geo.admin.ch'}, + {'Origin': 'https://map.geo.admin.ch'}, { - 'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' + 'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site' }, { - 'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' + 'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin' }, { 'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site'