Skip to content

Commit

Permalink
Fix(schema): use to_identifier as fallback when normalizing names (#2492
Browse files Browse the repository at this point in the history
)
  • Loading branch information
georgesittas authored Oct 31, 2023
1 parent f25b61c commit 8e20328
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
21 changes: 20 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5892,7 +5892,7 @@ def to_identifier(name, quoted=None, copy=True):
Args:
name: The name to turn into an identifier.
quoted: Whether or not force quote the identifier.
copy: Whether or not to copy a passed in Identefier node.
copy: Whether or not to copy name if it's an Identifier.
Returns:
The identifier ast node.
Expand All @@ -5913,6 +5913,25 @@ def to_identifier(name, quoted=None, copy=True):
return identifier


def parse_identifier(name: str, dialect: DialectType = None) -> Identifier:
"""
Parses a given string into an identifier.
Args:
name: The name to parse into an identifier.
dialect: The dialect to parse against.
Returns:
The identifier ast node.
"""
try:
expression = maybe_parse(name, dialect=dialect, into=Identifier)
except ParseError:
expression = to_identifier(name)

return expression


INTERVAL_STRING_RE = re.compile(r"\s*([0-9]+)\s*([a-zA-Z]+)\s*")


Expand Down
7 changes: 2 additions & 5 deletions sqlglot/optimizer/normalize_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlglot import ParseError, exp, parse_one
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType

Expand Down Expand Up @@ -49,10 +49,7 @@ def normalize_identifiers(expression, dialect=None):
The transformed expression.
"""
if isinstance(expression, str):
try:
expression = parse_one(expression, dialect=dialect, into=exp.Identifier)
except ParseError:
expression = exp.to_identifier(expression)
expression = exp.parse_identifier(expression, dialect=dialect)

dialect = Dialect.get_or_raise(dialect)

Expand Down
14 changes: 5 additions & 9 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import abc
import typing as t

import sqlglot
from sqlglot import expressions as exp
from sqlglot.dialects.dialect import Dialect
from sqlglot.errors import ParseError, SchemaError
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
from sqlglot.trie import TrieResult, in_trie, new_trie

Expand Down Expand Up @@ -448,19 +447,16 @@ def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.Da


def normalize_name(
name: str | exp.Identifier,
identifier: str | exp.Identifier,
dialect: DialectType = None,
is_table: bool = False,
normalize: t.Optional[bool] = True,
) -> str:
try:
identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier)
except ParseError:
return name if isinstance(name, str) else name.name
if isinstance(identifier, str):
identifier = exp.parse_identifier(identifier, dialect=dialect)

name = identifier.name
if not normalize:
return name
return identifier.name

# This can be useful for normalize_identifier
identifier.meta["is_table"] = is_table
Expand Down
9 changes: 9 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,15 @@ def test_qualify_columns(self, logger):
"CREATE FUNCTION `udfs`.`myTest`(`x` FLOAT64) AS (1)",
)

self.assertEqual(
optimizer.qualify.qualify(
parse_one("SELECT `bar_bazfoo_$id` FROM test", read="spark"),
schema={"test": {"bar_bazFoo_$id": "BIGINT"}},
dialect="spark",
).sql(dialect="spark"),
"SELECT `test`.`bar_bazfoo_$id` AS `bar_bazfoo_$id` FROM `test` AS `test`",
)

self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
)
Expand Down

0 comments on commit 8e20328

Please sign in to comment.