Skip to content

Commit

Permalink
Fix(mysql): convert (U)BIGINT to (UN)SIGNED in CAST expressions (#1832)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 26, 2023
1 parent 763d25b commit f7abc28
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
11 changes: 11 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,17 @@ class Generator(generator.Generator):

LIMIT_FETCH = "LIMIT"

def cast_sql(self, expression: exp.Cast) -> str:
"""(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead."""
if expression.to.this == exp.DataType.Type.BIGINT:
to = "SIGNED"
elif expression.to.this == exp.DataType.Type.UBIGINT:
to = "UNSIGNED"
else:
return super().cast_sql(expression)

return f"CAST({self.sql(expression, 'this')} AS {to})"

def show_sql(self, expression: exp.Show) -> str:
this = f" {expression.name}"
full = " FULL" if expression.args.get("full") else ""
Expand Down
9 changes: 5 additions & 4 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class TestMySQL(Validator):
dialect = "mysql"

def test_ddl(self):
self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10")
self.validate_identity(
Expand Down Expand Up @@ -49,10 +50,6 @@ def test_ddl(self):
def test_identity(self):
self.validate_identity("CAST(x AS ENUM('a', 'b'))")
self.validate_identity("CAST(x AS SET('a', 'b'))")
self.validate_identity("CAST(x AS SIGNED)", "CAST(x AS BIGINT)")
self.validate_identity("CAST(x AS SIGNED INTEGER)", "CAST(x AS BIGINT)")
self.validate_identity("CAST(x AS UNSIGNED)", "CAST(x AS UBIGINT)")
self.validate_identity("CAST(x AS UNSIGNED INTEGER)", "CAST(x AS UBIGINT)")
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
Expand Down Expand Up @@ -401,6 +398,10 @@ def test_mysql_time(self):
self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)")

def test_mysql(self):
self.validate_all("CAST(x AS SIGNED)", write={"mysql": "CAST(x AS SIGNED)"})
self.validate_all("CAST(x AS SIGNED INTEGER)", write={"mysql": "CAST(x AS SIGNED)"})
self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"})
self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"})
self.validate_all(
"SELECT * FROM t LOCK IN SHARE MODE", write={"mysql": "SELECT * FROM t FOR SHARE"}
)
Expand Down

0 comments on commit f7abc28

Please sign in to comment.