diff --git a/superset/config.py b/superset/config.py index b41690052b85f..403bae0acc084 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1212,6 +1212,9 @@ def SQL_QUERY_MUTATOR( # pylint: disable=invalid-name,unused-argument # Typically these should not be allowed. PREVENT_UNSAFE_DB_CONNECTIONS = True +# Prevents unsafe default endpoints to be registered on datasets. +PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET = True + # Path used to store SSL certificates that are generated when using custom certs. # Defaults to temporary directory. # Example: SSL_CERT_PATH = "/certs" diff --git a/superset/utils/urls.py b/superset/utils/urls.py index a8a6148813d96..c31bfb1a5103c 100644 --- a/superset/utils/urls.py +++ b/superset/utils/urls.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import unicodedata import urllib from typing import Any +from urllib.parse import urlparse -from flask import current_app, url_for +from flask import current_app, request, url_for def get_url_host(user_friendly: bool = False) -> str: @@ -48,3 +50,18 @@ def modify_url_query(url: str, **kwargs: Any) -> str: parts[3] = "&".join(f"{k}={urllib.parse.quote(v[0])}" for k, v in params.items()) return urllib.parse.urlunsplit(parts) + + +def is_safe_url(url: str) -> bool: + if url.startswith("///"): + return False + try: + ref_url = urlparse(request.host_url) + test_url = urlparse(url) + except ValueError: + return False + if unicodedata.category(url[0])[0] == "C": + return False + if test_url.scheme != ref_url.scheme or ref_url.netloc != test_url.netloc: + return False + return True diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 560c12d6f19b5..e942296df757f 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -18,7 +18,7 @@ from collections import Counter from typing import Any -from flask import g, request +from flask import current_app, g, request from flask_appbuilder import expose from flask_appbuilder.api import rison from flask_appbuilder.security.decorators import has_access_api @@ -39,6 +39,8 @@ from superset.extensions import security_manager from superset.models.core import Database from superset.superset_typing import FlaskResponse +from superset.utils.core import DatasourceType +from superset.utils.urls import is_safe_url from superset.views.base import ( api, BaseSupersetView, @@ -74,8 +76,22 @@ def save(self) -> FlaskResponse: datasource_id = datasource_dict.get("id") datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") + default_endpoint = datasource_dict["default_endpoint"] + if ( + default_endpoint + and not is_safe_url(default_endpoint) + and current_app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] + ): + return json_error_response( + _( + "The submitted URL is not considered safe," + " only use URLs with the same domain as Superset." + ), + status=400, + ) + orm_datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + DatasourceType(datasource_type), datasource_id, db.session ) orm_datasource.database_id = database_id diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 6d46afa0a9ddd..19cc35c5c1bc7 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -292,6 +292,44 @@ def test_save(self): print(k) self.assertEqual(resp[k], datasource_post[k]) + def test_save_default_endpoint_validation_fail(self): + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://www.google.com" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 400 + + def test_save_default_endpoint_validation_unsafe(self): + self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = False + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://www.google.com" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 200 + self.app.config["PREVENT_UNSAFE_DEFAULT_URLS_ON_DATASET"] = True + + def test_save_default_endpoint_validation_success(self): + self.login(username="admin") + tbl_id = self.get_table(name="birth_names").id + + datasource_post = get_datasource_post() + datasource_post["id"] = tbl_id + datasource_post["owners"] = [1] + datasource_post["default_endpoint"] = "http://localhost/superset/1" + data = dict(data=json.dumps(datasource_post)) + resp = self.client.post("/datasource/save/", data=data) + assert resp.status_code == 200 + def save_datasource_from_dict(self, datasource_post): data = dict(data=json.dumps(datasource_post)) resp = self.get_json_resp("/datasource/save/", data) diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py index dba38cb81af07..e3a8b75fa25ec 100644 --- a/tests/unit_tests/utils/urls_tests.py +++ b/tests/unit_tests/utils/urls_tests.py @@ -15,6 +15,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + from superset.utils.urls import modify_url_query EXPLORE_CHART_LINK = "http://localhost:9000/superset/explore/?form_data=%7B%22slice_id%22%3A+76%7D&standalone=true&force=false" @@ -33,3 +35,27 @@ def test_convert_chart_link() -> None: def test_convert_dashboard_link() -> None: test_url = modify_url_query(EXPLORE_DASHBOARD_LINK, standalone="0") assert test_url == "http://localhost:9000/superset/dashboard/3/?standalone=0" + + +@pytest.mark.parametrize( + "url,is_safe", + [ + ("http://localhost/", True), + ("http://localhost/superset/1", True), + ("https://localhost/", False), + ("https://localhost/superset/1", False), + ("localhost/superset/1", False), + ("ftp://localhost/superset/1", False), + ("http://external.com", False), + ("https://external.com", False), + ("external.com", False), + ("///localhost", False), + ("xpto://localhost:[3/1/", False), + ], +) +def test_is_safe_url(url: str, is_safe: bool) -> None: + from superset import app + from superset.utils.urls import is_safe_url + + with app.test_request_context("/"): + assert is_safe_url(url) == is_safe