Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Stop pending queries when user close dashboard #10836

Closed
38 changes: 25 additions & 13 deletions superset-frontend/src/dashboard/components/Dashboard.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ import {
LOG_ACTIONS_MOUNT_DASHBOARD,
Logger,
} from '../../logger/LogUtils';
import { isFeatureEnabled, FeatureFlag } from '../../featureFlags';
import OmniContainer from '../../components/OmniContainer';
import { areObjectsEqual } from '../../reduxUtils';

import '../stylesheets/index.less';
import getLocationHash from '../util/getLocationHash';
import isDashboardEmpty from '../util/isDashboardEmpty';
import isDashboardLoading from '../util/isDashboardLoading';

const propTypes = {
actions: PropTypes.shape({
Expand Down Expand Up @@ -68,16 +70,7 @@ const defaultProps = {
};

class Dashboard extends React.PureComponent {
// eslint-disable-next-line react/sort-comp
static onBeforeUnload(hasChanged) {
if (hasChanged) {
window.addEventListener('beforeunload', Dashboard.unload);
} else {
window.removeEventListener('beforeunload', Dashboard.unload);
}
}

static unload() {
static showUnsavedMessage() {
const message = t('You have unsaved changes.');
window.event.returnValue = message; // Gecko + IE
return message; // Gecko + Webkit, Safari, Chrome etc.
Expand All @@ -86,7 +79,11 @@ class Dashboard extends React.PureComponent {
constructor(props) {
super(props);
this.appliedFilters = props.activeFilters || {};
this.canStopPendingQueries = isFeatureEnabled(
FeatureFlag.STOP_DASHBOARD_PENDING_QUERIES,
);

this.stopPendingQueries = this.stopPendingQueries.bind(this);
this.onVisibilityChange = this.onVisibilityChange.bind(this);
}

Expand Down Expand Up @@ -143,7 +140,8 @@ class Dashboard extends React.PureComponent {
}

componentDidUpdate() {
const { hasUnsavedChanges, editMode } = this.props.dashboardState;
const { charts, dashboardState } = this.props;
const { hasUnsavedChanges, editMode } = dashboardState;

const { appliedFilters } = this;
const { activeFilters } = this.props;
Expand All @@ -153,9 +151,15 @@ class Dashboard extends React.PureComponent {
}

if (hasUnsavedChanges) {
Dashboard.onBeforeUnload(true);
window.addEventListener('beforeunload', Dashboard.showUnsavedMessage);
} else {
window.removeEventListener('beforeunload', Dashboard.showUnsavedMessage);
}

if (this.canStopPendingQueries && isDashboardLoading(charts)) {
window.addEventListener('beforeunload', this.stopPendingQueries);
} else {
Dashboard.onBeforeUnload(false);
window.removeEventListener('beforeunload', this.stopPendingQueries);
}
}

Expand Down Expand Up @@ -185,6 +189,14 @@ class Dashboard extends React.PureComponent {
return Object.values(this.props.charts);
}

stopPendingQueries() {
if (navigator && navigator.sendBeacon) {
navigator.sendBeacon(
`/superset/dashboard/${this.props.dashboardInfo.id}/stop/`,
);
}
}

applyFilters() {
const { appliedFilters } = this;
const { activeFilters } = this.props;
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/featureFlags.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export enum FeatureFlag {
SHARE_QUERIES_VIA_KV_STORE = 'SHARE_QUERIES_VIA_KV_STORE',
SQLLAB_BACKEND_PERSISTENCE = 'SQLLAB_BACKEND_PERSISTENCE',
THUMBNAILS = 'THUMBNAILS',
STOP_DASHBOARD_PENDING_QUERIES = 'STOP_DASHBOARD_PENDING_QUERIES',
SIP_34_SAVED_QUERIES_UI = 'SIP_34_SAVED_QUERIES_UI',
}

Expand Down
18 changes: 12 additions & 6 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ def _try_json_readsha( # pylint: disable=unused-argument
WTF_CSRF_ENABLED = True

# Add endpoints that need to be exempt from CSRF protection
WTF_CSRF_EXEMPT_LIST = ["superset.views.core.log", "superset.charts.api.data"]
WTF_CSRF_EXEMPT_LIST = [
"superset.views.core.log",
"superset.charts.api.data",
"superset.views.core.stop_dashboard_queries",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dangerous? I assume this means that anyone who steals a session cookie could constantly kill all that user's queries?

it's probably fine for internal deployments, but I wonder if there's a way we could do this without removing CSRF protections

Copy link
Author

@graceguo-supercat graceguo-supercat Sep 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know 2 APIs here that don't use CSRF token: log and stop_dashboard_queries, both of them use sendbeacon which can't send CSRF token. The advantage of sendbeacon vs regular POST is here: it doesn't need to wait response. And it's pretty common for sendbeacon call go without CSRF token.

CSRF token is used to prevent malicious site from executing some transaction, like move money from your bank account to mine :) I feel kill other ppl's queries when their dashboard is still loading, is not very dangerous.

But for superset.charts.api.data i am not sure why it is in the exempt list. But this is not related to this PR.

Copy link
Member

@villebro villebro Sep 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason superset.charts.api.data is exempt is because it is only a POST due to having a large request payload that doesn't sit well with a GET. Therefore it's not really a state changing POST, but a simulated GET.

]

# Whether to run the web server in debug mode or not
DEBUG = os.environ.get("FLASK_ENV") == "development"
Expand Down Expand Up @@ -309,6 +313,8 @@ def _try_json_readsha( # pylint: disable=unused-argument
"SIP_38_VIZ_REARCHITECTURE": False,
"TAGGING_SYSTEM": False,
"SQLLAB_BACKEND_PERSISTENCE": False,
# stop pending queries when user close/reload dashboard in browser
"STOP_DASHBOARD_PENDING_QUERIES": False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is behind a feature flag we can go even further and not register the endpoint at all. Take a look at: https://github.com/apache/incubator-superset/blob/master/superset/dashboards/api.py#L173

}

# This is merely a default.
Expand Down Expand Up @@ -627,11 +633,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
# db configuration and a result of this function.

# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
["Database", "models.User"], List[str]
] = lambda database, user: [
UPLOADED_CSV_HIVE_NAMESPACE
] if UPLOADED_CSV_HIVE_NAMESPACE else []
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[["Database", "models.User"], List[str]] = (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black automatically formatted this section. not related with my PR.

lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
)

# Values that should be treated as nulls for the csv uploads.
CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
Expand Down
10 changes: 10 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,3 +1025,13 @@ def get_extra_params(database: "Database") -> Dict[str, Any]:
logger.error(ex)
raise ex
return extra

@classmethod
def stop_queries(cls, username: str, dashboard_id: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name this stop_dashboard_queries since it accepts dashboard_id as a required parameter? Also, would user_id be better suited for this API than username?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

airbnb's internal API accept username and dashboard_id :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if there is an interest to have in in superset as well e.g. use smth like:
CALL system.runtime.kill_query('20151207_215727_00146_tx3nr');
prestodb/presto#1515

However a challenge here would be tracking all queries and there ids

Copy link
Author

@graceguo-supercat graceguo-supercat Sep 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. Dashboard is running in synchronized mode, there is no query id passed from query engine to dashboard. While in SQL lab, which is running in asynchronized mode, query id is saved into database, and celery Worker will update query status.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you find another solution, i am happy to learn :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to introduce a hook here that can either link a database or a db engine spec to a function that handles the query termination? Something similar to what JINJA_CONTEXT_ADDONS or CUSTOM_TEMPLATE_PROCESSORS does in config.py.

"""
An empty function. The actual stop implementation depends on the engine

:param: username: user sends out queries
:param dashboard_id: dashboard has charts that waiting for queries
"""
return None
22 changes: 21 additions & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
check_slice_perms,
get_cta_schema_name,
get_dashboard_extra_filters,
get_database_ids,
get_datasource_info,
get_form_data,
get_viz,
Expand Down Expand Up @@ -1421,6 +1422,25 @@ def fave_slices( # pylint: disable=no-self-use
payload.append(dash)
return json_success(json.dumps(payload, default=utils.json_int_dttm_ser))

@api
@has_access_api
@event_logger.log_this
@expose("/dashboard/<int:dashboard_id>/stop/", methods=["POST"])
def stop_dashboard_queries( # pylint: disable=no-self-use
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be useful to sync with @dpgaspar on using api v1.
Also please add some tests, we have presto on CI

Copy link
Author

@graceguo-supercat graceguo-supercat Sep 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in open source code base, stop_queries function is empty, nothing to test. airbnb and other companies can add their internal implementation to stop queries by dashboard_id, but this is not a standard Presto API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to have stop_queries implementation in the test config.
It will help in 2 ways:

  1. be a safeguard from open source contributions not to break it
  2. will work as an example for the companies willing to try it out

Copy link
Author

@graceguo-supercat graceguo-supercat Sep 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bkyryliuk do you have an existed example, what is implementation in the test config? thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think as I mentioned above, core apis are deprecated and new ones should be added here: https://github.com/apache/incubator-superset/blob/master/superset/dashboards/api.py @dpgaspar & @villebro would probably know more on this topic

self, dashboard_id: int
) -> FlaskResponse:
if is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"):
username = g.user.username
database_ids = get_database_ids(dashboard_id)

# stop pending query is only available for certain database(s)
for dbid in database_ids:
mydb = db.session.query(models.Database).get(dbid)
if mydb:
mydb.db_engine_spec.stop_queries(username, int(dashboard_id))

return Response(status=200)

@event_logger.log_this
@api
@has_access_api
Expand Down Expand Up @@ -1778,7 +1798,7 @@ def sync_druid_source(self) -> FlaskResponse: # pylint: disable=no-self-use
@expose("/get_or_create_table/", methods=["POST"])
@event_logger.log_this
def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use
""" Gets or creates a table object with attributes passed to the API.
"""Gets or creates a table object with attributes passed to the API.

It expects the json with params:
* datasourceName - e.g. table name, required
Expand Down
27 changes: 27 additions & 0 deletions superset/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,33 @@ def get_datasource_info(
return datasource_id, datasource_type


def get_database_ids(dashboard_id: int) -> List[int]:
"""
Find all database ids used by a given dashboard

:param dashboard_id: The dashboard id
:returns: A list of database ids used by the given dashboard
"""
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one()
slices = dashboard.slices
datasource_ids: Set[int] = set()
database_ids: Set[int] = set()

for slc in slices:
datasource = slc.datasource
if (
datasource
and datasource.type == "table"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only return database id for table type datasources.

and datasource.id not in datasource_ids
):
datasource_ids.add(datasource.id)
database = datasource.database
if database:
database_ids.add(database.id)

return list(database_ids)


def apply_display_max_row_limit(
sql_results: Dict[str, Any], rows: Optional[int] = None
) -> Dict[str, Any]:
Expand Down
16 changes: 16 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,22 @@ def test_get_column_names_from_metric(self):
"my_col"
]

@mock.patch.dict(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test for stop_dashboard_queries. @bkyryliuk Do you think it make sense?

"superset.extensions.feature_flag_manager._feature_flags",
{"STOP_DASHBOARD_PENDING_QUERIES": True},
clear=True,
)
def test_stop_dashboard_queries(self):
username = "admin"
self.login(username)
dashboard = self.get_dash_by_slug("births")
with mock.patch.object(BaseEngineSpec, "stop_queries") as mock_stop_queries:
resp = self.client.post(f"/superset/dashboard/{dashboard.id}/stop/")

self.assertTrue(is_feature_enabled("STOP_DASHBOARD_PENDING_QUERIES"))
self.assertEqual(resp.status_code, 200)
mock_stop_queries.assert_called_once_with(username, dashboard.id)


if __name__ == "__main__":
unittest.main()
81 changes: 81 additions & 0 deletions tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import os
import re
from typing import List
from unittest.mock import Mock, patch

import numpy
Expand All @@ -32,6 +33,7 @@

import tests.test_app
from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.exceptions import CertificateException, SupersetException
from superset.models.core import Database, Log
from superset.models.dashboard import Dashboard
Expand All @@ -45,6 +47,7 @@
get_form_data_token,
get_iterable,
get_email_address_list,
get_example_database,
get_or_create_db,
get_since_until,
get_stacktrace,
Expand All @@ -67,6 +70,7 @@
from superset.utils import schema
from superset.views.utils import (
build_extra_filters,
get_database_ids,
get_form_data,
get_time_range_endpoints,
)
Expand Down Expand Up @@ -1134,3 +1138,80 @@ def test_get_form_data_token(self):
assert get_form_data_token({"token": "token_abcdefg1"}) == "token_abcdefg1"
generated_token = get_form_data_token({})
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None

def test_get_database_ids(self) -> None:
world_health = db.session.query(Dashboard).filter_by(slug="world_health").one()
Copy link
Member

@bkyryliuk bkyryliuk Sep 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's tests some edge cases here:

  1. there is no database behind it e.g. markdown only
  2. there are not slices in the dashboard
  3. there are 2 databases used for the dashboard
  4. druid use case

Copy link
Author

@graceguo-supercat graceguo-supercat Sep 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no markdown slice any more, 1 and 2 are same case: no slices.
Added test cases for no slice, 2 database and druid slice cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow nice, what about iframe chart?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

markup, iframe and separator visualization types, are not real slices since they do not generate query, i converted them into dashboard component and retried these viz types.
#10590
apache-superset/superset-ui#746

dash_id = world_health.id
database_ids = get_database_ids(dash_id)
assert len(database_ids) == 1
assert database_ids == [get_example_database().id]

def test_get_database_ids_empty_dash(self) -> None:
# test dash with no slice
dashboard = Dashboard(dashboard_title="no slices", id=101, slices=[])
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
database_ids = get_database_ids(dashboard.id)
assert database_ids == []

def test_get_database_ids_multiple_databases(self) -> None:
# test dash with 2 databases
datasource_1 = Mock()
datasource_1.type = "table"
datasource_1.datasource_name = "table_datasource_1"
datasource_1.database = Mock()

datasource_2 = Mock()
datasource_2.type = "table"
datasource_2.datasource_name = "table_datasource_2"
datasource_2.database = Mock()

slices = [
Slice(
datasource_id=datasource_1.id,
datasource_type=datasource_1.type,
datasource_name=datasource_1.datasource_name,
slice_name="slice_name_1",
),
Slice(
datasource_id=datasource_2.id,
datasource_type=datasource_2.type,
datasource_name=datasource_2.datasource_name,
slice_name="slice_name_2",
),
]
dashboard = Dashboard(dashboard_title="with 2 slices", id=102, slices=slices)
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
mock_query.return_value.filter_by.return_value.first.side_effect = [
datasource_1,
datasource_2,
]
database_ids = get_database_ids(dashboard.id)
self.assertCountEqual(
database_ids, [datasource_1.database.id, datasource_2.database.id]
)

def test_get_database_ids_druid(self) -> None:
druid_datasource = Mock()
druid_datasource.type = "druid"
druid_datasource.datasource_name = "druid_datasource_1"
druid_datasource.cluster = Mock()

slices = [
Slice(
datasource_id=druid_datasource.id,
datasource_type=druid_datasource.type,
datasource_name=druid_datasource.datasource_name,
slice_name="slice_name_1",
),
]
dashboard = Dashboard(dashboard_title="druid dash", id=103, slices=slices)
with patch("superset.db.session.query") as mock_query:
mock_query.return_value.filter_by.return_value.one.return_value = dashboard
mock_query.return_value.filter_by.return_value.first.return_value = (
druid_datasource
)
database_ids = get_database_ids(dashboard.id)
# druid slice has no database id
assert database_ids == []