Skip to content

Commit

Permalink
Feat(postgres): add support for the PARTITION OF property in CREATE (#…
Browse files Browse the repository at this point in the history
…2476)

* Feat(postgres): add support for the PARTITION OF property in CREATE

* Revert arg type

* Retreat if we don't match OF keyword after PARTITION
  • Loading branch information
georgesittas authored Oct 28, 2023
1 parent e6f31d6 commit c3852db
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 9 deletions.
4 changes: 3 additions & 1 deletion sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression:


def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
kind = expression.args["kind"]
kind = expression.args.get("kind")
if not kind:
return expression

if kind.this == exp.DataType.Type.SERIAL:
data_type = exp.DataType(this=exp.DataType.Type.INT)
Expand Down
16 changes: 16 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,22 @@ class PartitionedByProperty(Property):
arg_types = {"this": True}


# https://www.postgresql.org/docs/current/sql-createtable.html
class PartitionBoundSpec(Expression):
# this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...)
arg_types = {
"this": False,
"expression": False,
"from_expressions": False,
"to_expressions": False,
}


class PartitionedOfProperty(Property):
# this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT
arg_types = {"this": True, "expression": True}


class RemoteWithConnectionModelProperty(Property):
arg_types = {"this": True}

Expand Down
24 changes: 24 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ class Generator:
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
Expand Down Expand Up @@ -1262,6 +1263,29 @@ def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"

def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
if isinstance(expression.this, list):
return f"IN ({self.expressions(expression, key='this', flat=True)})"
if expression.this:
modulus = self.sql(expression, "this")
remainder = self.sql(expression, "expression")
return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"

from_expressions = self.expressions(expression, key="from_expressions", flat=True)
to_expressions = self.expressions(expression, key="to_expressions", flat=True)
return f"FROM ({from_expressions}) TO ({to_expressions})"

def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
this = self.sql(expression, "this")

for_values_or_default = expression.expression
if isinstance(for_values_or_default, exp.PartitionBoundSpec):
for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
else:
for_values_or_default = " DEFAULT"

return f"PARTITION OF {this}{for_values_or_default}"

def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
Expand Down
53 changes: 53 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ class Parser(metaclass=_Parser):
"ON": lambda self: self._parse_on_property(),
"ORDER BY": lambda self: self._parse_order(skip_order_token=True),
"OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()),
"PARTITION": lambda self: self._parse_partitioned_of(),
"PARTITION BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED BY": lambda self: self._parse_partitioned_by(),
"PARTITIONED_BY": lambda self: self._parse_partitioned_by(),
Expand Down Expand Up @@ -1743,6 +1744,58 @@ def _parse_partition_by(self) -> t.List[exp.Expression]:
return self._parse_csv(self._parse_conjunction)
return []

def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec:
def _parse_partition_bound_expr() -> t.Optional[exp.Expression]:
if self._match_text_seq("MINVALUE"):
return exp.var("MINVALUE")
if self._match_text_seq("MAXVALUE"):
return exp.var("MAXVALUE")
return self._parse_bitwise()

this: t.Optional[exp.Expression | t.List[exp.Expression]] = None
expression = None
from_expressions = None
to_expressions = None

if self._match(TokenType.IN):
this = self._parse_wrapped_csv(self._parse_bitwise)
elif self._match(TokenType.FROM):
from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
self._match_text_seq("TO")
to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr)
elif self._match_text_seq("WITH", "(", "MODULUS"):
this = self._parse_number()
self._match_text_seq(",", "REMAINDER")
expression = self._parse_number()
self._match_r_paren()
else:
self.raise_error("Failed to parse partition bound spec.")

return self.expression(
exp.PartitionBoundSpec,
this=this,
expression=expression,
from_expressions=from_expressions,
to_expressions=to_expressions,
)

# https://www.postgresql.org/docs/current/sql-createtable.html
def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]:
if not self._match_text_seq("OF"):
self._retreat(self._index - 1)
return None

this = self._parse_table(schema=True)

if self._match(TokenType.DEFAULT):
expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT")
elif self._match_text_seq("FOR", "VALUES"):
expression = self._parse_partition_bound_spec()
else:
self.raise_error("Expecting either DEFAULT or FOR VALUES clause.")

return self.expression(exp.PartitionedOfProperty, this=this, expression=expression)

def _parse_partitioned_by(self) -> exp.PartitionedByProperty:
self._match(TokenType.EQ)
return self.expression(
Expand Down
35 changes: 27 additions & 8 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,6 @@ class TestPostgres(Validator):
dialect = "postgres"

def test_ddl(self):
self.validate_identity(
"CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)"
)
self.validate_identity(
"CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])",
"CREATE TABLE test (x TIMESTAMP[][])",
)
self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)")
self.validate_identity("CREATE TABLE test (elems JSONB[])")
self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)")
Expand All @@ -26,6 +19,29 @@ def test_ddl(self):
self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a")
self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b")
self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *")
self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a")
self.validate_identity("CREATE TABLE cities_partdef PARTITION OF cities DEFAULT")
self.validate_identity(
"CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)"
)
self.validate_identity(
"CREATE TABLE measurement_y2016m07 PARTITION OF measurement (unitsales DEFAULT 0) FOR VALUES FROM ('2016-07-01') TO ('2016-08-01')"
)
self.validate_identity(
"CREATE TABLE measurement_ym_older PARTITION OF measurement_year_month FOR VALUES FROM (MINVALUE, MINVALUE) TO (2016, 11)"
)
self.validate_identity(
"CREATE TABLE measurement_ym_y2016m11 PARTITION OF measurement_year_month FOR VALUES FROM (2016, 11) TO (2016, 12)"
)
self.validate_identity(
"CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b')"
)
self.validate_identity(
"CREATE TABLE cities_ab PARTITION OF cities (CONSTRAINT city_id_nonzero CHECK (city_id <> 0)) FOR VALUES IN ('a', 'b') PARTITION BY RANGE(population)"
)
self.validate_identity(
"CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)"
)
self.validate_identity(
"INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *"
)
Expand All @@ -44,7 +60,10 @@ def test_ddl(self):
self.validate_identity(
"DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a"
)
self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a")
self.validate_identity(
"CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])",
"CREATE TABLE test (x TIMESTAMP[][])",
)

self.validate_all(
"CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)",
Expand Down

0 comments on commit c3852db

Please sign in to comment.