diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index 5b463cdef..d228d210d 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -75,7 +75,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { db.RegisterModel((*PublisherToJournalist)(nil)) - dbInspector, err := sqlschema.NewInspector(db) + dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) if err != nil { t.Skip(err) } diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 28b45553c..518909592 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -3,6 +3,8 @@ package dbtest_test import ( "context" "errors" + "os" + "path/filepath" "strings" "testing" "time" @@ -19,14 +21,28 @@ const ( migrationLocksTable = "test_migration_locks" ) +var migrationsDir = filepath.Join(os.TempDir(), "dbtest") + +// cleanupMigrations adds a cleanup function to reset migration tables. +// The reset does not run for skipped tests to avoid unnecessary work. +// +// Usage: +// +// testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { +// cleanupMigrations(t, ctx, db) +// // some test that may generate migration entries in the db +// }) func cleanupMigrations(tb testing.TB, ctx context.Context, db *bun.DB) { tb.Cleanup(func() { - var err error - _, err = db.NewDropTable().ModelTableExpr(migrationsTable).Exec(ctx) - require.NoError(tb, err, "drop %q table", migrationsTable) + if tb.Skipped() { + return + } - _, err = db.NewDropTable().ModelTableExpr(migrationLocksTable).Exec(ctx) - require.NoError(tb, err, "drop %q table", migrationLocksTable) + m := migrate.NewMigrator(db, migrate.NewMigrations(), + migrate.WithTableName(migrationsTable), + migrate.WithLocksTableName(migrationLocksTable), + ) + require.NoError(tb, m.Reset(ctx)) }) } @@ -163,27 +179,45 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { require.Equal(t, []string{"down2", "down1"}, history) } -// newAutoMigrator creates an AutoMigrator configured to use test migratins/locks tables. -// If the dialect doesn't support schema inspections or migrations, the test will fail with the corresponding error. -func newAutoMigrator(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigratorOption) *migrate.AutoMigrator { +// newAutoMigratorOrSkip creates an AutoMigrator configured to use test migratins/locks +// tables and dedicated migrations directory. If an AutoMigrator cannob be created because +// the dialect doesn't support either schema inspections or migrations, the test will be *skipped* +// with the corresponding error. +// Additionally, it will create the migrations directory and if +// one does not exist and add a function to tear it down on cleanup. +func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigratorOption) *migrate.AutoMigrator { tb.Helper() opts = append(opts, migrate.WithTableNameAuto(migrationsTable), migrate.WithLocksTableNameAuto(migrationLocksTable), + migrate.WithMigrationsDirectoryAuto(migrationsDir), ) m, err := migrate.NewAutoMigrator(db, opts...) - require.NoError(tb, err) + if err != nil { + tb.Skip(err) + } + + err = os.MkdirAll(migrationsDir, os.ModePerm) + require.NoError(tb, err, "cannot continue test without migrations directory") + + tb.Cleanup(func() { + if err := os.RemoveAll(migrationsDir); err != nil { + tb.Logf("cleanup: remove migrations dir: %v", err) + } + }) + return m } // inspectDbOrSkip returns a function to inspect the current state of the database. -// It calls tb.Skip() if the current dialect doesn't support database inpection and -// fails the test if the inspector cannot successfully retrieve database state. +// The test will be *skipped* if the current dialect doesn't support database inpection +// and fail if the inspector cannot successfully retrieve database state. func inspectDbOrSkip(tb testing.TB, db *bun.DB) func(context.Context) sqlschema.State { tb.Helper() - inspector, err := sqlschema.NewInspector(db) + // AutoMigrator excludes these tables by default, but here we need to do this explicitly. + inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) if err != nil { tb.Skip(err) } @@ -194,7 +228,78 @@ func inspectDbOrSkip(tb testing.TB, db *bun.DB) func(context.Context) sqlschema. } } -func TestAutoMigrator_Run(t *testing.T) { +func TestAutoMigrator_CreateSQLMigrations(t *testing.T) { + type NewTable struct { + bun.BaseModel `bun:"table:new_table"` + Bar string + Baz time.Time + } + + testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { + ctx := context.Background() + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*NewTable)(nil))) + + migrations, err := m.CreateSQLMigrations(ctx) + require.NoError(t, err, "should create migrations successfully") + + require.Len(t, migrations, 2, "expected up/down migration pair") + require.DirExists(t, migrationsDir) + checkMigrationFileContains(t, ".up.sql", "CREATE TABLE") + checkMigrationFileContains(t, ".down.sql", "DROP TABLE") + }) +} + +// checkMigrationFileContains expected SQL snippet. +func checkMigrationFileContains(t *testing.T, fileSuffix string, content string) { + t.Helper() + + files, err := os.ReadDir(migrationsDir) + require.NoErrorf(t, err, "list files in %s", migrationsDir) + + for _, f := range files { + if strings.HasSuffix(f.Name(), fileSuffix) { + b, err := os.ReadFile(filepath.Join(migrationsDir, f.Name())) + require.NoError(t, err) + require.Containsf(t, string(b), content, "expected %s file to contain string", f.Name()) + return + } + } + t.Errorf("no *%s file in migrations directory (%s)", fileSuffix, migrationsDir) +} + +// checkMigrationFilesExist makes sure both up- and down- SQL migration files were created. +func checkMigrationFilesExist(t *testing.T) { + t.Helper() + + files, err := os.ReadDir(migrationsDir) + require.NoErrorf(t, err, "list files in %s", migrationsDir) + + var up, down bool + for _, f := range files { + if !up && strings.HasSuffix(f.Name(), ".up.sql") { + up = true + } else if !down && strings.HasSuffix(f.Name(), ".down.sql") { + down = true + } + } + + if !up { + t.Errorf("no .up.sql file created in migrations directory (%s)", migrationsDir) + } + if !down { + t.Errorf("no .down.sql file created in migrations directory (%s)", migrationsDir) + } +} + +func runMigrations(t *testing.T, m *migrate.AutoMigrator) { + t.Helper() + + _, err := m.Migrate(ctx) + require.NoError(t, err, "auto migration failed") + checkMigrationFilesExist(t) +} + +func TestAutoMigrator_Migrate(t *testing.T) { tests := []struct { fn func(t *testing.T, db *bun.DB) @@ -219,6 +324,11 @@ func TestAutoMigrator_Run(t *testing.T) { testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { for _, tt := range tests { t.Run(funcName(tt.fn), func(t *testing.T) { + // Because they are executed so fast, tests may generate migrations + // with the same timestamp, so that only the first of them will apply. + // To eliminate these side-effects we cleanup migration tables after + // after every test case. + cleanupMigrations(t, ctx, db) tt.fn(t, db) }) } @@ -241,16 +351,14 @@ func testRenameTable(t *testing.T, db *bun.DB) { inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*initial)(nil)) mustDropTableOnCleanup(t, ctx, db, (*changed)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*changed)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*changed)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) tables := state.Tables - require.Len(t, tables, 1) require.Equal(t, "changed", tables[0].Name) } @@ -272,16 +380,14 @@ func testCreateDropTable(t *testing.T, db *bun.DB) { inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*DropMe)(nil)) mustDropTableOnCleanup(t, ctx, db, (*CreateMe)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*CreateMe)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*CreateMe)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) tables := state.Tables - require.Len(t, tables, 1) require.Equal(t, "createme", tables[0].Name) } @@ -332,15 +438,14 @@ func testAlterForeignKeys(t *testing.T, db *bun.DB) { ) mustDropTableOnCleanup(t, ctx, db, (*ThingsToOwner)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel( + m := newAutoMigratorOrSkip(t, db, migrate.WithModel( (*ThingCommon)(nil), (*OwnerCommon)(nil), (*ThingsToOwner)(nil), )) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -399,7 +504,7 @@ func testForceRenameFK(t *testing.T, db *bun.DB) { ) mustDropTableOnCleanup(t, ctx, db, (*Person)(nil)) - m := newAutoMigrator(t, db, + m := newAutoMigratorOrSkip(t, db, migrate.WithModel( (*Person)(nil), (*PersonalThing)(nil), @@ -413,13 +518,11 @@ func testForceRenameFK(t *testing.T, db *bun.DB) { ) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) schema := db.Dialect().DefaultSchema() - wantName, ok := state.FKs[sqlschema.FK{ From: sqlschema.C(schema, "things", "owner_id"), To: sqlschema.C(schema, "people", "id"), @@ -459,7 +562,7 @@ func testCustomFKNameFunc(t *testing.T, db *bun.DB) { (*Column)(nil), ) - m := newAutoMigrator(t, db, + m := newAutoMigratorOrSkip(t, db, migrate.WithFKNameFunc(func(sqlschema.FK) string { return "test_fkey" }), migrate.WithModel( (*TableM)(nil), @@ -468,8 +571,7 @@ func testCustomFKNameFunc(t *testing.T, db *bun.DB) { ) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -514,18 +616,16 @@ func testRenamedColumns(t *testing.T, db *bun.DB) { (*Model1)(nil), ) mustDropTableOnCleanup(t, ctx, db, (*Renamed)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel( + m := newAutoMigratorOrSkip(t, db, migrate.WithModel( (*Model2)(nil), (*Renamed)(nil), )) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) - require.Len(t, state.Tables, 2) var renamed, model2 sqlschema.Table @@ -565,18 +665,16 @@ func testRenameColumnRenamesFK(t *testing.T, db *bun.DB) { ctx := context.Background() inspect := inspectDbOrSkip(t, db) mustCreateTableWithFKs(t, ctx, db, (*TennantBefore)(nil)) - m := newAutoMigrator(t, db, + m := newAutoMigratorOrSkip(t, db, migrate.WithRenameFK(true), migrate.WithModel((*TennantAfter)(nil)), ) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) - fkName := state.FKs[sqlschema.FK{ From: sqlschema.C(db.Dialect().DefaultSchema(), "tennants", "my_neighbour"), To: sqlschema.C(db.Dialect().DefaultSchema(), "tennants", "tennant_id"), @@ -655,11 +753,10 @@ func testChangeColumnType_AutoCast(t *testing.T, db *bun.DB) { ctx := context.Background() inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*TableBefore)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*TableAfter)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -699,11 +796,10 @@ func testIdentity(t *testing.T, db *bun.DB) { ctx := context.Background() inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*TableBefore)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*TableAfter)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -743,11 +839,10 @@ func testAddDropColumn(t *testing.T, db *bun.DB) { ctx := context.Background() inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*TableBefore)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*TableAfter)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -823,11 +918,10 @@ func testUnique(t *testing.T, db *bun.DB) { ctx := context.Background() inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*TableBefore)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*TableAfter)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -894,11 +988,10 @@ func testUniqueRenamedTable(t *testing.T, db *bun.DB) { inspect := inspectDbOrSkip(t, db) mustResetModel(t, ctx, db, (*TableBefore)(nil)) mustDropTableOnCleanup(t, ctx, db, (*TableAfter)(nil)) - m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil))) + m := newAutoMigratorOrSkip(t, db, migrate.WithModel((*TableAfter)(nil))) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) @@ -1011,17 +1104,16 @@ func testUpdatePrimaryKeys(t *testing.T, db *bun.DB) { (*AddNewPKBefore)(nil), (*ChangePKBefore)(nil), ) - m := newAutoMigrator(t, db, migrate.WithModel( + m := newAutoMigratorOrSkip(t, db, migrate.WithModel( (*DropPKAfter)(nil), (*AddNewPKAfter)(nil), (*ChangePKAfter)(nil)), ) // Act - err := m.Run(ctx) - require.NoError(t, err) + runMigrations(t, m) // Assert state := inspect(ctx) cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables) -} \ No newline at end of file +} diff --git a/migrate/auto.go b/migrate/auto.go index b1cacf691..70236e8e5 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -1,8 +1,11 @@ package migrate import ( + "bytes" "context" "fmt" + "os" + "path/filepath" "github.com/uptrace/bun" "github.com/uptrace/bun/migrate/sqlschema" @@ -72,6 +75,12 @@ func WithMarkAppliedOnSuccessAuto(enabled bool) AutoMigratorOption { } } +func WithMigrationsDirectoryAuto(directory string) AutoMigratorOption { + return func(m *AutoMigrator) { + m.migrationsOpts = append(m.migrationsOpts, WithMigrationsDirectory(directory)) + } +} + type AutoMigrator struct { db *bun.DB @@ -98,6 +107,9 @@ type AutoMigrator struct { // migratorOpts are passed to Migrator constructor. migratorOpts []MigratorOption + + // migrationsOpts are passed to Migrations constructor. + migrationsOpts []MigrationsOption } func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, error) { @@ -156,14 +168,37 @@ func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) { // Migrate writes required changes to a new migration file and runs the migration. // This will create and entry in the migrations table, making it possible to revert // the changes with Migrator.Rollback(). -func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) error { +func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { + migrations, _, err := am.createSQLMigrations(ctx) + if err != nil { + return nil, fmt.Errorf("auto migrate: %w", err) + } + + migrator := NewMigrator(am.db, migrations, am.migratorOpts...) + if err := migrator.Init(ctx); err != nil { + return nil, fmt.Errorf("auto migrate: %w", err) + } + + group, err := migrator.Migrate(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("auto migrate: %w", err) + } + return group, nil +} + +func (am *AutoMigrator) CreateSQLMigrations(ctx context.Context) ([]*MigrationFile, error) { + _, files, err := am.createSQLMigrations(ctx) + return files, err +} + +func (am *AutoMigrator) createSQLMigrations(ctx context.Context) (*Migrations, []*MigrationFile, error) { changes, err := am.plan(ctx) if err != nil { - return fmt.Errorf("auto migrate: %w", err) + return nil, nil, fmt.Errorf("create sql migrations: %w", err) } - migrations := NewMigrations() name, _ := genMigrationName("auto") + migrations := NewMigrations(am.migrationsOpts...) migrations.Add(Migration{ Name: name, Up: changes.Up(am.dbMigrator), @@ -171,26 +206,34 @@ func (am *AutoMigrator) Migrate(ctx context.Context, opts ...MigrationOption) er Comment: "Changes detected by bun.migrate.AutoMigrator", }) - migrator := NewMigrator(am.db, migrations, am.migratorOpts...) - if err := migrator.Init(ctx); err != nil { - return fmt.Errorf("auto migrate: %w", err) + up, err := am.createSQL(ctx, migrations, name+".up.sql", changes) + if err != nil { + return nil, nil, fmt.Errorf("create sql migration up: %w", err) } - if _, err := migrator.Migrate(ctx, opts...); err != nil { - return fmt.Errorf("auto migrate: %w", err) + down, err := am.createSQL(ctx, migrations, name+".down.sql", changes.GetReverse()) + if err != nil { + return nil, nil, fmt.Errorf("create sql migration down: %w", err) } - return nil + return migrations, []*MigrationFile{up, down}, nil } -// Run runs required migrations in-place and without creating a database entry. -func (am *AutoMigrator) Run(ctx context.Context) error { - changes, err := am.plan(ctx) - if err != nil { - return fmt.Errorf("auto migrate: %w", err) +func (am *AutoMigrator) createSQL(_ context.Context, migrations *Migrations, fname string, changes *changeset) (*MigrationFile, error) { + var buf bytes.Buffer + if err := changes.WriteTo(&buf, am.dbMigrator); err != nil { + return nil, err } - up := changes.Up(am.dbMigrator) - if err := up(ctx, am.db); err != nil { - return fmt.Errorf("auto migrate: %w", err) + content := buf.Bytes() + + fpath := filepath.Join(migrations.getDirectory(), fname) + if err := os.WriteFile(fpath, content, 0o644); err != nil { + return nil, err + } + + mf := &MigrationFile{ + Name: fname, + Path: fpath, + Content: string(content), } - return nil + return mf, nil } diff --git a/migrate/diff.go b/migrate/diff.go index fa1743671..e1ea59832 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "strings" "github.com/uptrace/bun" @@ -149,6 +150,15 @@ func (c *changeset) Func(m sqlschema.Migrator) MigrationFunc { } } +// GetReverse returns a new changeset with each operation in it "reversed" and in reverse order. +func (c *changeset) GetReverse() *changeset { + var reverse changeset + for i := len(c.operations) - 1; i >= 0; i-- { + reverse.Add(c.operations[i].GetReverse()) + } + return &reverse +} + // Up is syntactic sugar. func (c *changeset) Up(m sqlschema.Migrator) MigrationFunc { return c.Func(m) @@ -156,11 +166,7 @@ func (c *changeset) Up(m sqlschema.Migrator) MigrationFunc { // Down is syntactic sugar. func (c *changeset) Down(m sqlschema.Migrator) MigrationFunc { - var reverse changeset - for i := len(c.operations) - 1; i >= 0; i-- { - reverse.Add(c.operations[i].GetReverse()) - } - return reverse.Func(m) + return c.GetReverse().Func(m) } // apply generates SQL for each operation and executes it. @@ -184,6 +190,29 @@ func (c *changeset) apply(ctx context.Context, db *bun.DB, m sqlschema.Migrator) return nil } +func (c *changeset) WriteTo(w io.Writer, m sqlschema.Migrator) error { + var err error + + b := internal.MakeQueryBytes() + for _, op := range c.operations { + if _, isNoop := op.(*noop); isNoop { + // TODO: write migration-specific commend instead + b = append(b, "-- Down-migrations are not supported for some changes.\n"...) + continue + } + + b, err = m.AppendSQL(b, op) + if err != nil { + return fmt.Errorf("write changeset: %w", err) + } + b = append(b, ";\n"...) + } + if _, err := w.Write(b); err != nil { + return fmt.Errorf("write changeset: %w", err) + } + return nil +} + func (c *changeset) ResolveDependencies() error { if len(c.operations) <= 1 { return nil diff --git a/migrate/migrator.go b/migrate/migrator.go index 9f1b5222c..d5a72aec0 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -314,7 +314,7 @@ func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*Mig return []*MigrationFile{up, down}, nil } -func (m *Migrator) createSQL(ctx context.Context, fname string, transactional bool) (*MigrationFile, error) { +func (m *Migrator) createSQL(_ context.Context, fname string, transactional bool) (*MigrationFile, error) { fpath := filepath.Join(m.migrations.getDirectory(), fname) template := sqlTemplate