Skip to content

Commit

Permalink
Feat: create builders for the INSERT statement (#1630)
Browse files Browse the repository at this point in the history
* Feat: create builders for the INSERT statement

* Add ability to specify columns

* Add ability to do INSERT OVERWRITE through flag

* Remove unnecessary builders
  • Loading branch information
georgesittas authored May 16, 2023
1 parent bba360c commit c01edb0
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 38 deletions.
160 changes: 122 additions & 38 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,42 @@ class Insert(Expression):
"alternative": False,
}

def with_(
self,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Insert:
"""
Append to or set the common table expressions.
Example:
>>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql()
'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte'
Args:
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified expression.
"""
return _apply_cte_builder(
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
)


class OnConflict(Expression):
arg_types = {
Expand Down Expand Up @@ -2062,14 +2098,14 @@ def named_selects(self):

def with_(
self,
alias,
as_,
recursive=None,
append=True,
dialect=None,
copy=True,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
):
) -> Subqueryable:
"""
Append to or set the common table expressions.
Expand All @@ -2078,43 +2114,22 @@ def with_(
'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
Args:
alias (str | Expression): the SQL code string to parse as the table name.
alias: the SQL code string to parse as the table name.
If an `Expression` instance is passed, this is used as-is.
as_ (str | Expression): the SQL code string to parse as the table expression.
as_: the SQL code string to parse as the table expression.
If an `Expression` instance is passed, it will be used as-is.
recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`.
append (bool): if `True`, add to any existing expressions.
recursive: set the RECURSIVE part of the expression. Defaults to `False`.
append: if `True`, add to any existing expressions.
Otherwise, this resets the expressions.
dialect (str): the dialect used to parse the input expression.
copy (bool): if `False`, modify this expression instance in-place.
opts (kwargs): other options to use to parse the input expressions.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
Select: the modified expression.
The modified expression.
"""
alias_expression = maybe_parse(
alias,
dialect=dialect,
into=TableAlias,
**opts,
)
as_expression = maybe_parse(
as_,
dialect=dialect,
**opts,
)
cte = CTE(
this=as_expression,
alias=alias_expression,
)
return _apply_child_list_builder(
cte,
instance=self,
arg="with",
append=append,
copy=copy,
into=With,
properties={"recursive": recursive or False},
return _apply_cte_builder(
self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts
)


Expand Down Expand Up @@ -4525,6 +4540,30 @@ def _apply_conjunction_builder(
return inst


def _apply_cte_builder(
instance: E,
alias: ExpOrStr,
as_: ExpOrStr,
recursive: t.Optional[bool] = None,
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> E:
alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts)
as_expression = maybe_parse(as_, dialect=dialect, **opts)
cte = CTE(this=as_expression, alias=alias_expression)
return _apply_child_list_builder(
cte,
instance=instance,
arg="with",
append=append,
copy=copy,
into=With,
properties={"recursive": recursive or False},
)


def _combine(expressions, operator, dialect=None, copy=True, **opts):
expressions = [
condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions
Expand Down Expand Up @@ -4742,6 +4781,51 @@ def delete(
return delete_expr


def insert(
expression: ExpOrStr,
into: ExpOrStr,
columns: t.Optional[t.Sequence[ExpOrStr]] = None,
overwrite: t.Optional[bool] = None,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Insert:
"""
Builds an INSERT statement.
Example:
>>> insert("VALUES (1, 2, 3)", "tbl").sql()
'INSERT INTO tbl VALUES (1, 2, 3)'
Args:
expression: the sql string or expression of the INSERT statement
into: the tbl to insert data to.
columns: optionally the table's column names.
overwrite: whether to INSERT OVERWRITE or not.
dialect: the dialect used to parse the input expressions.
copy: whether or not to copy the expression.
**opts: other options to use to parse the input expressions.
Returns:
Insert: the syntax tree for the INSERT statement.
"""
expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts)
this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts)

if columns:
this = _apply_list_builder(
*columns,
instance=Schema(this=this),
arg="expressions",
into=Identifier,
copy=False,
dialect=dialect,
**opts,
)

return Insert(this=this, expression=expr, overwrite=overwrite)


def condition(expression, dialect=None, copy=True, **opts) -> Condition:
"""
Initialize a logical condition expression.
Expand Down
16 changes: 16 additions & 0 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,22 @@ def test_build(self):
"DELETE FROM tbl WHERE x = 1 RETURNING *",
"postgres",
),
(
lambda: exp.insert("SELECT * FROM tbl2", "tbl"),
"INSERT INTO tbl SELECT * FROM tbl2",
),
(
lambda: exp.insert("SELECT * FROM tbl2", "tbl", overwrite=True),
"INSERT OVERWRITE TABLE tbl SELECT * FROM tbl2",
),
(
lambda: exp.insert("VALUES (1, 2), (3, 4)", "tbl", columns=["cola", "colb"]),
"INSERT INTO tbl (cola, colb) VALUES (1, 2), (3, 4)",
),
(
lambda: exp.insert("SELECT * FROM cte", "t").with_("cte", as_="SELECT x FROM tbl"),
"WITH cte AS (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte",
),
(
lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)),
"(x, y) IN ((1, 2), (3, 4))",
Expand Down

0 comments on commit c01edb0

Please sign in to comment.