Skip to content

Commit

Permalink
Refactor!: use a dictionary for query modifier search
Browse files Browse the repository at this point in the history
tobymao committed Jul 3, 2023

Verified

This commit was signed with the committer’s verified signature.
Keith-CY Chen Yu
1 parent fe69102 commit df4448d
Showing 5 changed files with 56 additions and 43 deletions.
9 changes: 5 additions & 4 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
@@ -109,10 +109,11 @@ class Parser(parser.Parser):

QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"settings": lambda self: self._parse_csv(self._parse_conjunction)
if self._match(TokenType.SETTINGS)
else None,
"format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None,
TokenType.SETTINGS: (
"settings",
lambda self: self._advance() or self._parse_csv(self._parse_conjunction),
),
TokenType.FORMAT: ("format", lambda self: self._advance() or self._parse_id_var()),
}

def _parse_conjunction(self) -> t.Optional[exp.Expression]:
14 changes: 0 additions & 14 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
@@ -273,13 +273,6 @@ class Parser(parser.Parser):
),
}

QUERY_MODIFIER_PARSERS = {
**parser.Parser.QUERY_MODIFIER_PARSERS,
"cluster": lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
"distribute": lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
"sort": lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY),
}

def _parse_types(
self, check_func: bool = False, schema: bool = False
) -> t.Optional[exp.Expression]:
@@ -429,10 +422,3 @@ def datatype_sql(self, expression: exp.DataType) -> str:
expression = exp.DataType.build(expression.this)

return super().datatype_sql(expression)

def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_having_modifiers(expression) + [
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]
8 changes: 7 additions & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
@@ -491,7 +491,10 @@ def sql(
return expression

if key:
return self.sql(expression.args.get(key))
value = expression.args.get(key)
if value:
return self.sql(value)
return ""

if self._cache is not None:
expression_id = hash(expression)
@@ -1600,6 +1603,9 @@ def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]

def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
67 changes: 43 additions & 24 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
@@ -737,19 +737,29 @@ class Parser(metaclass=_Parser):
}

QUERY_MODIFIER_PARSERS = {
"joins": lambda self: list(iter(self._parse_join, None)),
"laterals": lambda self: list(iter(self._parse_lateral, None)),
"match": lambda self: self._parse_match_recognize(),
"where": lambda self: self._parse_where(),
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
"windows": lambda self: self._parse_window_clause(),
"order": lambda self: self._parse_order(),
"limit": lambda self: self._parse_limit(),
"offset": lambda self: self._parse_offset(),
"locks": lambda self: self._parse_locks(),
"sample": lambda self: self._parse_table_sample(as_modifier=True),
TokenType.MATCH_RECOGNIZE: ("match", lambda self: self._parse_match_recognize()),
TokenType.WHERE: ("where", lambda self: self._parse_where()),
TokenType.GROUP_BY: ("group", lambda self: self._parse_group()),
TokenType.HAVING: ("having", lambda self: self._parse_having()),
TokenType.QUALIFY: ("qualify", lambda self: self._parse_qualify()),
TokenType.WINDOW: ("windows", lambda self: self._parse_window_clause()),
TokenType.ORDER_BY: ("order", lambda self: self._parse_order()),
TokenType.LIMIT: ("limit", lambda self: self._parse_limit()),
TokenType.FETCH: ("limit", lambda self: self._parse_limit()),
TokenType.OFFSET: ("offset", lambda self: self._parse_offset()),
TokenType.FOR: ("locks", lambda self: self._parse_locks()),
TokenType.LOCK: ("locks", lambda self: self._parse_locks()),
TokenType.TABLE_SAMPLE: ("sample", lambda self: self._parse_table_sample(as_modifier=True)),
TokenType.USING: ("sample", lambda self: self._parse_table_sample(as_modifier=True)),
TokenType.CLUSTER_BY: (
"cluster",
lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY),
),
TokenType.DISTRIBUTE_BY: (
"distribute",
lambda self: self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY),
),
TokenType.SORT_BY: ("sort", lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY)),
}

SET_PARSERS = {
@@ -2037,15 +2047,24 @@ def _parse_query_modifiers(
self, this: t.Optional[exp.Expression]
) -> t.Optional[exp.Expression]:
if isinstance(this, self.MODIFIABLES):
for key, parser in self.QUERY_MODIFIER_PARSERS.items():
expression = parser(self)

if expression:
if key == "limit":
offset = expression.args.pop("offset", None)
if offset:
this.set("offset", exp.Offset(expression=offset))
this.set(key, expression)
for join in iter(self._parse_join, None):
this.append("joins", join)
for lateral in iter(self._parse_lateral, None):
this.append("laterals", lateral)

while True:
if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False):
key, parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type]
expression = parser(self)

if expression:
this.set(key, expression)
if key == "limit":
offset = expression.args.pop("offset", None)
if offset:
this.set("offset", exp.Offset(expression=offset))
continue
break
return this

def _parse_hint(self) -> t.Optional[exp.Hint]:
@@ -2508,8 +2527,8 @@ def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Table
kind=kind,
)

def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
return list(iter(self._parse_pivot, None)) or None

# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
1 change: 1 addition & 0 deletions tests/fixtures/identity.sql
Original file line number Diff line number Diff line change
@@ -794,6 +794,7 @@ ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla
SELECT partition FROM a
SELECT end FROM a
SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1
SELECT * FROM x WHERE a GROUP BY a HAVING b SORT BY s ORDER BY c LIMIT d
SELECT LEFT.FOO FROM BLA AS LEFT
SELECT RIGHT.FOO FROM BLA AS RIGHT
SELECT LEFT FROM LEFT LEFT JOIN RIGHT RIGHT JOIN LEFT

0 comments on commit df4448d

Please sign in to comment.