diff --git a/clickhouse.go b/clickhouse.go index d9dbc61..14c7d29 100644 --- a/clickhouse.go +++ b/clickhouse.go @@ -82,12 +82,6 @@ func (h *clickhouse) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction return tx.Commit() } -// splitter is a batchSplitter interface implementation. We need it for -// ClickHouseDB because clickhouse doesn't support multi-statements. -func (*clickhouse) splitter() []byte { - return []byte(";\n") -} - func (h *clickhouse) cleanTableQuery(tableName string) string { if h.cleanTableFn == nil { return h.baseHelper.cleanTableQuery(tableName) diff --git a/clickhouse_test.go b/clickhouse_test.go index c23de30..997af7f 100644 --- a/clickhouse_test.go +++ b/clickhouse_test.go @@ -1,5 +1,4 @@ //go:build clickhouse -// +build clickhouse package testfixtures @@ -11,10 +10,7 @@ import ( ) func TestClickhouse(t *testing.T) { - testLoader( - t, - "clickhouse", - os.Getenv("CLICKHOUSE_CONN_STRING"), - "testdata/schema/clickhouse.sql", - ) + db := openDB(t, "clickhouse", os.Getenv("CLICKHOUSE_CONN_STRING")) + loadSchemaInBatchesBySplitter(t, db, "testdata/schema/clickhouse.sql", []byte(";\n")) + testLoader(t, db, "clickhouse") } diff --git a/cockroachdb_test.go b/cockroachdb_test.go index 50882a0..6e7543e 100644 --- a/cockroachdb_test.go +++ b/cockroachdb_test.go @@ -1,5 +1,4 @@ //go:build cockroachdb -// +build cockroachdb package testfixtures @@ -13,11 +12,12 @@ import ( func TestCockroachDB(t *testing.T) { for _, dialect := range []string{"postgres", "pgx"} { + db := openDB(t, dialect, os.Getenv("CRDB_CONN_STRING")) + loadSchemaInOneQuery(t, db, "testdata/schema/cockroachdb.sql") testLoader( t, + db, dialect, - os.Getenv("CRDB_CONN_STRING"), - "testdata/schema/cockroachdb.sql", DangerousSkipTestDatabaseCheck(), UseDropConstraint(), ) diff --git a/helper.go b/helper.go index 1e00b85..1c19e89 100644 --- a/helper.go +++ b/helper.go @@ -34,16 +34,6 @@ type queryable interface { QueryRow(string, ...interface{}) *sql.Row } -// batchSplitter is an interface with method which returns byte slice for -// splitting SQL batches. This need to split sql statements and run its -// separately. -// -// For Microsoft SQL Server batch splitter is "GO". For details see -// https://docs.microsoft.com/en-us/sql/t-sql/language-elements/sql-server-utilities-statements-go -type batchSplitter interface { //nolint - splitter() []byte -} - var ( _ helper = &clickhouse{} _ helper = &spanner{} diff --git a/mysql_test.go b/mysql_test.go index a48b4d4..a6b51a2 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -1,4 +1,4 @@ -// +build mysql +//go:build mysql package testfixtures @@ -10,10 +10,7 @@ import ( ) func TestMySQL(t *testing.T) { - testLoader( - t, - "mysql", - os.Getenv("MYSQL_CONN_STRING"), - "testdata/schema/mysql.sql", - ) + db := openDB(t, "mysql", os.Getenv("MYSQL_CONN_STRING")) + loadSchemaInOneQuery(t, db, "testdata/schema/mysql.sql") + testLoader(t, db, "mysql") } diff --git a/postgresql_test.go b/postgresql_test.go index 9c862d2..f980d19 100644 --- a/postgresql_test.go +++ b/postgresql_test.go @@ -1,4 +1,4 @@ -// +build postgresql +//go:build postgresql package testfixtures @@ -11,36 +11,28 @@ import ( ) func TestPostgreSQL(t *testing.T) { - for _, dialect := range []string{"postgres", "pgx"} { - testLoader( - t, - dialect, - os.Getenv("PG_CONN_STRING"), - "testdata/schema/postgresql.sql", - ) - } + testPostgreSQL(t) } func TestPostgreSQLWithAlterConstraint(t *testing.T) { - for _, dialect := range []string{"postgres", "pgx"} { - testLoader( - t, - dialect, - os.Getenv("PG_CONN_STRING"), - "testdata/schema/postgresql.sql", - UseAlterConstraint(), - ) - } + testPostgreSQL(t, UseAlterConstraint()) } func TestPostgreSQLWithDropConstraint(t *testing.T) { + testPostgreSQL(t, UseDropConstraint()) +} + +func testPostgreSQL(t *testing.T, additionalOptions ...func(*Loader) error) { + t.Helper() for _, dialect := range []string{"postgres", "pgx"} { + db := openDB(t, dialect, os.Getenv("PG_CONN_STRING")) + loadSchemaInOneQuery(t, db, "testdata/schema/postgresql.sql") testLoader( t, + db, dialect, - os.Getenv("PG_CONN_STRING"), - "testdata/schema/postgresql.sql", - UseDropConstraint(), + additionalOptions..., ) } + } diff --git a/spanner.go b/spanner.go index 783c511..28fd6e3 100644 --- a/spanner.go +++ b/spanner.go @@ -84,12 +84,6 @@ func (h *spanner) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) ( return h.dropAndRecreateConstraints(db, loadFn) } -// splitter is a batchSplitter interface implementation. We need it for -// spanner because spanner doesn't support multi-statements. -func (*spanner) splitter() []byte { - return []byte(";\n") -} - func (h *spanner) cleanTableQuery(tableName string) string { if h.cleanTableFn == nil { return h.baseHelper.cleanTableQuery(tableName) diff --git a/spanner_test.go b/spanner_test.go index 532201f..e87d4b9 100644 --- a/spanner_test.go +++ b/spanner_test.go @@ -18,13 +18,9 @@ import ( func TestSpanner(t *testing.T) { prepareSpannerDB(t) - testLoader( - t, - "spanner", - os.Getenv("SPANNER_CONN_STRING"), - "testdata/schema/spanner.sql", - DangerousSkipTestDatabaseCheck(), - ) + db := openDB(t, "spanner", os.Getenv("SPANNER_CONN_STRING")) + loadSchemaInBatchesBySplitter(t, db, "testdata/schema/spanner.sql", []byte(";\n")) + testLoader(t, db, "spanner", DangerousSkipTestDatabaseCheck()) } func prepareSpannerDB(t *testing.T) { diff --git a/sqlite_test.go b/sqlite_test.go index b18699e..e8d01d9 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1,4 +1,4 @@ -// +build sqlite +//go:build sqlite package testfixtures @@ -10,10 +10,7 @@ import ( ) func TestSQLite(t *testing.T) { - testLoader( - t, - "sqlite3", - os.Getenv("SQLITE_CONN_STRING"), - "testdata/schema/sqlite.sql", - ) + db := openDB(t, "sqlite3", os.Getenv("SQLITE_CONN_STRING")) + loadSchemaInOneQuery(t, db, "testdata/schema/sqlite.sql") + testLoader(t, db, "sqlite3") } diff --git a/sqlserver.go b/sqlserver.go index 9a950d3..2a69601 100644 --- a/sqlserver.go +++ b/sqlserver.go @@ -143,11 +143,3 @@ func (h *sqlserver) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) return tx.Commit() } - -// splitter is a batchSplitter interface implementation. We need it for -// SQL Server because commands like a `CREATE SCHEMA...` and a `CREATE TABLE...` -// could not be executed in the same batch. -// See https://docs.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms175502(v=sql.105)#rules-for-using-batches -func (*sqlserver) splitter() []byte { - return []byte("GO\n") -} diff --git a/sqlserver_test.go b/sqlserver_test.go index 46d9e44..70356a8 100644 --- a/sqlserver_test.go +++ b/sqlserver_test.go @@ -1,4 +1,4 @@ -// +build sqlserver +//go:build sqlserver package testfixtures @@ -10,21 +10,21 @@ import ( ) func TestSQLServer(t *testing.T) { - testLoader( - t, - "sqlserver", - os.Getenv("SQLSERVER_CONN_STRING"), - "testdata/schema/sqlserver.sql", - DangerousSkipTestDatabaseCheck(), - ) + testSQLServer(t, "sqlserver") } func TestDeprecatedMssql(t *testing.T) { + testSQLServer(t, "mssql") +} + +func testSQLServer(t *testing.T, dialect string) { + t.Helper() + db := openDB(t, dialect, os.Getenv("SQLSERVER_CONN_STRING")) + loadSchemaInBatchesBySplitter(t, db, "testdata/schema/sqlserver.sql", []byte("GO\n")) testLoader( t, - "mssql", - os.Getenv("SQLSERVER_CONN_STRING"), - "testdata/schema/sqlserver.sql", + db, + dialect, DangerousSkipTestDatabaseCheck(), ) } diff --git a/testfixtures_test.go b/testfixtures_test.go index 05b4ba3..a62bf5d 100644 --- a/testfixtures_test.go +++ b/testfixtures_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "embed" + "errors" "fmt" "os" "testing" @@ -26,64 +27,70 @@ func TestFixtureFile(t *testing.T) { func TestRequiredOptions(t *testing.T) { t.Run("DatabaseIsRequired", func(t *testing.T) { _, err := New() - if err != errDatabaseIsRequired { + if !errors.Is(err, errDatabaseIsRequired) { t.Error("should return an error if database if not given") } }) t.Run("DialectIsRequired", func(t *testing.T) { _, err := New(Database(&sql.DB{})) - if err != errDialectIsRequired { + if !errors.Is(err, errDialectIsRequired) { t.Error("should return an error if dialect if not given") } }) } -func testLoader(t *testing.T, dialect, connStr, schemaFilePath string, additionalOptions ...func(*Loader) error) { //nolint +func openDB(t *testing.T, dialect, connStr string) *sql.DB { //nolint:unused + t.Helper() db, err := sql.Open(dialect, connStr) if err != nil { t.Errorf("failed to open database: %v", err) - return } - defer db.Close() + t.Cleanup(func() { + _ = db.Close() + }) if err := db.Ping(); err != nil { t.Errorf("failed to connect to database: %v", err) - return } + return db +} +func loadSchemaInOneQuery(t *testing.T, db *sql.DB, schemaFilePath string) { //nolint:unused + t.Helper() schema, err := os.ReadFile(schemaFilePath) if err != nil { t.Errorf("cannot read schema file: %v", err) return } - helper, err := helperForDialect(dialect) + loadSchemaInBatches(t, db, [][]byte{schema}) +} + +func loadSchemaInBatchesBySplitter(t *testing.T, db *sql.DB, schemaFilePath string, splitter []byte) { //nolint:unused + t.Helper() + schema, err := os.ReadFile(schemaFilePath) if err != nil { - t.Errorf("cannot get helper: %v", err) - return - } - if err := helper.init(db); err != nil { - t.Errorf("cannot init helper: %v", err) + t.Errorf("cannot read schema file: %v", err) return } + batches := bytes.Split(schema, splitter) + loadSchemaInBatches(t, db, batches) +} - var batches [][]byte - if h, ok := helper.(batchSplitter); ok { - batches = append(batches, bytes.Split(schema, h.splitter())...) - } else { - batches = append(batches, schema) - } - +func loadSchemaInBatches(t *testing.T, db *sql.DB, batches [][]byte) { //nolint:unused + t.Helper() for _, b := range batches { if len(b) == 0 { continue } - if _, err = db.Exec(string(b)); err != nil { + if _, err := db.Exec(string(b)); err != nil { t.Errorf("cannot load schema: %v", err) return } } +} +func testLoader(t *testing.T, db *sql.DB, dialect string, additionalOptions ...func(*Loader) error) { //nolint:unused t.Run("LoadFromDirectory", func(t *testing.T) { options := append( []func(*Loader) error{ @@ -524,18 +531,18 @@ func testLoader(t *testing.T, dialect, connStr, schemaFilePath string, additiona // sequence issues. var sql string - switch helper.paramType() { - case paramTypeDollar: + switch dialect { + case "postgres", "pgx", "clickhouse": sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES ($1, $2, $3, $4)" - case paramTypeQuestion: + case "mysql", "sqlite3", "mssql": sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES (?, ?, ?, ?)" - case paramTypeAtSign: + case "sqlserver", "spanner": sql = "INSERT INTO posts (title, content, created_at, updated_at) VALUES (@p1, @p2, @p3, @p4)" default: - panic("unrecognized param type") + t.Fatalf("undefined param type for %s dialect, modify switch statement", dialect) } - _, err = db.Exec(sql, "Post title", "Post content", time.Now(), time.Now()) + _, err := db.Exec(sql, "Post title", "Post content", time.Now(), time.Now()) if err != nil { t.Errorf("cannot insert post: %v", err) } @@ -553,7 +560,7 @@ func assertFixturesLoaded(t *testing.T, l *Loader) { //nolint func assertCount(t *testing.T, l *Loader, table string, expectedCount int) { //nolint count := 0 - sql := fmt.Sprintf("SELECT COUNT(*) FROM %s", l.helper.quoteKeyword(table)) + sql := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) row := l.db.QueryRow(sql) if err := row.Scan(&count); err != nil {