From 96bb150fa3ef3d58cbae8003b3f29f9690cbe481 Mon Sep 17 00:00:00 2001 From: Toby Mao Date: Tue, 2 May 2023 13:00:57 -0700 Subject: [PATCH] Feat: builder methods for basic ops (#1516) --- sqlglot/__init__.py | 1 + sqlglot/expressions.py | 176 ++++++++++++++++++++++++---- sqlglot/optimizer/annotate_types.py | 4 +- tests/test_build.py | 49 ++++++++ 4 files changed, 204 insertions(+), 26 deletions(-) diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 5faafb92f4..156f134574 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -21,6 +21,7 @@ Expression as Expression, alias_ as alias, and_ as and_, + coalesce as coalesce, column as column, condition as condition, except_ as except_, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index cd7bde9d88..00ad10b94b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -701,6 +701,119 @@ def not_(self): """ return not_(self) + def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: + this = self + other = convert(other) + if not isinstance(this, klass) and not isinstance(other, klass): + this = _wrap(this, Binary) + other = _wrap(other, Binary) + if reverse: + return klass(this=other, expression=this) + return klass(this=this, expression=other) + + def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]): + if isinstance(other, slice): + return Between( + this=self, + low=convert(other.start), + high=convert(other.stop), + ) + return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)]) + + def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In: + return In( + this=self, + expressions=[convert(e) for e in expressions], + query=maybe_parse(query, **opts) if query else None, + ) + + def like(self, other: ExpOrStr) -> Like: + return self._binop(Like, other) + + def ilike(self, other: ExpOrStr) -> ILike: + return self._binop(ILike, other) + + def eq(self, other: ExpOrStr) -> EQ: + return self._binop(EQ, other) + + def neq(self, other: ExpOrStr) -> NEQ: + return self._binop(NEQ, other) + + def rlike(self, other: ExpOrStr) -> RegexpLike: + return self._binop(RegexpLike, other) + + def __lt__(self, other: ExpOrStr) -> LT: + return self._binop(LT, other) + + def __le__(self, other: ExpOrStr) -> LTE: + return self._binop(LTE, other) + + def __gt__(self, other: ExpOrStr) -> GT: + return self._binop(GT, other) + + def __ge__(self, other: ExpOrStr) -> GTE: + return self._binop(GTE, other) + + def __add__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other) + + def __radd__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other, reverse=True) + + def __sub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other) + + def __rsub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other, reverse=True) + + def __mul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other) + + def __rmul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other, reverse=True) + + def __truediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other) + + def __rtruediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other, reverse=True) + + def __floordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other) + + def __rfloordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other, reverse=True) + + def __mod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other) + + def __rmod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other, reverse=True) + + def __pow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other) + + def __rpow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other, reverse=True) + + def __and__(self, other: ExpOrStr) -> And: + return self._binop(And, other) + + def __rand__(self, other: ExpOrStr) -> And: + return self._binop(And, other, reverse=True) + + def __or__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other) + + def __ror__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other, reverse=True) + + def __neg__(self) -> Neg: + return Neg(this=_wrap(self, Binary)) + + def __invert__(self) -> Not: + return not_(self) + class Predicate(Condition): """Relationships like x = y, x > 1, x >= y.""" @@ -3006,7 +3119,7 @@ class DropPartition(Expression): # Binary expressions like (ADD a b) -class Binary(Expression): +class Binary(Condition): arg_types = {"this": True, "expression": True} @property @@ -3022,7 +3135,7 @@ class Add(Binary): pass -class Connector(Binary, Condition): +class Connector(Binary): pass @@ -3184,7 +3297,7 @@ class ArrayOverlaps(Binary): # Unary Expressions # (NOT a) -class Unary(Expression): +class Unary(Condition): pass @@ -3192,11 +3305,11 @@ class BitwiseNot(Unary): pass -class Not(Unary, Condition): +class Not(Unary): pass -class Paren(Unary, Condition): +class Paren(Unary): arg_types = {"this": True, "with": False} @@ -4290,15 +4403,15 @@ def _combine(expressions, operator, dialect=None, **opts): expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: - this = _wrap_operator(this) + this = _wrap(this, Connector) for expression in expressions[1:]: - this = operator(this=this, expression=_wrap_operator(expression)) + this = operator(this=this, expression=_wrap(expression, Connector)) return this -def _wrap_operator(expression): - if isinstance(expression, (And, Or, Not)): - expression = Paren(this=expression) +def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: + if isinstance(expression, kind): + return Paren(this=expression) return expression @@ -4596,7 +4709,7 @@ def not_(expression, dialect=None, **opts) -> Not: dialect=dialect, **opts, ) - return Not(this=_wrap_operator(this)) + return Not(this=_wrap(this, Connector)) def paren(expression) -> Paren: @@ -4838,6 +4951,23 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca return Cast(this=expression, to=DataType.build(to, **opts)) +def coalesce(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Coalesce: + """Create a coalesce node. + + Example: + >>> coalesce('x + 1', '0').sql() + 'COALESCE(x + 1, 0)' + + Args: + expressions: The expressions to coalesce. + + Returns: + A coalesce node. + """ + this, *exprs = [maybe_parse(e, **opts) for e in expressions] + return Coalesce(this=this, expressions=exprs) + + def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. @@ -4956,16 +5086,22 @@ def convert(value) -> Expression: """ if isinstance(value, Expression): return value - if value is None: - return NULL - if isinstance(value, bool): - return Boolean(this=value) if isinstance(value, str): return Literal.string(value) - if isinstance(value, float) and math.isnan(value): + if isinstance(value, bool): + return Boolean(this=value) + if value is None or (isinstance(value, float) and math.isnan(value)): return NULL if isinstance(value, numbers.Number): return Literal.number(value) + if isinstance(value, datetime.datetime): + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) + return TimeStrToTime(this=datetime_literal) + if isinstance(value, datetime.date): + date_literal = Literal.string(value.strftime("%Y-%m-%d")) + return DateStrToDate(this=date_literal) if isinstance(value, tuple): return Tuple(expressions=[convert(v) for v in value]) if isinstance(value, list): @@ -4975,14 +5111,6 @@ def convert(value) -> Expression: keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) - if isinstance(value, datetime.datetime): - datetime_literal = Literal.string( - (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() - ) - return TimeStrToTime(this=datetime_literal) - if isinstance(value, datetime.date): - date_literal = Literal.string(value.strftime("%Y-%m-%d")) - return DateStrToDate(this=date_literal) raise ValueError(f"Cannot convert {value}") diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 9d08e20556..623875945f 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -338,7 +338,7 @@ def _annotate_binary(self, expression): left_type = expression.left.type.this right_type = expression.right.type.this - if isinstance(expression, (exp.And, exp.Or)): + if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): @@ -347,7 +347,7 @@ def _annotate_binary(self, expression): ) else: expression.type = exp.DataType.Type.BOOLEAN - elif isinstance(expression, (exp.Condition, exp.Predicate)): + elif isinstance(expression, exp.Predicate): expression.type = exp.DataType.Type.BOOLEAN else: expression.type = self._maybe_coerce(left_type, right_type) diff --git a/tests/test_build.py b/tests/test_build.py index 43707b0347..3f469a5598 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -3,6 +3,7 @@ from sqlglot import ( alias, and_, + coalesce, condition, except_, exp, @@ -18,7 +19,55 @@ class TestBuild(unittest.TestCase): def test_build(self): + x = condition("x") + for expression, sql, *dialect in [ + (lambda: x + 1, "x + 1"), + (lambda: 1 + x, "1 + x"), + (lambda: x - 1, "x - 1"), + (lambda: 1 - x, "1 - x"), + (lambda: x * 1, "x * 1"), + (lambda: 1 * x, "1 * x"), + (lambda: x / 1, "x / 1"), + (lambda: 1 / x, "1 / x"), + (lambda: x // 1, "CAST(x / 1 AS INT)"), + (lambda: 1 // x, "CAST(1 / x AS INT)"), + (lambda: x % 1, "x % 1"), + (lambda: 1 % x, "1 % x"), + (lambda: x**1, "POWER(x, 1)"), + (lambda: 1**x, "POWER(1, x)"), + (lambda: x & 1, "x AND 1"), + (lambda: 1 & x, "1 AND x"), + (lambda: x | 1, "x OR 1"), + (lambda: 1 | x, "1 OR x"), + (lambda: x < 1, "x < 1"), + (lambda: 1 < x, "x > 1"), + (lambda: x <= 1, "x <= 1"), + (lambda: 1 <= x, "x >= 1"), + (lambda: x > 1, "x > 1"), + (lambda: 1 > x, "x < 1"), + (lambda: x >= 1, "x >= 1"), + (lambda: 1 >= x, "x <= 1"), + (lambda: x.eq(1), "x = 1"), + (lambda: x.neq(1), "x <> 1"), + (lambda: x.isin(1, "2"), "x IN (1, '2')"), + (lambda: x.isin(query="select 1"), "x IN (SELECT 1)"), + (lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"), + (lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"), + (lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"), + (lambda: 1 + (x * 2) / 3, "1 + ((x * 2) / 3)"), + (lambda: x & "y", "x AND 'y'"), + (lambda: x | "y", "x OR 'y'"), + (lambda: -x, "-x"), + (lambda: ~x, "NOT x"), + (lambda: x[1], "x[1]"), + (lambda: x[1, 2], "x[1, 2]"), + (lambda: x["y"] + 1, "x['y'] + 1"), + (lambda: x.like("y"), "x LIKE 'y'"), + (lambda: x.ilike("y"), "x ILIKE 'y'"), + (lambda: x.rlike("y"), "REGEXP_LIKE(x, 'y')"), + (lambda: coalesce("x", 1), "COALESCE(x, 1)"), + (lambda: select("x"), "SELECT x"), (lambda: select("x"), "SELECT x"), (lambda: select("x", "y"), "SELECT x, y"), (lambda: select("x").from_("tbl"), "SELECT x FROM tbl"),