From b57540141076526943644d5a0e0d0d56cf860f52 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 24 May 2023 14:37:59 +0300 Subject: [PATCH 1/6] Fix: allow >3 table parts in exp.to_table --- sqlglot/expressions.py | 8 +++++++- tests/dialects/test_bigquery.py | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 620578cd10..46f5f67186 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5038,7 +5038,13 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) + catalog, db, table_name, *rest = ( + t.cast(Expression, to_identifier(x)) for x in split_num_words(sql_path, ".", 3) + ) + + if rest: + table_name = Dot.build([table_name, *rest]) + return Table(this=table_name, db=db, catalog=catalog, **kwargs) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 8db7643be2..f305a9a9d3 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -37,6 +37,12 @@ def test_bigquery(self): "CREATE TEMP TABLE foo AS SELECT 1", write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, ) + self.validate_all( + "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`", + write={ + "bigquery": "SELECT * FROM SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW", + }, + ) self.validate_all( "SELECT * FROM `my-project.my-dataset.my-table`", write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"}, From cb53907ce3ed2b8226327397c9ace57589442558 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 24 May 2023 14:46:27 +0300 Subject: [PATCH 2/6] Fix mypy types --- sqlglot/expressions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 46f5f67186..8c9d96a6bd 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5039,11 +5039,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") catalog, db, table_name, *rest = ( - t.cast(Expression, to_identifier(x)) for x in split_num_words(sql_path, ".", 3) + t.cast(t.Optional[Expression], to_identifier(x)) for x in split_num_words(sql_path, ".", 3) ) - if rest: - table_name = Dot.build([table_name, *rest]) + if rest and table_name: + table_name = Dot.build(t.cast(t.List[Expression], [table_name, *rest])) return Table(this=table_name, db=db, catalog=catalog, **kwargs) From c3fa460faa17fe5ea5f4bf85fdcf1d7f7494c8ae Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 24 May 2023 20:47:08 +0300 Subject: [PATCH 3/6] Keep parity between to_table and _parse_table by using maybe_parse --- sqlglot/dialects/bigquery.py | 8 ++++++++ sqlglot/expressions.py | 21 +++++++++++---------- tests/test_expressions.py | 4 ---- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 369298def2..c1ed03aa61 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -231,9 +231,17 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: return this def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def quote_unsafe_table_part(part: exp.Expression) -> exp.Expression: + if isinstance(part, exp.Identifier): + part.set("quoted", not exp.SAFE_IDENTIFIER_RE.match(part.name)) + return part + table = super()._parse_table_parts(schema=schema) if isinstance(table.this, exp.Identifier) and "." in table.name: table = exp.to_table(table.name, dialect="bigquery") + for part in table.parts: + part.transform(quote_unsafe_table_part, copy=False) + return table class Generator(generator.Generator): diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8c9d96a6bd..f520bee616 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -28,7 +28,6 @@ ensure_collection, ensure_list, seq_get, - split_num_words, subclasses, ) from sqlglot.tokens import Token @@ -2188,7 +2187,7 @@ def catalog(self) -> str: @property def parts(self) -> t.List[Identifier]: - """Return the parts of a column in order catalog, db, table.""" + """Return the parts of a table in order catalog, db, table.""" return [ t.cast(Identifier, self.args[part]) for part in ("catalog", "db", "this") @@ -5022,13 +5021,17 @@ def to_table(sql_path: None, **kwargs) -> None: ... -def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: +def to_table( + sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs +) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. Args: sql_path: a `[catalog].[schema].[table]` string. + dialect: the source dialect according to which the table name will be parsed. + kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: A table expression. @@ -5038,14 +5041,12 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name, *rest = ( - t.cast(t.Optional[Expression], to_identifier(x)) for x in split_num_words(sql_path, ".", 3) - ) - - if rest and table_name: - table_name = Dot.build(t.cast(t.List[Expression], [table_name, *rest])) + table = maybe_parse(sql_path, into=Table, dialect=dialect) + if table: + for k, v in kwargs.items(): + table.set(k, v) - return Table(this=table_name, db=db, catalog=catalog, **kwargs) + return table def to_column(sql_path: str | Column, **kwargs) -> Column: diff --git a/tests/test_expressions.py b/tests/test_expressions.py index d9d327811a..7735e78608 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -726,10 +726,6 @@ def test_to_table(self): self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) with self.assertRaises(ValueError): exp.to_table(1) - empty_string = exp.to_table("") - self.assertEqual(empty_string.name, "") - self.assertIsNone(table_only.args.get("db")) - self.assertIsNone(table_only.args.get("catalog")) def test_to_column(self): column_only = exp.to_column("column_name") From 425fd07f75e0f96482870ed99d700a4865715e85 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Wed, 24 May 2023 21:16:27 +0300 Subject: [PATCH 4/6] Refactor bigquery _parse_table_parts --- sqlglot/dialects/bigquery.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index c1ed03aa61..e7278251d0 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -18,7 +18,7 @@ timestrtotime_sql, ts_or_ds_to_date_sql, ) -from sqlglot.helper import seq_get +from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType E = t.TypeVar("E", bound=exp.Expression) @@ -231,16 +231,17 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: return this def _parse_table_parts(self, schema: bool = False) -> exp.Expression: - def quote_unsafe_table_part(part: exp.Expression) -> exp.Expression: - if isinstance(part, exp.Identifier): - part.set("quoted", not exp.SAFE_IDENTIFIER_RE.match(part.name)) - return part - table = super()._parse_table_parts(schema=schema) if isinstance(table.this, exp.Identifier) and "." in table.name: - table = exp.to_table(table.name, dialect="bigquery") - for part in table.parts: - part.transform(quote_unsafe_table_part, copy=False) + catalog, db, this, *rest = ( + t.cast(t.Optional[exp.Expression], exp.to_identifier(x)) + for x in split_num_words(table.name, ".", 3) + ) + + if rest and this: + this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) + + table = exp.Table(this=this, db=db, catalog=catalog) return table From c7e0fc5271b37e00c89234a08264f20169a96473 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 25 May 2023 00:55:45 +0300 Subject: [PATCH 5/6] Fix exp.Table expression parser --- sqlglot/parser.py | 4 ++-- tests/test_parser.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f41e8355b3..f79add06e7 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -469,7 +469,7 @@ class Parser(metaclass=_Parser): exp.Limit: lambda self: self._parse_limit(), exp.Offset: lambda self: self._parse_offset(), exp.TableAlias: lambda self: self._parse_table_alias(), - exp.Table: lambda self: self._parse_table(), + exp.Table: lambda self: self._parse_table_parts(), exp.Condition: lambda self: self._parse_conjunction(), exp.Expression: lambda self: self._parse_statement(), exp.Properties: lambda self: self._parse_properties(), @@ -2255,7 +2255,7 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: or self._parse_placeholder() ) - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: catalog = None db = None table = self._parse_table_part(schema=schema) diff --git a/tests/test_parser.py b/tests/test_parser.py index 11df56690e..ded51a5a2a 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -15,6 +15,14 @@ def test_parse_into(self): self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) + self.assertIsInstance(parse_one("foo", into=exp.Table), exp.Table) + + with self.assertRaises(ParseError) as ctx: + parse_one("SELECT * FROM tbl", into=exp.Table) + + self.assertEqual( + str(ctx.exception), "Failed to parse into ", + ) def test_parse_into_error(self): expected_message = "Failed to parse into []" From 842876b5779ae823ea78c586fb73674bbbff4803 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 25 May 2023 00:59:00 +0300 Subject: [PATCH 6/6] Fixup --- sqlglot/dialects/bigquery.py | 2 +- sqlglot/dialects/tsql.py | 2 +- sqlglot/parser.py | 2 +- tests/test_parser.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index e7278251d0..f24f887608 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -230,7 +230,7 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: return this - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: table = super()._parse_table_parts(schema=schema) if isinstance(table.this, exp.Identifier) and "." in table.name: catalog, db, this, *rest = ( diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 30137a5366..dd1ff9209f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -377,7 +377,7 @@ def _parse_system_time(self) -> t.Optional[exp.Expression]: return system_time - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: table = super()._parse_table_parts(schema=schema) table.set("system_time", self._parse_system_time()) return table diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f79add06e7..93429e35e8 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -2299,7 +2299,7 @@ def _parse_table( subquery.set("pivots", self._parse_pivots()) return subquery - this = self._parse_table_parts(schema=schema) + this: exp.Expression = self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) diff --git a/tests/test_parser.py b/tests/test_parser.py index ded51a5a2a..e5788e9876 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -21,7 +21,8 @@ def test_parse_into(self): parse_one("SELECT * FROM tbl", into=exp.Table) self.assertEqual( - str(ctx.exception), "Failed to parse into ", + str(ctx.exception), + "Failed to parse into ", ) def test_parse_into_error(self):