From f398cec1c3873efdf61ac0b94ebe06c657f0cf91 Mon Sep 17 00:00:00 2001 From: Vladimir Mihailenco Date: Sun, 17 Oct 2021 10:04:29 +0300 Subject: [PATCH] feat(pgdriver): add support for unix socket DSN --- driver/pgdriver/config.go | 47 ++++++++++++++++++-------- driver/pgdriver/config_test.go | 61 +++++++++++++++++++++++++--------- 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/driver/pgdriver/config.go b/driver/pgdriver/config.go index 2fa03147b..1af50e25b 100644 --- a/driver/pgdriver/config.go +++ b/driver/pgdriver/config.go @@ -187,30 +187,49 @@ func parseDSN(dsn string) ([]Option, error) { return nil, err } - if u.Scheme != "postgres" && u.Scheme != "postgresql" { - return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme) - } - + q := queryOptions{q: u.Query()} var opts []Option - if u.Host != "" { - addr := u.Host - if !strings.Contains(addr, ":") { - addr += ":5432" + switch u.Scheme { + case "postgres", "postgresql": + if u.Host != "" { + addr := u.Host + if !strings.Contains(addr, ":") { + addr += ":5432" + } + opts = append(opts, WithAddr(addr)) + } + + if len(u.Path) > 1 { + opts = append(opts, WithDatabase(u.Path[1:])) + } + + if host := q.string("host"); host != "" { + opts = append(opts, WithAddr(host)) + if host[0] == '/' { + opts = append(opts, WithNetwork("unix")) + } + } + case "unix": + if len(u.Path) == 0 { + return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn) } - opts = append(opts, WithAddr(addr)) + + opts = append(opts, WithNetwork("unix")) + if u.Host != "" { + opts = append(opts, WithDatabase(u.Host)) + } + opts = append(opts, WithAddr(u.Path)) + default: + return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme) } + if u.User != nil { opts = append(opts, WithUser(u.User.Username())) if password, ok := u.User.Password(); ok { opts = append(opts, WithPassword(password)) } } - if len(u.Path) > 1 { - opts = append(opts, WithDatabase(u.Path[1:])) - } - - q := queryOptions{q: u.Query()} if appName := q.string("application_name"); appName != "" { opts = append(opts, WithApplicationName(appName)) diff --git a/driver/pgdriver/config_test.go b/driver/pgdriver/config_test.go index f2ce9c492..616c20b1b 100644 --- a/driver/pgdriver/config_test.go +++ b/driver/pgdriver/config_test.go @@ -1,6 +1,7 @@ package pgdriver_test import ( + "fmt" "testing" "time" @@ -16,12 +17,12 @@ func TestParseDSN(t *testing.T) { tests := []Test{ { - dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable", + dsn: "postgres://user:password@localhost:5432/testDatabase?sslmode=disable", cfg: &pgdriver.Config{ Network: "tcp", Addr: "localhost:5432", - User: "postgres", - Password: "1", + User: "user", + Password: "password", Database: "testDatabase", DialTimeout: 5 * time.Second, ReadTimeout: 10 * time.Second, @@ -29,12 +30,12 @@ func TestParseDSN(t *testing.T) { }, }, { - dsn: "postgres://postgres:1@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3", + dsn: "postgres://user:password@localhost:5432/testDatabase?sslmode=disable&dial_timeout=1&read_timeout=2s&write_timeout=3", cfg: &pgdriver.Config{ Network: "tcp", Addr: "localhost:5432", - User: "postgres", - Password: "1", + User: "user", + Password: "password", Database: "testDatabase", DialTimeout: 1 * time.Second, ReadTimeout: 2 * time.Second, @@ -42,12 +43,12 @@ func TestParseDSN(t *testing.T) { }, }, { - dsn: "postgres://postgres:1@localhost:5432/testDatabase?search_path=foo", + dsn: "postgres://user:password@localhost:5432/testDatabase?search_path=foo", cfg: &pgdriver.Config{ Network: "tcp", Addr: "localhost:5432", - User: "postgres", - Password: "1", + User: "user", + Password: "password", Database: "testDatabase", ConnParams: map[string]interface{}{ "search_path": "foo", @@ -58,11 +59,11 @@ func TestParseDSN(t *testing.T) { }, }, { - dsn: "postgres://postgres:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable", + dsn: "postgres://user:password@app.xxx.us-east-1.rds.amazonaws.com:5432/test?sslmode=disable", cfg: &pgdriver.Config{ Network: "tcp", Addr: "app.xxx.us-east-1.rds.amazonaws.com:5432", - User: "postgres", + User: "user", Password: "password", Database: "test", DialTimeout: 5 * time.Second, @@ -70,14 +71,42 @@ func TestParseDSN(t *testing.T) { WriteTimeout: 5 * time.Second, }, }, + { + dsn: "postgres://user:password@/dbname?host=/var/run/postgresql/.s.PGSQL.5432", + cfg: &pgdriver.Config{ + Network: "unix", + Addr: "/var/run/postgresql/.s.PGSQL.5432", + User: "user", + Password: "password", + Database: "dbname", + DialTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 5 * time.Second, + }, + }, + { + dsn: "unix://user:pass@dbname/var/run/postgresql/.s.PGSQL.5432", + cfg: &pgdriver.Config{ + Network: "unix", + Addr: "/var/run/postgresql/.s.PGSQL.5432", + User: "user", + Password: "pass", + Database: "dbname", + DialTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 5 * time.Second, + }, + }, } - for _, test := range tests { - c := pgdriver.NewConnector(pgdriver.WithDSN(test.dsn)) + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + c := pgdriver.NewConnector(pgdriver.WithDSN(test.dsn)) - cfg := c.Config() - cfg.Dialer = nil + cfg := c.Config() + cfg.Dialer = nil - require.Equal(t, test.cfg, cfg) + require.Equal(t, test.cfg, cfg) + }) } }