Skip to content

Commit

Permalink
Fix: allow optimizer to handle non unionables
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 25, 2023
1 parent eb6eaf5 commit 964b04c
Showing 8 changed files with 65 additions and 44 deletions.
4 changes: 4 additions & 0 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from dataclasses import dataclass, field

from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.qualify_columns import qualify_columns
@@ -71,6 +72,9 @@ def lineage(
optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)

if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")

def to_node(
column_name: str,
scope: Scope,
39 changes: 20 additions & 19 deletions sqlglot/optimizer/eliminate_ctes.py
Original file line number Diff line number Diff line change
@@ -19,24 +19,25 @@ def eliminate_ctes(expression):
"""
root = build_scope(expression)

ref_count = root.ref_count()

# Traverse the scope tree in reverse so we can remove chains of unused CTEs
for scope in reversed(list(root.traverse())):
if scope.is_cte:
count = ref_count[id(scope)]
if count <= 0:
cte_node = scope.expression.parent
with_node = cte_node.parent
cte_node.pop()

# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
with_node.pop()

# Decrement the ref count for all sources this CTE selects from
for _, source in scope.selected_sources.values():
if isinstance(source, Scope):
ref_count[id(source)] -= 1
if root:
ref_count = root.ref_count()

# Traverse the scope tree in reverse so we can remove chains of unused CTEs
for scope in reversed(list(root.traverse())):
if scope.is_cte:
count = ref_count[id(scope)]
if count <= 0:
cte_node = scope.expression.parent
with_node = cte_node.parent
cte_node.pop()

# Pop the entire WITH clause if this is the last CTE
if len(with_node.expressions) <= 0:
with_node.pop()

# Decrement the ref count for all sources this CTE selects from
for _, source in scope.selected_sources.values():
if isinstance(source, Scope):
ref_count[id(source)] -= 1

return expression
3 changes: 3 additions & 0 deletions sqlglot/optimizer/eliminate_subqueries.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,9 @@ def eliminate_subqueries(expression):

root = build_scope(expression)

if not root:
return expression

# Map of alias->Scope|Table
# These are all aliases that are already used in the expression.
# We don't want to create new CTEs that conflict with these names.
42 changes: 22 additions & 20 deletions sqlglot/optimizer/pushdown_predicates.py
Original file line number Diff line number Diff line change
@@ -21,26 +21,28 @@ def pushdown_predicates(expression):
sqlglot.Expression: optimized expression
"""
root = build_scope(expression)
scope_ref_count = root.ref_count()

for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)

# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)

if root:
scope_ref_count = root.ref_count()

for scope in reversed(list(root.traverse())):
select = scope.expression
where = select.args.get("where")
if where:
selected_sources = scope.selected_sources
# a right join can only push down to itself and not the source FROM table
for k, (node, source) in selected_sources.items():
parent = node.find_ancestor(exp.Join, exp.From)
if isinstance(parent, exp.Join) and parent.side == "RIGHT":
selected_sources = {k: (node, source)}
break
pushdown(where.this, selected_sources, scope_ref_count)

# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
name = join.this.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)

return expression

12 changes: 9 additions & 3 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import typing as t
from collections import defaultdict
from enum import Enum, auto

@@ -477,7 +478,7 @@ def ref_count(self):
return scope_ref_count


def traverse_scope(expression):
def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
"""
Traverse an expression by it's "scopes".
@@ -502,10 +503,12 @@ def traverse_scope(expression):
Returns:
list[Scope]: scope instances
"""
if not isinstance(expression, exp.Unionable):
return []
return list(_traverse_scope(Scope(expression)))


def build_scope(expression):
def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
"""
Build a scope tree.
@@ -514,7 +517,10 @@ def build_scope(expression):
Returns:
Scope: root scope
"""
return traverse_scope(expression)[-1]
scopes = traverse_scope(expression)
if scopes:
return scopes[-1]
return None


def _traverse_scope(scope):
5 changes: 4 additions & 1 deletion sqlglot/transforms.py
Original file line number Diff line number Diff line change
@@ -161,7 +161,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import build_scope

taken_select_names = set(expression.named_selects)
taken_source_names = set(build_scope(expression).selected_sources)
scope = build_scope(expression)
if not scope:
return expression
taken_source_names = set(scope.selected_sources)

for select in expression.selects:
to_replace = select
2 changes: 1 addition & 1 deletion tests/fixtures/optimizer/qualify_columns__invalid.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SELECT z.a FROM x;
SELECT z.* FROM x;
SELECT x FROM x;
INSERT INTO x VALUES (1, 2);
SELECT x FROM VALUES (1, 2);
SELECT a FROM x AS z JOIN y AS z;
SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c);
SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a;
2 changes: 2 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -141,6 +141,8 @@ def check_file(self, file, func, pretty=False, execute=False, **kwargs):
assert_frame_equal(df1, df2)

def test_optimize(self):
self.assertEqual(optimizer.optimize("x = 1 + 1", identify=None).sql(), "x = 2")

schema = {
"x": {"a": "INT", "b": "INT"},
"y": {"b": "INT", "c": "INT"},

0 comments on commit 964b04c

Please sign in to comment.