diff --git a/connection.go b/connection.go index a4ae8cc36..aa27953b1 100644 --- a/connection.go +++ b/connection.go @@ -581,43 +581,63 @@ func pack(h *smallWBuf, enc *encoder, reqid uint32, return } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) (err error) { +func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error { var packet smallWBuf - req := newAuthRequest(conn.opts.User, string(scramble)) - err = pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) + err := pack(&packet, newEncoder(&packet), 0, req, ignoreStreamId, conn.Schema) if err != nil { - return errors.New("auth: pack error " + err.Error()) + return fmt.Errorf("pack error: %w", err) } - if err := write(w, packet.b); err != nil { - return errors.New("auth: write error " + err.Error()) + if err = write(w, packet.b); err != nil { + return fmt.Errorf("write error: %w", err) } if err = w.Flush(); err != nil { - return errors.New("auth: flush error " + err.Error()) + return fmt.Errorf("flush error: %w", err) } - return + return err +} + +func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) error { + req := newAuthRequest(conn.opts.User, string(scramble)) + + err := conn.writeRequest(w, req) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil } -func (conn *Connection) readAuthResponse(r io.Reader) (err error) { +func (conn *Connection) readResponse(r io.Reader) (Response, error) { respBytes, err := conn.read(r) if err != nil { - return errors.New("auth: read error " + err.Error()) + return Response{}, fmt.Errorf("read error: %w", err) } + resp := Response{buf: smallBuf{b: respBytes}} err = resp.decodeHeader(conn.dec) if err != nil { - return errors.New("auth: decode response header error " + err.Error()) + return resp, fmt.Errorf("decode response header error: %w", err) } err = resp.decodeBody() if err != nil { switch err.(type) { case Error: - return err + return resp, err default: - return errors.New("auth: decode response body error " + err.Error()) + return resp, fmt.Errorf("decode response body error: %w", err) } } - return + return resp, nil +} + +func (conn *Connection) readAuthResponse(r io.Reader) error { + _, err := conn.readResponse(r) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + + return nil } func (conn *Connection) createConnection(reconnect bool) (err error) {