Skip to content

Commit

Permalink
fix: propagate body stream error to close function (#1743) (#1757)
Browse files Browse the repository at this point in the history
* fix: propagate body stream error to close function (#1743)

* fix: http test

* fix: close body stream with error in encoding functions

* fix: lint

---------

Co-authored-by: Max Denushev <[email protected]>
  • Loading branch information
mdenushev and Max Denushev authored Apr 22, 2024
1 parent e88bd48 commit 57b9352
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 31 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2975,12 +2975,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
hc.releaseReader(br)
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn || resp.ConnectionClose() {
if closeConn || resp.ConnectionClose() || wErr != nil {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
Expand Down
76 changes: 48 additions & 28 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,26 +321,31 @@ func (resp *Response) BodyStream() io.Reader {
}

func (resp *Response) CloseBodyStream() error {
return resp.closeBodyStream()
return resp.closeBodyStream(nil)
}

type ReadCloserWithError interface {
io.Reader
CloseWithError(err error) error
}

type closeReader struct {
io.Reader
closeFunc func() error
closeFunc func(err error) error
}

func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}

func (c *closeReader) Close() error {
func (c *closeReader) CloseWithError(err error) error {
if c.closeFunc == nil {
return nil
}
return c.closeFunc()
return c.closeFunc(err)
}

// BodyWriter returns writer for populating request body.
Expand Down Expand Up @@ -394,7 +399,7 @@ func (resp *Response) Body() []byte {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
Expand Down Expand Up @@ -618,7 +623,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error {
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
_, err := copyZeroAlloc(w, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
return err
}
_, err := w.Write(resp.bodyBytes())
Expand All @@ -629,29 +634,29 @@ func (resp *Response) BodyWriteTo(w io.Writer) error {
//
// It is safe re-using p after the function returns.
func (resp *Response) AppendBody(p []byte) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().Write(p) //nolint:errcheck
}

// AppendBodyString appends s to response body.
func (resp *Response) AppendBodyString(s string) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().WriteString(s) //nolint:errcheck
}

// SetBody sets response body.
//
// It is safe re-using body argument after the function returns.
func (resp *Response) SetBody(body []byte) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.Write(body) //nolint:errcheck
}

// SetBodyString sets response body.
func (resp *Response) SetBodyString(body string) {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.WriteString(body) //nolint:errcheck
Expand All @@ -660,7 +665,7 @@ func (resp *Response) SetBodyString(body string) {
// ResetBody resets response body.
func (resp *Response) ResetBody() {
resp.bodyRaw = nil
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
if resp.body != nil {
if resp.keepBodyBuffer {
resp.body.Reset()
Expand Down Expand Up @@ -700,7 +705,7 @@ func (resp *Response) ReleaseBody(size int) {
return
}
if cap(resp.body.B) > size {
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(nil) //nolint:errcheck
resp.body = nil
}
}
Expand Down Expand Up @@ -734,7 +739,7 @@ func (resp *Response) SwapBody(body []byte) []byte {
if resp.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, resp.bodyStream)
resp.closeBodyStream() //nolint:errcheck
resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
Expand Down Expand Up @@ -1725,10 +1730,13 @@ func (resp *Response) brotliBody(level int) {
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessBrotliWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
Expand Down Expand Up @@ -1780,10 +1788,13 @@ func (resp *Response) gzipBody(level int) {
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessGzipWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
Expand Down Expand Up @@ -1835,10 +1846,13 @@ func (resp *Response) deflateBody(level int) {
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessDeflateWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
Expand Down Expand Up @@ -1887,10 +1901,13 @@ func (resp *Response) zstdBody(level int) {
wf: zw,
bw: sw,
}
copyZeroAlloc(fw, bs) //nolint:errcheck
_, wErr := copyZeroAlloc(fw, bs)
releaseStacklessZstdWriter(zw, level)
if bsc, ok := bs.(io.Closer); ok {
bsc.Close()
switch v := bs.(type) {
case io.Closer:
v.Close()
case ReadCloserWithError:
v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
Expand Down Expand Up @@ -2053,7 +2070,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error
}
}
}
errc := resp.closeBodyStream()
errc := resp.closeBodyStream(err)
if err == nil {
err = errc
}
Expand All @@ -2075,14 +2092,17 @@ func (req *Request) closeBodyStream() error {
return err
}

func (resp *Response) closeBodyStream() error {
func (resp *Response) closeBodyStream(wErr error) error {
if resp.bodyStream == nil {
return nil
}
var err error
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok {
err = bsc.CloseWithError(wErr)
}
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
Expand Down
2 changes: 1 addition & 1 deletion http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,7 @@ func TestResponseBodyStream(t *testing.T) {
t.Fatalf("parse response find err: %v", err)
}
defer func() {
if err := response.closeBodyStream(); err != nil {
if err := response.closeBodyStream(nil); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()
Expand Down

0 comments on commit 57b9352

Please sign in to comment.