diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index b4383d3e..43022662 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -24,19 +24,19 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string } if o.Column.Comment != nil { - if err := addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column.Name), o.Column.Comment); err != nil { + if err := addCommentToColumn(ctx, conn, table.Name, TemporaryName(o.Column.Name), o.Column.Comment); err != nil { return nil, fmt.Errorf("failed to add comment to column: %w", err) } } if !o.Column.IsNullable() && o.Column.Default == nil { - if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil { + if err := addNotNullConstraint(ctx, conn, table.Name, o.Column.Name, TemporaryName(o.Column.Name)); err != nil { return nil, fmt.Errorf("failed to add not null constraint: %w", err) } } if o.Column.Check != nil { - if err := o.addCheckConstraint(ctx, conn); err != nil { + if err := o.addCheckConstraint(ctx, table.Name, conn); err != nil { return nil, fmt.Errorf("failed to add check constraint: %w", err) } } @@ -231,9 +231,9 @@ func addNotNullConstraint(ctx context.Context, conn db.DB, table, column, physic return err } -func (o *OpAddColumn) addCheckConstraint(ctx context.Context, conn db.DB) error { +func (o *OpAddColumn) addCheckConstraint(ctx context.Context, tableName string, conn db.DB) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID", - pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(tableName), pq.QuoteIdentifier(o.Column.Check.Name), rewriteCheckExpression(o.Column.Check.Constraint, o.Column.Name, TemporaryName(o.Column.Name)), )) diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 51d29506..c8fb2796 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -1494,6 +1494,74 @@ func TestAddColumnDefaultTransformation(t *testing.T) { }, roll.WithSQLTransformer(sqlTransformer)) } +func TestAddColumnToATableCreatedInTheSameMigration(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{ + { + name: "add column to newly created table", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + }, + }, + }, + &migrations.OpAddColumn{ + Table: "users", + Column: migrations.Column{ + Name: "age", + Type: "integer", + Nullable: ptr(false), + Check: &migrations.CheckConstraint{ + Name: "age_check", + Constraint: "age >= 18", + }, + Comment: ptr("the age of the user"), + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // Inserting into the new column on the new table works. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "Alice", "age": "30", + }) + + // Inserting a value that doesn't meet the check constraint fails. + MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "Bob", "age": "8", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Inserting into the new column on the new table works. + MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "Bob", "age": "31", + }) + + // Inserting a value that doesn't meet the check constraint fails. + MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{ + "name": "Carl", "age": "8", + }, testutils.CheckViolationErrorCode) + }, + }, + }, roll.WithSkipValidation(true)) // TODO: remove once this migration can be validated +} + func TestAddColumnInvalidNameLength(t *testing.T) { t.Parallel()