Skip to content

Commit

Permalink
replace sqlparse.format with sqlglot.transpile
Browse files Browse the repository at this point in the history
  • Loading branch information
dvchristianbors committed Oct 7, 2024
1 parent 62dfeb1 commit 8b1a28e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 59 deletions.
2 changes: 1 addition & 1 deletion superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def validate(
# performance for WHERE ... IN (...) clauses
# Clauses are anyway checked for their validity in
# e.g., connectors/sqla/models/get_query_str_extended
# self._sanitize_filters()
self._sanitize_filters()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
Expand Down
2 changes: 1 addition & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,7 +1900,7 @@ class ExtraDynamicQueryFilters(TypedDict, total=False):
elif importlib.util.find_spec("superset_config") and not is_test():
try:
# pylint: disable=import-error,wildcard-import,unused-wildcard-import
import superset_config
import superset_config as superset_config
from superset_config import * # noqa: F403, F401

click.secho(
Expand Down
7 changes: 4 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

import pandas as pd
import requests
import sqlparse
import sqlglot
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
from deprecation import deprecated
Expand Down Expand Up @@ -1236,7 +1236,7 @@ def get_cte_query(cls, sql: str) -> str | None:
"""
if not cls.allows_cte_in_subquery:
stmt = sqlparse.parse(sql)[0]
stmt = sqlglot.tokenize(sql)

# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
Expand Down Expand Up @@ -2158,7 +2158,8 @@ def cancel_query( # pylint: disable=unused-argument

@classmethod
def parse_sql(cls, sql: str) -> list[str]:
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
return sqlglot.transpile(sql)
# return [str(s).strip(" ;") for s in sqlparse.parse(sql)]

@classmethod
def get_impersonation_key(cls, user: User | None) -> Any:
Expand Down
7 changes: 1 addition & 6 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import ColumnElement, expression, Select
from sqlglot import parse

from superset import app, db_engine_specs, is_feature_enabled
from superset.commands.database.exceptions import DatabaseInvalidError
Expand Down Expand Up @@ -654,11 +653,7 @@ def get_df( # pylint: disable=too-many-locals
schema: str | None = None,
mutator: Callable[[pd.DataFrame], None] | None = None,
) -> pd.DataFrame:
# before we split sqls using sql parse, however this core code is only reachable
# with single sql queries. Thus, we remove the engine spec parser here
# sqls = self.db_engine_spec.parse_sql(sql)
sqls = parse(sql)

sqls = self.db_engine_spec.parse_sql(sql)
with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
engine_url = engine.url

Expand Down
53 changes: 30 additions & 23 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
from collections.abc import Iterator
from typing import Any, cast, TYPE_CHECKING

import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot.dialects.dialect import Dialects
from sqlglot.errors import ParseError
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Expand All @@ -42,7 +44,6 @@
Where,
)
from sqlparse.tokens import (
Comment,
CTE,
DDL,
DML,
Expand Down Expand Up @@ -257,6 +258,7 @@ def __init__(
sql_statement: str,
engine: str = "base",
):
sql_statement = sqlglot.transpile(sql_statement)
self.sql: str = sql_statement
self._engine = engine
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
Expand Down Expand Up @@ -579,30 +581,35 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:

def sanitize_clause(clause: str) -> str:
# clause = sqlparse.format(clause, strip_comments=True)
statements = sqlparse.parse(clause)
try:
statements = sqlglot.transpile(clause, pretty=True)
except Exception as p_err:
if isinstance(p_err, ParseError):
raise QueryClauseValidationException(str(p_err)) from p_err
raise ValueError(str(p_err)) from None
if len(statements) != 1:
raise QueryClauseValidationException("Clause contains multiple statements")
open_parens = 0

previous_token = None
for token in statements[0]:
if token.value == "/" and previous_token and previous_token.value == "*":
raise QueryClauseValidationException("Closing unopened multiline comment")
if token.value == "*" and previous_token and previous_token.value == "/":
raise QueryClauseValidationException("Unclosed multiline comment")
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
previous_token = token
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")

if previous_token and previous_token.ttype in Comment:
if previous_token.value[-1] != "\n":
clause = f"{clause}\n"
# open_parens = 0

# previous_token = None
# for token in statements[0]:
# if token.value == "/" and previous_token and previous_token.value == "*":
# raise QueryClauseValidationException("Closing unopened multiline comment")
# if token.value == "*" and previous_token and previous_token.value == "/":
# raise QueryClauseValidationException("Unclosed multiline comment")
# if token.value in (")", "("):
# open_parens += 1 if token.value == "(" else -1
# if open_parens < 0:
# raise QueryClauseValidationException(
# "Closing unclosed parenthesis in filter clause"
# )
# previous_token = token
# if open_parens > 0:
# raise QueryClauseValidationException("Unclosed parenthesis in filter clause")

# if previous_token and previous_token.ttype in Comment:
# if previous_token.value[-1] != "\n":
# clause = f"{clause}\n"

return clause

Expand Down
40 changes: 15 additions & 25 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from unittest.mock import Mock

import pytest
import sqlglot
import sqlglot
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
Expand Down Expand Up @@ -817,9 +819,7 @@ def test_is_valid_ctas() -> None:
Test if a query is a valid CTAS.
A valid CTAS has a ``SELECT`` as its last statement.
"""
assert ParsedQuery("SELECT * FROM table").is_valid_ctas() is True

assert (
ParsedQuery(
"""
Expand All @@ -828,7 +828,6 @@ def test_is_valid_ctas() -> None:
-- comment 2
""",
).is_valid_ctas()
is True
)
assert (
Expand All @@ -841,7 +840,6 @@ def test_is_valid_ctas() -> None:
""",
).is_valid_ctas()
is True
)
assert (
ParsedQuery(
Expand All @@ -853,7 +851,6 @@ def test_is_valid_ctas() -> None:
).is_valid_ctas()
is False
)

assert (
ParsedQuery(
"""
Expand All @@ -864,7 +861,6 @@ def test_is_valid_ctas() -> None:
is False
)

def test_is_valid_cvas() -> None:
"""
Test if a query is a valid CVAS.
Expand All @@ -876,9 +872,7 @@ def test_is_valid_cvas() -> None:
assert (
ParsedQuery(
"""
-- comment
SELECT * FROM table
-- comment 2
assert ParsedQuery("SELECT * FROM table").is_valid_cvas() is True
""",
).is_valid_cvas()
is True
Expand All @@ -887,7 +881,6 @@ def test_is_valid_cvas() -> None:
assert (
ParsedQuery(
"""
-- comment
SET @value = 42;
SELECT @value as foo;
-- comment 2
Expand All @@ -900,7 +893,6 @@ def test_is_valid_cvas() -> None:
ParsedQuery(
"""
-- comment
EXPLAIN SELECT * FROM table
-- comment 2
""",
).is_valid_cvas()
Expand All @@ -912,7 +904,6 @@ def test_is_valid_cvas() -> None:
"""
SELECT * FROM table;
INSERT INTO TABLE (foo) VALUES (42);
""",
).is_valid_cvas()
is False
)
Expand All @@ -923,7 +914,6 @@ def test_is_select_cte_with_comments() -> None:
Some CTES with comments are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH blah AS
(SELECT * FROM core_dev.manager_team),

blah2 AS
Expand Down Expand Up @@ -1168,28 +1158,28 @@ def test_messy_breakdown_statements() -> None:
]
def test_sqlparse_formatting():
def test_sqlglot_formatting():
"""
Test that ``from_unixtime`` is formatted correctly.
"""
assert sqlparse.format(
assert sqlglot.transpile(
"SELECT extract(HOUR from from_unixtime(hour_ts) "
"AT TIME ZONE 'America/Los_Angeles') from table",
reindent=True,
) == (
"SELECT extract(HOUR\n from from_unixtime(hour_ts) "
"AT TIME ZONE 'America/Los_Angeles')\nfrom table"
)
pretty=True,
)[0] == (
"SELECT\n EXTRACT(HOUR FROM FROM_UNIXTIME(hour_ts) AT TIME ZONE 'America/Los_Angeles')"
"\nFROM table"
def test_sqlglot_formatting():
def test_strip_comments_from_sql() -> None:
"""
assert sqlglot.transpile(
Test that comments are stripped out correctly.
"""
assert (
strip_comments_from_sql("SELECT col1, col2 FROM table1")
== "SELECT col1, col2 FROM table1"
)
pretty=True,
)[0] == (
"SELECT\n EXTRACT(HOUR FROM FROM_UNIXTIME(hour_ts) AT TIME ZONE 'America/Los_Angeles')"
"\nFROM table"
assert (
strip_comments_from_sql("SELECT col1, col2 FROM table1\n-- comment")
== "SELECT col1, col2 FROM table1\n"
Expand Down

0 comments on commit 8b1a28e

Please sign in to comment.