Skip to content

Commit

Permalink
Feat(oracle): support KEEP (.. [FIRST|LAST] ..) window function syntax (
Browse files Browse the repository at this point in the history
#1522)

* Feat(oracle): support KEEP (.. [FIRST|LAST] ..) window function syntax

* Fixup

* Fixup
  • Loading branch information
georgesittas authored May 2, 2023
1 parent e11a5ce commit 455b9e9
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 4 deletions.
2 changes: 2 additions & 0 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class Oracle(Dialect):
}

class Parser(parser.Parser):
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP}

FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,8 @@ class Window(Expression):
"order": False,
"spec": False,
"alias": False,
"over": False,
"first": False,
}


Expand Down
10 changes: 8 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,12 +1553,18 @@ def window_sql(self, expression: exp.Window) -> str:
spec_sql = " " + self.window_spec_sql(spec) if spec else ""

alias = self.sql(expression, "alias")
this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
over = self.sql(expression, "over") or "OVER"
this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"

first = expression.args.get("first")
if first is not None:
first = " FIRST " if first else " LAST "
first = first or ""

if not partition and not order and not spec and alias:
return f"{this} {alias}"

window_args = alias + partition_sql + order_sql + spec_sql
window_args = alias + first + partition_sql + order_sql + spec_sql

return f"{this} ({window_args.strip()})"

Expand Down
18 changes: 16 additions & 2 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class Parser(metaclass=_Parser):
TokenType.IS,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.KEEP,
TokenType.LAZY,
TokenType.LEADING,
TokenType.LEFT,
Expand Down Expand Up @@ -755,6 +756,7 @@ class Parser(metaclass=_Parser):
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}

WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER}

ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY}

Expand Down Expand Up @@ -3773,14 +3775,24 @@ def _parse_window(

# bigquery select from window x AS (partition by ...)
if alias:
over = None
self._match(TokenType.ALIAS)
elif not self._match(TokenType.OVER):
elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS):
return this
else:
over = self._prev.text.upper()

if not self._match(TokenType.L_PAREN):
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
return self.expression(
exp.Window, this=this, alias=self._parse_id_var(False), over=over
)

window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)

first = self._match(TokenType.FIRST)
if self._match_text_seq("LAST"):
first = False

partition = self._parse_partition_by()
order = self._parse_order()
kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
Expand Down Expand Up @@ -3811,6 +3823,8 @@ def _parse_window(
order=order,
spec=spec,
alias=window_alias,
over=over,
first=first,
)

def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class TokenType(AutoName):
ISNULL = auto()
JOIN = auto()
JOIN_MARKER = auto()
KEEP = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
Expand Down Expand Up @@ -562,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer):
"IS": TokenType.IS,
"ISNULL": TokenType.ISNULL,
"JOIN": TokenType.JOIN,
"KEEP": TokenType.KEEP,
"LATERAL": TokenType.LATERAL,
"LAZY": TokenType.LAZY,
"LEADING": TokenType.LEADING,
Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ class TestOracle(Validator):

def test_oracle(self):
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
)

self.validate_all(
"NVL(NULL, 1)",
Expand Down
1 change: 1 addition & 0 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def validate(self, sql, target, **kwargs):
self.assertEqual(transpile(sql, **kwargs)[0], target)

def test_alias(self):
self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP")
self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite")
self.assertEqual(transpile("SELECT 1 is")[0], "SELECT 1 AS is")
self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
Expand Down

0 comments on commit 455b9e9

Please sign in to comment.