Skip to content

Commit

Permalink
[fix] SQL parsing of table names (#7490)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley authored Jun 3, 2019
1 parent 78c1674 commit 45b41aa
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
12 changes: 5 additions & 7 deletions superset/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,11 @@ def datasource_access_by_name(

def get_schema_and_table(self, table_in_query, schema):
table_name_pieces = table_in_query.split('.')
if len(table_name_pieces) == 2:
table_schema = table_name_pieces[0]
table_name = table_name_pieces[1]
else:
table_schema = schema
table_name = table_name_pieces[0]
return (table_schema, table_name)
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:])
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces)
return (schema, table_name_pieces[0])

def datasource_access_by_fullname(
self, database, table_in_query, schema):
Expand Down
39 changes: 32 additions & 7 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.
# pylint: disable=C,R,W
import logging
from typing import Optional

import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
ON_KEYWORD = 'ON'
Expand Down Expand Up @@ -75,11 +77,34 @@ def get_statements(self):
return statements

@staticmethod
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()
def __get_full_name(tlist: TokenList) -> Optional[str]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
:param tlist: The SQL tokens
:returns: The valid full table name
"""

# Strip the alias if present.
idx = len(tlist.tokens)

if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)

if ws_idx != -1:
idx = ws_idx

tokens = tlist.tokens[:idx]

if (
len(tokens) in (1, 3, 5) and
all(imt(token, t=[Name, String]) for token in tokens[0::2]) and
all(imt(token, m=(Punctuation, '.')) for token in tokens[1::2])
):
return '.'.join([remove_quotes(token.value) for token in tokens[0::2]])

return None

@staticmethod
def __is_identifier(token: Token):
Expand Down
37 changes: 36 additions & 1 deletion tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def test_simple_select(self):
query = 'SELECT * FROM tbname'
self.assertEquals({'tbname'}, self.extract_tables(query))

query = 'SELECT * FROM tbname foo'
self.assertEquals({'tbname'}, self.extract_tables(query))

query = 'SELECT * FROM tbname AS foo'
self.assertEquals({'tbname'}, self.extract_tables(query))

# underscores
query = 'SELECT * FROM tb_name'
self.assertEquals({'tb_name'},
Expand All @@ -47,11 +53,40 @@ def test_simple_select(self):
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname'))

# Ill-defined schema/table.
self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'))

self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname foo'))

self.assertEquals(
{'schemaname.tbname'},
self.extract_tables('SELECT * FROM schemaname.tbname AS foo'))

# cluster
self.assertEquals(
{'clustername.schemaname.tbname'},
self.extract_tables('SELECT * FROM clustername.schemaname.tbname'))

# Ill-defined cluster/schema/table.
self.assertEquals(
set(),
self.extract_tables('SELECT * FROM schemaname.'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername.schemaname.'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..'))

self.assertEquals(
set(),
self.extract_tables('SELECT * FROM clustername..tbname'))

# quotes
query = 'SELECT field1, field2 FROM tb_name'
self.assertEquals({'tb_name'}, self.extract_tables(query))
Expand Down

0 comments on commit 45b41aa

Please sign in to comment.