Skip to content

Commit

Permalink
Validate that FK constraint name is unique (#428)
Browse files Browse the repository at this point in the history
We can do this in our validation step before even touching the database.

Part of #105
  • Loading branch information
ryanslade authored Oct 23, 2024
1 parent d00dd1e commit 9922e5f
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 6 deletions.
13 changes: 13 additions & 0 deletions pkg/migrations/op_set_fk.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"github.com/lib/pq"

"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
Expand Down Expand Up @@ -58,6 +59,18 @@ func (o *OpSetForeignKey) Validate(ctx context.Context, s *schema.Schema) error
}
}

table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
}

if table.ConstraintExists(o.References.Name) {
return ConstraintAlreadyExistsError{
Table: table.Name,
Constraint: o.References.Name,
}
}

if o.Up == "" {
return FieldRequiredError{Name: "up"}
}
Expand Down
86 changes: 84 additions & 2 deletions pkg/migrations/op_set_fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"database/sql"
"testing"

"github.com/xataio/pgroll/internal/testutils"

"github.com/stretchr/testify/assert"

"github.com/xataio/pgroll/internal/testutils"
"github.com/xataio/pgroll/pkg/migrations"
)

Expand Down Expand Up @@ -1120,6 +1120,88 @@ func TestSetForeignKey(t *testing.T) {
ColumnMustHaveComment(t, db, schema, "posts", "user_id", "the id of the author")
},
},
{
name: "validate that foreign key name is unique",
migrations: []migrations.Migration{
{
Name: "01_add_tables",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "name",
Type: "text",
},
},
},
&migrations.OpCreateTable{
Name: "posts",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: ptr(true),
},
{
Name: "title",
Type: "text",
},
{
Name: "user_id",
Type: "integer",
Nullable: ptr(true),
},
},
},
},
},
{
Name: "02_add_fk_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "user_id",
References: &migrations.ForeignKeyReference{
Name: "fk_users_id",
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
{
Name: "03_add_fk_constraint_again",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "user_id",
References: &migrations.ForeignKeyReference{
Name: "fk_users_id",
Table: "users",
Column: "id",
},
Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)",
Down: "user_id",
},
},
},
},
wantStartErr: migrations.ConstraintAlreadyExistsError{
Table: "posts",
Constraint: "fk_users_id",
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {},
},
})
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/roll/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (m *Roll) Complete(ctx context.Context) error {
}

// read the current schema
schema, err := m.state.ReadSchema(ctx, m.schema)
currentSchema, err := m.state.ReadSchema(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to read schema: %w", err)
}
Expand All @@ -175,7 +175,7 @@ func (m *Roll) Complete(ctx context.Context) error {
// execute operations
refreshViews := false
for _, op := range migration.Operations {
err := op.Complete(ctx, m.pgConn, m.sqlTransformer, schema)
err := op.Complete(ctx, m.pgConn, m.sqlTransformer, currentSchema)
if err != nil {
return fmt.Errorf("unable to execute complete operation: %w", err)
}
Expand All @@ -189,12 +189,12 @@ func (m *Roll) Complete(ctx context.Context) error {

// recreate views for the new version (if some operations require it, ie SQL)
if refreshViews && !m.disableVersionSchemas {
schema, err = m.state.ReadSchema(ctx, m.schema)
currentSchema, err = m.state.ReadSchema(ctx, m.schema)
if err != nil {
return fmt.Errorf("unable to read schema: %w", err)
}

err = m.ensureViews(ctx, schema, migration.Name)
err = m.ensureViews(ctx, currentSchema, migration.Name)
if err != nil {
return err
}
Expand Down

0 comments on commit 9922e5f

Please sign in to comment.