Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[cherries] Picking a couple of SQL parsing cherries #140

Merged
merged 2 commits into from
Jul 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions superset/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,11 @@ def datasource_access_by_name(

def get_schema_and_table(self, table_in_query, schema):
table_name_pieces = table_in_query.split('.')
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return (table_schema, table_name)
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:])
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces)
return (schema, table_name_pieces[0])

def datasource_access_by_fullname(
self, database, table_in_query, schema):
Expand Down
66 changes: 46 additions & 20 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Optional

import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, Name
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
Expand Down Expand Up @@ -75,32 +77,55 @@ def get_statements(self):
return statements

@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
def __get_full_name(tlist: TokenList) -> Optional[str]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.

:param tlist: The SQL tokens
:returns: The valid full table name
"""

# Strip the alias if present.
idx = len(tlist.tokens)

if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)

if ws_idx != -1:
idx = ws_idx

tokens = tlist.tokens[:idx]

if (
len(tokens) in (1, 3, 5) and
all(imt(token, t=[Name, String]) for token in tokens[0::2]) and
all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2])
):
return '.'.join([remove_quotes(token.value) for token in tokens[0::2]])

return None

@staticmethod
def __is_identifier(token):
def __is_identifier(token: Token):
return isinstance(token, (IdentifierList, Identifier))

def __process_identifier(self, identifier):
def __process_tokenlist(self, tlist: TokenList):
# exclude subselects
if '(' not in str(identifier):
table_name = self.__get_full_name(identifier)
if '(' not in str(tlist):
table_name = self.__get_full_name(tlist)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
return

# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
if tlist.has_alias():
self._alias_names.add(tlist.get_alias())

# some aliases are not parsed properly
if tlist.tokens[0].ttype == Name:
self._alias_names.add(tlist.tokens[0].value)
self.__extract_from_token(tlist)

def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query.
Expand Down Expand Up @@ -144,10 +169,11 @@ def __extract_from_token(self, token, depth=0):

if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_identifier(item)
self.__process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token in item.get_identifiers():
self.__process_identifier(token)
if isinstance(token, TokenList):
self.__process_tokenlist(token)
elif isinstance(item, IdentifierList):
for token in item.tokens:
if not self.__is_identifier(token):
Expand Down
46 changes: 45 additions & 1 deletion tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def test_simple_select(self):
query = 'SELECT * FROM tbname'
self.assertEquals({'tbname'}, self.extract_tables(query))

query = 'SELECT * FROM tbname foo'
self.assertEquals({'tbname'}, self.extract_tables(query))

query = 'SELECT * FROM tbname AS foo'
self.assertEquals({'tbname'}, self.extract_tables(query))

# underscores
query = 'SELECT * FROM tb_name'
self.assertEquals({'tb_name'},
Expand All @@ -47,11 +53,40 @@ def test_simple_select(self):
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname'))

# Ill-defined schema/table.
self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'))

self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname foo'))

self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname AS foo'))

# cluster
self.assertEquals(
{'clustername.schemaname.tbname'},
self.extract_tables('SELECT * FROM clustername.schemaname.tbname'))

# Ill-defined cluster/schema/table.
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM schemaname.'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername.schemaname.'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..tbname'))

# quotes
query = 'SELECT field1, field2 FROM tb_name'
self.assertEquals({'tb_name'}, self.extract_tables(query))
Expand Down Expand Up @@ -462,3 +497,12 @@ def test_messy_breakdown_statements(self):
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)

def test_identifier_list_with_keyword_as_alias(self):
query = """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEquals({'foo'}, self.extract_tables(query))