From f56b1c3a9a9f693a27680d6a3729c5909e18731b Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 28 Aug 2024 03:31:48 +0530 Subject: [PATCH 1/2] add support for sending error codes on stream reset --- const.go | 21 ++++++++++++++++++++ session_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++- stream.go | 28 ++++++++++++++++++-------- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/const.go b/const.go index 7062231..716d085 100644 --- a/const.go +++ b/const.go @@ -57,6 +57,27 @@ func (e *GoAwayError) Is(target error) bool { return false } +// A StreamError is used for errors returned from Read and Write calls after the stream is Reset +type StreamError struct { + ErrorCode uint32 + Remote bool +} + +func (s *StreamError) Error() string { + if s.Remote { + return fmt.Sprintf("stream reset by remote, error code: %d", s.ErrorCode) + } + return fmt.Sprintf("stream reset, error code: %d", s.ErrorCode) +} + +func (s *StreamError) Is(target error) bool { + if target == ErrStreamReset { + return true + } + e, ok := target.(*StreamError) + return ok && *e == *s +} + var ( // ErrInvalidVersion means we received a frame with an // invalid version diff --git a/session_test.go b/session_test.go index df3e3c9..2c06abb 100644 --- a/session_test.go +++ b/session_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1571,6 +1572,56 @@ func TestStreamResetRead(t *testing.T) { wc.Wait() } +func TestStreamResetWithError(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + wc := new(sync.WaitGroup) + wc.Add(2) + go func() { + defer wc.Done() + stream, err := server.AcceptStream() + if err != nil { + t.Error(err) + } + + se := &StreamError{} + _, err = io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: true, ErrorCode: 42} + assert.Equal(t, se, expected) + }() + + stream, err := client.OpenStream(context.Background()) + if err != nil { + t.Error(err) + } + + go func() { + defer wc.Done() + + se := &StreamError{} + _, err := io.ReadAll(stream) + if !errors.As(err, &se) { + t.Errorf("exptected StreamError, got type:%T, err: %s", err, err) + return + } + expected := &StreamError{Remote: false, ErrorCode: 42} + assert.Equal(t, se, expected) + }() + + time.Sleep(1 * time.Second) + err = stream.ResetWithError(42) + if err != nil { + t.Fatal(err) + } + wc.Wait() +} + func TestLotsOfWritesWithStreamDeadline(t *testing.T) { config := testConf() config.EnableKeepAlive = false @@ -1809,7 +1860,7 @@ func TestMaxIncomingStreams(t *testing.T) { require.NoError(t, err) str.SetDeadline(time.Now().Add(time.Second)) _, err = str.Read([]byte{0}) - require.EqualError(t, err, "stream reset") + require.ErrorIs(t, err, ErrStreamReset) // Now close one of the streams. // This should then allow the client to open a new stream. diff --git a/stream.go b/stream.go index e1e5602..f6f32ec 100644 --- a/stream.go +++ b/stream.go @@ -42,6 +42,7 @@ type Stream struct { state streamState writeState, readState halfStreamState stateLock sync.Mutex + resetErr *StreamError recvBuf segmentedBuffer @@ -89,6 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.readState + resetErr := s.resetErr s.stateLock.Unlock() switch state { @@ -101,7 +103,7 @@ START: } // Closed, but we have data pending -> read. case halfReset: - return 0, ErrStreamReset + return 0, resetErr default: panic("unknown state") } @@ -147,6 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.writeState + resetErr := s.resetErr s.stateLock.Unlock() switch state { @@ -155,7 +158,7 @@ START: case halfClosed: return 0, ErrStreamClosed case halfReset: - return 0, ErrStreamReset + return 0, resetErr default: panic("unknown state") } @@ -250,13 +253,17 @@ func (s *Stream) sendClose() error { } // sendReset is used to send a RST -func (s *Stream) sendReset() error { - hdr := encode(typeWindowUpdate, flagRST, s.id, 0) +func (s *Stream) sendReset(errCode uint32) error { + hdr := encode(typeWindowUpdate, flagRST, s.id, errCode) return s.session.sendMsg(hdr, nil, nil) } // Reset resets the stream (forcibly closes the stream) func (s *Stream) Reset() error { + return s.ResetWithError(0) +} + +func (s *Stream) ResetWithError(errCode uint32) error { sendReset := false s.stateLock.Lock() switch s.state { @@ -281,10 +288,11 @@ func (s *Stream) Reset() error { s.readState = halfReset } s.state = streamFinished + s.resetErr = &StreamError{Remote: false, ErrorCode: errCode} s.notifyWaiting() s.stateLock.Unlock() if sendReset { - _ = s.sendReset() + _ = s.sendReset(errCode) } s.cleanup() return nil @@ -382,7 +390,7 @@ func (s *Stream) cleanup() { // processFlags is used to update the state of the stream // based on set flags, if any. Lock must be held -func (s *Stream) processFlags(flags uint16) { +func (s *Stream) processFlags(flags uint16, hdr header) { // Close the stream without holding the state lock var closeStream bool defer func() { @@ -425,6 +433,10 @@ func (s *Stream) processFlags(flags uint16) { s.writeState = halfReset } s.state = streamFinished + // Length in a window update frame with RST flag encodes an error code. + if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil { + s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} + } s.stateLock.Unlock() closeStream = true s.notifyWaiting() @@ -439,7 +451,7 @@ func (s *Stream) notifyWaiting() { // incrSendWindow updates the size of our send window func (s *Stream) incrSendWindow(hdr header, flags uint16) { - s.processFlags(flags) + s.processFlags(flags, hdr) // Increase window, unblock a sender atomic.AddUint32(&s.sendWindow, hdr.Length()) asyncNotify(s.sendNotifyCh) @@ -447,7 +459,7 @@ func (s *Stream) incrSendWindow(hdr header, flags uint16) { // readData is used to handle a data frame func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { - s.processFlags(flags) + s.processFlags(flags, hdr) // Check that our recv window is not exceeded length := hdr.Length() From 9190b780f8929b9a2ae6dcb0524ef0c64d9374f6 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 29 Aug 2024 01:00:53 +0530 Subject: [PATCH 2/2] fix err on conn close --- const.go | 2 +- session.go | 2 +- stream.go | 25 ++++++++++++++++--------- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/const.go b/const.go index 716d085..e1e9dc5 100644 --- a/const.go +++ b/const.go @@ -173,7 +173,7 @@ const ( // It's not an implementation choice, the value defined in the specification. initialStreamWindow = 256 * 1024 maxStreamWindow = 16 * 1024 * 1024 - goAwayWaitTime = 5 * time.Second + goAwayWaitTime = 50 * time.Millisecond ) const ( diff --git a/session.go b/session.go index 204b168..c9af6e0 100644 --- a/session.go +++ b/session.go @@ -334,7 +334,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro s.streamLock.Lock() defer s.streamLock.Unlock() for id, stream := range s.streams { - stream.forceClose() + stream.forceClose(fmt.Errorf("%w: connection closed: %w", ErrStreamReset, s.shutdownErr)) delete(s.streams, id) stream.memorySpan.Done() } diff --git a/stream.go b/stream.go index f6f32ec..e79562d 100644 --- a/stream.go +++ b/stream.go @@ -41,8 +41,8 @@ type Stream struct { state streamState writeState, readState halfStreamState + writeErr, readErr error stateLock sync.Mutex - resetErr *StreamError recvBuf segmentedBuffer @@ -90,7 +90,7 @@ func (s *Stream) Read(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.readState - resetErr := s.resetErr + resetErr := s.readErr s.stateLock.Unlock() switch state { @@ -149,7 +149,7 @@ func (s *Stream) write(b []byte) (n int, err error) { START: s.stateLock.Lock() state := s.writeState - resetErr := s.resetErr + resetErr := s.writeErr s.stateLock.Unlock() switch state { @@ -283,12 +283,13 @@ func (s *Stream) ResetWithError(errCode uint32) error { // If we've already sent/received an EOF, no need to reset that side. if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = &StreamError{Remote: false, ErrorCode: errCode} } if s.readState == halfOpen { s.readState = halfReset + s.readErr = &StreamError{Remote: false, ErrorCode: errCode} } s.state = streamFinished - s.resetErr = &StreamError{Remote: false, ErrorCode: errCode} s.notifyWaiting() s.stateLock.Unlock() if sendReset { @@ -344,6 +345,7 @@ func (s *Stream) CloseRead() error { panic("invalid state") } s.readState = halfReset + s.readErr = ErrStreamReset cleanup = s.writeState != halfOpen if cleanup { s.state = streamFinished @@ -365,13 +367,15 @@ func (s *Stream) Close() error { } // forceClose is used for when the session is exiting -func (s *Stream) forceClose() { +func (s *Stream) forceClose(err error) { s.stateLock.Lock() if s.readState == halfOpen { s.readState = halfReset + s.readErr = err } if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = err } s.state = streamFinished s.notifyWaiting() @@ -426,17 +430,20 @@ func (s *Stream) processFlags(flags uint16, hdr header) { } if flags&flagRST == flagRST { s.stateLock.Lock() + var resetErr error = ErrStreamReset + // Length in a window update frame with RST flag encodes an error code. + if hdr.MsgType() == typeWindowUpdate { + resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} + } if s.readState == halfOpen { s.readState = halfReset + s.readErr = resetErr } if s.writeState == halfOpen { s.writeState = halfReset + s.writeErr = resetErr } s.state = streamFinished - // Length in a window update frame with RST flag encodes an error code. - if hdr.MsgType() == typeWindowUpdate && s.resetErr == nil { - s.resetErr = &StreamError{Remote: true, ErrorCode: hdr.Length()} - } s.stateLock.Unlock() closeStream = true s.notifyWaiting()