Skip to content

Commit

Permalink
Add Redshift support (#488)
Browse files Browse the repository at this point in the history
This small fix checks that the database is Redshift and make client-side
quoting instead of server quoting.

---------

Co-authored-by: Dossy Shiobara <[email protected]>
  • Loading branch information
aterekhov-plr and dossy authored Jan 9, 2024
1 parent 8dfb7fd commit c415157
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
/testdata/db/schema.sql
/vendor
dist
docker-compose.override.yml
node_modules
22 changes: 19 additions & 3 deletions pkg/driver/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
func init() {
dbmate.RegisterDriver(NewDriver, "postgres")
dbmate.RegisterDriver(NewDriver, "postgresql")
dbmate.RegisterDriver(NewDriver, "redshift")
}

// Driver provides top level database functions
Expand Down Expand Up @@ -71,11 +72,19 @@ func connectionString(u *url.URL) string {
query.Del("port")
}
if port == "" {
port = "5432"
switch u.Scheme {
case "postgresql":
fallthrough
case "postgres":
port = "5432"
case "redshift":
port = "5439"
}
}

// generate output URL
out, _ := url.Parse(u.String())
out.Scheme = "postgres"
out.Host = fmt.Sprintf("%s:%s", hostname, port)
out.RawQuery = query.Encode()

Expand Down Expand Up @@ -115,8 +124,10 @@ func (drv *Driver) openPostgresDB() (*sql.DB, error) {
return nil, err
}

// connect to postgres database
postgresURL.Path = "postgres"
// connect to postgres database, unless this is a Redshift connection
if drv.databaseURL.Scheme != "redshift" {
postgresURL.Path = "postgres"
}

return sql.Open("postgres", postgresURL.String())
}
Expand Down Expand Up @@ -425,6 +436,11 @@ func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string
return "", "", err
}

// Quote identifiers for Redshift
if drv.databaseURL.Scheme == "redshift" {
return pq.QuoteIdentifier(schema), pq.QuoteIdentifier(strings.Join(tableNameParts, ".")), nil
}

// quote all parts
// use server rather than client to do this to avoid unnecessary quotes
// (which would change schema.sql diff)
Expand Down
63 changes: 62 additions & 1 deletion pkg/driver/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package postgres

import (
"database/sql"
"fmt"
"net/url"
"os"
"runtime"
Expand All @@ -21,6 +22,18 @@ func testPostgresDriver(t *testing.T) *Driver {
return drv.(*Driver)
}

func testRedshiftDriver(t *testing.T) *Driver {
url, ok := os.LookupEnv("REDSHIFT_TEST_URL")
if !ok {
t.Skip("skipping test, no REDSHIFT_TEST_URL provided")
}
u := dbutil.MustParseURL(url)
drv, err := dbmate.New(u).Driver()
require.NoError(t, err)

return drv.(*Driver)
}

func prepTestPostgresDB(t *testing.T) *sql.DB {
drv := testPostgresDriver(t)

Expand All @@ -33,7 +46,23 @@ func prepTestPostgresDB(t *testing.T) *sql.DB {
require.NoError(t, err)

// connect database
db, err := sql.Open("postgres", drv.databaseURL.String())
db, err := sql.Open("postgres", connectionString(drv.databaseURL))
require.NoError(t, err)

return db
}

func prepRedshiftTestDB(t *testing.T, drv *Driver) *sql.DB {
// connect database
db, err := sql.Open("postgres", connectionString(drv.databaseURL))
require.NoError(t, err)

_, migrationsTable, err := drv.quotedMigrationsTableNameParts(db)
if err != nil {
t.Error(err)
}

_, err = db.Exec(fmt.Sprintf("drop table if exists %s", migrationsTable))
require.NoError(t, err)

return db
Expand Down Expand Up @@ -80,6 +109,8 @@ func TestConnectionString(t *testing.T) {
{"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
{"postgres://bob:secret@/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
{"postgres://bob:secret@/foo?host=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
// redshift default port is 5439, not 5432
{"redshift://myhost/foo", "postgres://myhost:5439/foo"},
}

for _, c := range cases {
Expand Down Expand Up @@ -367,6 +398,36 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
})
}

func TestRedshiftCreateMigrationsTable(t *testing.T) {
if _, ok := os.LookupEnv("REDSHIFT_TEST_URL"); !ok {
t.Skip("skipping test, no REDSHIFT_TEST_URL provided")
}

t.Run("default schema", func(t *testing.T) {
drv := testRedshiftDriver(t)
db := prepRedshiftTestDB(t, drv)
defer dbutil.MustClose(db)

// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
require.Error(t, err, "migrations table exists when it shouldn't")
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())

// create table
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)

// migrations table should exist
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
require.NoError(t, err)

// create table should be idempotent
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
}

func TestPostgresSelectMigrations(t *testing.T) {
drv := testPostgresDriver(t)
drv.migrationsTableName = "test_migrations"
Expand Down

0 comments on commit c415157

Please sign in to comment.