Skip to content

Commit

Permalink
chore(sql): clean up invalid filter clause exception types (#17702)
Browse files Browse the repository at this point in the history
* chore(sql): clean up invalid filter clause exception types

* fix lint

* rename exception
  • Loading branch information
villebro authored and eschutho committed Dec 11, 2021
1 parent 4c00bd4 commit 8b0ab83
Show file tree
Hide file tree
Showing 6 changed files with 1,235 additions and 1 deletion.
16 changes: 16 additions & 0 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.sql_parse import validate_filter_clause
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
Expand Down Expand Up @@ -277,6 +283,7 @@ def validate(
try:
self._validate_there_are_no_missing_series()
self._validate_no_have_duplicate_labels()
self._validate_filters()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
Expand All @@ -295,6 +302,15 @@ def _validate_no_have_duplicate_labels(self) -> None:
)
)

def _validate_filters(self) -> None:
for param in ("where", "having"):
clause = self.extras.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

def _validate_there_are_no_missing_series(self) -> None:
missing_series = [col for col in self.series_columns if col not in self.columns]
if missing_series:
Expand Down
4 changes: 4 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ class CacheLoadError(SupersetException):
status = 404


class QueryClauseValidationException(SupersetException):
status = 400


class DashboardImportException(SupersetException):
pass

Expand Down
22 changes: 22 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

from superset.exceptions import QueryClauseValidationException

RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
ON_KEYWORD = "ON"
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
Expand Down Expand Up @@ -345,3 +347,23 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
for i in statement.tokens:
str_res += str(i.value)
return str_res


def validate_filter_clause(clause: str) -> None:
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
raise QueryClauseValidationException("Filter clause contains comment")

statements = sqlparse.parse(clause)
if len(statements) != 1:
raise QueryClauseValidationException("Filter clause contains multiple queries")
open_parens = 0

for token in statements[0]:
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
11 changes: 11 additions & 0 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@
from superset.exceptions import (
CacheLoadError,
NullValueException,
QueryClauseValidationException,
QueryObjectValidationError,
SpatialException,
SupersetSecurityException,
)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.sql_parse import validate_filter_clause
from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
Expand Down Expand Up @@ -354,6 +356,15 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals
self.from_dttm = from_dttm
self.to_dttm = to_dttm

# validate sql filters
for param in ("where", "having"):
clause = self.form_data.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

# extras are used to query elements specific to a datasource type
# for instance the extra where clause that applies only to Tables
extras = {
Expand Down
28 changes: 28 additions & 0 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,32 @@ def test_chart_data_incorrect_request(self):
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_with_invalid_where_parameter_closing_unclosed__400(self):
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = []
request_payload["queries"][0]["extras"][
"where"
] = "state = 'CA') OR (state = 'NY'"

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")

assert rv.status_code == 400

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_with_invalid_having_parameter_closing_and_comment__400(self):
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["filters"] = []
request_payload["queries"][0]["extras"][
"having"
] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment"

rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")

assert rv.status_code == 400

def test_chart_data_with_invalid_datasource(self):
"""
Chart data API: Test chart data query with invalid schema
Expand Down Expand Up @@ -2092,3 +2118,5 @@ def test_chart_data_virtual_table_with_colons(self):
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]


Loading

0 comments on commit 8b0ab83

Please sign in to comment.