Skip to content

Commit

Permalink
Feat(bigquery): pushdown CTE column names (#1847)
Browse files Browse the repository at this point in the history
* Feat(bigquery): pushdown CTE column names

* Formatting

* Minor cleanup
  • Loading branch information
georgesittas authored Jun 28, 2023
1 parent 898f1a2 commit 95a4b70
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
36 changes: 36 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import re
import typing as t

Expand All @@ -21,6 +22,8 @@
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType

logger = logging.getLogger("sqlglot")


def _date_add_sql(
data_type: str, kind: str
Expand Down Expand Up @@ -104,6 +107,33 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression


def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
"""BigQuery doesn't allow column names when defining a CTE, so we try to push them down."""
if isinstance(expression, exp.CTE) and expression.alias_column_names:
cte_query = expression.this

if cte_query.is_star:
logger.warning(
"Can't push down CTE column names for star queries. Run the query through"
" the optimizer or use 'qualify' to expand the star projections first."
)
return expression

column_names = expression.alias_column_names
expression.args["alias"].set("columns", None)

for name, select in zip(column_names, cte_query.selects):
to_replace = select

if isinstance(select, exp.Alias):
select = select.this

# Inner aliases are shadowed by the CTE column names
to_replace.replace(exp.alias_(select, name))

return expression


class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True

Expand Down Expand Up @@ -309,6 +339,7 @@ class Generator(generator.Generator):
"TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone"))
),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
Expand Down Expand Up @@ -486,6 +517,11 @@ class Generator(generator.Generator):
"within",
}

def cte_sql(self, expression: exp.CTE) -> str:
if expression.alias_column_names:
self.unsupported("Column names in CTE definition are not supported.")
return super().cte_sql(expression)

def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
Expand Down
33 changes: 33 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

from sqlglot import ErrorLevel, ParseError, UnsupportedError, transpile
from tests.dialects.test_dialect import Validator

Expand Down Expand Up @@ -571,3 +573,34 @@ def test_rename_table(self):
"bigquery": "ALTER TABLE db.t1 RENAME TO t2",
},
)

@mock.patch("sqlglot.dialects.bigquery.logger")
def test_pushdown_cte_column_names(self, mock_logger):
with self.assertRaises(UnsupportedError):
transpile(
"WITH cte(foo) AS (SELECT * FROM tbl) SELECT foo FROM cte",
read="spark",
write="bigquery",
unsupported_level=ErrorLevel.RAISE,
)

self.validate_all(
"WITH cte AS (SELECT 1 AS foo) SELECT foo FROM cte",
read={"spark": "WITH cte(foo) AS (SELECT 1) SELECT foo FROM cte"},
)
self.validate_all(
"WITH cte AS (SELECT 1 AS foo) SELECT foo FROM cte",
read={"spark": "WITH cte(foo) AS (SELECT 1 AS bar) SELECT foo FROM cte"},
)
self.validate_all(
"WITH cte AS (SELECT 1 AS bar) SELECT bar FROM cte",
read={"spark": "WITH cte AS (SELECT 1 AS bar) SELECT bar FROM cte"},
)
self.validate_all(
"WITH cte AS (SELECT 1 AS foo, 2) SELECT foo FROM cte",
read={"postgres": "WITH cte(foo) AS (SELECT 1, 2) SELECT foo FROM cte"},
)
self.validate_all(
"WITH cte AS (SELECT 1 AS foo UNION ALL SELECT 2) SELECT foo FROM cte",
read={"postgres": "WITH cte(foo) AS (SELECT 1 UNION ALL SELECT 2) SELECT foo FROM cte"},
)

0 comments on commit 95a4b70

Please sign in to comment.