From 129ef05ebc351cda7522c9933dab91a3e643468b Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Thu, 10 Dec 2020 17:51:33 -0700 Subject: [PATCH 01/13] WIP --- Makefile | 6 ------ bcda/api/requests.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index f728920b2..8bdc03d88 100644 --- a/Makefile +++ b/Makefile @@ -103,12 +103,6 @@ load-fixtures: $(MAKE) load-synthetic-suppression-data $(MAKE) load-fixtures-ssas - # Since we've rebuilt the databases, we need to restart the worker and api - # to ensure it picks up connections to the new db instance. - # TODO (BCDA-3710) we should be able to remove this line once we have connection pooling on the worker - docker-compose restart api worker ssas - sleep 5 - load-synthetic-cclf-data: docker-compose up -d api docker-compose up -d db diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 4bd006d10..9ac771c7c 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "net/url" "regexp" @@ -59,6 +60,37 @@ func NewHandler(resources []string, basePath string) *Handler { log.Fatal(err) } + // With que-go locked to pgx v3, we need a mechanism that will allow us to + // discard bad connections in the pgxpool (see: https://github.com/jackc/pgx/issues/494) + // This implementation is based off of the "fix" that is present in v4 + // (see: https://github.com/jackc/pgx/blob/v4.10.0/pgxpool/pool.go#L333) + + // If we implement a cleanup function on the handler, + // consider using context.WithCancel() to give us a hook to stop the goroutine + ctx := context.Background() + go func() { + ticker := time.NewTicker(10 * time.Second) + for { + select { + case <-ctx.Done(): + ticker.Stop() + log.Debug("Stopping pgxpool checker") + return + case <-ticker.C: + log.Debug("Sending ping") + conn, err := pgxpool.Acquire() + if err != nil { + log.Warnf("Failed to acquire connection %s", err.Error()) + return + } + if err := conn.Ping(ctx); err != nil { + log.Warnf("Failed to ping %s", err.Error()) + } + pgxpool.Release(conn) + } + } + }() + h.qc = que.NewClient(pgxpool) cutoffDuration := time.Duration(utils.GetEnvInt("CCLF_CUTOFF_DATE_DAYS", 45)*24) * time.Hour From ac84100432bcd4ef5e2ca5fb4651ecf0a6d6e815 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Thu, 10 Dec 2020 18:25:17 -0700 Subject: [PATCH 02/13] Refresh db connections --- bcda/api/requests.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 9ac771c7c..79624e1d2 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -46,6 +46,11 @@ func NewHandler(resources []string, basePath string) *Handler { h := &Handler{} + db := database.GetGORMDbConnection() + db.DB().SetMaxOpenConns(utils.GetEnvInt("BCDA_DB_MAX_OPEN_CONNS", 25)) + db.DB().SetMaxIdleConns(utils.GetEnvInt("BCDA_DB_MAX_IDLE_CONNS", 25)) + db.DB().SetConnMaxLifetime(time.Duration(utils.GetEnvInt("BCDA_DB_CONN_MAX_LIFETIME_MIN", 5)) * time.Minute) + queueDatabaseURL := os.Getenv("QUEUE_DATABASE_URL") pgxcfg, err := pgx.ParseURI(queueDatabaseURL) if err != nil { @@ -64,6 +69,7 @@ func NewHandler(resources []string, basePath string) *Handler { // discard bad connections in the pgxpool (see: https://github.com/jackc/pgx/issues/494) // This implementation is based off of the "fix" that is present in v4 // (see: https://github.com/jackc/pgx/blob/v4.10.0/pgxpool/pool.go#L333) + // Use the same approach to validate the connection associated with the gorm client // If we implement a cleanup function on the handler, // consider using context.WithCancel() to give us a hook to stop the goroutine @@ -78,15 +84,24 @@ func NewHandler(resources []string, basePath string) *Handler { return case <-ticker.C: log.Debug("Sending ping") - conn, err := pgxpool.Acquire() + c, err := pgxpool.Acquire() if err != nil { log.Warnf("Failed to acquire connection %s", err.Error()) - return } - if err := conn.Ping(ctx); err != nil { + if err := c.Ping(ctx); err != nil { log.Warnf("Failed to ping %s", err.Error()) } - pgxpool.Release(conn) + pgxpool.Release(c) + + c1, err := db.DB().Conn(ctx) + if err != nil { + log.Warnf("Failed to acquire connection %s", err.Error()) + continue + } + if err := c1.PingContext(ctx); err != nil { + log.Warnf("Failed to ping %s", err.Error()) + } + c1.Close() } } }() @@ -103,10 +118,6 @@ func NewHandler(resources []string, basePath string) *Handler { log.Fatalf("Failed to parse RUNOUT_CLAIM_THRU_DATE '%s'. Err: %s", runoutClaimThru, err.Error()) } - db := database.GetGORMDbConnection() - db.DB().SetMaxOpenConns(utils.GetEnvInt("BCDA_DB_MAX_OPEN_CONNS", 25)) - db.DB().SetMaxIdleConns(utils.GetEnvInt("BCDA_DB_MAX_IDLE_CONNS", 25)) - db.DB().SetConnMaxLifetime(time.Duration(utils.GetEnvInt("BCDA_DB_CONN_MAX_LIFETIME_MIN", 5)) * time.Minute) repository := postgres.NewRepository(db) h.svc = models.NewService(repository, cutoffDuration, utils.GetEnvInt("BCDA_SUPPRESSION_LOOKBACK_DAYS", 60), runoutCutoffDuration, runoutClaimThruDate, From 25d598424f402523cb43568b8321e2405f822de9 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 07:38:43 -0700 Subject: [PATCH 03/13] Log error on close --- bcda/api/requests.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bcda/api/requests.go b/bcda/api/requests.go index 79624e1d2..6cc717b33 100644 --- a/bcda/api/requests.go +++ b/bcda/api/requests.go @@ -101,7 +101,9 @@ func NewHandler(resources []string, basePath string) *Handler { if err := c1.PingContext(ctx); err != nil { log.Warnf("Failed to ping %s", err.Error()) } - c1.Close() + if err := c1.Close(); err != nil { + log.Warnf("Failed to close conection %s", err.Error()) + } } } }() From 6396f6e66c35da1576447fa3cc8fd76faca2fbcf Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 08:19:41 -0700 Subject: [PATCH 04/13] Updating lib/pg library to include connection fix --- Gopkg.lock | 7 +- vendor/github.com/lib/pq/.gitignore | 2 + vendor/github.com/lib/pq/.travis.sh | 13 - vendor/github.com/lib/pq/.travis.yml | 17 +- vendor/github.com/lib/pq/CONTRIBUTING.md | 29 - vendor/github.com/lib/pq/README.md | 77 +-- vendor/github.com/lib/pq/array.go | 139 +++++ vendor/github.com/lib/pq/buf.go | 2 +- vendor/github.com/lib/pq/conn.go | 500 ++++++++++++------ vendor/github.com/lib/pq/conn_go18.go | 50 +- vendor/github.com/lib/pq/connector.go | 98 +++- vendor/github.com/lib/pq/copy.go | 33 +- vendor/github.com/lib/pq/doc.go | 25 +- vendor/github.com/lib/pq/encode.go | 31 +- vendor/github.com/lib/pq/error.go | 13 +- vendor/github.com/lib/pq/go.mod | 3 + vendor/github.com/lib/pq/krb.go | 27 + vendor/github.com/lib/pq/notice.go | 71 +++ vendor/github.com/lib/pq/notify.go | 66 ++- vendor/github.com/lib/pq/scram/scram.go | 264 +++++++++ vendor/github.com/lib/pq/ssl.go | 8 +- vendor/github.com/lib/pq/ssl_go1.7.go | 14 - vendor/github.com/lib/pq/ssl_renegotiation.go | 8 - vendor/github.com/lib/pq/user_posix.go | 2 +- 24 files changed, 1148 insertions(+), 351 deletions(-) delete mode 100644 vendor/github.com/lib/pq/CONTRIBUTING.md create mode 100644 vendor/github.com/lib/pq/go.mod create mode 100644 vendor/github.com/lib/pq/krb.go create mode 100644 vendor/github.com/lib/pq/notice.go create mode 100644 vendor/github.com/lib/pq/scram/scram.go delete mode 100644 vendor/github.com/lib/pq/ssl_go1.7.go delete mode 100644 vendor/github.com/lib/pq/ssl_renegotiation.go diff --git a/Gopkg.lock b/Gopkg.lock index deeb4a0af..37ce2aa81 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -183,15 +183,17 @@ version = "v1.0.2" [[projects]] - digest = "1:2d370ad7a48d3523291ad53a04b9b1510c182880927f800b46221cc465427800" + digest = "1:1f56b7e0292f71c450ab897beb96e9978f22b46da14f2d1add4bcfee669c2d3f" name = "github.com/lib/pq" packages = [ ".", "hstore", "oid", + "scram", ] pruneopts = "UT" - revision = "90697d60dd844d5ef6ff15135d0203f65d2f53b8" + revision = "4604d39ddc9f62e2ca152114c73aec99b77a3468" + version = "v1.9.0" [[projects]] branch = "master" @@ -537,7 +539,6 @@ "github.com/jinzhu/gorm", "github.com/jinzhu/gorm/dialects/postgres", "github.com/lib/pq", - "github.com/newrelic/go-agent", "github.com/newrelic/go-agent/_integrations/nrlogrus", "github.com/newrelic/go-agent/v3/newrelic", "github.com/otiai10/copy", diff --git a/vendor/github.com/lib/pq/.gitignore b/vendor/github.com/lib/pq/.gitignore index 0f1d00e11..3243952a4 100644 --- a/vendor/github.com/lib/pq/.gitignore +++ b/vendor/github.com/lib/pq/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/vendor/github.com/lib/pq/.travis.sh b/vendor/github.com/lib/pq/.travis.sh index a297dc452..ebf447030 100755 --- a/vendor/github.com/lib/pq/.travis.sh +++ b/vendor/github.com/lib/pq/.travis.sh @@ -70,17 +70,4 @@ postgresql_uninstall() { sudo rm -rf /var/lib/postgresql } -megacheck_install() { - # Lock megacheck version at $MEGACHECK_VERSION to prevent spontaneous - # new error messages in old code. - go get -d honnef.co/go/tools/... - git -C $GOPATH/src/honnef.co/go/tools/ checkout $MEGACHECK_VERSION - go install honnef.co/go/tools/cmd/megacheck - megacheck --version -} - -golint_install() { - go get github.com/golang/lint/golint -} - $1 diff --git a/vendor/github.com/lib/pq/.travis.yml b/vendor/github.com/lib/pq/.travis.yml index 18556e089..f378207f2 100644 --- a/vendor/github.com/lib/pq/.travis.yml +++ b/vendor/github.com/lib/pq/.travis.yml @@ -1,9 +1,8 @@ language: go go: - - 1.8.x - - 1.9.x - - 1.10.x + - 1.14.x + - 1.15.x - master sudo: true @@ -14,16 +13,12 @@ env: - PQGOSSLTESTS=1 - PQSSLCERTTEST_PATH=$PWD/certs - PGHOST=127.0.0.1 - - MEGACHECK_VERSION=2017.2.2 + - GODEBUG=x509ignoreCN=0 matrix: - PGVERSION=10 - PGVERSION=9.6 - PGVERSION=9.5 - PGVERSION=9.4 - - PGVERSION=9.3 - - PGVERSION=9.2 - - PGVERSION=9.1 - - PGVERSION=9.0 before_install: - ./.travis.sh postgresql_uninstall @@ -31,9 +26,9 @@ before_install: - ./.travis.sh postgresql_install - ./.travis.sh postgresql_configure - ./.travis.sh client_configure - - ./.travis.sh megacheck_install - - ./.travis.sh golint_install - go get golang.org/x/tools/cmd/goimports + - go get golang.org/x/lint/golint + - GO111MODULE=on go get honnef.co/go/tools/cmd/staticcheck@2020.1.3 before_script: - createdb pqgotest @@ -44,7 +39,7 @@ script: - > goimports -d -e $(find -name '*.go') | awk '{ print } END { exit NR == 0 ? 0 : 1 }' - go vet ./... - - megacheck -go 1.8 ./... + - staticcheck -go 1.13 ./... - golint ./... - PQTEST_BINARY_PARAMETERS=no go test -race -v ./... - PQTEST_BINARY_PARAMETERS=yes go test -race -v ./... diff --git a/vendor/github.com/lib/pq/CONTRIBUTING.md b/vendor/github.com/lib/pq/CONTRIBUTING.md deleted file mode 100644 index 84c937f15..000000000 --- a/vendor/github.com/lib/pq/CONTRIBUTING.md +++ /dev/null @@ -1,29 +0,0 @@ -## Contributing to pq - -`pq` has a backlog of pull requests, but contributions are still very -much welcome. You can help with patch review, submitting bug reports, -or adding new functionality. There is no formal style guide, but -please conform to the style of existing code and general Go formatting -conventions when submitting patches. - -### Patch review - -Help review existing open pull requests by commenting on the code or -proposed functionality. - -### Bug reports - -We appreciate any bug reports, but especially ones with self-contained -(doesn't depend on code outside of pq), minimal (can't be simplified -further) test cases. It's especially helpful if you can submit a pull -request with just the failing test case (you'll probably want to -pattern it after the tests in -[conn_test.go](https://github.com/lib/pq/blob/master/conn_test.go). - -### New functionality - -There are a number of pending patches for new functionality, so -additional feature patches will take a while to merge. Still, patches -are generally reviewed based on usefulness and complexity in addition -to time-in-queue, so if you have a knockout idea, take a shot. Feel -free to open an issue discussion your proposed patch beforehand. diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index d71f3c2c3..c972a86a5 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -1,21 +1,11 @@ # pq - A pure Go postgres driver for Go's database/sql package -[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://godoc.org/github.com/lib/pq) -[![Build Status](https://travis-ci.org/lib/pq.svg?branch=master)](https://travis-ci.org/lib/pq) +[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://pkg.go.dev/github.com/lib/pq?tab=doc) ## Install go get github.com/lib/pq -## Docs - -For detailed documentation and basic usage examples, please see the package -documentation at . - -## Tests - -`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. - ## Features * SSL @@ -29,67 +19,12 @@ documentation at . * Unix socket support * Notifications: `LISTEN`/`NOTIFY` * pgpass support +* GSS (Kerberos) auth -## Future / Things you can help with - -* Better COPY FROM / COPY TO (see discussion in #181) +## Tests -## Thank you (alphabetical) +`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. -Some of these contributors are from the original library `bmizerany/pq.go` whose -code still exists in here. +## Status -* Andy Balholm (andybalholm) -* Ben Berkert (benburkert) -* Benjamin Heatwole (bheatwole) -* Bill Mill (llimllib) -* Bjørn Madsen (aeons) -* Blake Gentry (bgentry) -* Brad Fitzpatrick (bradfitz) -* Charlie Melbye (cmelbye) -* Chris Bandy (cbandy) -* Chris Gilling (cgilling) -* Chris Walsh (cwds) -* Dan Sosedoff (sosedoff) -* Daniel Farina (fdr) -* Eric Chlebek (echlebek) -* Eric Garrido (minusnine) -* Eric Urban (hydrogen18) -* Everyone at The Go Team -* Evan Shaw (edsrzf) -* Ewan Chou (coocood) -* Fazal Majid (fazalmajid) -* Federico Romero (federomero) -* Fumin (fumin) -* Gary Burd (garyburd) -* Heroku (heroku) -* James Pozdena (jpoz) -* Jason McVetta (jmcvetta) -* Jeremy Jay (pbnjay) -* Joakim Sernbrant (serbaut) -* John Gallagher (jgallagher) -* Jonathan Rudenberg (titanous) -* Joël Stemmer (jstemmer) -* Kamil Kisiel (kisielk) -* Kelly Dunn (kellydunn) -* Keith Rarick (kr) -* Kir Shatrov (kirs) -* Lann Martin (lann) -* Maciek Sakrejda (uhoh-itsmaciek) -* Marc Brinkmann (mbr) -* Marko Tiikkaja (johto) -* Matt Newberry (MattNewberry) -* Matt Robenolt (mattrobenolt) -* Martin Olsen (martinolsen) -* Mike Lewis (mikelikespie) -* Nicolas Patry (Narsil) -* Oliver Tonnhofer (olt) -* Patrick Hayes (phayes) -* Paul Hammond (paulhammond) -* Ryan Smith (ryandotsmith) -* Samuel Stauffer (samuel) -* Timothée Peignier (cyberdelia) -* Travis Cline (tmc) -* TruongSinh Tran-Nguyen (truongsinh) -* Yaismel Miranda (ympons) -* notedit (notedit) +This package is effectively in maintenance mode and is not actively developed. Small patches and features are only rarely reviewed and merged. We recommend using [pgx](https://github.com/jackc/pgx) which is actively maintained. diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index e4933e227..405da2368 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + var x int + if x, err = strconv.Atoi(string(v)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/vendor/github.com/lib/pq/buf.go b/vendor/github.com/lib/pq/buf.go index 666b0012a..4b0a0a8f7 100644 --- a/vendor/github.com/lib/pq/buf.go +++ b/vendor/github.com/lib/pq/buf.go @@ -66,7 +66,7 @@ func (b *writeBuf) int16(n int) { } func (b *writeBuf) string(s string) { - b.buf = append(b.buf, (s + "\000")...) + b.buf = append(append(b.buf, s...), '\000') } func (b *writeBuf) byte(c byte) { diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 43c8df29f..db0b6cef5 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -2,7 +2,9 @@ package pq import ( "bufio" + "context" "crypto/md5" + "crypto/sha256" "database/sql" "database/sql/driver" "encoding/binary" @@ -16,10 +18,12 @@ import ( "path/filepath" "strconv" "strings" + "sync/atomic" "time" "unicode" "github.com/lib/pq/oid" + "github.com/lib/pq/scram" ) // Common error types @@ -35,13 +39,18 @@ var ( errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) +// Compile time validation that our types implement the expected interfaces +var ( + _ driver.Driver = Driver{} +) + // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. -func (d *Driver) Open(name string) (driver.Conn, error) { +func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } @@ -89,13 +98,25 @@ type Dialer interface { DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) } -type defaultDialer struct{} +// DialerContext is the context-aware dialer interface. +type DialerContext interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} -func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { - return net.Dial(ntw, addr) +type defaultDialer struct { + d net.Dialer } -func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(ntw, addr, timeout) + +func (d defaultDialer) Dial(network, address string) (net.Conn, error) { + return d.d.Dial(network, address) +} +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.DialContext(ctx, network, address) +} +func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.d.DialContext(ctx, network, address) } type conn struct { @@ -121,7 +142,7 @@ type conn struct { // If true, this connection is bad and all public-facing functions should // return ErrBadConn. - bad bool + bad *atomic.Value // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for @@ -134,6 +155,15 @@ type conn struct { // If true this connection is in the middle of a COPY inCopy bool + + // If not nil, notices will be synchronously sent here + noticeHandler func(*Error) + + // If not nil, notifications will be synchronously sent here + notificationHandler func(*Notification) + + // GSSAPI context + gss GSS } // Handle driver-side settings in parsed connection string. @@ -244,90 +274,38 @@ func (cn *conn) writeBuf(b byte) *writeBuf { } } -// Open opens a new connection to the database. name is a connection string. +// Open opens a new connection to the database. dsn is a connection string. // Most users should only use it through database/sql package from the standard // library. -func Open(name string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, name) +func Open(dsn string) (_ driver.Conn, err error) { + return DialOpen(defaultDialer{}, dsn) } // DialOpen opens a new connection to the database using a dialer. -func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { +func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + c.dialer = d + return c.open(context.Background()) +} + +func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any // connection errors into ErrBadConns, hiding the real error message from // the user. defer errRecoverNoErrBadConn(&err) - o := make(values) - - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o["extra_float_digits"] = "2" - for k, v := range parseEnviron(os.Environ()) { - o[k] = v - } - - if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { - name, err = ParseURL(name) - if err != nil { - return nil, err - } - } - - if err := parseOpts(name, o); err != nil { - return nil, err - } - - // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback - } - } - - // We can't work with any client_encoding other than UTF-8 currently. - // However, we have historically allowed the user to set it to UTF-8 - // explicitly, and there's no reason to break such programs, so allow that. - // Note that the "options" setting could also set client_encoding, but - // parsing its value is not worth it. Instead, we always explicitly send - // client_encoding as a separate run-time parameter, which should override - // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") - } - o["client_encoding"] = "UTF8" - // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", - "ISO, MDY", datestyle)) - } - } else { - o["datestyle"] = "ISO, MDY" - } - - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { - u, err := userCurrent() - if err != nil { - return nil, err - } - o["user"] = u - } + o := c.opts - cn := &conn{ + bad := &atomic.Value{} + bad.Store(false) + cn = &conn{ opts: o, - dialer: d, + dialer: c.dialer, + bad: bad, } err = cn.handleDriverSettings(o) if err != nil { @@ -335,13 +313,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } cn.handlePgpass(o) - cn.c, err = dial(d, o) + cn.c, err = dial(ctx, c.dialer, o) if err != nil { return nil, err } err = cn.ssl(o) if err != nil { + if cn.c != nil { + cn.c.Close() + } return nil, err } @@ -364,12 +345,8 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { return cn, err } -func dial(d Dialer, o values) (net.Conn, error) { - ntw, addr := network(o) - // SSL is not necessary or supported over UNIX domain sockets - if ntw == "unix" { - o["sslmode"] = "disable" - } +func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { + network, address := network(o) // Zero or not specified means wait indefinitely. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { @@ -378,19 +355,30 @@ func dial(d Dialer, o values) (net.Conn, error) { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) } duration := time.Duration(seconds) * time.Second + // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection // establishment and set a deadline for doing the initial handshake. // The deadline is then reset after startup() is done. deadline := time.Now().Add(duration) - conn, err := d.DialTimeout(ntw, addr, duration) + var conn net.Conn + if dctx, ok := d.(DialerContext); ok { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + conn, err = dctx.DialContext(ctx, network, address) + } else { + conn, err = d.DialTimeout(network, address, duration) + } if err != nil { return nil, err } err = conn.SetDeadline(deadline) return conn, err } - return d.Dial(ntw, addr) + if dctx, ok := d.(DialerContext); ok { + return dctx.DialContext(ctx, network, address) + } + return d.Dial(network, address) } func network(o values) (string, string) { @@ -522,9 +510,22 @@ func (cn *conn) isInTransaction() bool { cn.txnStatus == txnStatusInFailedTransaction } +func (cn *conn) setBad() { + if cn.bad != nil { + cn.bad.Store(true) + } +} + +func (cn *conn) getBad() bool { + if cn.bad != nil { + return cn.bad.Load().(bool) + } + return false +} + func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { - cn.bad = true + cn.setBad() errorf("unexpected transaction status %v", cn.txnStatus) } } @@ -534,7 +535,7 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -545,11 +546,11 @@ func (cn *conn) begin(mode string) (_ driver.Tx, err error) { return nil, err } if commandTag != "BEGIN" { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { - cn.bad = true + cn.setBad() return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil @@ -563,7 +564,7 @@ func (cn *conn) closeTxn() { func (cn *conn) Commit() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) @@ -576,7 +577,7 @@ func (cn *conn) Commit() (err error) { // would get the same behaviour if you issued a COMMIT in a failed // transaction, so it's also the least surprising thing to do here. if cn.txnStatus == txnStatusInFailedTransaction { - if err := cn.Rollback(); err != nil { + if err := cn.rollback(); err != nil { return err } return ErrInFailedTransaction @@ -585,12 +586,12 @@ func (cn *conn) Commit() (err error) { _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } if commandTag != "COMMIT" { - cn.bad = true + cn.setBad() return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) @@ -599,16 +600,19 @@ func (cn *conn) Commit() (err error) { func (cn *conn) Rollback() (err error) { defer cn.closeTxn() - if cn.bad { + if cn.getBad() { return driver.ErrBadConn } defer cn.errRecover(&err) + return cn.rollback() +} +func (cn *conn) rollback() (err error) { cn.checkIsInTransaction(true) _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { - cn.bad = true + cn.setBad() } return err } @@ -648,7 +652,7 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err case 'T', 'D': // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -670,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected message %q in simple query execution", t) } if res == nil { @@ -681,8 +685,11 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. - if t == 'C' && res.colNames == nil { + if t == 'C' { res.result, res.tag = cn.parseComplete(r.string()) + if res.colNames != nil { + return + } } res.done = true case 'Z': @@ -694,7 +701,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { err = parseError(r) case 'D': if res == nil { - cn.bad = true + cn.setBad() errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next @@ -704,12 +711,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) + res.rowsHeader = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: - cn.bad = true + cn.setBad() errorf("unknown response for simple query: %q", t) } } @@ -802,7 +809,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt { } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -841,7 +848,7 @@ func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } if cn.inCopy { @@ -861,23 +868,21 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { cn.readParseResponse() cn.readBindResponse() rows := &rows{cn: cn} - rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + rows.rowsHeader = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil } st := cn.prepareTo(query, "") st.exec(args) return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: cn, + rowsHeader: st.rowsHeader, }, nil } // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { - if cn.bad { + if cn.getBad() { return nil, driver.ErrBadConn } defer cn.errRecover(&err) @@ -911,9 +916,20 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +type safeRetryError struct { + Err error +} + +func (se *safeRetryError) Error() string { + return se.Err.Error() +} + func (cn *conn) send(m *writeBuf) { - _, err := cn.c.Write(m.wrap()) + n, err := cn.c.Write(m.wrap()) if err != nil { + if n == 0 { + err = &safeRetryError{Err: err} + } panic(err) } } @@ -938,7 +954,7 @@ func (cn *conn) sendSimpleMessage(typ byte) (err error) { // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { - cn.bad = true + cn.setBad() errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ @@ -992,12 +1008,17 @@ func (cn *conn) recv() (t byte, r *readBuf) { if err != nil { panic(err) } - switch t { case 'E': panic(parseError(r)) case 'N': - // ignore + if n := cn.noticeHandler; n != nil { + n(parseError(r)) + } + case 'A': + if n := cn.notificationHandler; n != nil { + n(recvNotification(r)) + } default: return } @@ -1014,8 +1035,14 @@ func (cn *conn) recv1Buf(r *readBuf) byte { } switch t { - case 'A', 'N': - // ignore + case 'A': + if n := cn.notificationHandler; n != nil { + n(recvNotification(r)) + } + case 'N': + if n := cn.noticeHandler; n != nil { + n(parseError(r)) + } case 'S': cn.processParameterStatus(r) default: @@ -1083,7 +1110,10 @@ func isDriverSetting(key string) bool { return true case "binary_parameters": return true - + case "krbsrvname": + return true + case "krbspn": + return true default: return false } @@ -1163,6 +1193,108 @@ func (cn *conn) auth(r *readBuf, o values) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case 7: // GSSAPI, startup + if newGss == nil { + errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") + } + cli, err := newGss() + if err != nil { + errorf("kerberos error: %s", err.Error()) + } + + var token []byte + + if spn, ok := o["krbspn"]; ok { + // Use the supplied SPN if provided.. + token, err = cli.GetInitTokenFromSpn(spn) + } else { + // Allow the kerberos service name to be overridden + service := "postgres" + if val, ok := o["krbsrvname"]; ok { + service = val + } + + token, err = cli.GetInitToken(o["host"], service) + } + + if err != nil { + errorf("failed to get Kerberos ticket: %q", err) + } + + w := cn.writeBuf('p') + w.bytes(token) + cn.send(w) + + // Store for GSSAPI continue message + cn.gss = cli + + case 8: // GSSAPI continue + + if cn.gss == nil { + errorf("GSSAPI protocol error") + } + + b := []byte(*r) + + done, tokOut, err := cn.gss.Continue(b) + if err == nil && !done { + w := cn.writeBuf('p') + w.bytes(tokOut) + cn.send(w) + } + + // Errors fall through and read the more detailed message + // from the server.. + + case 10: + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + sc.Step(nil) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + scOut := sc.Out() + + w := cn.writeBuf('p') + w.string("SCRAM-SHA-256") + w.int32(len(scOut)) + w.bytes(scOut) + cn.send(w) + + t, r := cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 11 { + errorf("unexpected authentication response: %q", t) + } + + nextStep := r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + + scOut = sc.Out() + w = cn.writeBuf('p') + w.bytes(scOut) + cn.send(w) + + t, r = cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 12 { + errorf("unexpected authentication response: %q", t) + } + + nextStep = r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + default: errorf("unknown authentication response: %d", code) } @@ -1180,12 +1312,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1} var colFmtDataAllText = []byte{0, 0} type stmt struct { - cn *conn - name string - colNames []string - colFmts []format + cn *conn + name string + rowsHeader colFmtData []byte - colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1194,7 +1324,7 @@ func (st *stmt) Close() (err error) { if st.closed { return nil } - if st.cn.bad { + if st.cn.getBad() { return driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1208,14 +1338,14 @@ func (st *stmt) Close() (err error) { t, _ := st.cn.recv1() if t != '3' { - st.cn.bad = true + st.cn.setBad() errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { - st.cn.bad = true + st.cn.setBad() errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) @@ -1224,22 +1354,20 @@ func (st *stmt) Close() (err error) { } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) st.exec(v) return &rows{ - cn: st.cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: st.cn, + rowsHeader: st.rowsHeader, }, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { - if st.cn.bad { + if st.cn.getBad() { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) @@ -1326,7 +1454,7 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { - cn.bad = true + cn.setBad() errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] @@ -1338,22 +1466,28 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { - cn.bad = true + cn.setBad() errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag } -type rows struct { - cn *conn - finish func() +type rowsHeader struct { colNames []string colTyps []fieldDesc colFmts []format - done bool - rb readBuf - result driver.Result - tag string +} + +type rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader } func (rs *rows) Close() error { @@ -1399,7 +1533,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } conn := rs.cn - if conn.bad { + if conn.getBad() { return driver.ErrBadConn } defer conn.errRecover(&err) @@ -1424,7 +1558,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'D': n := rs.rb.int16() if err != nil { - conn.bad = true + conn.setBad() errorf("unexpected DataRow after error %s", err) } if n < len(dest) { @@ -1440,7 +1574,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } return case 'T': - rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next return io.EOF default: errorf("unexpected message after execute: %q", t) @@ -1449,10 +1584,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } func (rs *rows) HasNextResultSet() bool { - return !rs.done + hasNext := rs.next != nil && !rs.done + return hasNext } func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil return nil } @@ -1475,6 +1616,39 @@ func QuoteIdentifier(name string) string { return `"` + strings.Replace(name, `"`, `""`, -1) + `"` } +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + func md5s(s string) string { h := md5.New() h.Write([]byte(s)) @@ -1579,7 +1753,7 @@ func (cn *conn) readReadyForQuery() { cn.processReadyForQuery(r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message %q; expected ReadyForQuery", t) } } @@ -1599,7 +1773,7 @@ func (cn *conn) readParseResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Parse response %q", t) } } @@ -1624,25 +1798,25 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe statement response %q", t) } } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { +func (cn *conn) readPortalDescribeResponse() rowsHeader { t, r := cn.recv1() switch t { case 'T': return parsePortalRowDescribe(r) case 'n': - return nil, nil, nil + return rowsHeader{} case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Describe response %q", t) } panic("not reached") @@ -1658,7 +1832,7 @@ func (cn *conn) readBindResponse() { cn.readReadyForQuery() panic(err) default: - cn.bad = true + cn.setBad() errorf("unexpected Bind response %q", t) } } @@ -1685,7 +1859,7 @@ func (cn *conn) postExecuteWorkaround() { cn.saveMessage(t, r) return default: - cn.bad = true + cn.setBad() errorf("unexpected message during extended query execution: %q", t) } } @@ -1698,7 +1872,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co switch t { case 'C': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) @@ -1712,7 +1886,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co err = parseError(r) case 'T', 'D', 'I': if err != nil { - cn.bad = true + cn.setBad() errorf("unexpected %q after error %s", t, err) } if t == 'I' { @@ -1720,7 +1894,7 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } // ignore any results default: - cn.bad = true + cn.setBad() errorf("unknown %s response: %q", protocolState, t) } } @@ -1742,11 +1916,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { +func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() - colNames = make([]string, n) - colFmts = make([]format, n) - colTyps = make([]fieldDesc, n) + colNames := make([]string, n) + colFmts := make([]format, n) + colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) @@ -1755,7 +1929,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } - return + return rowsHeader{ + colNames: colNames, + colFmts: colFmts, + colTyps: colTyps, + } } // parseEnviron tries to mimic some of libpq's environment handling diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index a5254f2b4..8cab67c9d 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -1,5 +1,3 @@ -// +build go1.8 - package pq import ( @@ -9,6 +7,8 @@ import ( "fmt" "io" "io/ioutil" + "sync/atomic" + "time" ) // Implement the "QueryerContext" interface @@ -76,20 +76,51 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, return tx, nil } +func (cn *conn) Ping(ctx context.Context) error { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + rows, err := cn.simpleQuery(";") + if err != nil { + return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + } + rows.Close() + return nil +} + func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { - finished := make(chan struct{}) + finished := make(chan struct{}, 1) go func() { select { case <-done: - _ = cn.cancel() - finished <- struct{}{} + select { + case finished <- struct{}{}: + default: + // We raced with the finish func, let the next query handle this with the + // context. + return + } + + // Set the connection state to bad so it does not get reused. + cn.setBad() + + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _ = cn.cancel(ctxCancel) case <-finished: } }() return func() { select { case <-finished: + cn.setBad() + cn.Close() case finished <- struct{}{}: } } @@ -97,16 +128,19 @@ func (cn *conn) watchCancel(ctx context.Context) func() { return nil } -func (cn *conn) cancel() error { - c, err := dial(cn.dialer, cn.opts) +func (cn *conn) cancel(ctx context.Context) error { + c, err := dial(ctx, cn.dialer, cn.opts) if err != nil { return err } defer c.Close() { + bad := &atomic.Value{} + bad.Store(false) can := conn{ - c: c, + c: c, + bad: bad, } err = can.ssl(cn.opts) if err != nil { diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go index 9e66eb5df..d7d472615 100644 --- a/vendor/github.com/lib/pq/connector.go +++ b/vendor/github.com/lib/pq/connector.go @@ -1,10 +1,12 @@ -// +build go1.10 - package pq import ( "context" "database/sql/driver" + "errors" + "fmt" + "os" + "strings" ) // Connector represents a fixed configuration for the pq driver with a given @@ -14,30 +16,100 @@ import ( // // See https://golang.org/pkg/database/sql/driver/#Connector. // See https://golang.org/pkg/database/sql/#OpenDB. -type connector struct { - name string +type Connector struct { + opts values + dialer Dialer } // Connect returns a connection to the database using the fixed configuration // of this Connector. Context is not used. -func (c *connector) Connect(_ context.Context) (driver.Conn, error) { - return (&Driver{}).Open(c.name) +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + return c.open(ctx) } -// Driver returnst the underlying driver of this Connector. -func (c *connector) Driver() driver.Driver { +// Driver returns the underlying driver of this Connector. +func (c *Connector) Driver() driver.Driver { return &Driver{} } -var _ driver.Connector = &connector{} - // NewConnector returns a connector for the pq driver in a fixed configuration -// with the given name. The returned connector can be used to create any number +// with the given dsn. The returned connector can be used to create any number // of equivalent Conn's. The returned connector is intended to be used with // database/sql.OpenDB. // // See https://golang.org/pkg/database/sql/driver/#Connector. // See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(name string) (driver.Connector, error) { - return &connector{name: name}, nil +func NewConnector(dsn string) (*Connector, error) { + var err error + o := make(values) + + // A number of defaults are applied here, in this order: + // + // * Very low precedence defaults applied in every situation + // * Environment variables + // * Explicitly passed connection information + o["host"] = "localhost" + o["port"] = "5432" + // N.B.: Extra float digits should be set to 3, but that breaks + // Postgres 8.4 and older, where the max is 2. + o["extra_float_digits"] = "2" + for k, v := range parseEnviron(os.Environ()) { + o[k] = v + } + + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + dsn, err = ParseURL(dsn) + if err != nil { + return nil, err + } + } + + if err := parseOpts(dsn, o); err != nil { + return nil, err + } + + // Use the "fallback" application name if necessary + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback + } + } + + // We can't work with any client_encoding other than UTF-8 currently. + // However, we have historically allowed the user to set it to UTF-8 + // explicitly, and there's no reason to break such programs, so allow that. + // Note that the "options" setting could also set client_encoding, but + // parsing its value is not worth it. Instead, we always explicitly send + // client_encoding as a separate run-time parameter, which should override + // anything set in options. + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { + return nil, errors.New("client_encoding must be absent or 'UTF8'") + } + o["client_encoding"] = "UTF8" + // DateStyle needs a similar treatment. + if datestyle, ok := o["datestyle"]; ok { + if datestyle != "ISO, MDY" { + return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) + } + } else { + o["datestyle"] = "ISO, MDY" + } + + // If a user is not provided by any other means, the last + // resort is to use the current operating system provided user + // name. + if _, ok := o["user"]; !ok { + u, err := userCurrent() + if err != nil { + return nil, err + } + o["user"] = u + } + + // SSL is not necessary or supported over UNIX domain sockets + if network, _ := network(o); network == "unix" { + o["sslmode"] = "disable" + } + + return &Connector{opts: o, dialer: defaultDialer{}}, nil } diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 345c2398f..bb3cbd7b9 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -49,6 +49,7 @@ type copyin struct { buffer []byte rowData chan []byte done chan bool + driver.Result closed bool @@ -151,8 +152,12 @@ func (ci *copyin) resploop() { switch t { case 'C': // complete + res, _ := ci.cn.parseComplete(r.string()) + ci.setResult(res) case 'N': - // NoticeResponse + if n := ci.cn.noticeHandler; n != nil { + n(parseError(&r)) + } case 'Z': ci.cn.processReadyForQuery(&r) ci.done <- true @@ -171,13 +176,13 @@ func (ci *copyin) resploop() { func (ci *copyin) setBad() { ci.Lock() - ci.cn.bad = true + ci.cn.setBad() ci.Unlock() } func (ci *copyin) isBad() bool { ci.Lock() - b := ci.cn.bad + b := ci.cn.getBad() ci.Unlock() return b } @@ -199,6 +204,22 @@ func (ci *copyin) setError(err error) { ci.Unlock() } +func (ci *copyin) setResult(result driver.Result) { + ci.Lock() + ci.Result = result + ci.Unlock() +} + +func (ci *copyin) getResult() driver.Result { + ci.Lock() + result := ci.Result + ci.Unlock() + if result == nil { + return driver.RowsAffected(0) + } + return result +} + func (ci *copyin) NumInput() int { return -1 } @@ -229,7 +250,11 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { } if len(v) == 0 { - return nil, ci.Close() + if err := ci.Close(); err != nil { + return driver.RowsAffected(0), err + } + + return ci.getResult(), nil } numValues := len(v) diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index a1b029713..b57184801 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -239,7 +239,30 @@ for more information). Note that the channel name will be truncated to 63 bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at -http://godoc.org/github.com/lib/pq/example/listen. +https://godoc.org/github.com/lib/pq/example/listen. + +Kerberos Support + + +If you need support for Kerberos authentication, add the following to your main +package: + + import "github.com/lib/pq/auth/kerberos" + + func init() { + pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) + } + +This package is in a separate module so that users who don't need Kerberos +don't have to download unnecessary dependencies. + +When imported, additional connection string parameters are supported: + + * krbsrvname - GSS (Kerberos) service name when constructing the + SPN (default is `postgres`). This will be combined with the host + to form the full SPN: `krbsrvname/host`. + * krbspn - GSS (Kerberos) SPN. This takes priority over + `krbsrvname` if present. */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 3b0d365f2..c4dafe270 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math" + "regexp" "strconv" "strings" "sync" @@ -16,6 +17,8 @@ import ( "github.com/lib/pq/oid" ) +var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`) + func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { switch v := x.(type) { case []byte: @@ -117,11 +120,10 @@ func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interfa } return i case oid.T_float4, oid.T_float8: - bits := 64 - if typ == oid.T_float4 { - bits = 32 - } - f, err := strconv.ParseFloat(string(s), bits) + // We always use 64 bit parsing, regardless of whether the input text is for + // a float4 or float8, because clients expect float64s for all float datatypes + // and returning a 32-bit parsed float64 produces lossy results. + f, err := strconv.ParseFloat(string(s), 64) if err != nil { errorf("%s", err) } @@ -203,10 +205,27 @@ func mustParse(f string, typ oid.Oid, s []byte) time.Time { str[len(str)-3] == ':' { f += ":00" } + // Special case for 24:00 time. + // Unfortunately, golang does not parse 24:00 as a proper time. + // In this case, we want to try "round to the next day", to differentiate. + // As such, we find if the 24:00 time matches at the beginning; if so, + // we default it back to 00:00 but add a day later. + var is2400Time bool + switch typ { + case oid.T_timetz, oid.T_time: + if matches := time2400Regex.FindStringSubmatch(str); matches != nil { + // Concatenate timezone information at the back. + str = "00:00:00" + str[len(matches[1]):] + is2400Time = true + } + } t, err := time.Parse(f, str) if err != nil { errorf("decode: %s", err) } + if is2400Time { + t = t.Add(24 * time.Hour) + } return t } @@ -488,7 +507,7 @@ func FormatTimestamp(t time.Time) []byte { b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() - offset = offset % 60 + offset %= 60 if offset != 0 { // RFC3339Nano already printed the minus sign if offset < 0 { diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index 96aae29c6..c19c349f1 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -478,13 +478,13 @@ func errRecoverNoErrBadConn(err *error) { } } -func (c *conn) errRecover(err *error) { +func (cn *conn) errRecover(err *error) { e := recover() switch v := e.(type) { case nil: // Do nothing case runtime.Error: - c.bad = true + cn.setBad() panic(v) case *Error: if v.Fatal() { @@ -493,8 +493,11 @@ func (c *conn) errRecover(err *error) { *err = v } case *net.OpError: - c.bad = true + cn.setBad() *err = v + case *safeRetryError: + cn.setBad() + *err = driver.ErrBadConn case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn @@ -503,13 +506,13 @@ func (c *conn) errRecover(err *error) { } default: - c.bad = true + cn.setBad() panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - c.bad = true + cn.setBad() } } diff --git a/vendor/github.com/lib/pq/go.mod b/vendor/github.com/lib/pq/go.mod new file mode 100644 index 000000000..b5a5639ab --- /dev/null +++ b/vendor/github.com/lib/pq/go.mod @@ -0,0 +1,3 @@ +module github.com/lib/pq + +go 1.13 diff --git a/vendor/github.com/lib/pq/krb.go b/vendor/github.com/lib/pq/krb.go new file mode 100644 index 000000000..408ec01f9 --- /dev/null +++ b/vendor/github.com/lib/pq/krb.go @@ -0,0 +1,27 @@ +package pq + +// NewGSSFunc creates a GSS authentication provider, for use with +// RegisterGSSProvider. +type NewGSSFunc func() (GSS, error) + +var newGss NewGSSFunc + +// RegisterGSSProvider registers a GSS authentication provider. For example, if +// you need to use Kerberos to authenticate with your server, add this to your +// main package: +// +// import "github.com/lib/pq/auth/kerberos" +// +// func init() { +// pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) +// } +func RegisterGSSProvider(newGssArg NewGSSFunc) { + newGss = newGssArg +} + +// GSS provides GSSAPI authentication (e.g., Kerberos). +type GSS interface { + GetInitToken(host string, service string) ([]byte, error) + GetInitTokenFromSpn(spn string) ([]byte, error) + Continue(inToken []byte) (done bool, outToken []byte, err error) +} diff --git a/vendor/github.com/lib/pq/notice.go b/vendor/github.com/lib/pq/notice.go new file mode 100644 index 000000000..01dd8c723 --- /dev/null +++ b/vendor/github.com/lib/pq/notice.go @@ -0,0 +1,71 @@ +// +build go1.10 + +package pq + +import ( + "context" + "database/sql/driver" +) + +// NoticeHandler returns the notice handler on the given connection, if any. A +// runtime panic occurs if c is not a pq connection. This is rarely used +// directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead. +func NoticeHandler(c driver.Conn) func(*Error) { + return c.(*conn).noticeHandler +} + +// SetNoticeHandler sets the given notice handler on the given connection. A +// runtime panic occurs if c is not a pq connection. A nil handler may be used +// to unset it. This is rarely used directly, use ConnectorNoticeHandler and +// ConnectorWithNoticeHandler instead. +// +// Note: Notice handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func SetNoticeHandler(c driver.Conn, handler func(*Error)) { + c.(*conn).noticeHandler = handler +} + +// NoticeHandlerConnector wraps a regular connector and sets a notice handler +// on it. +type NoticeHandlerConnector struct { + driver.Connector + noticeHandler func(*Error) +} + +// Connect calls the underlying connector's connect method and then sets the +// notice handler. +func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { + c, err := n.Connector.Connect(ctx) + if err == nil { + SetNoticeHandler(c, n.noticeHandler) + } + return c, err +} + +// ConnectorNoticeHandler returns the currently set notice handler, if any. If +// the given connector is not a result of ConnectorWithNoticeHandler, nil is +// returned. +func ConnectorNoticeHandler(c driver.Connector) func(*Error) { + if c, ok := c.(*NoticeHandlerConnector); ok { + return c.noticeHandler + } + return nil +} + +// ConnectorWithNoticeHandler creates or sets the given handler for the given +// connector. If the given connector is a result of calling this function +// previously, it is simply set on the given connector and returned. Otherwise, +// this returns a new connector wrapping the given one and setting the notice +// handler. A nil notice handler may be used to unset it. +// +// The returned connector is intended to be used with database/sql.OpenDB. +// +// Note: Notice handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector { + if c, ok := c.(*NoticeHandlerConnector); ok { + c.noticeHandler = handler + return c + } + return &NoticeHandlerConnector{Connector: c, noticeHandler: handler} +} diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index 947d189f4..5c421fdb8 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -4,6 +4,8 @@ package pq // This module contains support for Postgres LISTEN/NOTIFY. import ( + "context" + "database/sql/driver" "errors" "fmt" "sync" @@ -29,6 +31,61 @@ func recvNotification(r *readBuf) *Notification { return &Notification{bePid, channel, extra} } +// SetNotificationHandler sets the given notification handler on the given +// connection. A runtime panic occurs if c is not a pq connection. A nil handler +// may be used to unset it. +// +// Note: Notification handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func SetNotificationHandler(c driver.Conn, handler func(*Notification)) { + c.(*conn).notificationHandler = handler +} + +// NotificationHandlerConnector wraps a regular connector and sets a notification handler +// on it. +type NotificationHandlerConnector struct { + driver.Connector + notificationHandler func(*Notification) +} + +// Connect calls the underlying connector's connect method and then sets the +// notification handler. +func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { + c, err := n.Connector.Connect(ctx) + if err == nil { + SetNotificationHandler(c, n.notificationHandler) + } + return c, err +} + +// ConnectorNotificationHandler returns the currently set notification handler, if any. If +// the given connector is not a result of ConnectorWithNotificationHandler, nil is +// returned. +func ConnectorNotificationHandler(c driver.Connector) func(*Notification) { + if c, ok := c.(*NotificationHandlerConnector); ok { + return c.notificationHandler + } + return nil +} + +// ConnectorWithNotificationHandler creates or sets the given handler for the given +// connector. If the given connector is a result of calling this function +// previously, it is simply set on the given connector and returned. Otherwise, +// this returns a new connector wrapping the given one and setting the notification +// handler. A nil notification handler may be used to unset it. +// +// The returned connector is intended to be used with database/sql.OpenDB. +// +// Note: Notification handlers are executed synchronously by pq meaning commands +// won't continue to be processed until the handler returns. +func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector { + if c, ok := c.(*NotificationHandlerConnector); ok { + c.notificationHandler = handler + return c + } + return &NotificationHandlerConnector{Connector: c, notificationHandler: handler} +} + const ( connStateIdle int32 = iota connStateExpectResponse @@ -174,8 +231,12 @@ func (l *ListenerConn) listenerConnLoop() (err error) { } l.replyChan <- message{t, nil} - case 'N', 'S': + case 'S': // ignore + case 'N': + if n := l.cn.noticeHandler; n != nil { + n(parseError(r)) + } default: return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) } @@ -725,6 +786,9 @@ func (l *Listener) Close() error { } l.isClosed = true + // Unblock calls to Listen() + l.reconnectCond.Broadcast() + return nil } diff --git a/vendor/github.com/lib/pq/scram/scram.go b/vendor/github.com/lib/pq/scram/scram.go new file mode 100644 index 000000000..477216b60 --- /dev/null +++ b/vendor/github.com/lib/pq/scram/scram.go @@ -0,0 +1,264 @@ +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-256, for example, use: +// +// client := scram.NewClient(sha256.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that occurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 16 + buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go index e1a326a0d..d90208455 100644 --- a/vendor/github.com/lib/pq/ssl.go +++ b/vendor/github.com/lib/pq/ssl.go @@ -58,7 +58,13 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { if err != nil { return nil, err } - sslRenegotiation(&tlsConf) + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but + // the default configuration of older versions has it enabled. Redshift + // also initiates renegotiations and cannot be reconfigured. + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { client := tls.Client(conn, &tlsConf) diff --git a/vendor/github.com/lib/pq/ssl_go1.7.go b/vendor/github.com/lib/pq/ssl_go1.7.go deleted file mode 100644 index d7ba43b32..000000000 --- a/vendor/github.com/lib/pq/ssl_go1.7.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build go1.7 - -package pq - -import "crypto/tls" - -// Accept renegotiation requests initiated by the backend. -// -// Renegotiation was deprecated then removed from PostgreSQL 9.5, but -// the default configuration of older versions has it enabled. Redshift -// also initiates renegotiations and cannot be reconfigured. -func sslRenegotiation(conf *tls.Config) { - conf.Renegotiation = tls.RenegotiateFreelyAsClient -} diff --git a/vendor/github.com/lib/pq/ssl_renegotiation.go b/vendor/github.com/lib/pq/ssl_renegotiation.go deleted file mode 100644 index 85ed5e437..000000000 --- a/vendor/github.com/lib/pq/ssl_renegotiation.go +++ /dev/null @@ -1,8 +0,0 @@ -// +build !go1.7 - -package pq - -import "crypto/tls" - -// Renegotiation is not supported by crypto/tls until Go 1.7. -func sslRenegotiation(*tls.Config) {} diff --git a/vendor/github.com/lib/pq/user_posix.go b/vendor/github.com/lib/pq/user_posix.go index bf982524f..a51019205 100644 --- a/vendor/github.com/lib/pq/user_posix.go +++ b/vendor/github.com/lib/pq/user_posix.go @@ -1,6 +1,6 @@ // Package pq is a pure Go Postgres driver for the database/sql package. -// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris rumprun +// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd plan9 solaris rumprun package pq From 60f4f15fa42d83ca1d67cb939a4b8f344b6b99d1 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 08:24:30 -0700 Subject: [PATCH 05/13] Ensure db is started successfully before starting application --- Makefile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 8bdc03d88..033e56008 100644 --- a/Makefile +++ b/Makefile @@ -137,7 +137,10 @@ docker-build: docker-bootstrap: $(MAKE) docker-build docker-compose up --exit-code-from openapi openapi - docker-compose up -d + # Let the databases start up so we can successfully start up the application + docker-compose up -d queue db + sleep 5 + docker-compose up -d api worker sleep 40 $(MAKE) load-fixtures From 635b24357e08919177ea23b2030f736c080a9f2d Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 09:05:50 -0700 Subject: [PATCH 06/13] Sleep for longer to ensure db is started up --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 033e56008..1f2794678 100644 --- a/Makefile +++ b/Makefile @@ -139,9 +139,9 @@ docker-bootstrap: docker-compose up --exit-code-from openapi openapi # Let the databases start up so we can successfully start up the application docker-compose up -d queue db - sleep 5 + sleep 20 docker-compose up -d api worker - sleep 40 + sleep 20 $(MAKE) load-fixtures api-shell: From e503d690c91bbe0f5f15e2ed7770292470b23399 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 09:32:30 -0700 Subject: [PATCH 07/13] Add debugging to see what is running --- Makefile | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 1f2794678..04a08e080 100644 --- a/Makefile +++ b/Makefile @@ -103,6 +103,8 @@ load-fixtures: $(MAKE) load-synthetic-suppression-data $(MAKE) load-fixtures-ssas + docker-compose log api worker + load-synthetic-cclf-data: docker-compose up -d api docker-compose up -d db @@ -136,12 +138,13 @@ docker-build: docker-bootstrap: $(MAKE) docker-build - docker-compose up --exit-code-from openapi openapi + # docker-compose up --exit-code-from openapi openapi # Let the databases start up so we can successfully start up the application docker-compose up -d queue db sleep 20 - docker-compose up -d api worker + docker-compose restart api worker sleep 20 + docker-compose log api worker $(MAKE) load-fixtures api-shell: From 11ee78795f24ef13b6bfb4aec23a30fa46c1c348 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 09:40:33 -0700 Subject: [PATCH 08/13] Revert "Add debugging to see what is running" This reverts commit e503d690c91bbe0f5f15e2ed7770292470b23399. --- Makefile | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 04a08e080..1f2794678 100644 --- a/Makefile +++ b/Makefile @@ -103,8 +103,6 @@ load-fixtures: $(MAKE) load-synthetic-suppression-data $(MAKE) load-fixtures-ssas - docker-compose log api worker - load-synthetic-cclf-data: docker-compose up -d api docker-compose up -d db @@ -138,13 +136,12 @@ docker-build: docker-bootstrap: $(MAKE) docker-build - # docker-compose up --exit-code-from openapi openapi + docker-compose up --exit-code-from openapi openapi # Let the databases start up so we can successfully start up the application docker-compose up -d queue db sleep 20 - docker-compose restart api worker + docker-compose up -d api worker sleep 20 - docker-compose log api worker $(MAKE) load-fixtures api-shell: From cf4dc65bb5a81988f0f56b296b1130d6822a69a3 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 09:42:28 -0700 Subject: [PATCH 09/13] Start all containers --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1f2794678..9bbd09326 100644 --- a/Makefile +++ b/Makefile @@ -140,8 +140,9 @@ docker-bootstrap: # Let the databases start up so we can successfully start up the application docker-compose up -d queue db sleep 20 - docker-compose up -d api worker + docker-compose up -d sleep 20 + docker-compose logs api worker ssas $(MAKE) load-fixtures api-shell: From d7ac0922b5b2d1104aa66f3b18838f53fbec11a9 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 09:59:01 -0700 Subject: [PATCH 10/13] Double sleeptime to ensure app and worker have successfully started --- Makefile | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 9bbd09326..0d922c05f 100644 --- a/Makefile +++ b/Makefile @@ -139,10 +139,9 @@ docker-bootstrap: docker-compose up --exit-code-from openapi openapi # Let the databases start up so we can successfully start up the application docker-compose up -d queue db - sleep 20 + sleep 5 docker-compose up -d - sleep 20 - docker-compose logs api worker ssas + sleep 40 $(MAKE) load-fixtures api-shell: From 5d1334bdc49574f5af9836ba0ea2c15e6785bdf5 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 10:18:47 -0700 Subject: [PATCH 11/13] Lets wait even longer --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 0d922c05f..7e350f0c0 100644 --- a/Makefile +++ b/Makefile @@ -136,12 +136,12 @@ docker-build: docker-bootstrap: $(MAKE) docker-build - docker-compose up --exit-code-from openapi openapi # Let the databases start up so we can successfully start up the application docker-compose up -d queue db sleep 5 docker-compose up -d - sleep 40 + docker-compose up --exit-code-from openapi openapi + sleep 120 $(MAKE) load-fixtures api-shell: From b244e8526f758b497acbcd2c8088ae75d5d93958 Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 10:37:21 -0700 Subject: [PATCH 12/13] Retool --- Makefile | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 7e350f0c0..780307f10 100644 --- a/Makefile +++ b/Makefile @@ -81,6 +81,8 @@ performance-test: docker-compose -f docker-compose.test.yml run --rm -w /go/src/github.com/CMSgov/bcda-app/test/performance_test tests sh performance_test.sh test: + # Ensure instances are running as expected + docker-compose restart api worker ssas $(MAKE) lint $(MAKE) unit-test $(MAKE) postman env=local @@ -136,12 +138,9 @@ docker-build: docker-bootstrap: $(MAKE) docker-build - # Let the databases start up so we can successfully start up the application - docker-compose up -d queue db - sleep 5 - docker-compose up -d docker-compose up --exit-code-from openapi openapi - sleep 120 + docker-compose up -d + sleep 40 $(MAKE) load-fixtures api-shell: From ce28188452d629891744214ea0e89e37100207ba Mon Sep 17 00:00:00 2001 From: Martin Trang Date: Fri, 11 Dec 2020 13:11:18 -0700 Subject: [PATCH 13/13] Cleanup --- Makefile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 780307f10..ef7875862 100644 --- a/Makefile +++ b/Makefile @@ -81,8 +81,6 @@ performance-test: docker-compose -f docker-compose.test.yml run --rm -w /go/src/github.com/CMSgov/bcda-app/test/performance_test tests sh performance_test.sh test: - # Ensure instances are running as expected - docker-compose restart api worker ssas $(MAKE) lint $(MAKE) unit-test $(MAKE) postman env=local @@ -105,6 +103,10 @@ load-fixtures: $(MAKE) load-synthetic-suppression-data $(MAKE) load-fixtures-ssas + # Ensure components are started as expected + docker-compose restart api worker ssas + sleep 5 + load-synthetic-cclf-data: docker-compose up -d api docker-compose up -d db