Skip to content

Commit

Permalink
allow retry on netErrors in safe situations
Browse files Browse the repository at this point in the history
In some situations the sql package does not retry a pq operation when it
should.
One of the situations is #870. When a
postgresql-server is restarted and after the restart is finished an
operation is triggered on the already established connection, it failed
with an broken pipe error in some circumstances.

The sql package does not retry the operation and instead fail because
the pq driver does not return driver.ErrBadConn for network errors.
The driver must not return ErrBadConn when the server might have already
executed the operation. This would cause that sql package is retrying it
and the operation would be run multiple times by the postgresql server.
In some situations it's safe to return ErrBadConn on network errors.
This is the case when it's ensured that the server did not receive the
message that triggers the operation.

This commit introduces a netErrorNoWrite error. This error should be
used when network operations panic when it's safe to retry the
operation.
When errRecover() receives this error it returns ErrBadConn() and marks
the connection as bad.
A mustSendRetryable() function is introduced that wraps a netOpError in
an netErrorNoWrite when panicing.
mustSendRetryable() is called in situations when the send that triggers
the operation failed.
  • Loading branch information
fho committed May 28, 2019
1 parent 2ff3cb3 commit 89cb5db
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
39 changes: 27 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ func (cn *conn) gname() string {
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
cn.mustSendRetryable(b)

for {
t, r := cn.recv1()
Expand Down Expand Up @@ -632,7 +632,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {

b := cn.writeBuf('Q')
b.string(q)
cn.send(b)
cn.mustSendRetryable(b)

for {
t, r := cn.recv1()
Expand Down Expand Up @@ -765,7 +765,7 @@ func (cn *conn) prepareTo(q, stmtName string) *stmt {
b.string(st.name)

b.next('S')
cn.send(b)
cn.mustSendRetryable(b)

cn.readParseResponse()
st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
Expand Down Expand Up @@ -882,13 +882,28 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
return r, err
}

func (cn *conn) send(m *writeBuf) {
func (cn *conn) send(m *writeBuf) error {
_, err := cn.c.Write(m.wrap())
return err
}

func (cn *conn) mustSend(m *writeBuf) {
err := cn.send(m)
if err != nil {
panic(err)
}
}

func (cn *conn) mustSendRetryable(m *writeBuf) {
err := cn.send(m)
if err != nil {
if _, ok := err.(*net.OpError); ok {
err = &netErrorNoWrite{err}
}
panic(err)
}
}

func (cn *conn) sendStartupPacket(m *writeBuf) error {
_, err := cn.c.Write((m.wrap())[1:])
return err
Expand Down Expand Up @@ -1109,7 +1124,7 @@ func (cn *conn) auth(r *readBuf, o values) {
case 3:
w := cn.writeBuf('p')
w.string(o["password"])
cn.send(w)
cn.mustSend(w)

t, r := cn.recv()
if t != 'R' {
Expand All @@ -1123,7 +1138,7 @@ func (cn *conn) auth(r *readBuf, o values) {
s := string(r.next(4))
w := cn.writeBuf('p')
w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
cn.send(w)
cn.mustSend(w)

t, r := cn.recv()
if t != 'R' {
Expand All @@ -1145,7 +1160,7 @@ func (cn *conn) auth(r *readBuf, o values) {
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
cn.mustSend(w)

t, r := cn.recv()
if t != 'R' {
Expand All @@ -1165,7 +1180,7 @@ func (cn *conn) auth(r *readBuf, o values) {
scOut = sc.Out()
w = cn.writeBuf('p')
w.bytes(scOut)
cn.send(w)
cn.mustSend(w)

t, r = cn.recv()
if t != 'R' {
Expand Down Expand Up @@ -1219,9 +1234,9 @@ func (st *stmt) Close() (err error) {
w := st.cn.writeBuf('C')
w.byte('S')
w.string(st.name)
st.cn.send(w)
st.cn.mustSend(w)

st.cn.send(st.cn.writeBuf('S'))
st.cn.mustSend(st.cn.writeBuf('S'))

t, _ := st.cn.recv1()
if t != '3' {
Expand Down Expand Up @@ -1299,7 +1314,7 @@ func (st *stmt) exec(v []driver.Value) {
w.int32(0)

w.next('S')
cn.send(w)
cn.mustSend(w)

cn.readBindResponse()
cn.postExecuteWorkaround()
Expand Down Expand Up @@ -1601,7 +1616,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
b.int32(0)

b.next('S')
cn.send(b)
cn.mustSendRetryable(b)
}

func (cn *conn) processParameterStatus(r *readBuf) {
Expand Down
15 changes: 15 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,18 @@ func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}

// NetErrorNoWrite is a network error that occured before a message that
// indicates the operation to execute was transfered to the server.
// These operations are safe to retry. This error should be replaced with
// driver.ErrBadConn before it's passed to the caller.
type netErrorNoWrite struct {
Err error
}

func (e *netErrorNoWrite) Error() string {
return "netErrorNoWrite: " + e.Err.Error()
}

// TODO(ainar-g) Rename to errorf after removing panics.
func fmterrorf(s string, args ...interface{}) error {
return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))
Expand Down Expand Up @@ -492,6 +504,9 @@ func (c *conn) errRecover(err *error) {
} else {
*err = v
}
case *netErrorNoWrite:
c.bad = true
*err = driver.ErrBadConn
case *net.OpError:
c.bad = true
*err = v
Expand Down

0 comments on commit 89cb5db

Please sign in to comment.