diff --git a/internal/testutils/error_codes.go b/internal/testutils/error_codes.go index d5391e9f..cd6b54ea 100644 --- a/internal/testutils/error_codes.go +++ b/internal/testutils/error_codes.go @@ -8,4 +8,5 @@ const ( NotNullViolationErrorCode string = "not_null_violation" UniqueViolationErrorCode string = "unique_violation" UndefinedColumnErrorCode string = "undefined_column" + UndefinedTableErrorCode string = "undefined_table" ) diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index da8d397b..8cb92f9c 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -46,10 +46,10 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string err := createTrigger(ctx, conn, tr, triggerConfig{ Name: TriggerName(o.Table, o.Column.Name), Direction: TriggerDirectionUp, - Columns: s.GetTable(o.Table).Columns, + Columns: table.Columns, SchemaName: s.Name, LatestSchema: latestSchema, - TableName: o.Table, + TableName: table.Name, PhysicalColumn: TemporaryName(o.Column.Name), SQL: o.Up, }) @@ -120,10 +120,11 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransforme } func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { + table := s.GetTable(o.Table) tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s", - pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(table.Name), pq.QuoteIdentifier(tempName))) if err != nil { return err diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index ffcfcc2a..a0edb6e8 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -1658,18 +1658,18 @@ func TestAddColumnDefaultTransformation(t *testing.T) { }, roll.WithSQLTransformer(sqlTransformer)) } -func TestAddColumnToATableCreatedInTheSameMigration(t *testing.T) { +func TestAddColumnInMultiOperationMigrations(t *testing.T) { t.Parallel() ExecuteTests(t, TestCases{ { - name: "add column to newly created table", + name: "create table, add column", migrations: []migrations.Migration{ { - Name: "01_add_table", + Name: "01_multi_operation", Operations: migrations.Operations{ &migrations.OpCreateTable{ - Name: "users", + Name: "items", Columns: []migrations.Column{ { Name: "id", @@ -1683,46 +1683,193 @@ func TestAddColumnToATableCreatedInTheSameMigration(t *testing.T) { }, }, &migrations.OpAddColumn{ - Table: "users", + Table: "items", Column: migrations.Column{ - Name: "age", - Type: "integer", - Nullable: false, - Default: ptr("18"), - Check: &migrations.CheckConstraint{ - Name: "age_check", - Constraint: "age >= 18", + Name: "description", + Type: "text", + }, + Up: "UPPER(name)", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // Can insert into the view in the new schema (the only version schema) + MustInsert(t, db, schema, "01_multi_operation", "items", map[string]string{ + "name": "apples", + "description": "green", + }) + + // The table has the expected rows + rows := MustSelect(t, db, schema, "01_multi_operation", "items") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "apples", "description": "green"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The table no longer exists + TableMustNotExist(t, db, schema, "items") + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // Can insert into the view in the new schema (the only version schema) + MustInsert(t, db, schema, "01_multi_operation", "items", map[string]string{ + "name": "bananas", + "description": "yellow", + }) + + // The table has the expected rows + rows := MustSelect(t, db, schema, "01_multi_operation", "items") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "bananas", "description": "yellow"}, + }, rows) + }, + }, + { + name: "rename table, add column", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "items", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "varchar(255)", }, - Comment: ptr("the age of the user"), }, }, }, }, + { + Name: "02_multi_operation", + Operations: migrations.Operations{ + &migrations.OpRenameTable{ + From: "items", + To: "products", + }, + &migrations.OpAddColumn{ + Table: "products", + Column: migrations.Column{ + Name: "description", + Type: "text", + }, + Up: "UPPER(name)", + }, + }, + }, }, afterStart: func(t *testing.T, db *sql.DB, schema string) { - // Inserting into the new column on the new table works. - MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ - "name": "Alice", "age": "30", + // Can insert into the new table in the new schema using its new name + MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{ + "name": "apples", + "description": "green", }) - // Inserting a value that doesn't meet the check constraint fails. - MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{ - "name": "Bob", "age": "8", - }, testutils.CheckViolationErrorCode) + // Can't insert into the new table in the new schema using its old name + MustNotInsert(t, db, schema, "02_multi_operation", "items", map[string]string{ + "name": "bananas", + "description": "yellow", + }, testutils.UndefinedTableErrorCode) + + // The table has the expected rows in the old schema + rows := MustSelect(t, db, schema, "01_create_table", "items") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "apples"}, + }, rows) + + // The table has the expected rows in the new schema + rows = MustSelect(t, db, schema, "02_multi_operation", "products") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "apples", "description": "green"}, + }, rows) }, afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // Can insert into the old table in the old schema using its old name + MustInsert(t, db, schema, "01_create_table", "items", map[string]string{ + "name": "bananas", + }) + + // The temporary column, functions and triggers have been cleaned up + TableMustBeCleanedUp(t, db, schema, "items", "description") }, afterComplete: func(t *testing.T, db *sql.DB, schema string) { - // Inserting into the new column on the new table works. - MustInsert(t, db, schema, "01_add_table", "users", map[string]string{ - "name": "Bob", "age": "31", + // Can insert into the new table in the new schema using its new name + MustInsert(t, db, schema, "02_multi_operation", "products", map[string]string{ + "name": "carrots", + "description": "crunchy", }) - // Inserting a value that doesn't meet the check constraint fails. - MustNotInsert(t, db, schema, "01_add_table", "users", map[string]string{ - "name": "Carl", "age": "8", - }, testutils.CheckViolationErrorCode) + // The table has the new name in the new schema and has the expected + // rows + rows := MustSelect(t, db, schema, "02_multi_operation", "products") + assert.Equal(t, []map[string]any{ + {"id": 1, "name": "apples", "description": "APPLES"}, + {"id": 2, "name": "bananas", "description": "BANANAS"}, + {"id": 3, "name": "carrots", "description": "crunchy"}, + }, rows) + + // The temporary column, functions and triggers have been cleaned up + TableMustBeCleanedUp(t, db, schema, "products", "description") + }, + }, + }) +} + +func TestAddColumnValidationInMultiOperationMigrations(t *testing.T) { + t.Parallel() + + ExecuteTests(t, TestCases{ + { + name: "adding a column with the same name twice fails to validate", + migrations: []migrations.Migration{ + { + Name: "01_create_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "items", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "varchar(255)", + }, + }, + }, + }, + }, + { + Name: "02_multi_operation", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "items", + Column: migrations.Column{ + Name: "description", + Type: "text", + }, + Up: "UPPER(name)", + }, + &migrations.OpAddColumn{ + Table: "items", + Column: migrations.Column{ + Name: "description", + Type: "varchar(255)", + }, + Up: "UPPER(name)", + }, + }, + }, }, + wantStartErr: migrations.ColumnAlreadyExistsError{Table: "items", Name: "description"}, }, }) } diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index a2b85b4a..f06e286a 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -137,7 +137,9 @@ func (o *OpCreateTable) updateSchema(s *schema.Schema) *schema.Schema { columns := make(map[string]*schema.Column, len(o.Columns)) for _, col := range o.Columns { columns[col.Name] = &schema.Column{ - Name: col.Name, + Name: col.Name, + Unique: col.Unique, + Nullable: col.Nullable, } } var uniqueConstraints map[string]*schema.UniqueConstraint @@ -153,10 +155,20 @@ func (o *OpCreateTable) updateSchema(s *schema.Schema) *schema.Schema { } } } + + // Build the table's primary key from the columns that have the `Pk` flag set + var primaryKey []string + for _, col := range o.Columns { + if col.Pk { + primaryKey = append(primaryKey, col.Name) + } + } + s.AddTable(o.Name, &schema.Table{ Name: o.Name, Columns: columns, UniqueConstraints: uniqueConstraints, + PrimaryKey: primaryKey, }) return s