Skip to content

Commit

Permalink
Feat: builder methods for basic ops (#1516)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao authored May 2, 2023
1 parent 20cacba commit 96bb150
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 26 deletions.
1 change: 1 addition & 0 deletions sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Expression as Expression,
alias_ as alias,
and_ as and_,
coalesce as coalesce,
column as column,
condition as condition,
except_ as except_,
Expand Down
176 changes: 152 additions & 24 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,119 @@ def not_(self):
"""
return not_(self)

def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E:
this = self
other = convert(other)
if not isinstance(this, klass) and not isinstance(other, klass):
this = _wrap(this, Binary)
other = _wrap(other, Binary)
if reverse:
return klass(this=other, expression=this)
return klass(this=this, expression=other)

def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]):
if isinstance(other, slice):
return Between(
this=self,
low=convert(other.start),
high=convert(other.stop),
)
return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)])

def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In:
return In(
this=self,
expressions=[convert(e) for e in expressions],
query=maybe_parse(query, **opts) if query else None,
)

def like(self, other: ExpOrStr) -> Like:
return self._binop(Like, other)

def ilike(self, other: ExpOrStr) -> ILike:
return self._binop(ILike, other)

def eq(self, other: ExpOrStr) -> EQ:
return self._binop(EQ, other)

def neq(self, other: ExpOrStr) -> NEQ:
return self._binop(NEQ, other)

def rlike(self, other: ExpOrStr) -> RegexpLike:
return self._binop(RegexpLike, other)

def __lt__(self, other: ExpOrStr) -> LT:
return self._binop(LT, other)

def __le__(self, other: ExpOrStr) -> LTE:
return self._binop(LTE, other)

def __gt__(self, other: ExpOrStr) -> GT:
return self._binop(GT, other)

def __ge__(self, other: ExpOrStr) -> GTE:
return self._binop(GTE, other)

def __add__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other)

def __radd__(self, other: ExpOrStr) -> Add:
return self._binop(Add, other, reverse=True)

def __sub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other)

def __rsub__(self, other: ExpOrStr) -> Sub:
return self._binop(Sub, other, reverse=True)

def __mul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other)

def __rmul__(self, other: ExpOrStr) -> Mul:
return self._binop(Mul, other, reverse=True)

def __truediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other)

def __rtruediv__(self, other: ExpOrStr) -> Div:
return self._binop(Div, other, reverse=True)

def __floordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other)

def __rfloordiv__(self, other: ExpOrStr) -> IntDiv:
return self._binop(IntDiv, other, reverse=True)

def __mod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other)

def __rmod__(self, other: ExpOrStr) -> Mod:
return self._binop(Mod, other, reverse=True)

def __pow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other)

def __rpow__(self, other: ExpOrStr) -> Pow:
return self._binop(Pow, other, reverse=True)

def __and__(self, other: ExpOrStr) -> And:
return self._binop(And, other)

def __rand__(self, other: ExpOrStr) -> And:
return self._binop(And, other, reverse=True)

def __or__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other)

def __ror__(self, other: ExpOrStr) -> Or:
return self._binop(Or, other, reverse=True)

def __neg__(self) -> Neg:
return Neg(this=_wrap(self, Binary))

def __invert__(self) -> Not:
return not_(self)


class Predicate(Condition):
"""Relationships like x = y, x > 1, x >= y."""
Expand Down Expand Up @@ -3006,7 +3119,7 @@ class DropPartition(Expression):


# Binary expressions like (ADD a b)
class Binary(Expression):
class Binary(Condition):
arg_types = {"this": True, "expression": True}

@property
Expand All @@ -3022,7 +3135,7 @@ class Add(Binary):
pass


class Connector(Binary, Condition):
class Connector(Binary):
pass


Expand Down Expand Up @@ -3184,19 +3297,19 @@ class ArrayOverlaps(Binary):

# Unary Expressions
# (NOT a)
class Unary(Expression):
class Unary(Condition):
pass


class BitwiseNot(Unary):
pass


class Not(Unary, Condition):
class Not(Unary):
pass


class Paren(Unary, Condition):
class Paren(Unary):
arg_types = {"this": True, "with": False}


Expand Down Expand Up @@ -4290,15 +4403,15 @@ def _combine(expressions, operator, dialect=None, **opts):
expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
this = _wrap(this, Connector)
for expression in expressions[1:]:
this = operator(this=this, expression=_wrap_operator(expression))
this = operator(this=this, expression=_wrap(expression, Connector))
return this


def _wrap_operator(expression):
if isinstance(expression, (And, Or, Not)):
expression = Paren(this=expression)
def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren:
if isinstance(expression, kind):
return Paren(this=expression)
return expression


Expand Down Expand Up @@ -4596,7 +4709,7 @@ def not_(expression, dialect=None, **opts) -> Not:
dialect=dialect,
**opts,
)
return Not(this=_wrap_operator(this))
return Not(this=_wrap(this, Connector))


def paren(expression) -> Paren:
Expand Down Expand Up @@ -4838,6 +4951,23 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca
return Cast(this=expression, to=DataType.build(to, **opts))


def coalesce(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Coalesce:
"""Create a coalesce node.
Example:
>>> coalesce('x + 1', '0').sql()
'COALESCE(x + 1, 0)'
Args:
expressions: The expressions to coalesce.
Returns:
A coalesce node.
"""
this, *exprs = [maybe_parse(e, **opts) for e in expressions]
return Coalesce(this=this, expressions=exprs)


def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
Expand Down Expand Up @@ -4956,16 +5086,22 @@ def convert(value) -> Expression:
"""
if isinstance(value, Expression):
return value
if value is None:
return NULL
if isinstance(value, bool):
return Boolean(this=value)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, float) and math.isnan(value):
if isinstance(value, bool):
return Boolean(this=value)
if value is None or (isinstance(value, float) and math.isnan(value)):
return NULL
if isinstance(value, numbers.Number):
return Literal.number(value)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
if isinstance(value, tuple):
return Tuple(expressions=[convert(v) for v in value])
if isinstance(value, list):
Expand All @@ -4975,14 +5111,6 @@ def convert(value) -> Expression:
keys=[convert(k) for k in value],
values=[convert(v) for v in value.values()],
)
if isinstance(value, datetime.datetime):
datetime_literal = Literal.string(
(value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat()
)
return TimeStrToTime(this=datetime_literal)
if isinstance(value, datetime.date):
date_literal = Literal.string(value.strftime("%Y-%m-%d"))
return DateStrToDate(this=date_literal)
raise ValueError(f"Cannot convert {value}")


Expand Down
4 changes: 2 additions & 2 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _annotate_binary(self, expression):
left_type = expression.left.type.this
right_type = expression.right.type.this

if isinstance(expression, (exp.And, exp.Or)):
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
Expand All @@ -347,7 +347,7 @@ def _annotate_binary(self, expression):
)
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(left_type, right_type)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlglot import (
alias,
and_,
coalesce,
condition,
except_,
exp,
Expand All @@ -18,7 +19,55 @@

class TestBuild(unittest.TestCase):
def test_build(self):
x = condition("x")

for expression, sql, *dialect in [
(lambda: x + 1, "x + 1"),
(lambda: 1 + x, "1 + x"),
(lambda: x - 1, "x - 1"),
(lambda: 1 - x, "1 - x"),
(lambda: x * 1, "x * 1"),
(lambda: 1 * x, "1 * x"),
(lambda: x / 1, "x / 1"),
(lambda: 1 / x, "1 / x"),
(lambda: x // 1, "CAST(x / 1 AS INT)"),
(lambda: 1 // x, "CAST(1 / x AS INT)"),
(lambda: x % 1, "x % 1"),
(lambda: 1 % x, "1 % x"),
(lambda: x**1, "POWER(x, 1)"),
(lambda: 1**x, "POWER(1, x)"),
(lambda: x & 1, "x AND 1"),
(lambda: 1 & x, "1 AND x"),
(lambda: x | 1, "x OR 1"),
(lambda: 1 | x, "1 OR x"),
(lambda: x < 1, "x < 1"),
(lambda: 1 < x, "x > 1"),
(lambda: x <= 1, "x <= 1"),
(lambda: 1 <= x, "x >= 1"),
(lambda: x > 1, "x > 1"),
(lambda: 1 > x, "x < 1"),
(lambda: x >= 1, "x >= 1"),
(lambda: 1 >= x, "x <= 1"),
(lambda: x.eq(1), "x = 1"),
(lambda: x.neq(1), "x <> 1"),
(lambda: x.isin(1, "2"), "x IN (1, '2')"),
(lambda: x.isin(query="select 1"), "x IN (SELECT 1)"),
(lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"),
(lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"),
(lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"),
(lambda: 1 + (x * 2) / 3, "1 + ((x * 2) / 3)"),
(lambda: x & "y", "x AND 'y'"),
(lambda: x | "y", "x OR 'y'"),
(lambda: -x, "-x"),
(lambda: ~x, "NOT x"),
(lambda: x[1], "x[1]"),
(lambda: x[1, 2], "x[1, 2]"),
(lambda: x["y"] + 1, "x['y'] + 1"),
(lambda: x.like("y"), "x LIKE 'y'"),
(lambda: x.ilike("y"), "x ILIKE 'y'"),
(lambda: x.rlike("y"), "REGEXP_LIKE(x, 'y')"),
(lambda: coalesce("x", 1), "COALESCE(x, 1)"),
(lambda: select("x"), "SELECT x"),
(lambda: select("x"), "SELECT x"),
(lambda: select("x", "y"), "SELECT x, y"),
(lambda: select("x").from_("tbl"), "SELECT x FROM tbl"),
Expand Down

0 comments on commit 96bb150

Please sign in to comment.