Skip to content

Commit

Permalink
Update other dialects that have case-insensitive identifiers too
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Jun 16, 2023
1 parent f3c754d commit d5a1305
Show file tree
Hide file tree
Showing 17 changed files with 63 additions and 32 deletions.
18 changes: 11 additions & 7 deletions sqlglot/dataframe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will
## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`.
* The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
* Ex: `{'cola': 'string', 'colb': 'int'}`
Expand All @@ -33,12 +33,16 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F

sqlglot.schema.add_table('employee', {
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
}) # Register the table structure prior to reading from the table
sqlglot.schema.add_table(
'employee',
{
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
},
dialect="spark",
) # Register the table structure prior to reading from the table

spark = SparkSession()

Expand Down
5 changes: 5 additions & 0 deletions sqlglot/dataframe/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable

if t.TYPE_CHECKING:
Expand All @@ -22,6 +23,10 @@ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expressio
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")

if isinstance(expression, exp.Column):
expression.transform(Spark.normalize_identifier, copy=False)

self.expression: exp.Expression = expression

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
dialect="spark",
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dataframe/sql/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list

NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
Expand All @@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
Spark.normalize_identifier(identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)

Expand Down
11 changes: 4 additions & 7 deletions sqlglot/dataframe/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,14 @@ def __init__(self, spark: SparkSession):
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame

sqlglot.schema.add_table(tableName)
sqlglot.schema.add_table(tableName, dialect="spark")

return DataFrame(
self.spark,
exp.Select()
.from_(tableName)
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.select(
*(
column if Spark.can_identify(column, "safe") else f'"{column}"'
for column in sqlglot.schema.column_names(tableName)
)
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
),
)

Expand Down Expand Up @@ -74,7 +71,7 @@ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> Data
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True)
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
df = df._convert_leaf_to_cte().select(*columns)

return self.copy(_df=df)
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:

class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True

# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_MAPPING = {
Expand Down
3 changes: 2 additions & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
},
"STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0],
"IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0],
"can_identify": klass.can_identify,
}

if enum not in ("", "bigquery"):
Expand All @@ -123,6 +122,8 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
if not klass.STRICT_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe

klass.generator_class.can_identify = klass.can_identify

return klass


Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"

# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True

# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True

# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx


class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
**Postgres.TIME_MAPPING,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:


class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class Generator:
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"

can_identify: t.Callable[[str, str | bool], bool] = lambda *_: False
can_identify: t.Callable[[str, str | bool], bool]

# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
Expand Down
6 changes: 3 additions & 3 deletions tests/dataframe/integration/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def setUpClass(cls):
data=district_data, schema=cls.sqlglot_district_schema
)
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
sqlglot.schema.add_table("district", cls.sqlglot_district_schema)
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark")
sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark")
sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark")

def setUp(self) -> None:
warnings.filterwarnings("ignore", category=ResourceWarning)
Expand Down
8 changes: 4 additions & 4 deletions tests/dataframe/unit/test_dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_insertInto_overwrite(self):

@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
sqlglot.schema.add_table("table_name", {"employee_id": "INT"}, dialect="spark")
df = self.df_employee.write.byName.insertInto("table_name")
expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
Expand Down Expand Up @@ -88,8 +88,8 @@ def test_saveAsTable_cache(self):
self.compare_sql(df, expected_statements)

def test_quotes(self):
sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"})
df = self.spark.table('"Test"')
sqlglot.schema.add_table("`Test`", {"`ID`": "STRING"}, dialect="spark")
df = self.spark.table("`Test`")
self.compare_sql(
df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"]
df.select(df["`ID`"]), ["SELECT `test`.`id` AS `id` FROM `test` AS `test`"]
)
14 changes: 7 additions & 7 deletions tests/dataframe/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_typed_schema_nested(self):
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_select_only(self):
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
self.assertEqual(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
Expand All @@ -80,17 +80,17 @@ def test_sql_select_only(self):

@mock.patch("sqlglot.schema", MappingSchema())
def test_select_quoted(self):
sqlglot.schema.add_table('"TEST"', {"name": "string"})
sqlglot.schema.add_table("`TEST`", {"name": "string"}, dialect="spark")

self.assertEqual(
SparkSession().table('"TEST"').select(F.col("name")).sql(dialect="snowflake")[0],
'''SELECT "TEST"."name" AS "name" FROM "TEST" AS "TEST"''',
SparkSession().table("`TEST`").select(F.col("name")).sql(dialect="snowflake")[0],
'''SELECT "test"."name" AS "name" FROM "test" AS "test"''',
)

@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_with_aggs(self):
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
self.assertEqual(
"WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola",
Expand All @@ -100,15 +100,15 @@ def test_sql_with_aggs(self):
@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_create(self):
query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)

@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_insert(self):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark")
df = self.spark.sql(query)
expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def test_schema_get_column_type(self):
def test_schema_normalization(self):
schema = MappingSchema(
schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}},
dialect="spark",
dialect="clickhouse",
)

table_z = exp.Table(this="z", db="y", catalog="x")
Expand All @@ -228,7 +228,9 @@ def test_schema_normalization(self):

# Check that the correct dialect is used when calling schema methods
schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql")
self.assertEqual(schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="spark"))
self.assertEqual(
schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="clickhouse")
)

# Check that all identifiers are normalized to lowercase for BigQuery, even quoted ones
schema = MappingSchema(schema={"`Foo`": {"BaR": "int"}}, dialect="bigquery")
Expand Down

0 comments on commit d5a1305

Please sign in to comment.