diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 156fd51805..9ccd02e70e 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -61,6 +61,8 @@ class Oracle(Dialect): } class Parser(parser.Parser): + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} + FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1824aaad39..dd73fd10ed 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2772,6 +2772,8 @@ class Window(Expression): "order": False, "spec": False, "alias": False, + "over": False, + "first": False, } diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ae863a1ad6..500cbbc753 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1553,12 +1553,18 @@ def window_sql(self, expression: exp.Window) -> str: spec_sql = " " + self.window_spec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" + over = self.sql(expression, "over") or "OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" + + first = expression.args.get("first") + if first is not None: + first = " FIRST " if first else " LAST " + first = first or "" if not partition and not order and not spec and alias: return f"{this} {alias}" - window_args = alias + partition_sql + order_sql + spec_sql + window_args = alias + first + partition_sql + order_sql + spec_sql return f"{this} ({window_args.strip()})" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index c9ebcdaa34..70d5239698 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -232,6 +232,7 @@ class Parser(metaclass=_Parser): TokenType.IS, TokenType.ISNULL, TokenType.INTERVAL, + TokenType.KEEP, TokenType.LAZY, TokenType.LEADING, TokenType.LEFT, @@ -750,6 +751,7 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -3768,14 +3770,24 @@ def _parse_window( # bigquery select from window x AS (partition by ...) if alias: + over = None self._match(TokenType.ALIAS) - elif not self._match(TokenType.OVER): + elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS): return this + else: + over = self._prev.text.upper() if not self._match(TokenType.L_PAREN): - return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) + return self.expression( + exp.Window, this=this, alias=self._parse_id_var(False), over=over + ) window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) + + first = self._match(TokenType.FIRST) + if self._match_text_seq("LAST"): + first = False + partition = self._parse_partition_by() order = self._parse_order() kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text @@ -3806,6 +3818,8 @@ def _parse_window( order=order, spec=spec, alias=window_alias, + over=over, + first=first, ) def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 4272b996d0..b7756cf5ad 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -215,6 +215,7 @@ class TokenType(AutoName): ISNULL = auto() JOIN = auto() JOIN_MARKER = auto() + KEEP = auto() LANGUAGE = auto() LATERAL = auto() LAZY = auto() @@ -561,6 +562,7 @@ class Tokenizer(metaclass=_Tokenizer): "IS": TokenType.IS, "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, + "KEEP": TokenType.KEEP, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, "LEADING": TokenType.LEADING, diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 9eedd76081..dd297d6bf0 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -6,6 +6,9 @@ class TestOracle(Validator): def test_oracle(self): self.validate_identity("SELECT * FROM V$SESSION") + self.validate_identity( + "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" + ) self.validate_all( "NVL(NULL, 1)", diff --git a/tests/test_transpile.py b/tests/test_transpile.py index bb495cd913..d68f6f8444 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -20,6 +20,7 @@ def validate(self, sql, target, **kwargs): self.assertEqual(transpile(sql, **kwargs)[0], target) def test_alias(self): + self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP") self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite") self.assertEqual(transpile("SELECT 1 is")[0], "SELECT 1 AS is") self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")