Skip to content

Commit

Permalink
Refactor(schema): replace _ensure_table with exp.maybe_parse (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored May 31, 2023
1 parent 24d44ad commit 12d3cca
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
1 change: 1 addition & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4575,6 +4575,7 @@ def maybe_parse(
sql = str(sql_or_expression)
if prefix:
sql = f"{prefix} {sql}"

return sqlglot.parse_one(sql, read=dialect, into=into, **opts)


Expand Down
19 changes: 3 additions & 16 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def add_table(
dialect: the SQL dialect that will be used to parse `table` if it's a string.
"""
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
)
normalized_column_mapping = {
self._normalize_name(key, dialect=dialect): value
Expand All @@ -250,7 +250,7 @@ def column_names(
dialect: DialectType = None,
) -> t.List[str]:
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
)

schema = self.find(normalized_table)
Expand All @@ -270,7 +270,7 @@ def get_column_type(
dialect: DialectType = None,
) -> exp.DataType:
normalized_table = self._normalize_table(
self._ensure_table(table, dialect=dialect), dialect=dialect
exp.maybe_parse(table, into=exp.Table, dialect=dialect), dialect=dialect
)
normalized_column_name = self._normalize_name(
column if isinstance(column, str) else column.this, dialect=dialect
Expand Down Expand Up @@ -345,19 +345,6 @@ def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1

def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table:
if isinstance(table, exp.Table):
return table

dialect = dialect or self.dialect
parsed_table = sqlglot.parse_one(table, read=dialect, into=exp.Table)

if not parsed_table:
in_dialect = f" in dialect {dialect}" if dialect else ""
raise SchemaError(f"Failed to parse table '{table}'{in_dialect}.")

return parsed_table

def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object.
Expand Down

0 comments on commit 12d3cca

Please sign in to comment.