Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fall back to SQL for unconvertible column constraint options in CREATE TABLE statements #550

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/sql2pgroll/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func convertAlterTableAddUniqueConstraint(stmt *pgq.AlterTableStmt, constraint *
}

func convertAlterTableAddForeignKeyConstraint(stmt *pgq.AlterTableStmt, constraint *pgq.Constraint) (migrations.Operation, error) {
if !canConvertAlterTableAddForeignKeyConstraint(constraint) {
if !canConvertForeignKeyConstraint(constraint) {
return nil, nil
}

Expand Down Expand Up @@ -240,7 +240,7 @@ func parseOnDeleteAction(action string) (migrations.ForeignKeyReferenceOnDelete,
}
}

func canConvertAlterTableAddForeignKeyConstraint(constraint *pgq.Constraint) bool {
func canConvertForeignKeyConstraint(constraint *pgq.Constraint) bool {
if constraint.SkipValidation {
return false
}
Expand Down
63 changes: 50 additions & 13 deletions pkg/sql2pgroll/create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package sql2pgroll

import (
"fmt"
"slices"

pgq "github.com/xataio/pg_query_go/v6"

Expand Down Expand Up @@ -80,35 +81,56 @@ func convertColumnDef(col *pgq.ColumnDef) (*migrations.Column, error) {
return nil, nil
}

// Convert the column type
// Deparse the column type
typeString, err := pgq.DeparseTypeName(col.TypeName)
if err != nil {
return nil, fmt.Errorf("error deparsing column type: %w", err)
}

// Determine column nullability, uniqueness, and primary key status
var notNull, unique, pk bool
var defaultValue *string
for _, constraint := range col.Constraints {
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_NOTNULL {
// Named inline constraints are not supported
anyNamed := slices.ContainsFunc(col.GetConstraints(), func(c *pgq.Node) bool {
return c.GetConstraint().GetConname() != ""
})
if anyNamed {
return nil, nil
}

// Convert column constraints
var notNull, pk, unique bool
for _, c := range col.GetConstraints() {
switch c.GetConstraint().GetContype() {
case pgq.ConstrType_CONSTR_NULL:
notNull = false
case pgq.ConstrType_CONSTR_NOTNULL:
notNull = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_UNIQUE {
case pgq.ConstrType_CONSTR_UNIQUE:
if !canConvertUniqueConstraint(c.GetConstraint()) {
return nil, nil
}
unique = true
}
if constraint.GetConstraint().GetContype() == pgq.ConstrType_CONSTR_PRIMARY {
case pgq.ConstrType_CONSTR_PRIMARY:
if !canConvertPrimaryKeyConstraint(c.GetConstraint()) {
return nil, nil
}
pk = true
notNull = true
case pgq.ConstrType_CONSTR_CHECK:
if !canConvertCheckConstraint(c.GetConstraint()) {
return nil, nil
}
case pgq.ConstrType_CONSTR_FOREIGN:
if !canConvertForeignKeyConstraint(c.GetConstraint()) {
return nil, nil
}
}
}

return &migrations.Column{
Name: col.Colname,
Name: col.GetColname(),
Type: typeString,
Nullable: !notNull,
Unique: unique,
Default: defaultValue,
Pk: pk,
Unique: unique,
}, nil
}

Expand All @@ -127,3 +149,18 @@ func canConvertColumnDef(col *pgq.ColumnDef) bool {
return true
}
}

// canConvertPrimaryKeyConstraint returns true iff `constraint` can be converted
// to a pgroll primary key constraint.
func canConvertPrimaryKeyConstraint(constraint *pgq.Constraint) bool {
switch {
case
// Specifying an index tablespace is not supported
constraint.GetIndexspace() != "",
// Storage options are not supported
len(constraint.GetOptions()) != 0:
return false
default:
return true
}
}
32 changes: 32 additions & 0 deletions pkg/sql2pgroll/create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestConvertCreateTableStatements(t *testing.T) {
sql: "CREATE TABLE foo(a int)",
expectedOp: expect.CreateTableOp1,
},
{
sql: "CREATE TABLE foo(a int NULL)",
expectedOp: expect.CreateTableOp1,
},
{
sql: "CREATE TABLE foo(a int NOT NULL)",
expectedOp: expect.CreateTableOp2,
Expand Down Expand Up @@ -131,6 +135,34 @@ func TestUnconvertableCreateTableStatements(t *testing.T) {
"CREATE TABLE foo(a int, UNIQUE (a))",
"CREATE TABLE foo(a int, PRIMARY KEY (a))",
"CREATE TABLE foo(a int, FOREIGN KEY (a) REFERENCES bar(b))",

// Primary key constraint options are not supported
"CREATE TABLE foo(a int PRIMARY KEY USING INDEX TABLESPACE bar)",
"CREATE TABLE foo(a int PRIMARY KEY WITH (fillfactor=70))",

// CHECK constraint NO INHERIT option is not supported
"CREATE TABLE foo(a int CHECK (a > 0) NO INHERIT)",

// Options on UNIQUE constraints are not supported
"CREATE TABLE foo(a int UNIQUE NULLS NOT DISTINCT)",
"CREATE TABLE foo(a int UNIQUE WITH (fillfactor=70))",
"CREATE TABLE foo(a int UNIQUE USING INDEX TABLESPACE baz)",

// Some options on FOREIGN KEY constraints are not supported
"CREATE TABLE foo(a int REFERENCES bar (b) ON UPDATE RESTRICT)",
"CREATE TABLE foo(a int REFERENCES bar (b) ON UPDATE CASCADE)",
"CREATE TABLE foo(a int REFERENCES bar (b) ON UPDATE SET NULL)",
"CREATE TABLE foo(a int REFERENCES bar (b) ON UPDATE SET DEFAULT)",
"CREATE TABLE foo(a int REFERENCES bar (b) MATCH FULL)",

// Named inline constraints are not supported
"CREATE TABLE foo(a int CONSTRAINT foo_check CHECK (a > 0))",
"CREATE TABLE foo(a int CONSTRAINT foo_unique UNIQUE)",
"CREATE TABLE foo(a int CONSTRAINT foo_pk PRIMARY KEY)",
"CREATE TABLE foo(a int CONSTRAINT foo_fk REFERENCES bar(b))",
"CREATE TABLE foo(a int CONSTRAINT foo_default DEFAULT 0)",
"CREATE TABLE foo(a int CONSTRAINT foo_null NULL)",
"CREATE TABLE foo(a int CONSTRAINT foo_notnull NOT NULL)",
}

for _, sql := range tests {
Expand Down