From 10015233029220a7bcf9364e6309946686938f83 Mon Sep 17 00:00:00 2001 From: Grace Date: Thu, 10 Sep 2020 20:35:05 -0700 Subject: [PATCH 1/7] feat: Stop pending queries when user close dashboard --- .../src/dashboard/components/Dashboard.jsx | 38 ++++++++++------ superset-frontend/src/featureFlags.ts | 1 + superset/config.py | 18 +++++--- superset/db_engine_specs/base.py | 4 ++ superset/views/core.py | 44 ++++++++++++++++++- 5 files changed, 85 insertions(+), 20 deletions(-) diff --git a/superset-frontend/src/dashboard/components/Dashboard.jsx b/superset-frontend/src/dashboard/components/Dashboard.jsx index 19833c47065c4..646b1bad9f696 100644 --- a/superset-frontend/src/dashboard/components/Dashboard.jsx +++ b/superset-frontend/src/dashboard/components/Dashboard.jsx @@ -34,12 +34,14 @@ import { LOG_ACTIONS_MOUNT_DASHBOARD, Logger, } from '../../logger/LogUtils'; +import { isFeatureEnabled, FeatureFlag } from '../../featureFlags'; import OmniContainer from '../../components/OmniContainer'; import { areObjectsEqual } from '../../reduxUtils'; import '../stylesheets/index.less'; import getLocationHash from '../util/getLocationHash'; import isDashboardEmpty from '../util/isDashboardEmpty'; +import isDashboardLoading from '../util/isDashboardLoading'; const propTypes = { actions: PropTypes.shape({ @@ -68,16 +70,7 @@ const defaultProps = { }; class Dashboard extends React.PureComponent { - // eslint-disable-next-line react/sort-comp - static onBeforeUnload(hasChanged) { - if (hasChanged) { - window.addEventListener('beforeunload', Dashboard.unload); - } else { - window.removeEventListener('beforeunload', Dashboard.unload); - } - } - - static unload() { + static showUnsavedMessage() { const message = t('You have unsaved changes.'); window.event.returnValue = message; // Gecko + IE return message; // Gecko + Webkit, Safari, Chrome etc. @@ -86,7 +79,11 @@ class Dashboard extends React.PureComponent { constructor(props) { super(props); this.appliedFilters = props.activeFilters || {}; + this.canStopPendingQueries = isFeatureEnabled( + FeatureFlag.STOP_DASHBOARD_PENDING_QUERIES, + ); + this.onStopPendingQueries = this.onStopPendingQueries.bind(this); this.onVisibilityChange = this.onVisibilityChange.bind(this); } @@ -143,7 +140,8 @@ class Dashboard extends React.PureComponent { } componentDidUpdate() { - const { hasUnsavedChanges, editMode } = this.props.dashboardState; + const { charts, dashboardState } = this.props; + const { hasUnsavedChanges, editMode } = dashboardState; const appliedFilters = this.appliedFilters; const { activeFilters } = this.props; @@ -153,9 +151,15 @@ class Dashboard extends React.PureComponent { } if (hasUnsavedChanges) { - Dashboard.onBeforeUnload(true); + window.addEventListener('beforeunload', Dashboard.showUnsavedMessage); + } else { + window.removeEventListener('beforeunload', Dashboard.showUnsavedMessage); + } + + if (this.canStopPendingQueries && isDashboardLoading(charts)) { + window.addEventListener('beforeunload', this.onStopPendingQueries); } else { - Dashboard.onBeforeUnload(false); + window.removeEventListener('beforeunload', this.onStopPendingQueries); } } @@ -163,6 +167,14 @@ class Dashboard extends React.PureComponent { window.removeEventListener('visibilitychange', this.onVisibilityChange); } + onStopPendingQueries() { + if (navigator && navigator.sendBeacon) { + navigator.sendBeacon( + `/superset/dashboard/${this.props.dashboardInfo.id}/stop/`, + ); + } + } + onVisibilityChange() { if (document.visibilityState === 'hidden') { // from visible to hidden diff --git a/superset-frontend/src/featureFlags.ts b/superset-frontend/src/featureFlags.ts index 35817c74bf72c..35a39ad6ef974 100644 --- a/superset-frontend/src/featureFlags.ts +++ b/superset-frontend/src/featureFlags.ts @@ -27,6 +27,7 @@ export enum FeatureFlag { SHARE_QUERIES_VIA_KV_STORE = 'SHARE_QUERIES_VIA_KV_STORE', SQLLAB_BACKEND_PERSISTENCE = 'SQLLAB_BACKEND_PERSISTENCE', THUMBNAILS = 'THUMBNAILS', + STOP_DASHBOARD_PENDING_QUERIES = 'STOP_DASHBOARD_PENDING_QUERIES', } export type FeatureFlagMap = { diff --git a/superset/config.py b/superset/config.py index 58a6b0d0cb30e..0949d665d5a5f 100644 --- a/superset/config.py +++ b/superset/config.py @@ -172,7 +172,11 @@ def _try_json_readsha( # pylint: disable=unused-argument WTF_CSRF_ENABLED = True # Add endpoints that need to be exempt from CSRF protection -WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log", "superset.charts.api.data"] +WTF_CSRF_EXEMPT_LIST = [ + "superset.views.core.log", + "superset.charts.api.data", + "superset.views.core.stop_dashboard_queries", +] # Whether to run the web server in debug mode or not DEBUG = os.environ.get("FLASK_ENV") == "development" @@ -310,6 +314,8 @@ def _try_json_readsha( # pylint: disable=unused-argument "TAGGING_SYSTEM": False, "SQLLAB_BACKEND_PERSISTENCE": False, "SIP_34_DATABASE_UI": False, + # stop pending queries when user close/reload dashboard in browser + "STOP_DASHBOARD_PENDING_QUERIES": False, } # This is merely a default. @@ -628,11 +634,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # db configuration and a result of this function. # mypy doesn't catch that if case ensures list content being always str -ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ - ["Database", "models.User"], List[str] -] = lambda database, user: [ - UPLOADED_CSV_HIVE_NAMESPACE -] if UPLOADED_CSV_HIVE_NAMESPACE else [] +ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[["Database", "models.User"], List[str]] = ( + lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE] + if UPLOADED_CSV_HIVE_NAMESPACE + else [] +) # Values that should be treated as nulls for the csv uploads. CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 502cd6e0673e8..991d203dcdb0a 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -206,6 +206,10 @@ def is_db_column_type_match( def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool: return False + @classmethod + def get_allow_stop_pending_queries(cls) -> bool: + return False + @classmethod def get_engine( cls, diff --git a/superset/views/core.py b/superset/views/core.py index 070d84f29e93e..021cf9d9e40a5 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1421,6 +1421,48 @@ def fave_slices( # pylint: disable=no-self-use payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) + @api + @has_access_api + @event_logger.log_this + @expose("/dashboard//stop/", methods=["POST"]) + def stop_dashboard_queries( # pylint: disable=no-self-use + self, dashboard_id: int + ) -> FlaskResponse: + if is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"): + username = g.user.username if g.user else None + dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() + slices = dashboard.slices + datasource_ids = set() + database_ids = set() + + # find databases for all charts in a given dashboard + # stop pending query is only available for certain database(s) + for slc in slices: + datasource_type = slc.datasource.type + datasource_id = slc.datasource.id + + if datasource_id and datasource_type: + ds_class = ConnectorRegistry.sources.get(datasource_type) + datasource = ( + db.session.query(ds_class).filter_by(id=datasource_id).one() + ) + if datasource and datasource_id not in datasource_ids: + datasource_ids.add(datasource_id) + database_id = datasource.database.id + + if database_id in database_ids: + continue + + database_ids.add(database_id) + mydb = db.session.query(models.Database).get(database_id) + if ( + mydb + and mydb.db_engine_spec.get_allow_stop_pending_queries() + ): + mydb.db_engine_spec.stop_queries(username, dashboard_id) + + return Response(status=200) + @event_logger.log_this @api @has_access_api @@ -1777,7 +1819,7 @@ def sync_druid_source(self) -> FlaskResponse: # pylint: disable=no-self-use @expose("/get_or_create_table/", methods=["POST"]) @event_logger.log_this def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use - """ Gets or creates a table object with attributes passed to the API. + """Gets or creates a table object with attributes passed to the API. It expects the json with params: * datasourceName - e.g. table name, required From eeb9e333237a8548b40e780f24bcbb9b0c1aa1c7 Mon Sep 17 00:00:00 2001 From: Grace Date: Tue, 15 Sep 2020 13:43:44 -0700 Subject: [PATCH 2/7] fix comments and add test --- .../src/dashboard/components/Dashboard.jsx | 8 ++--- superset/db_engine_specs/base.py | 14 +++++--- superset/views/core.py | 36 ++++--------------- superset/views/utils.py | 28 +++++++++++++++ tests/utils_tests.py | 14 ++++++++ 5 files changed, 63 insertions(+), 37 deletions(-) diff --git a/superset-frontend/src/dashboard/components/Dashboard.jsx b/superset-frontend/src/dashboard/components/Dashboard.jsx index 646b1bad9f696..c2fba30bcd33a 100644 --- a/superset-frontend/src/dashboard/components/Dashboard.jsx +++ b/superset-frontend/src/dashboard/components/Dashboard.jsx @@ -83,7 +83,7 @@ class Dashboard extends React.PureComponent { FeatureFlag.STOP_DASHBOARD_PENDING_QUERIES, ); - this.onStopPendingQueries = this.onStopPendingQueries.bind(this); + this.stopPendingQueries = this.stopPendingQueries.bind(this); this.onVisibilityChange = this.onVisibilityChange.bind(this); } @@ -157,9 +157,9 @@ class Dashboard extends React.PureComponent { } if (this.canStopPendingQueries && isDashboardLoading(charts)) { - window.addEventListener('beforeunload', this.onStopPendingQueries); + window.addEventListener('beforeunload', this.stopPendingQueries); } else { - window.removeEventListener('beforeunload', this.onStopPendingQueries); + window.removeEventListener('beforeunload', this.stopPendingQueries); } } @@ -167,7 +167,7 @@ class Dashboard extends React.PureComponent { window.removeEventListener('visibilitychange', this.onVisibilityChange); } - onStopPendingQueries() { + stopPendingQueries() { if (navigator && navigator.sendBeacon) { navigator.sendBeacon( `/superset/dashboard/${this.props.dashboardInfo.id}/stop/`, diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 991d203dcdb0a..8f333f0718149 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -206,10 +206,6 @@ def is_db_column_type_match( def get_allow_cost_estimate(cls, version: Optional[str] = None) -> bool: return False - @classmethod - def get_allow_stop_pending_queries(cls) -> bool: - return False - @classmethod def get_engine( cls, @@ -1005,3 +1001,13 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: logger.error(ex) raise ex return extra + + @classmethod + def stop_queries(cls, username: str, dashboard_id: int) -> None: + """ + An empty function. The actual stop implementation depends on the engine + + :param: username: user sends out queries + :param dashboard_id: dashboard has charts that waiting for queries + """ + return None diff --git a/superset/views/core.py b/superset/views/core.py index 021cf9d9e40a5..b2d7de5526036 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -118,6 +118,7 @@ check_slice_perms, get_cta_schema_name, get_dashboard_extra_filters, + get_database_ids, get_datasource_info, get_form_data, get_viz, @@ -1429,37 +1430,14 @@ def stop_dashboard_queries( # pylint: disable=no-self-use self, dashboard_id: int ) -> FlaskResponse: if is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"): - username = g.user.username if g.user else None - dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() - slices = dashboard.slices - datasource_ids = set() - database_ids = set() + username = g.user.username + database_ids = get_database_ids(dashboard_id) - # find databases for all charts in a given dashboard # stop pending query is only available for certain database(s) - for slc in slices: - datasource_type = slc.datasource.type - datasource_id = slc.datasource.id - - if datasource_id and datasource_type: - ds_class = ConnectorRegistry.sources.get(datasource_type) - datasource = ( - db.session.query(ds_class).filter_by(id=datasource_id).one() - ) - if datasource and datasource_id not in datasource_ids: - datasource_ids.add(datasource_id) - database_id = datasource.database.id - - if database_id in database_ids: - continue - - database_ids.add(database_id) - mydb = db.session.query(models.Database).get(database_id) - if ( - mydb - and mydb.db_engine_spec.get_allow_stop_pending_queries() - ): - mydb.db_engine_spec.stop_queries(username, dashboard_id) + for dbid in database_ids: + mydb = db.session.query(models.Database).get(dbid) + if mydb: + mydb.db_engine_spec.stop_queries(username, dashboard_id) return Response(status=200) diff --git a/superset/views/utils.py b/superset/views/utils.py index eaecc5fe87031..f0509a3df43fd 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -213,6 +213,34 @@ def get_datasource_info( return datasource_id, datasource_type +def get_database_ids(dashboard_id: int,) -> List[int]: + """ + Find all database ids used by a given dashboard + + :param dashboard_id: The dashboard id + :returns: A list of database ids used by the given dashboard + """ + dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() + slices = dashboard.slices + datasource_ids = set() + database_ids = set() + + for slc in slices: + datasource_type = slc.datasource.type + datasource_id = slc.datasource.id + + if datasource_id and datasource_type: + ds_class = ConnectorRegistry.sources.get(datasource_type) + datasource = db.session.query(ds_class).filter_by(id=datasource_id).one() + if datasource and datasource_id not in datasource_ids: + datasource_ids.add(datasource_id) + database = datasource.database + if database: + database_ids.add(database.id) + + return list(database_ids) + + def apply_display_max_row_limit( sql_results: Dict[str, Any], rows: Optional[int] = None ) -> Dict[str, Any]: diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 91d1ad39d5a55..4ca0cbe2fc96f 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -23,6 +23,7 @@ import json import os import re +from typing import List from unittest.mock import Mock, patch import numpy @@ -69,6 +70,7 @@ from superset.utils import schema from superset.views.utils import ( build_extra_filters, + get_database_ids, get_form_data, get_time_range_endpoints, ) @@ -1162,3 +1164,15 @@ def test_get_form_data_token(self): assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1" generated_token = get_form_data_token({}) assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None + + def test_get_database_ids(self) -> None: + world_health = db.session.query(Dashboard).filter_by(slug="world_health").one() + dash_id = world_health.id + database_ids = get_database_ids(dash_id) + assert len(database_ids) == 1 + + world_slice = ( + db.session.query(Slice).filter_by(slice_name="World's Population").one() + ) + database_id = world_slice.datasource.database.id + assert database_ids == [database_id] From 7e52c2db64d67b39ee2dd90f57f8c214c840e074 Mon Sep 17 00:00:00 2001 From: Grace Date: Wed, 16 Sep 2020 10:40:42 -0700 Subject: [PATCH 3/7] try add test for stop_dashboard_queries --- .../src/dashboard/components/Dashboard.jsx | 16 ++++++++-------- superset/views/utils.py | 2 +- tests/core_tests.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/superset-frontend/src/dashboard/components/Dashboard.jsx b/superset-frontend/src/dashboard/components/Dashboard.jsx index c2fba30bcd33a..f0c4812cff251 100644 --- a/superset-frontend/src/dashboard/components/Dashboard.jsx +++ b/superset-frontend/src/dashboard/components/Dashboard.jsx @@ -167,14 +167,6 @@ class Dashboard extends React.PureComponent { window.removeEventListener('visibilitychange', this.onVisibilityChange); } - stopPendingQueries() { - if (navigator && navigator.sendBeacon) { - navigator.sendBeacon( - `/superset/dashboard/${this.props.dashboardInfo.id}/stop/`, - ); - } - } - onVisibilityChange() { if (document.visibilityState === 'hidden') { // from visible to hidden @@ -197,6 +189,14 @@ class Dashboard extends React.PureComponent { return Object.values(this.props.charts); } + stopPendingQueries() { + if (navigator && navigator.sendBeacon) { + navigator.sendBeacon( + `/superset/dashboard/${this.props.dashboardInfo.id}/stop/`, + ); + } + } + applyFilters() { const appliedFilters = this.appliedFilters; const { activeFilters } = this.props; diff --git a/superset/views/utils.py b/superset/views/utils.py index f0509a3df43fd..b8ba9adab1b49 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -213,7 +213,7 @@ def get_datasource_info( return datasource_id, datasource_type -def get_database_ids(dashboard_id: int,) -> List[int]: +def get_database_ids(dashboard_id: int) -> List[int]: """ Find all database ids used by a given dashboard diff --git a/tests/core_tests.py b/tests/core_tests.py index 0941852e2bcb9..5fdc38d1c25c6 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -1231,6 +1231,22 @@ def test_get_column_names_from_metric(self): "my_col" ] + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"STOP_DASHBOARD_PENDING_QUERIES": True}, + clear=True, + ) + def test_stop_dashboard_queries(self): + username = "admin" + self.login(username) + dashboard = self.get_dash_by_slug("births") + with mock.patch.object(BaseEngineSpec, "stop_queries") as mock_stop_queries: + resp = self.client.post(f"/superset/dashboard/{dashboard.id}/stop/") + + self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES")) + self.assertEqual(resp.status_code, 200) + mock_stop_queries.assert_called_once() + if __name__ == "__main__": unittest.main() From 0825159e3b7ac27657e70df626b38b3a87e585c6 Mon Sep 17 00:00:00 2001 From: Grace Date: Fri, 18 Sep 2020 10:20:21 -0700 Subject: [PATCH 4/7] add extra tests --- superset/views/core.py | 4 +-- superset/views/utils.py | 25 +++++++------ tests/core_tests.py | 2 +- tests/utils_tests.py | 77 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 87 insertions(+), 21 deletions(-) diff --git a/superset/views/core.py b/superset/views/core.py index 969e4279f0b0b..289f26cb16d50 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1425,7 +1425,7 @@ def fave_slices( # pylint: disable=no-self-use @api @has_access_api @event_logger.log_this - @expose("/dashboard//stop/", methods=["POST"]) + @expose("/dashboard//stop/", methods=["POST"]) def stop_dashboard_queries( # pylint: disable=no-self-use self, dashboard_id: int ) -> FlaskResponse: @@ -1437,7 +1437,7 @@ def stop_dashboard_queries( # pylint: disable=no-self-use for dbid in database_ids: mydb = db.session.query(models.Database).get(dbid) if mydb: - mydb.db_engine_spec.stop_queries(username, dashboard_id) + mydb.db_engine_spec.stop_queries(username, int(dashboard_id)) return Response(status=200) diff --git a/superset/views/utils.py b/superset/views/utils.py index b8ba9adab1b49..27b0332369eb0 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -222,21 +222,20 @@ def get_database_ids(dashboard_id: int) -> List[int]: """ dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() slices = dashboard.slices - datasource_ids = set() - database_ids = set() + datasource_ids: Set[int] = set() + database_ids: Set[int] = set() for slc in slices: - datasource_type = slc.datasource.type - datasource_id = slc.datasource.id - - if datasource_id and datasource_type: - ds_class = ConnectorRegistry.sources.get(datasource_type) - datasource = db.session.query(ds_class).filter_by(id=datasource_id).one() - if datasource and datasource_id not in datasource_ids: - datasource_ids.add(datasource_id) - database = datasource.database - if database: - database_ids.add(database.id) + datasource = slc.datasource + if ( + datasource + and datasource.type == "table" + and datasource.id not in datasource_ids + ): + datasource_ids.add(datasource.id) + database = datasource.database + if database: + database_ids.add(database.id) return list(database_ids) diff --git a/tests/core_tests.py b/tests/core_tests.py index 4781506d1af93..d8fc09dfd8696 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -1251,7 +1251,7 @@ def test_stop_dashboard_queries(self): self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES")) self.assertEqual(resp.status_code, 200) - mock_stop_queries.assert_called_once() + mock_stop_queries.assert_called_once_with(username, dashboard.id) if __name__ == "__main__": diff --git a/tests/utils_tests.py b/tests/utils_tests.py index e3095b9ec6ed7..85227c2c96956 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -33,6 +33,7 @@ import tests.test_app from superset import app, db, security_manager +from superset.connectors.base.models import BaseDatasource from superset.exceptions import CertificateException, SupersetException from superset.models.core import Database, Log from superset.models.dashboard import Dashboard @@ -46,6 +47,7 @@ get_form_data_token, get_iterable, get_email_address_list, + get_example_database, get_or_create_db, get_since_until, get_stacktrace, @@ -1142,9 +1144,74 @@ def test_get_database_ids(self) -> None: dash_id = world_health.id database_ids = get_database_ids(dash_id) assert len(database_ids) == 1 + assert database_ids == [get_example_database().id] + + def test_get_database_ids_empty_dash(self) -> None: + # test dash with no slice + dashboard = Dashboard(dashboard_title="no slices", id=101, slices=[]) + with patch("superset.db.session.query") as mock_query: + mock_query.return_value.filter_by.return_value.one.return_value = dashboard + database_ids = get_database_ids(dashboard.id) + assert database_ids == [] + + def test_get_database_ids_multiple_databases(self) -> None: + # test dash with 2 databases + datasource_1 = Mock() + datasource_1.type = "table" + datasource_1.datasource_name = "table_datasource_1" + datasource_1.database = Mock() + + datasource_2 = Mock() + datasource_2.type = "table" + datasource_2.datasource_name = "table_datasource_2" + datasource_2.database = Mock() + + slices = [ + Slice( + datasource_id=datasource_1.id, + datasource_type=datasource_1.type, + datasource_name=datasource_1.datasource_name, + slice_name="slice_name_1", + ), + Slice( + datasource_id=datasource_2.id, + datasource_type=datasource_2.type, + datasource_name=datasource_2.datasource_name, + slice_name="slice_name_2", + ), + ] + dashboard = Dashboard(dashboard_title="with 2 slices", id=102, slices=slices) + with patch("superset.db.session.query") as mock_query: + mock_query.return_value.filter_by.return_value.one.return_value = dashboard + mock_query.return_value.filter_by.return_value.first.side_effect = [ + datasource_1, + datasource_2, + ] + database_ids = get_database_ids(dashboard.id) + self.assertCountEqual( + database_ids, [datasource_1.database.id, datasource_2.database.id] + ) - world_slice = ( - db.session.query(Slice).filter_by(slice_name="World's Population").one() - ) - database_id = world_slice.datasource.database.id - assert database_ids == [database_id] + def test_get_database_ids_druid(self) -> None: + druid_datasource = Mock() + druid_datasource.type = "druid" + druid_datasource.datasource_name = "druid_datasource_1" + druid_datasource.cluster = Mock() + + slices = [ + Slice( + datasource_id=druid_datasource.id, + datasource_type=druid_datasource.type, + datasource_name=druid_datasource.datasource_name, + slice_name="slice_name_1", + ), + ] + dashboard = Dashboard(dashboard_title="druid dash", id=103, slices=slices) + with patch("superset.db.session.query") as mock_query: + mock_query.return_value.filter_by.return_value.one.return_value = dashboard + mock_query.return_value.filter_by.return_value.first.return_value = ( + druid_datasource + ) + database_ids = get_database_ids(dashboard.id) + # druid slice has no database id + assert database_ids == [] From 50515586c731ab712f52d416ffece7e8f9281bb9 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Wed, 30 Sep 2020 15:35:49 +0300 Subject: [PATCH 5/7] add new endpoint --- .../src/dashboard/components/Dashboard.jsx | 3 +- superset/charts/api.py | 53 +++++++++++++++++-- superset/charts/schemas.py | 8 +++ superset/config.py | 16 +++++- superset/db_engine_specs/base.py | 10 ---- superset/views/core.py | 19 ------- tests/charts/api_tests.py | 23 ++++++++ tests/core_tests.py | 16 ------ 8 files changed, 98 insertions(+), 50 deletions(-) diff --git a/superset-frontend/src/dashboard/components/Dashboard.jsx b/superset-frontend/src/dashboard/components/Dashboard.jsx index 4721709538d6b..97114dda83c7b 100644 --- a/superset-frontend/src/dashboard/components/Dashboard.jsx +++ b/superset-frontend/src/dashboard/components/Dashboard.jsx @@ -192,7 +192,8 @@ class Dashboard extends React.PureComponent { stopPendingQueries() { if (navigator && navigator.sendBeacon) { navigator.sendBeacon( - `/superset/dashboard/${this.props.dashboardInfo.id}/stop/`, + '/api/v1/chart/data/stop/', + JSON.stringify({ dashboard_id: this.props.dashboardInfo.id }), ); } } diff --git a/superset/charts/api.py b/superset/charts/api.py index fc65275673db1..e4076c4e18b63 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -19,7 +19,7 @@ from typing import Any, Dict import simplejson -from flask import g, make_response, redirect, request, Response, url_for +from flask import current_app, g, make_response, redirect, request, Response, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import gettext as _, ngettext @@ -82,6 +82,7 @@ class ChartRestApi(BaseSupersetModelRestApi): RouteMethod.RELATED, "bulk_delete", # not using RouteMethod since locally defined "data", + "data_stop", "viz_types", } class_permission_name = "SliceModelView" @@ -184,6 +185,7 @@ def __init__(self) -> None: "screenshot", "cache_screenshot", } + super().__init__() @expose("/", methods=["POST"]) @@ -421,8 +423,6 @@ def bulk_delete(self, **kwargs: Any) -> Response: @expose("/data", methods=["POST"]) @event_logger.log_this - @protect() - @safe @statsd_metrics def data(self) -> Response: """ @@ -503,6 +503,53 @@ def data(self) -> Response: return response + @expose("/data/stop", methods=["POST"]) + @event_logger.log_this + @protect() + @safe + @statsd_metrics + def data_stop(self) -> Response: + """ + Takes a dashboard id and tries to cancel all associated chart data requests + issued by the user. + --- + post: + description: >- + Takes a dashboard id and tries to cancel all associated chart data requests + issued by the user + requestBody: + description: >- + The dashboard id. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ChartDataStopSchema" + responses: + 200: + description: Pending dashboard queries terminated + content: + application/json: + schema: + type: object + 400: + $ref: '#/components/responses/400' + 500: + $ref: '#/components/responses/500' + """ + if request.is_json: + json_body = request.json + dashboard_id = json_body.get("dashboard_id") + if not dashboard_id: + return self.response(400, message="dashboard_id missing in body") + hook = current_app.config["STOP_DASHBOARD_PENDING_QUERIES_HOOK"] + try: + hook(dashboard_id, g.user.username) + return self.response(200) + except Exception as ex: + return self.response(500, message=str(ex)) + return self.response(400, message="body missing") + @expose("//cache_screenshot/", methods=["GET"]) @protect() @rison(screenshot_query_schema) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 58012a820a879..05ae0885b5161 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -903,7 +903,15 @@ class ChartDataResponseSchema(Schema): ) +class ChartDataStopSchema(Schema): + dashboard_id = fields.Integer( + description="the dashboard for which to terminate any pending chart data requests", + required=True, + ) + + CHART_SCHEMAS = ( + ChartDataStopSchema, ChartDataQueryContextSchema, ChartDataResponseSchema, # TODO: These should optimally be included in the QueryContext schema as an `anyOf` diff --git a/superset/config.py b/superset/config.py index 7d6b64dd4454b..5ae35305fa5ba 100644 --- a/superset/config.py +++ b/superset/config.py @@ -173,7 +173,7 @@ def _try_json_readsha( # pylint: disable=unused-argument WTF_CSRF_EXEMPT_LIST = [ "superset.views.core.log", "superset.charts.api.data", - "superset.views.core.stop_dashboard_queries", + "superset.charts.api.data_stop", ] # Whether to run the web server in debug mode or not @@ -656,6 +656,20 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # language. This allows you to define custom logic to process macro template. CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {} +# A dictionary mapping database_ids to hooks that are called with username_id and +# dashboard_id if the `STOP_DASHBOARD_PENDING_QUERIES` feature flag is enabled. +# Example: +# def STOP_DASHBOARD_PENDING_QUERIES_HOOK( +# dashboard_id: int, +# username: str +# ) -> None: +# if datasource_id == 10: +# call_external_api(dashboard_id, username) +# return None +STOP_DASHBOARD_PENDING_QUERIES_HOOK: Callable[ + [int, str], None +] = lambda dashboard_id, username: None + # Roles that are controlled by the API / Superset and should not be changes # by humans. ROBOT_PERMISSION_ROLES = ["Public", "Gamma", "Alpha", "Admin", "sql_lab"] diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 4d3cc744a231c..8456bb7f7e651 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1025,13 +1025,3 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: logger.error(ex) raise ex return extra - - @classmethod - def stop_queries(cls, username: str, dashboard_id: int) -> None: - """ - An empty function. The actual stop implementation depends on the engine - - :param: username: user sends out queries - :param dashboard_id: dashboard has charts that waiting for queries - """ - return None diff --git a/superset/views/core.py b/superset/views/core.py index d11d95e6f58ec..c613632d45ee4 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1435,25 +1435,6 @@ def fave_slices( # pylint: disable=no-self-use payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) - @api - @has_access_api - @event_logger.log_this - @expose("/dashboard//stop/", methods=["POST"]) - def stop_dashboard_queries( # pylint: disable=no-self-use - self, dashboard_id: int - ) -> FlaskResponse: - if is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"): - username = g.user.username - database_ids = get_database_ids(dashboard_id) - - # stop pending query is only available for certain database(s) - for dbid in database_ids: - mydb = db.session.query(models.Database).get(dbid) - if mydb: - mydb.db_engine_spec.stop_queries(username, int(dashboard_id)) - - return Response(status=200) - @event_logger.log_this @api @has_access_api diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 7127180fee2e1..304cb0b6ce03a 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -28,6 +28,7 @@ from superset.utils.core import get_example_database from tests.test_app import app +from superset import is_feature_enabled from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import db, security_manager from superset.models.dashboard import Dashboard @@ -951,3 +952,25 @@ def test_chart_data_jinja_filter_request(self): result = response_payload["result"][0]["query"] if get_example_database().backend != "presto": assert "('boy' = 'boy')" in result + + @mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", + {"STOP_DASHBOARD_PENDING_QUERIES": True}, + clear=True, + ) + def test_stop_dashboard_queries(self): + hook = app.config["STOP_DASHBOARD_PENDING_QUERIES_HOOK"] + mock_hook = mock.Mock() + app.config["STOP_DASHBOARD_PENDING_QUERIES_HOOK"] = mock_hook + + username = "admin" + self.login(username) + dashboard = self.get_dash_by_slug("births") + resp = self.client.post( + f"/api/v1/chart/data/stop", json={"dashboard_id": dashboard.id} + ) + + self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES")) + self.assertEqual(resp.status_code, 200) + mock_hook.assert_called_once_with(dashboard.id, username) + app.config["STOP_DASHBOARD_PENDING_QUERIES_HOOK"] = hook diff --git a/tests/core_tests.py b/tests/core_tests.py index d8fc09dfd8696..44889d72be8b6 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -1237,22 +1237,6 @@ def test_get_column_names_from_metric(self): "my_col" ] - @mock.patch.dict( - "superset.extensions.feature_flag_manager._feature_flags", - {"STOP_DASHBOARD_PENDING_QUERIES": True}, - clear=True, - ) - def test_stop_dashboard_queries(self): - username = "admin" - self.login(username) - dashboard = self.get_dash_by_slug("births") - with mock.patch.object(BaseEngineSpec, "stop_queries") as mock_stop_queries: - resp = self.client.post(f"/superset/dashboard/{dashboard.id}/stop/") - - self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES")) - self.assertEqual(resp.status_code, 200) - mock_stop_queries.assert_called_once_with(username, dashboard.id) - if __name__ == "__main__": unittest.main() From 1c302dffc5479dde6abec52f2a72a895dfd40d60 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 1 Oct 2020 10:11:07 +0300 Subject: [PATCH 6/7] remove get_dashboard_ids and tests --- superset/views/core.py | 1 - superset/views/utils.py | 27 -------------- tests/utils_tests.py | 81 ----------------------------------------- 3 files changed, 109 deletions(-) diff --git a/superset/views/core.py b/superset/views/core.py index c613632d45ee4..5e5ff051d9528 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -131,7 +131,6 @@ get_dashboard, get_dashboard_changedon_dt, get_dashboard_extra_filters, - get_database_ids, get_datasource_info, get_form_data, get_viz, diff --git a/superset/views/utils.py b/superset/views/utils.py index b0fbfd8cbe5c2..dc164944c2b57 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -221,33 +221,6 @@ def get_datasource_info( return datasource_id, datasource_type -def get_database_ids(dashboard_id: int) -> List[int]: - """ - Find all database ids used by a given dashboard - - :param dashboard_id: The dashboard id - :returns: A list of database ids used by the given dashboard - """ - dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one() - slices = dashboard.slices - datasource_ids: Set[int] = set() - database_ids: Set[int] = set() - - for slc in slices: - datasource = slc.datasource - if ( - datasource - and datasource.type == "table" - and datasource.id not in datasource_ids - ): - datasource_ids.add(datasource.id) - database = datasource.database - if database: - database_ids.add(database.id) - - return list(database_ids) - - def apply_display_max_row_limit( sql_results: Dict[str, Any], rows: Optional[int] = None ) -> Dict[str, Any]: diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 69aa9b8b6f82e..c63fcc119328d 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -23,7 +23,6 @@ import json import os import re -from typing import List from unittest.mock import Mock, patch import numpy @@ -33,7 +32,6 @@ import tests.test_app from superset import app, db, security_manager -from superset.connectors.base.models import BaseDatasource from superset.exceptions import CertificateException, SupersetException from superset.models.core import Database, Log from superset.models.dashboard import Dashboard @@ -47,7 +45,6 @@ get_form_data_token, get_iterable, get_email_address_list, - get_example_database, get_or_create_db, get_since_until, get_stacktrace, @@ -71,7 +68,6 @@ from superset.views.utils import ( build_extra_filters, get_dashboard_changedon_dt, - get_database_ids, get_form_data, get_time_range_endpoints, ) @@ -1150,80 +1146,3 @@ def test_get_dashboard_changedon_dt(self) -> None: assert get_dashboard_changedon_dt(self, slug) == max( dashboard_last_changedon, slices_last_changedon ).replace(microsecond=0) - - def test_get_database_ids(self) -> None: - world_health = db.session.query(Dashboard).filter_by(slug="world_health").one() - dash_id = world_health.id - database_ids = get_database_ids(dash_id) - assert len(database_ids) == 1 - assert database_ids == [get_example_database().id] - - def test_get_database_ids_empty_dash(self) -> None: - # test dash with no slice - dashboard = Dashboard(dashboard_title="no slices", id=101, slices=[]) - with patch("superset.db.session.query") as mock_query: - mock_query.return_value.filter_by.return_value.one.return_value = dashboard - database_ids = get_database_ids(dashboard.id) - assert database_ids == [] - - def test_get_database_ids_multiple_databases(self) -> None: - # test dash with 2 databases - datasource_1 = Mock() - datasource_1.type = "table" - datasource_1.datasource_name = "table_datasource_1" - datasource_1.database = Mock() - - datasource_2 = Mock() - datasource_2.type = "table" - datasource_2.datasource_name = "table_datasource_2" - datasource_2.database = Mock() - - slices = [ - Slice( - datasource_id=datasource_1.id, - datasource_type=datasource_1.type, - datasource_name=datasource_1.datasource_name, - slice_name="slice_name_1", - ), - Slice( - datasource_id=datasource_2.id, - datasource_type=datasource_2.type, - datasource_name=datasource_2.datasource_name, - slice_name="slice_name_2", - ), - ] - dashboard = Dashboard(dashboard_title="with 2 slices", id=102, slices=slices) - with patch("superset.db.session.query") as mock_query: - mock_query.return_value.filter_by.return_value.one.return_value = dashboard - mock_query.return_value.filter_by.return_value.first.side_effect = [ - datasource_1, - datasource_2, - ] - database_ids = get_database_ids(dashboard.id) - self.assertCountEqual( - database_ids, [datasource_1.database.id, datasource_2.database.id] - ) - - def test_get_database_ids_druid(self) -> None: - druid_datasource = Mock() - druid_datasource.type = "druid" - druid_datasource.datasource_name = "druid_datasource_1" - druid_datasource.cluster = Mock() - - slices = [ - Slice( - datasource_id=druid_datasource.id, - datasource_type=druid_datasource.type, - datasource_name=druid_datasource.datasource_name, - slice_name="slice_name_1", - ), - ] - dashboard = Dashboard(dashboard_title="druid dash", id=103, slices=slices) - with patch("superset.db.session.query") as mock_query: - mock_query.return_value.filter_by.return_value.one.return_value = dashboard - mock_query.return_value.filter_by.return_value.first.return_value = ( - druid_datasource - ) - database_ids = get_database_ids(dashboard.id) - # druid slice has no database id - assert database_ids == [] From 847b740cd8405a06117fa0e2d8f22d114c89f602 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 1 Oct 2020 10:25:00 +0300 Subject: [PATCH 7/7] address CI errors --- .../src/dashboard/components/Dashboard.jsx | 2 +- superset/charts/api.py | 10 ++++------ superset/charts/schemas.py | 4 +++- superset/config.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/superset-frontend/src/dashboard/components/Dashboard.jsx b/superset-frontend/src/dashboard/components/Dashboard.jsx index 97114dda83c7b..473dc4d7f856a 100644 --- a/superset-frontend/src/dashboard/components/Dashboard.jsx +++ b/superset-frontend/src/dashboard/components/Dashboard.jsx @@ -192,7 +192,7 @@ class Dashboard extends React.PureComponent { stopPendingQueries() { if (navigator && navigator.sendBeacon) { navigator.sendBeacon( - '/api/v1/chart/data/stop/', + '/api/v1/chart/data/stop', JSON.stringify({ dashboard_id: this.props.dashboardInfo.id }), ); } diff --git a/superset/charts/api.py b/superset/charts/api.py index e4076c4e18b63..e80d03fcf171c 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -185,7 +185,6 @@ def __init__(self) -> None: "screenshot", "cache_screenshot", } - super().__init__() @expose("/", methods=["POST"]) @@ -423,6 +422,8 @@ def bulk_delete(self, **kwargs: Any) -> Response: @expose("/data", methods=["POST"]) @event_logger.log_this + @protect() + @safe @statsd_metrics def data(self) -> Response: """ @@ -543,11 +544,8 @@ def data_stop(self) -> Response: if not dashboard_id: return self.response(400, message="dashboard_id missing in body") hook = current_app.config["STOP_DASHBOARD_PENDING_QUERIES_HOOK"] - try: - hook(dashboard_id, g.user.username) - return self.response(200) - except Exception as ex: - return self.response(500, message=str(ex)) + hook(dashboard_id, g.user.username) + return self.response(200) return self.response(400, message="body missing") @expose("//cache_screenshot/", methods=["GET"]) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 05ae0885b5161..da253777a8153 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -905,7 +905,9 @@ class ChartDataResponseSchema(Schema): class ChartDataStopSchema(Schema): dashboard_id = fields.Integer( - description="the dashboard for which to terminate any pending chart data requests", + description="the dashboard for which to terminate pending chart data requests. " + "Requires defining a hook for handling query cancellation requests " + "by setting `STOP_DASHBOARD_PENDING_QUERIES_HOOK`.", required=True, ) diff --git a/superset/config.py b/superset/config.py index 5ae35305fa5ba..8b704f43033a1 100644 --- a/superset/config.py +++ b/superset/config.py @@ -974,7 +974,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods elif importlib.util.find_spec("superset_config"): try: import superset_config # pylint: disable=import-error - from superset_config import * # pylint: disable=import-error,wildcard-import,unused-wildcard-import + from superset_config import * # type: ignore # pylint: disable=import-error,wildcard-import,unused-wildcard-import print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]") except Exception: