Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve SQL parsing #26767

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ describe('AdhocMetrics', () => {
});

it('Clear metric and set simple adhoc metric', () => {
const metric = 'sum(num_girls)';
const metric = 'SUM(num_girls)';
const metricName = 'Sum Girls';
cy.get('[data-test=metrics]')
.find('[data-test="remove-control-button"]')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ describe('Visualization > Table', () => {
});
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*name/i,
querySubstring: /group by\n.*name/i,
chartSelector: 'table',
});
});
Expand Down Expand Up @@ -246,7 +246,7 @@ describe('Visualization > Table', () => {
cy.visitChartByParams(formData);
cy.verifySliceSuccess({
waitAlias: '@chartData',
querySubstring: /group by.*state/i,
querySubstring: /group by\n.*state/i,
chartSelector: 'table',
});
cy.get('td').contains(/\d*%/);
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/SqlLab/actions/sqlLab.js
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ export function formatQuery(queryEditor) {
const { sql } = getUpToDateQuery(getState(), queryEditor);
return SupersetClient.post({
endpoint: `/api/v1/sqllab/format_sql/`,
// TODO (betodealmeida): pass engine as a parameter for better formatting
body: JSON.stringify({ sql }),
headers: { 'Content-Type': 'application/json' },
}).then(({ json }) => {
Expand Down
53 changes: 3 additions & 50 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
from flask import escape, Markup
from flask_appbuilder import Model
from flask_appbuilder.security.sqla.models import User
Expand Down Expand Up @@ -104,7 +103,6 @@
ExploreMixin,
ImportExportMixin,
QueryResult,
QueryStringExtended,
validate_adhoc_subquery,
)
from superset.models.slice import Slice
Expand Down Expand Up @@ -1099,7 +1097,9 @@ def _process_sql_expression(


class SqlaTable(
Model, BaseDatasource, ExploreMixin
Model,
BaseDatasource,
ExploreMixin,
): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""

Expand Down Expand Up @@ -1413,26 +1413,6 @@ def mutate_query_from_config(self, sql: str) -> str:
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(table=self, database=self.database, **kwargs)

def get_query_str_extended(
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
applied_template_filters=sqlaq.applied_template_filters,
applied_filter_columns=sqlaq.applied_filter_columns,
rejected_filter_columns=sqlaq.rejected_filter_columns,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)

def get_query_str(self, query_obj: QueryObjectDict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
Expand Down Expand Up @@ -1474,33 +1454,6 @@ def get_from_clause(

return from_clause, cte

def get_rendered_sql(
self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
"""

sql = self.sql
if template_processor:
try:
sql = template_processor.process_template(sql)
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error while rendering virtual dataset query: %(msg)s",
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
return sql

def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import ParsedQuery, Table
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
Expand Down Expand Up @@ -1448,7 +1448,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
qry = partition_query
sql = database.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
sql = SQLScript(sql, engine=cls.engine).format()
return sql

@classmethod
Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from re import Pattern
from typing import Any, TYPE_CHECKING

import sqlparse
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
Expand All @@ -37,6 +36,7 @@
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.sql_lab import Query
from superset.sql_parse import SQLScript
from superset.utils import core as utils
from superset.utils.core import GenericDataType

Expand Down Expand Up @@ -281,8 +281,9 @@ def get_default_schema_for_query(
This method simply uses the parent method after checking that there are no
malicious path setting in the query.
"""
sql = sqlparse.format(query.sql, strip_comments=True)
if re.search(r"set\s+search_path\s*=", sql, re.IGNORECASE):
script = SQLScript(query.sql, engine=cls.engine)
settings = script.get_settings()
if "search_path" in settings:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.QUERY_SECURITY_ACCESS_ERROR,
Expand Down
2 changes: 2 additions & 0 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class SupersetErrorType(StrEnum):
RESULTS_BACKEND_ERROR = "RESULTS_BACKEND_ERROR"
ASYNC_WORKERS_ERROR = "ASYNC_WORKERS_ERROR"
ADHOC_SUBQUERY_NOT_ALLOWED_ERROR = "ADHOC_SUBQUERY_NOT_ALLOWED_ERROR"
INVALID_SQL_ERROR = "INVALID_SQL_ERROR"

# Generic errors
GENERIC_COMMAND_ERROR = "GENERIC_COMMAND_ERROR"
Expand Down Expand Up @@ -176,6 +177,7 @@ class SupersetErrorType(StrEnum):
SupersetErrorType.INVALID_PAYLOAD_SCHEMA_ERROR: [1020],
SupersetErrorType.INVALID_CTAS_QUERY_ERROR: [1023],
SupersetErrorType.INVALID_CVAS_QUERY_ERROR: [1024, 1025],
SupersetErrorType.INVALID_SQL_ERROR: [1003],
SupersetErrorType.SQLLAB_TIMEOUT_ERROR: [1026, 1027],
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR: [1029],
SupersetErrorType.SYNTAX_ERROR: [1030],
Expand Down
17 changes: 17 additions & 0 deletions superset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,20 @@ def __init__(self, exc: ValidationError, payload: dict[str, Any]):
extra={"messages": exc.messages, "payload": payload},
)
super().__init__(error)


class SupersetParseError(SupersetErrorException):
"""
Exception to be raised when we fail to parse SQL.
"""

status = 422

def __init__(self, sql: str, engine: Optional[str] = None):
error = SupersetError(
message=_("The SQL is invalid and cannot be parsed."),
error_type=SupersetErrorType.INVALID_SQL_ERROR,
level=ErrorLevel.ERROR,
extra={"sql": sql, "engine": engine},
)
super().__init__(error)
27 changes: 20 additions & 7 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ColumnNotFoundException,
QueryClauseValidationException,
QueryObjectValidationError,
SupersetParseError,
SupersetSecurityException,
)
from superset.extensions import feature_flag_manager
Expand All @@ -73,6 +74,8 @@
insert_rls_in_predicate,
ParsedQuery,
sanitize_clause,
SQLScript,
SQLStatement,
)
from superset.superset_typing import (
AdhocMetric,
Expand Down Expand Up @@ -901,12 +904,18 @@
return sql

def get_query_str_extended(
self, query_obj: QueryObjectDict, mutate: bool = True
self,
query_obj: QueryObjectDict,
mutate: bool = True,
) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
try:
sql = SQLStatement(sql, engine=self.db_engine_spec.engine).format()
except SupersetParseError:
logger.warning("Unable to parse SQL to format it, passing it as-is")

if mutate:
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
Expand Down Expand Up @@ -1054,7 +1063,8 @@
)

def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> str:
"""
Render sql with template engine (Jinja).
Expand All @@ -1071,13 +1081,16 @@
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:

script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine)
if len(script.statements) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)

sql = script.statements[0].format(comments=False)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))

Check warning on line 1093 in superset/models/helpers.py

View check run for this annotation

Codecov / codecov/patch

superset/models/helpers.py#L1093

Added line #L1093 was not covered by tests
return sql

def text(self, clause: str) -> TextClause:
Expand Down
Loading
Loading