Skip to content

Commit

Permalink
fix(sql_parse): Ensure table extraction handles Jinja templating
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Mar 12, 2024
1 parent b1adede commit 88ab840
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 15 deletions.
10 changes: 5 additions & 5 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dateutil
from flask import current_app, has_request_context, request
from flask_babel import gettext as _
from jinja2 import DebugUndefined
from jinja2 import DebugUndefined, Environment
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
Expand Down Expand Up @@ -479,11 +479,11 @@ def __init__(
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs)

# custom filters
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
self.env.filters["where_in"] = WhereInMacro(database.get_dialect())

def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
Expand All @@ -496,7 +496,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

context = validate_template_context(self.engine, kwargs)
Expand Down Expand Up @@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
engine = "trino"

def process_template(self, sql: str, **kwargs: Any) -> str:
template = self._env.from_string(sql)
template = self.env.from_string(sql)
kwargs.update(self._context)

# Backwards compatibility if migrating from Presto.
Expand Down
61 changes: 52 additions & 9 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from unittest.mock import Mock

import sqlparse
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
Expand Down Expand Up @@ -283,20 +285,61 @@ def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Due to Jinja templating a multiphase approach is necessary as the referenced SQL
statement is likely ill-defined (due to the presence of the Jinja macros) and
thus non-parsable by SQLGlot.
Firstly, we extract any tables referenced with the confines of specific Jinja
macros. Secondly, we replace these non-SQL Jinja calls with benign SQL to try
to ensure that the resulting SQL statements are parsable by SQLGlot.
Note: this uses sqlglot, since it's better at catching more edge cases.
"""

from superset.jinja_context import get_template_processor # pylint: disable=import-outside-toplevel

tables = set()
sql = self.stripped()

# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=self._dialect))
template = processor.env.parse(sql)

for node in template.find_all(nodes.Call):
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part)
for part in node.args[0].value.split(".")[::-1]
]
)
)

# Replace the potentially problematic macro with benign SQL.
node.__class__ = nodes.TemplateData
node.fields = nodes.TemplateData.fields
node.data = "NULL"

sql = processor.process_template(template)

try:
statements = parse(self.stripped(), dialect=self._dialect)
tables.update(
[
table
for statement in parse(sql, dialect=self._dialect)
for table in self._extract_tables_from_statement(statement)
if statement
]
)
except SqlglotError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()
logger.warning("Unable to parse SQL (%s): %s", self._dialect, sql)

return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
if statement
}
return tables

def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Expand Down
2 changes: 1 addition & 1 deletion superset/sqllab/query_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _validate(
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access
syntax_tree = sql_template_processor._env.parse(rendered_query)
syntax_tree = sql_template_processor.env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(syntax_tree)
if undefined_parameters:
self._raise_undefined_parameter_exception(
Expand Down
37 changes: 37 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,43 @@ def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
)


@pytest.mark.parametrize(
"engine",
[
"hive",
"presto",
"trino",
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
),
],
)
def test_extract_tables_jinja(
engine: str,
macro: str,
sql: str,
expected: set[Table],
) -> None:
assert extract_tables(sql.format(engine=engine, macro=macro), engine) == expected


def test_update() -> None:
"""
Test that ``UPDATE`` is not detected as ``SELECT``.
Expand Down

0 comments on commit 88ab840

Please sign in to comment.