Skip to content

Commit

Permalink
Fix(trino): wrap SEQUENCE in an UNNEST call if used as a source (toby…
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored and adrianisk committed Jun 21, 2023
1 parent 8e0ee40 commit 07a4601
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
19 changes: 12 additions & 7 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s
def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str:
start = expression.args["start"]
end = expression.args["end"]
step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series
step = expression.args.get("step")

target_type = None

Expand All @@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) ->
else:
start = exp.Cast(this=start, to=to)

return self.func("SEQUENCE", start, end, step)
sql = self.func("SEQUENCE", start, end, step)
if isinstance(expression.parent, exp.Table):
sql = f"UNNEST({sql})"

return sql


def _ensure_utf8(charset: exp.Literal) -> None:
Expand Down Expand Up @@ -204,6 +208,7 @@ class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
Expand All @@ -219,23 +224,23 @@ class Parser(parser.Parser):
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
"DATE_TRUNC": date_trunc_to_time,
"FROM_HEX": exp.Unhex.from_arg_list,
"FROM_UNIXTIME": _from_unixtime,
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
"NOW": exp.CurrentTimestamp.from_arg_list,
"SEQUENCE": exp.GenerateSeries.from_arg_list,
"STRPOS": lambda args: exp.StrPosition(
this=seq_get(args, 0),
substr=seq_get(args, 1),
instance=seq_get(args, 2),
),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("TRIM")
Expand Down
27 changes: 27 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ def test_postgres(self):
"trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)",
},
)
self.validate_all(
"GENERATE_SERIES(a, b)",
write={
"postgres": "GENERATE_SERIES(a, b)",
"presto": "SEQUENCE(a, b)",
"trino": "SEQUENCE(a, b)",
"tsql": "GENERATE_SERIES(a, b)",
},
)
self.validate_all(
"GENERATE_SERIES(a, b)",
read={
"postgres": "GENERATE_SERIES(a, b)",
"presto": "SEQUENCE(a, b)",
"trino": "SEQUENCE(a, b)",
"tsql": "GENERATE_SERIES(a, b)",
},
)
self.validate_all(
"SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
write={
"postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
"presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
"trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))",
"tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)",
},
)
self.validate_all(
"END WORK AND NO CHAIN",
write={"postgres": "COMMIT AND NO CHAIN"},
Expand Down

0 comments on commit 07a4601

Please sign in to comment.