diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md index 02179f4b9e..86fdc4b045 100644 --- a/sqlglot/dataframe/README.md +++ b/sqlglot/dataframe/README.md @@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will ## Instructions * [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library. * Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`. -* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('', )`. +* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('', , dialect="spark")`. * The column structure can be defined the following ways: * Dictionary where the keys are column names and values are string of the Spark SQL type name. * Ex: `{'cola': 'string', 'colb': 'int'}` @@ -33,12 +33,16 @@ import sqlglot from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql import functions as F -sqlglot.schema.add_table('employee', { - 'employee_id': 'INT', - 'fname': 'STRING', - 'lname': 'STRING', - 'age': 'INT', -}) # Register the table structure prior to reading from the table +sqlglot.schema.add_table( + 'employee', + { + 'employee_id': 'INT', + 'fname': 'STRING', + 'lname': 'STRING', + 'age': 'INT', + }, + dialect="spark", +) # Register the table structure prior to reading from the table spark = SparkSession() diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index a8b89d1a72..f4cfebaf2d 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -5,6 +5,7 @@ import sqlglot from sqlglot import expressions as exp from sqlglot.dataframe.sql.types import DataType +from sqlglot.dialects import Spark from sqlglot.helper import flatten, is_iterable if t.TYPE_CHECKING: @@ -22,6 +23,10 @@ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expressio expression = sqlglot.maybe_parse(expression, dialect="spark") if expression is None: raise ValueError(f"Could not parse {expression}") + + if isinstance(expression, exp.Column): + expression.transform(Spark.normalize_identifier, copy=False) + self.expression: exp.Expression = expression def __repr__(self): diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 3fc923238f..64cceeac02 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -316,6 +316,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: expression.alias_or_name: expression.type.sql("spark") for expression in select_expression.expressions }, + dialect="spark", ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 75feba7c29..4eec782429 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -5,6 +5,7 @@ from sqlglot import expressions as exp from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.dialects import Spark from sqlglot.helper import ensure_list NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) @@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ for expression in expressions: identifiers = expression.find_all(exp.Identifier) for identifier in identifiers: + Spark.normalize_identifier(identifier) replace_alias_name_with_cte_name(spark, expression_context, identifier) replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 7da2901aa1..9d87d4a785 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -19,17 +19,14 @@ def __init__(self, spark: SparkSession): def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame - sqlglot.schema.add_table(tableName) + sqlglot.schema.add_table(tableName, dialect="spark") return DataFrame( self.spark, exp.Select() - .from_(tableName) + .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) .select( - *( - column if Spark.can_identify(column, "safe") else f'"{column}"' - for column in sqlglot.schema.column_names(tableName) - ) + *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) ), ) @@ -74,7 +71,7 @@ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> Data ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True) + columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 7629c3f646..a2c81e544d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -105,6 +105,8 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None TIME_MAPPING = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index d46b12f23a..0e25b9bcfd 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -108,7 +108,6 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[ }, "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], - "can_identify": klass.can_identify, } if enum not in ("", "bigquery"): @@ -123,6 +122,8 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[ if not klass.STRICT_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe + klass.generator_class.can_identify = klass.can_identify + return klass diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 4aa5035fdc..164b212b73 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -88,6 +88,9 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" + # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8847119515..eeba60ed8b 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -153,6 +153,9 @@ class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True + # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + TIME_MAPPING = { "y": "%Y", "Y": "%Y", diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index f71515159d..265780e4e8 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -172,6 +172,11 @@ class Presto(Dialect): TIME_MAPPING = MySQL.TIME_MAPPING STRICT_STRING_CONCAT = True + # https://github.com/trinodb/trino/issues/17 + # https://github.com/trinodb/trino/issues/12289 + # https://github.com/prestodb/presto/issues/2863 + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a7e25fae0d..db6cc3f153 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx class Redshift(Postgres): + # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'" TIME_MAPPING = { **Postgres.TIME_MAPPING, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 3b837ea3f8..803f361e8b 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression: class SQLite(Dialect): + # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6882aaea12..7f7d5dec0a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -266,7 +266,7 @@ class Generator: NORMALIZE_FUNCTIONS: bool | str = "upper" NULL_ORDERING = "nulls_are_small" - can_identify: t.Callable[[str, str | bool], bool] = lambda *_: False + can_identify: t.Callable[[str, str | bool], bool] # Delimiters for quotes, identifiers and the corresponding escape characters QUOTE_START = "'" diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index c84a34282f..22d4982444 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -135,9 +135,9 @@ def setUpClass(cls): data=district_data, schema=cls.sqlglot_district_schema ) cls.df_district.createOrReplaceTempView("district") - sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) - sqlglot.schema.add_table("store", cls.sqlglot_store_schema) - sqlglot.schema.add_table("district", cls.sqlglot_district_schema) + sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark") + sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark") + sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark") def setUp(self) -> None: warnings.filterwarnings("ignore", category=ResourceWarning) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 3f45468398..303d2f987e 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -30,7 +30,7 @@ def test_insertInto_overwrite(self): @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): - sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) + sqlglot.schema.add_table("table_name", {"employee_id": "INT"}, dialect="spark") df = self.df_employee.write.byName.insertInto("table_name") expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @@ -88,8 +88,8 @@ def test_saveAsTable_cache(self): self.compare_sql(df, expected_statements) def test_quotes(self): - sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"}) - df = self.spark.table('"Test"') + sqlglot.schema.add_table("`Test`", {"`ID`": "STRING"}, dialect="spark") + df = self.spark.table("`Test`") self.compare_sql( - df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"] + df.select(df["`ID`"]), ["SELECT `test`.`id` AS `id` FROM `test` AS `test`"] ) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 0970a2e88f..4c275e9c0d 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -71,7 +71,7 @@ def test_typed_schema_nested(self): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_select_only(self): query = "SELECT cola, colb FROM table" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) self.assertEqual( "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", @@ -80,17 +80,17 @@ def test_sql_select_only(self): @mock.patch("sqlglot.schema", MappingSchema()) def test_select_quoted(self): - sqlglot.schema.add_table('"TEST"', {"name": "string"}) + sqlglot.schema.add_table("`TEST`", {"name": "string"}, dialect="spark") self.assertEqual( - SparkSession().table('"TEST"').select(F.col("name")).sql(dialect="snowflake")[0], - '''SELECT "TEST"."name" AS "name" FROM "TEST" AS "TEST"''', + SparkSession().table("`TEST`").select(F.col("name")).sql(dialect="snowflake")[0], + '''SELECT "test"."name" AS "name" FROM "test" AS "test"''', ) @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_with_aggs(self): query = "SELECT cola, colb FROM table" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb")) self.assertEqual( "WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola", @@ -100,7 +100,7 @@ def test_sql_with_aggs(self): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_create(self): query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) @@ -108,7 +108,7 @@ def test_sql_create(self): @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_insert(self): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" - sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) + sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") df = self.spark.sql(query) expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) diff --git a/tests/test_schema.py b/tests/test_schema.py index e43d830856..bffad376dc 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -201,7 +201,7 @@ def test_schema_get_column_type(self): def test_schema_normalization(self): schema = MappingSchema( schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}}, - dialect="spark", + dialect="clickhouse", ) table_z = exp.Table(this="z", db="y", catalog="x") @@ -228,7 +228,9 @@ def test_schema_normalization(self): # Check that the correct dialect is used when calling schema methods schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql") - self.assertEqual(schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="spark")) + self.assertEqual( + schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="clickhouse") + ) # Check that all identifiers are normalized to lowercase for BigQuery, even quoted ones schema = MappingSchema(schema={"`Foo`": {"BaR": "int"}}, dialect="bigquery")