diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index c6ce3d07..a9f1e308 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -17,87 +17,231 @@ import ( // Duplicator duplicates a column in a table, including all constraints and // comments. type Duplicator struct { + stmtBuilder *duplicatorStmtBuilder conn db.DB - table *schema.Table - column *schema.Column - asName string - withoutNotNull bool - withType string - withoutConstraint string + columns map[string]*columnToDuplicate + withoutConstraint []string +} + +type columnToDuplicate struct { + column *schema.Column + asName string + withoutNotNull bool + withType string +} + +// duplicatorStmtBuilder is a helper for building SQL statements to duplicate +// columns and constraints in a table. +type duplicatorStmtBuilder struct { + table *schema.Table } const ( dataTypeMismatchErrorCode pq.ErrorCode = "42804" undefinedFunctionErrorCode pq.ErrorCode = "42883" + + cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)` + cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s` + cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID` ) // NewColumnDuplicator creates a new Duplicator for a column. -func NewColumnDuplicator(conn db.DB, table *schema.Table, column *schema.Column) *Duplicator { +func NewColumnDuplicator(conn db.DB, table *schema.Table, columns ...*schema.Column) *Duplicator { + cols := make(map[string]*columnToDuplicate, len(columns)) + for _, column := range columns { + cols[column.Name] = &columnToDuplicate{ + column: column, + asName: TemporaryName(column.Name), + withType: column.Type, + } + } return &Duplicator{ - conn: conn, - table: table, - column: column, - asName: TemporaryName(column.Name), - withType: column.Type, + stmtBuilder: &duplicatorStmtBuilder{ + table: table, + }, + conn: conn, + columns: cols, + withoutConstraint: make([]string, 0), } } // WithType sets the type of the new column. -func (d *Duplicator) WithType(t string) *Duplicator { - d.withType = t +func (d *Duplicator) WithType(columnName, t string) *Duplicator { + d.columns[columnName].withType = t return d } // WithoutConstraint excludes a constraint from being duplicated. func (d *Duplicator) WithoutConstraint(c string) *Duplicator { - d.withoutConstraint = c + d.withoutConstraint = append(d.withoutConstraint, c) return d } // WithoutNotNull excludes the NOT NULL constraint from being duplicated. -func (d *Duplicator) WithoutNotNull() *Duplicator { - d.withoutNotNull = true +func (d *Duplicator) WithoutNotNull(columnName string) *Duplicator { + d.columns[columnName].withoutNotNull = true return d } // Duplicate duplicates a column in the table, including all constraints and // comments. func (d *Duplicator) Duplicate(ctx context.Context) error { + colNames := make([]string, 0, len(d.columns)) + for name, c := range d.columns { + colNames = append(colNames, name) + + // Duplicate the column with the new type + // and check and fk constraints + if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" { + _, err := d.conn.ExecContext(ctx, sql) + if err != nil { + return err + } + } + + // Duplicate the column's default value + if sql := d.stmtBuilder.duplicateDefault(c.column, c.asName); sql != "" { + _, err := d.conn.ExecContext(ctx, sql) + err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode) + if err != nil { + return err + } + } + + if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" { + _, err := d.conn.ExecContext(ctx, sql) + if err != nil { + return err + } + } + } + + // Generate SQL to duplicate any check constraints on the columns. This may faile + // if the check constraint is not valid for the new column type, in which case + // the error is ignored. + for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) { + // Update the check constraint expression to use the new column names if any of the columns are duplicated + _, err := d.conn.ExecContext(ctx, sql) + err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode) + if err != nil { + return err + } + } + + // Generate SQL to duplicate any unique constraints on the columns + // The constraint is duplicated by adding a unique index on the column concurrently. + // The index is converted into a unique constraint on migration completion. + for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) { + // Update the unique constraint columns to use the new column names if any of the columns are duplicated + if _, err := d.conn.ExecContext(ctx, sql); err != nil { + return err + } + } + + return nil +} + +func (d *duplicatorStmtBuilder) duplicateCheckConstraints(withoutConstraint []string, colNames ...string) []string { + stmts := make([]string, 0, len(d.table.CheckConstraints)) + for _, cc := range d.table.CheckConstraints { + if slices.Contains(withoutConstraint, cc.Name) { + continue + } + if duplicatedConstraintColumns := d.duplicatedConstraintColumns(cc.Columns, colNames...); len(duplicatedConstraintColumns) > 0 { + stmts = append(stmts, fmt.Sprintf(cAlterTableAddCheckConstraintSQL, + pq.QuoteIdentifier(d.table.Name), + pq.QuoteIdentifier(DuplicationName(cc.Name)), + rewriteCheckExpression(cc.Definition, duplicatedConstraintColumns...), + )) + } + } + return stmts +} + +func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []string, colNames ...string) []string { + stmts := make([]string, 0, len(d.table.UniqueConstraints)) + for _, uc := range d.table.UniqueConstraints { + if slices.Contains(withoutConstraint, uc.Name) { + continue + } + if duplicatedMember, constraintColumns := d.allConstraintColumns(uc.Columns, colNames...); duplicatedMember { + stmts = append(stmts, fmt.Sprintf(cCreateUniqueIndexSQL, + pq.QuoteIdentifier(DuplicationName(uc.Name)), + pq.QuoteIdentifier(d.table.Name), + strings.Join(quoteColumnNames(constraintColumns), ", "), + )) + } + } + return stmts +} + +// duplicatedConstraintColumns returns a new slice of constraint columns with +// the columns that are duplicated replaced with temporary names. +func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string { + newConstraintColumns := make([]string, 0) + for _, column := range constraintColumns { + if slices.Contains(duplicatedColumns, column) { + newConstraintColumns = append(newConstraintColumns, column) + } + } + return newConstraintColumns +} + +// allConstraintColumns returns a new slice of constraint columns with the columns +// that are duplicated replaced with temporary names and a boolean indicating if +// any of the columns are duplicated. +func (d *duplicatorStmtBuilder) allConstraintColumns(constraintColumns []string, duplicatedColumns ...string) (bool, []string) { + duplicatedMember := false + newConstraintColumns := make([]string, len(constraintColumns)) + for i, column := range constraintColumns { + if slices.Contains(duplicatedColumns, column) { + newConstraintColumns[i] = TemporaryName(column) + duplicatedMember = true + } else { + newConstraintColumns[i] = column + } + } + return duplicatedMember, newConstraintColumns +} + +func (d *duplicatorStmtBuilder) duplicateColumn( + column *schema.Column, + asName string, + withoutNotNull bool, + withType string, + withoutConstraint []string, +) string { const ( - cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` - cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s` - cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID` - cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)` - cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s` - cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID` - cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s` + cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` + cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s` + cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID` ) // Generate SQL to duplicate the column's name and type sql := fmt.Sprintf(cAlterTableSQL, pq.QuoteIdentifier(d.table.Name), - pq.QuoteIdentifier(d.asName), - d.withType) + pq.QuoteIdentifier(asName), + withType) // Generate SQL to add an unchecked NOT NULL constraint if the original column // is NOT NULL. The constraint will be validated on migration completion. - if !d.column.Nullable && !d.withoutNotNull { + if !column.Nullable && !withoutNotNull { sql += fmt.Sprintf(", "+cAddCheckConstraintSQL, - pq.QuoteIdentifier(DuplicationName(NotNullConstraintName(d.column.Name))), - fmt.Sprintf("CHECK (%s IS NOT NULL)", pq.QuoteIdentifier(d.asName)), + pq.QuoteIdentifier(DuplicationName(NotNullConstraintName(column.Name))), + fmt.Sprintf("CHECK (%s IS NOT NULL)", pq.QuoteIdentifier(asName)), ) } // Generate SQL to duplicate any foreign key constraints on the column for _, fk := range d.table.ForeignKeys { - if fk.Name == d.withoutConstraint { + if slices.Contains(withoutConstraint, fk.Name) { continue } - if slices.Contains(fk.Columns, d.column.Name) { + if slices.Contains(fk.Columns, column.Name) { sql += fmt.Sprintf(", "+cAddForeignKeySQL, pq.QuoteIdentifier(DuplicationName(fk.Name)), - strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, d.column.Name, d.asName)), ", "), + strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "), pq.QuoteIdentifier(fk.ReferencedTable), strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "), fk.OnDelete, @@ -105,86 +249,35 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { } } - _, err := d.conn.ExecContext(ctx, sql) - if err != nil { - return err + return sql +} + +func (d *duplicatorStmtBuilder) duplicateDefault(column *schema.Column, asName string) string { + if column.Default == nil { + return "" } + const cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s` + // Generate SQL to duplicate any default value on the column. This may fail // if the default value is not valid for the new column type, in which case // the error is ignored. - if d.column.Default != nil { - sql := fmt.Sprintf(cSetDefaultSQL, pq.QuoteIdentifier(d.table.Name), d.asName, *d.column.Default) - - _, err := d.conn.ExecContext(ctx, sql) + return fmt.Sprintf(cSetDefaultSQL, pq.QuoteIdentifier(d.table.Name), asName, *column.Default) +} - err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode) - if err != nil { - return err - } +func (d *duplicatorStmtBuilder) duplicateComment(column *schema.Column, asName string) string { + if column.Comment == "" { + return "" } - // Generate SQL to duplicate any check constraints on the column. This may faile - // if the check constraint is not valid for the new column type, in which case - // the error is ignored. - for _, cc := range d.table.CheckConstraints { - if cc.Name == d.withoutConstraint { - continue - } - - if slices.Contains(cc.Columns, d.column.Name) { - sql := fmt.Sprintf(cAlterTableAddCheckConstraintSQL, - pq.QuoteIdentifier(d.table.Name), - pq.QuoteIdentifier(DuplicationName(cc.Name)), - rewriteCheckExpression(cc.Definition, d.column.Name, d.asName), - ) - - _, err := d.conn.ExecContext(ctx, sql) - - err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode) - if err != nil { - return err - } - } - } + const cCommentOnColumnSQL = `COMMENT ON COLUMN %s.%s IS %s` // Generate SQL to duplicate the column's comment - if d.column.Comment != "" { - sql = fmt.Sprintf(cCommentOnColumnSQL, - pq.QuoteIdentifier(d.table.Name), - pq.QuoteIdentifier(d.asName), - pq.QuoteLiteral(d.column.Comment), - ) - - _, err = d.conn.ExecContext(ctx, sql) - if err != nil { - return err - } - } - - // Generate SQL to duplicate any unique constraints on the column - // The constraint is duplicated by adding a unique index on the column concurrently. - // The index is converted into a unique constraint on migration completion. - for _, uc := range d.table.UniqueConstraints { - if uc.Name == d.withoutConstraint { - continue - } - - if slices.Contains(uc.Columns, d.column.Name) { - sql = fmt.Sprintf(cCreateUniqueIndexSQL, - pq.QuoteIdentifier(DuplicationName(uc.Name)), - pq.QuoteIdentifier(d.table.Name), - strings.Join(quoteColumnNames(copyAndReplace(uc.Columns, d.column.Name, d.asName)), ", "), - ) - - _, err = d.conn.ExecContext(ctx, sql) - if err != nil { - return err - } - } - } - - return nil + return fmt.Sprintf(cCommentOnColumnSQL, + pq.QuoteIdentifier(d.table.Name), + pq.QuoteIdentifier(asName), + pq.QuoteLiteral(column.Comment), + ) } // DiplicationName returns the name of a duplicated column. diff --git a/pkg/migrations/duplicate_test.go b/pkg/migrations/duplicate_test.go new file mode 100644 index 00000000..13d7dc16 --- /dev/null +++ b/pkg/migrations/duplicate_test.go @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: Apache-2.0 + +package migrations + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/xataio/pgroll/pkg/schema" +) + +var table = &schema.Table{ + Name: "test_table", + Columns: map[string]schema.Column{ + "id": {Name: "id", Type: "serial"}, + "name": {Name: "name", Type: "text"}, + "nick": {Name: "nick", Type: "text"}, + "age": {Name: "age", Type: "integer"}, + "email": {Name: "email", Type: "text"}, + "city": {Name: "city", Type: "text"}, + "description": {Name: "description", Type: "text"}, + }, + UniqueConstraints: map[string]schema.UniqueConstraint{ + "unique_email": {Name: "unique_email", Columns: []string{"email"}}, + "unique_name_nick": {Name: "unique_name_nick", Columns: []string{"name", "nick"}}, + }, + CheckConstraints: map[string]schema.CheckConstraint{ + "email_at": {Name: "email_at", Columns: []string{"email"}, Definition: `"email" ~ '@'`}, + "adults": {Name: "adults", Columns: []string{"age"}, Definition: `"age" > 18`}, + "new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`}, + "different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`}, + }, +} + +func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) { + d := &duplicatorStmtBuilder{table} + for name, testCases := range map[string]struct { + columns []string + expectedStmts []string + }{ + "single column duplicated with no constraint": { + columns: []string{"description"}, + expectedStmts: []string{}, + }, + "single-column check constraint with single column duplicated": { + columns: []string{"email"}, + expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_email_at" "_pgroll_new_email" ~ '@' NOT VALID`}, + }, + "multiple multi and single column check constraint with single column duplicated": { + columns: []string{"age"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_adults" "_pgroll_new_age" > 18 NOT VALID`, + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_new_york_adults" "city" = 'New York' AND "_pgroll_new_age" > 21 NOT VALID`, + }, + }, + "multiple multi and single column check constraint with multiple column duplicated": { + columns: []string{"age", "description"}, + expectedStmts: []string{ + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_adults" "_pgroll_new_age" > 18 NOT VALID`, + `ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_new_york_adults" "city" = 'New York' AND "_pgroll_new_age" > 21 NOT VALID`, + }, + }, + "multi-column check constraint with multiple columns with single column duplicated": { + columns: []string{"name"}, + expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_different_nick" "_pgroll_new_name" != "nick" NOT VALID`}, + }, + "multi-column check constraint with multiple columns duplicated": { + columns: []string{"name", "nick"}, + expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_different_nick" "_pgroll_new_name" != "_pgroll_new_nick" NOT VALID`}, + }, + } { + t.Run(name, func(t *testing.T) { + stmts := d.duplicateCheckConstraints(nil, testCases.columns...) + assert.Equal(t, len(testCases.expectedStmts), len(stmts)) + for _, stmt := range stmts { + assert.True(t, slices.Contains(testCases.expectedStmts, stmt)) + } + }) + } +} + +func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) { + d := &duplicatorStmtBuilder{table} + for name, testCases := range map[string]struct { + columns []string + expectedStmts []string + }{ + "single column duplicated": { + columns: []string{"city"}, + expectedStmts: []string{}, + }, + "single-column constraint with single column duplicated": { + columns: []string{"email"}, + expectedStmts: []string{`CREATE UNIQUE INDEX CONCURRENTLY "_pgroll_dup_unique_email" ON "test_table" ("_pgroll_new_email")`}, + }, + "single-column constraint with multiple column duplicated": { + columns: []string{"email", "description"}, + expectedStmts: []string{`CREATE UNIQUE INDEX CONCURRENTLY "_pgroll_dup_unique_email" ON "test_table" ("_pgroll_new_email")`}, + }, + "multi-column constraint with single column duplicated": { + columns: []string{"name"}, + expectedStmts: []string{`CREATE UNIQUE INDEX CONCURRENTLY "_pgroll_dup_unique_name_nick" ON "test_table" ("_pgroll_new_name", "nick")`}, + }, + "multi-column constraint with multiple unrelated column duplicated": { + columns: []string{"name", "description"}, + expectedStmts: []string{`CREATE UNIQUE INDEX CONCURRENTLY "_pgroll_dup_unique_name_nick" ON "test_table" ("_pgroll_new_name", "nick")`}, + }, + "multi-column constraint with multiple columns": { + columns: []string{"name", "nick"}, + expectedStmts: []string{`CREATE UNIQUE INDEX CONCURRENTLY "_pgroll_dup_unique_name_nick" ON "test_table" ("_pgroll_new_name", "_pgroll_new_nick")`}, + }, + } { + t.Run(name, func(t *testing.T) { + stmts := d.duplicateUniqueConstraints(nil, testCases.columns...) + assert.Equal(t, len(testCases.expectedStmts), len(stmts)) + for _, stmt := range stmts { + assert.True(t, slices.Contains(testCases.expectedStmts, stmt)) + } + }) + } +} diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index b6595683..199b98f6 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -241,7 +241,7 @@ func (o *OpAddColumn) addCheckConstraint(ctx context.Context, tableName string, _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(o.Column.Check.Name), - rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name, TemporaryName(o.Column.Name)), + rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name), )) return err } diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 1b90bab0..08c45030 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -336,9 +336,9 @@ func duplicatorForOperations(ops []Operation, conn db.DB, table *schema.Table, c for _, op := range ops { switch op := (op).(type) { case *OpDropNotNull: - d = d.WithoutNotNull() + d = d.WithoutNotNull(column.Name) case *OpChangeType: - d = d.WithType(op.Type) + d = d.WithType(column.Name, op.Type) } } return d diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index ff8613a7..148eaf47 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -16,11 +16,52 @@ import ( var _ Operation = (*OpCreateConstraint)(nil) func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { - var err error - var table *schema.Table - for _, col := range o.Columns { - if table, err = o.duplicateColumnBeforeStart(ctx, conn, latestSchema, tr, col, s); err != nil { - return nil, err + table := s.GetTable(o.Table) + columns := make([]*schema.Column, len(o.Columns)) + for i, colName := range o.Columns { + columns[i] = table.GetColumn(colName) + } + + d := NewColumnDuplicator(conn, table, columns...) + if err := d.Duplicate(ctx); err != nil { + return nil, fmt.Errorf("failed to duplicate columns for new constraint: %w", err) + } + + // Setup triggers + for _, colName := range o.Columns { + upSQL := o.Up[colName] + physicalColumnName := TemporaryName(colName) + err := createTrigger(ctx, conn, tr, triggerConfig{ + Name: TriggerName(o.Table, colName), + Direction: TriggerDirectionUp, + Columns: table.Columns, + SchemaName: s.Name, + LatestSchema: latestSchema, + TableName: o.Table, + PhysicalColumn: physicalColumnName, + SQL: upSQL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create up trigger: %w", err) + } + + table.AddColumn(colName, schema.Column{ + Name: physicalColumnName, + }) + + downSQL := o.Down[colName] + err = createTrigger(ctx, conn, tr, triggerConfig{ + Name: TriggerName(o.Table, physicalColumnName), + Direction: TriggerDirectionDown, + Columns: table.Columns, + LatestSchema: latestSchema, + SchemaName: s.Name, + TableName: o.Table, + PhysicalColumn: colName, + SQL: downSQL, + }) + if err != nil { + return nil, fmt.Errorf("failed to create down trigger: %w", err) } } @@ -32,58 +73,6 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema return table, nil } -func (o *OpCreateConstraint) duplicateColumnBeforeStart(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, colName string, s *schema.Schema) (*schema.Table, error) { - table := s.GetTable(o.Table) - column := table.GetColumn(colName) - - d := NewColumnDuplicator(conn, table, column) - if err := d.Duplicate(ctx); err != nil { - return nil, fmt.Errorf("failed to duplicate column for new constraint: %w", err) - } - - upSQL, ok := o.Up[colName] - if !ok { - return nil, fmt.Errorf("up migration is missing for column %s", colName) - } - physicalColumnName := TemporaryName(colName) - err := createTrigger(ctx, conn, tr, triggerConfig{ - Name: TriggerName(o.Table, colName), - Direction: TriggerDirectionUp, - Columns: table.Columns, - SchemaName: s.Name, - LatestSchema: latestSchema, - TableName: o.Table, - PhysicalColumn: physicalColumnName, - SQL: upSQL, - }) - if err != nil { - return nil, fmt.Errorf("failed to create up trigger: %w", err) - } - - table.AddColumn(colName, schema.Column{ - Name: physicalColumnName, - }) - - downSQL, ok := o.Down[colName] - if !ok { - return nil, fmt.Errorf("down migration is missing for column %s", colName) - } - err = createTrigger(ctx, conn, tr, triggerConfig{ - Name: TriggerName(o.Table, physicalColumnName), - Direction: TriggerDirectionDown, - Columns: table.Columns, - LatestSchema: latestSchema, - SchemaName: s.Name, - TableName: o.Table, - PhysicalColumn: colName, - SQL: downSQL, - }) - if err != nil { - return nil, fmt.Errorf("failed to create down trigger: %w", err) - } - return table, nil -} - func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { switch o.Type { //nolint:gocritic // more cases will be added case OpCreateConstraintTypeUnique: diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index e3fbe8d3..37512adf 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -86,7 +86,7 @@ func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.D _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Check.Name), - rewriteCheckExpression(o.Check.Constraint, o.Column, TemporaryName(o.Column)), + rewriteCheckExpression(o.Check.Constraint, o.Column), )) return err @@ -97,6 +97,9 @@ func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn db.D // On migration start, however, the check is actually applied to the new (temporary) // column. // This function naively rewrites the check expression to apply to the new column. -func rewriteCheckExpression(check string, oldColumn, newColumn string) string { - return strings.ReplaceAll(check, oldColumn, newColumn) +func rewriteCheckExpression(check string, columns ...string) string { + for _, col := range columns { + check = strings.ReplaceAll(check, col, TemporaryName(col)) + } + return check } diff --git a/pkg/migrations/rename.go b/pkg/migrations/rename.go index 8c9ee7c5..2b27b77e 100644 --- a/pkg/migrations/rename.go +++ b/pkg/migrations/rename.go @@ -151,6 +151,8 @@ func RenameDuplicatedColumn(ctx context.Context, conn db.DB, table *schema.Table if err != nil { return fmt.Errorf("failed to create unique constraint from index %q: %w", ui.Name, err) } + // Index no longer exists, remove it from the table + delete(table.Indexes, ui.Name) } }