diff --git a/.gitignore b/.gitignore index 8bc8742e..28ce465d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ /testdata/db/schema.sql /vendor dist +docker-compose.override.yml node_modules diff --git a/pkg/driver/postgres/postgres.go b/pkg/driver/postgres/postgres.go index 8c56d7b5..d0240db8 100644 --- a/pkg/driver/postgres/postgres.go +++ b/pkg/driver/postgres/postgres.go @@ -19,6 +19,7 @@ import ( func init() { dbmate.RegisterDriver(NewDriver, "postgres") dbmate.RegisterDriver(NewDriver, "postgresql") + dbmate.RegisterDriver(NewDriver, "redshift") } // Driver provides top level database functions @@ -71,11 +72,19 @@ func connectionString(u *url.URL) string { query.Del("port") } if port == "" { - port = "5432" + switch u.Scheme { + case "postgresql": + fallthrough + case "postgres": + port = "5432" + case "redshift": + port = "5439" + } } // generate output URL out, _ := url.Parse(u.String()) + out.Scheme = "postgres" out.Host = fmt.Sprintf("%s:%s", hostname, port) out.RawQuery = query.Encode() @@ -115,8 +124,10 @@ func (drv *Driver) openPostgresDB() (*sql.DB, error) { return nil, err } - // connect to postgres database - postgresURL.Path = "postgres" + // connect to postgres database, unless this is a Redshift connection + if drv.databaseURL.Scheme != "redshift" { + postgresURL.Path = "postgres" + } return sql.Open("postgres", postgresURL.String()) } @@ -425,6 +436,11 @@ func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string return "", "", err } + // Quote identifiers for Redshift + if drv.databaseURL.Scheme == "redshift" { + return pq.QuoteIdentifier(schema), pq.QuoteIdentifier(strings.Join(tableNameParts, ".")), nil + } + // quote all parts // use server rather than client to do this to avoid unnecessary quotes // (which would change schema.sql diff) diff --git a/pkg/driver/postgres/postgres_test.go b/pkg/driver/postgres/postgres_test.go index ed51f915..a9008cb1 100644 --- a/pkg/driver/postgres/postgres_test.go +++ b/pkg/driver/postgres/postgres_test.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "fmt" "net/url" "os" "runtime" @@ -21,6 +22,18 @@ func testPostgresDriver(t *testing.T) *Driver { return drv.(*Driver) } +func testRedshiftDriver(t *testing.T) *Driver { + url, ok := os.LookupEnv("REDSHIFT_TEST_URL") + if !ok { + t.Skip("skipping test, no REDSHIFT_TEST_URL provided") + } + u := dbutil.MustParseURL(url) + drv, err := dbmate.New(u).Driver() + require.NoError(t, err) + + return drv.(*Driver) +} + func prepTestPostgresDB(t *testing.T) *sql.DB { drv := testPostgresDriver(t) @@ -33,7 +46,23 @@ func prepTestPostgresDB(t *testing.T) *sql.DB { require.NoError(t, err) // connect database - db, err := sql.Open("postgres", drv.databaseURL.String()) + db, err := sql.Open("postgres", connectionString(drv.databaseURL)) + require.NoError(t, err) + + return db +} + +func prepRedshiftTestDB(t *testing.T, drv *Driver) *sql.DB { + // connect database + db, err := sql.Open("postgres", connectionString(drv.databaseURL)) + require.NoError(t, err) + + _, migrationsTable, err := drv.quotedMigrationsTableNameParts(db) + if err != nil { + t.Error(err) + } + + _, err = db.Exec(fmt.Sprintf("drop table if exists %s", migrationsTable)) require.NoError(t, err) return db @@ -80,6 +109,8 @@ func TestConnectionString(t *testing.T) { {"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"}, {"postgres://bob:secret@/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"}, {"postgres://bob:secret@/foo?host=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"}, + // redshift default port is 5439, not 5432 + {"redshift://myhost/foo", "postgres://myhost:5439/foo"}, } for _, c := range cases { @@ -367,6 +398,36 @@ func TestPostgresCreateMigrationsTable(t *testing.T) { }) } +func TestRedshiftCreateMigrationsTable(t *testing.T) { + if _, ok := os.LookupEnv("REDSHIFT_TEST_URL"); !ok { + t.Skip("skipping test, no REDSHIFT_TEST_URL provided") + } + + t.Run("default schema", func(t *testing.T) { + drv := testRedshiftDriver(t) + db := prepRedshiftTestDB(t, drv) + defer dbutil.MustClose(db) + + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + require.Error(t, err, "migrations table exists when it shouldn't") + require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error()) + + // create table + err = drv.CreateMigrationsTable(db) + require.NoError(t, err) + + // migrations table should exist + err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + require.NoError(t, err) + + // create table should be idempotent + err = drv.CreateMigrationsTable(db) + require.NoError(t, err) + }) +} + func TestPostgresSelectMigrations(t *testing.T) { drv := testPostgresDriver(t) drv.migrationsTableName = "test_migrations"