Skip to content

Commit

Permalink
Merge pull request #209 from go-gormigrate/refactor-tests
Browse files Browse the repository at this point in the history
Refactor integration-test: avoid hardcoding driver features in the test, keep it in dialect container instead
  • Loading branch information
avakarev authored Jun 3, 2023
2 parents 21a38fc + caeca82 commit 905c727
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 70 deletions.
2 changes: 1 addition & 1 deletion integration-test/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.20

require (
github.com/glebarez/sqlite v1.8.0
github.com/go-gormigrate/gormigrate/v2 v2.0.0-00010101000000-000000000000
github.com/go-gormigrate/gormigrate/v2 v2.1.0
github.com/joho/godotenv v1.5.1
github.com/stretchr/testify v1.8.4
gorm.io/driver/mysql v1.5.1
Expand Down
105 changes: 51 additions & 54 deletions integration-test/gormigrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,40 @@ import (
"github.com/go-gormigrate/gormigrate/v2"
)

var databases []database
var dialects dialectList

type database struct {
dialect string
driver gorm.Dialector
type dialect struct {
name string
driver gorm.Dialector
// Not all databases support transactional DDL statements
supportsAtomicDDL bool
}

type dialectList []dialect

func (dl dialectList) WithTransactionSupport() dialectList {
filtered := dialectList{}
for _, d := range dl {
if d.supportsAtomicDDL {
filtered = append(filtered, d)
}
}
return filtered
}

func (dl dialectList) forEachDB(t *testing.T, fn func(gormdb *gorm.DB)) {
for _, dia := range dl {
// Ensure defers are not stacked up for each DB
func(dia dialect) {
db, err := gorm.Open(dia.driver, &gorm.Config{})
require.NoError(t, err, "Could not connect to database %s, %v", dia.name, err)

// ensure database is clean before running test
assert.NoError(t, db.Migrator().DropTable("migrations", "people", "pets"))

fn(db)
}(dia)
}
}

var migrations = []*gormigrate.Migration{
Expand Down Expand Up @@ -83,7 +112,7 @@ type Book struct {
}

func TestMigration(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, migrations)

err := m.Migrate()
Expand All @@ -107,7 +136,7 @@ func TestMigration(t *testing.T) {
}

func TestMigrateTo(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, extendedMigrations)

err := m.MigrateTo("201608301430")
Expand All @@ -120,7 +149,7 @@ func TestMigrateTo(t *testing.T) {
}

func TestRollbackTo(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, extendedMigrations)

// First, apply all migrations.
Expand All @@ -144,7 +173,7 @@ func TestRollbackTo(t *testing.T) {
// If initSchema is defined, but no migrations are provided,
// then initSchema is executed.
func TestInitSchemaNoMigrations(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, []*gormigrate.Migration{})
m.InitSchema(func(tx *gorm.DB) error {
if err := tx.AutoMigrate(&Person{}); err != nil {
Expand All @@ -167,7 +196,7 @@ func TestInitSchemaNoMigrations(t *testing.T) {
// then initSchema is executed and the migration IDs are stored,
// even though the relevant migrations are not applied.
func TestInitSchemaWithMigrations(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, migrations)
m.InitSchema(func(tx *gorm.DB) error {
if err := tx.AutoMigrate(&Person{}); err != nil {
Expand All @@ -190,7 +219,7 @@ func TestInitSchemaAlreadyInitialised(t *testing.T) {
gorm.Model
}

forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, []*gormigrate.Migration{})

// Migrate with empty initialisation
Expand Down Expand Up @@ -222,7 +251,7 @@ func TestInitSchemaExistingMigrations(t *testing.T) {
gorm.Model
}

forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, migrations)

// Migrate without initialisation
Expand All @@ -244,7 +273,7 @@ func TestInitSchemaExistingMigrations(t *testing.T) {
}

func TestMigrationIDDoesNotExist(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, gormigrate.DefaultOptions, migrations)
assert.Equal(t, gormigrate.ErrMigrationIDDoesNotExist, m.MigrateTo("1234"))
assert.Equal(t, gormigrate.ErrMigrationIDDoesNotExist, m.RollbackTo("1234"))
Expand All @@ -254,7 +283,7 @@ func TestMigrationIDDoesNotExist(t *testing.T) {
}

func TestMissingID(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
migrationsMissingID := []*gormigrate.Migration{
{
Migrate: func(tx *gorm.DB) error {
Expand All @@ -269,7 +298,7 @@ func TestMissingID(t *testing.T) {
}

func TestReservedID(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
migrationsReservedID := []*gormigrate.Migration{
{
ID: "SCHEMA_INIT",
Expand All @@ -286,7 +315,7 @@ func TestReservedID(t *testing.T) {
}

func TestDuplicatedID(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
migrationsDuplicatedID := []*gormigrate.Migration{
{
ID: "201705061500",
Expand All @@ -309,7 +338,7 @@ func TestDuplicatedID(t *testing.T) {
}

func TestEmptyMigrationList(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
t.Run("with empty list", func(t *testing.T) {
m := gormigrate.New(db, gormigrate.DefaultOptions, []*gormigrate.Migration{})
err := m.Migrate()
Expand All @@ -328,7 +357,7 @@ func TestMigration_WithUseTransactions(t *testing.T) {
options := gormigrate.DefaultOptions
options.UseTransaction = true

forEachDatabase(t, func(db *gorm.DB) {
dialects.WithTransactionSupport().forEachDB(t, func(db *gorm.DB) {
m := gormigrate.New(db, options, migrations)

err := m.Migrate()
Expand All @@ -348,26 +377,26 @@ func TestMigration_WithUseTransactions(t *testing.T) {
assert.False(t, db.Migrator().HasTable(&Person{}))
assert.False(t, db.Migrator().HasTable(&Pet{}))
assert.Equal(t, int64(0), tableCount(t, db, "migrations"))
}, "postgres", "sqlite", "sqlitego", "sqlserver")
})
}

func TestMigration_WithUseTransactionsShouldRollback(t *testing.T) {
options := gormigrate.DefaultOptions
options.UseTransaction = true

forEachDatabase(t, func(db *gorm.DB) {
dialects.WithTransactionSupport().forEachDB(t, func(db *gorm.DB) {
assert.True(t, true)
m := gormigrate.New(db, options, failingMigration)

// Migration should return an error and not leave around a Book table
err := m.Migrate()
assert.Error(t, err)
assert.False(t, db.Migrator().HasTable(&Book{}))
}, "postgres", "sqlite", "sqlitego", "sqlserver")
})
}

func TestUnexpectedMigrationEnabled(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
options := gormigrate.DefaultOptions
options.ValidateUnknownMigrations = true
m := gormigrate.New(db, options, migrations)
Expand All @@ -383,7 +412,7 @@ func TestUnexpectedMigrationEnabled(t *testing.T) {
}

func TestUnexpectedMigrationDisabled(t *testing.T) {
forEachDatabase(t, func(db *gorm.DB) {
dialects.forEachDB(t, func(db *gorm.DB) {
options := gormigrate.DefaultOptions
options.ValidateUnknownMigrations = false
m := gormigrate.New(db, options, migrations)
Expand All @@ -402,35 +431,3 @@ func tableCount(t *testing.T, db *gorm.DB, tableName string) (count int64) {
assert.NoError(t, db.Table(tableName).Count(&count).Error)
return
}

func forEachDatabase(t *testing.T, fn func(database *gorm.DB), dialects ...string) {
if len(databases) == 0 {
panic("No database chosen for testing!")
}

for _, database := range databases {
if len(dialects) > 0 && !contains(dialects, database.dialect) {
t.Skipf("test is not supported by [%s] dialect", database.dialect)
}

// Ensure defers are not stacked up for each DB
func() {
db, err := gorm.Open(database.driver, &gorm.Config{})
require.NoError(t, err, "Could not connect to database %s, %v", database.dialect, err)

// ensure tables do not exists
assert.NoError(t, db.Migrator().DropTable("migrations", "people", "pets"))

fn(db)
}()
}
}

func contains(haystack []string, needle string) bool {
for _, straw := range haystack {
if straw == needle {
return true
}
}
return false
}
10 changes: 7 additions & 3 deletions integration-test/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ import (
)

func init() {
databases = append(databases, database{
dialect: "mysql",
driver: mysql.Open(os.Getenv("MYSQL_DSN")),
dialects = append(dialects, dialect{
name: "mysql",
driver: mysql.Open(os.Getenv("MYSQL_DSN")),
// mysql/mariadb causes implicit commits in transactional DDL statements, see for details:
// https://mariadb.com/kb/en/sql-statements-that-cause-an-implicit-commit
// https://dev.mysql.com/doc/refman/8.0/en/atomic-ddl.html
supportsAtomicDDL: false,
})
}
7 changes: 4 additions & 3 deletions integration-test/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

func init() {
databases = append(databases, database{
dialect: "postgres",
driver: postgres.Open(os.Getenv("POSTGRES_DSN")),
dialects = append(dialects, dialect{
name: "postgres",
driver: postgres.Open(os.Getenv("POSTGRES_DSN")),
supportsAtomicDDL: true,
})
}
7 changes: 4 additions & 3 deletions integration-test/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

func init() {
databases = append(databases, database{
dialect: "sqlite",
driver: sqlite.Open(os.Getenv("SQLITE_DSN")),
dialects = append(dialects, dialect{
name: "sqlite",
driver: sqlite.Open(os.Getenv("SQLITE_DSN")),
supportsAtomicDDL: true,
})
}
7 changes: 4 additions & 3 deletions integration-test/sqlitego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

func init() {
databases = append(databases, database{
dialect: "sqlitego",
driver: sqlite.Open(os.Getenv("SQLITE_DSN")),
dialects = append(dialects, dialect{
name: "sqlitego",
driver: sqlite.Open(os.Getenv("SQLITE_DSN")),
supportsAtomicDDL: true,
})
}
7 changes: 4 additions & 3 deletions integration-test/sqlserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (
)

func init() {
databases = append(databases, database{
dialect: "sqlserver",
driver: sqlserver.Open(os.Getenv("SQLSERVER_DSN")),
dialects = append(dialects, dialect{
name: "sqlserver",
driver: sqlserver.Open(os.Getenv("SQLSERVER_DSN")),
supportsAtomicDDL: true,
})
}

0 comments on commit 905c727

Please sign in to comment.