diff --git a/session.go b/session.go index e229730..785ef86 100644 --- a/session.go +++ b/session.go @@ -281,25 +281,26 @@ func (s *Session) AcceptStream() (*Stream, error) { // Close is used to close the session and all streams. It doesn't send a GoAway before // closing the connection. func (s *Session) Close() error { - return s.close(ErrSessionShutdown, false, goAwayNormal) + return <-s.close(ErrSessionShutdown, false, goAwayNormal) } -// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode. +// CloseWithErrorChan is used to close the session and all streams after sending a GoAway message with errCode. // Blocks for ConnectionWriteTimeout to write the GoAway message. // // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn. // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel // receive buffer. -func (s *Session) CloseWithError(errCode uint32) error { +func (s *Session) CloseWithErrorChan(errCode uint32) chan error { return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode) } -func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error { +func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) chan error { s.shutdownLock.Lock() defer s.shutdownLock.Unlock() - + errCh := make(chan error, 1) if s.shutdown { - return nil + errCh <- nil + return errCh } s.shutdown = true if s.shutdownErr == nil { @@ -308,35 +309,43 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro close(s.shutdownCh) s.stopKeepalive() - // Only send GoAway if we have an error code. - if sendGoAway && errCode != goAwayNormal { - // wait for write loop to exit - // We need to write the current frame completely before sending a goaway. - // This will wait for at most s.config.ConnectionWriteTimeout - <-s.sendDoneCh - ga := s.goAway(errCode) - if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { - _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here - } - s.conn.SetWriteDeadline(time.Time{}) - } - - s.conn.Close() - <-s.sendDoneCh - <-s.recvDoneCh - - resetErr := shutdownErr + resetErr := s.shutdownErr if _, ok := resetErr.(*GoAwayError); !ok { resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr) } + s.streamLock.Lock() - defer s.streamLock.Unlock() for id, stream := range s.streams { stream.forceClose(resetErr) delete(s.streams, id) stream.memorySpan.Done() } - return nil + s.streamLock.Unlock() + + if sendGoAway { + go func() { + // wait for write loop to exit + // We need to write the current frame completely before sending a goaway. + // This will wait for at most s.config.ConnectionWriteTimeout + <-s.sendDoneCh + ga := s.goAway(errCode) + if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil { + _, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here + } + s.conn.SetWriteDeadline(time.Time{}) + s.conn.Close() + <-s.sendDoneCh + <-s.recvDoneCh + errCh <- nil + }() + return errCh + } + + errCh <- nil + s.conn.Close() + <-s.sendDoneCh + <-s.recvDoneCh + return errCh } // GoAway can be used to prevent accepting further @@ -748,12 +757,10 @@ func (s *Session) handleStreamMessage(hdr header) error { return err } } - // Get the stream s.streamLock.Lock() stream := s.streams[id] s.streamLock.Unlock() - // If we do not have a stream, likely we sent a RST and/or closed the stream for reading. if stream == nil { // Drain any data on the wire @@ -850,7 +857,6 @@ func (s *Session) incomingStream(id uint32) error { return err } stream := newStream(s, id, streamSYNReceived, initialStreamWindow, span) - s.streamLock.Lock() defer s.streamLock.Unlock() diff --git a/session_test.go b/session_test.go index 6d3bce0..605eb75 100644 --- a/session_test.go +++ b/session_test.go @@ -671,7 +671,8 @@ func TestCloseWithError(t *testing.T) { defer client.Close() defer server.Close() - if err := server.CloseWithError(42); err != nil { + errCh := server.CloseWithErrorChan(42) + if err := <-errCh; err != nil { t.Fatalf("err: %v", err) }