Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: make some SQL builders pure #1526

Merged
merged 11 commits into from
May 4, 2023
120 changes: 72 additions & 48 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def load(cls, obj):


class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
def and_(self, *expressions, dialect=None, copy=True, **opts):
"""
AND this condition with one or multiple expressions.

Expand All @@ -662,14 +662,15 @@ def and_(self, *expressions, dialect=None, **opts):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.

Returns:
And: the new condition.
"""
return and_(self, *expressions, dialect=dialect, **opts)
return and_(self, *expressions, dialect=dialect, copy=copy, **opts)

def or_(self, *expressions, dialect=None, **opts):
def or_(self, *expressions, dialect=None, copy=True, **opts):
"""
OR this condition with one or multiple expressions.

Expand All @@ -681,50 +682,59 @@ def or_(self, *expressions, dialect=None, **opts):
*expressions (str | Expression): the SQL code strings to parse.
If an `Expression` instance is passed, it will be used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy the involved expressions (only applies to Expressions).
opts (kwargs): other options to use to parse the input expressions.

Returns:
Or: the new condition.
"""
return or_(self, *expressions, dialect=dialect, **opts)
return or_(self, *expressions, dialect=dialect, copy=copy, **opts)

def not_(self):
def not_(self, copy=True):
"""
Wrap this condition with NOT.

Example:
>>> condition("x=1").not_().sql()
'NOT x = 1'

Args:
copy (bool): whether or not to copy this object.

Returns:
Not: the new condition.
"""
return not_(self)
return not_(self, copy=copy)

def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
this = self.copy()
other = convert(other, copy=True)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)

def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])
def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]):
return Bracket(
this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)]
)

def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
def isin(
self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts
) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
this=_maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
)

def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between:
return Between(
this=_maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)

def like(self, other: ExpOrStr) -> Like:
Expand Down Expand Up @@ -809,10 +819,10 @@ def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)

def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))
return Neg(this=_wrap(self.copy(), Binary))
georgesittas marked this conversation as resolved.
Show resolved Hide resolved

def __invert__(self) -> Not:
return not_(self)
return not_(self.copy())


class Predicate(Condition):
Expand Down Expand Up @@ -2611,7 +2621,7 @@ def join(
join.set("kind", kind.text)

if on:
on = and_(*ensure_collection(on), dialect=dialect, **opts)
on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts)
join.set("on", on)

if using:
Expand Down Expand Up @@ -3538,14 +3548,20 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}

def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts)))
return this
instance = _maybe_copy(self, copy)
instance.append(
"ifs",
If(
this=maybe_parse(condition, copy=copy, **opts),
true=maybe_parse(then, copy=copy, **opts),
),
)
return instance

def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
this = self.copy() if copy else self
this.set("default", maybe_parse(condition, **opts))
return this
instance = _maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance


class Cast(Func):
Expand Down Expand Up @@ -4405,14 +4421,16 @@ def _apply_conjunction_builder(
if append and existing is not None:
expressions = [existing.this if into else existing] + list(expressions)

node = and_(*expressions, dialect=dialect, **opts)
node = and_(*expressions, dialect=dialect, copy=copy, **opts)

inst.set(arg, into(this=node) if into else node)
return inst


def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
def _combine(expressions, operator, dialect=None, copy=True, **opts):
expressions = [
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
]
this = expressions[0]
if expressions[1:]:
this = _wrap(this, Connector)
Expand Down Expand Up @@ -4626,7 +4644,7 @@ def delete(
return delete_expr


def condition(expression, dialect=None, **opts) -> Condition:
def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.

Expand All @@ -4645,6 +4663,7 @@ def condition(expression, dialect=None, **opts) -> Condition:
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression (in the case that the
input expression is a SQL string).
copy (bool): Whether or not to copy `expression` (only applies to expressions).
**opts: other options to use to parse the input expressions (again, in the case
that the input expression is a SQL string).

Expand All @@ -4655,11 +4674,12 @@ def condition(expression, dialect=None, **opts) -> Condition:
expression,
into=Condition,
dialect=dialect,
copy=copy,
**opts,
)


def and_(*expressions, dialect=None, **opts) -> And:
def and_(*expressions, dialect=None, copy=True, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.

Expand All @@ -4671,15 +4691,16 @@ def and_(*expressions, dialect=None, **opts) -> And:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.

Returns:
And: the new condition
"""
return _combine(expressions, And, dialect, **opts)
return _combine(expressions, And, dialect, copy=copy, **opts)


def or_(*expressions, dialect=None, **opts) -> Or:
def or_(*expressions, dialect=None, copy=True, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.

Expand All @@ -4691,15 +4712,16 @@ def or_(*expressions, dialect=None, **opts) -> Or:
*expressions (str | Expression): the SQL code strings to parse.
If an Expression instance is passed, this is used as-is.
dialect (str): the dialect used to parse the input expression.
copy (bool): whether or not to copy `expressions` (only applies to Expressions).
**opts: other options to use to parse the input expressions.

Returns:
Or: the new condition
"""
return _combine(expressions, Or, dialect, **opts)
return _combine(expressions, Or, dialect, copy=copy, **opts)


def not_(expression, dialect=None, **opts) -> Not:
def not_(expression, dialect=None, copy=True, **opts) -> Not:
"""
Wrap a condition with a NOT operator.

Expand All @@ -4719,13 +4741,14 @@ def not_(expression, dialect=None, **opts) -> Not:
this = condition(
expression,
dialect=dialect,
copy=copy,
**opts,
)
return Not(this=_wrap(this, Connector))


def paren(expression) -> Paren:
return Paren(this=expression)
def paren(expression, copy=True) -> Paren:
return Paren(this=_maybe_copy(expression, copy))


SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$")
Expand Down Expand Up @@ -5068,19 +5091,20 @@ def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
)


def convert(value) -> Expression:
def convert(value: t.Any, copy: bool = False) -> Expression:
"""Convert a python value into an expression object.

Raises an error if a conversion is not possible.

Args:
value (Any): a python object
value: A python object.
copy: Whether or not to copy `value` (only applies to Expressions and collections).

Returns:
Expression: the equivalent expression object
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
return value
return _maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):
Expand All @@ -5098,13 +5122,13 @@ def convert(value) -> Expression:
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
return Tuple(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, list):
return Array(expressions=[convert(v) for v in value])
return Array(expressions=[convert(v, copy=copy) for v in value])
if isinstance(value, dict):
return Map(
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
keys=[convert(k, copy=copy) for k in value],
values=[convert(v, copy=copy) for v in value.values()],
)
raise ValueError(f"Cannot convert {value}")

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/eliminate_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def extract_condition(condition):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False)

for condition in on.flatten():
if isinstance(condition, exp.EQ):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _expand_using(scope, resolver):
tables[join_table] = None

join.args.pop("using")
join.set("on", exp.and_(*conditions))
join.set("on", exp.and_(*conditions, copy=False))

if column_tables:
for column in scope.columns:
Expand Down
1 change: 1 addition & 0 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def rewrite_between(expression: exp.Expression) -> exp.Expression:
return exp.and_(
exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
copy=False,
)
return expression

Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3584,7 +3584,9 @@ def _parse_decode(self) -> t.Optional[exp.Expression]:
exp.and_(
exp.Is(this=expression.copy(), expression=exp.Null()),
exp.Is(this=search.copy(), expression=exp.Null()),
copy=False,
),
copy=False,
)
ifs.append(exp.If(this=cond, true=result))

Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6385,6 +6385,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
Expand Down Expand Up @@ -7589,6 +7597,14 @@ WITH "tmp1" AS (
"item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1')
OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
)
AND (
"item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')
OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants')
)
AND (
"item"."i_category" IN ('Books', 'Children', 'Electronics')
OR "item"."i_category" IN ('Women', 'Music', 'Men')
Expand Down
6 changes: 6 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")
x_plus_one = x + 1

# Make sure we're not mutating x by changing its parent to be x_plus_one
self.assertIsNone(x.parent)
self.assertNotEqual(id(x_plus_one.this), id(x))

for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
Expand Down Expand Up @@ -51,6 +56,7 @@ def test_build(self):
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
(lambda: x.between(1, 2), "x BETWEEN 1 AND 2"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
Expand Down