Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change CloseWithError to CloseWithErrorChan #123

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,25 +281,26 @@ func (s *Session) AcceptStream() (*Stream, error) {
// Close is used to close the session and all streams. It doesn't send a GoAway before
// closing the connection.
func (s *Session) Close() error {
return s.close(ErrSessionShutdown, false, goAwayNormal)
return <-s.close(ErrSessionShutdown, false, goAwayNormal)
}

// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
// CloseWithErrorChan is used to close the session and all streams after sending a GoAway message with errCode.
// Blocks for ConnectionWriteTimeout to write the GoAway message.
//
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
// receive buffer.
func (s *Session) CloseWithError(errCode uint32) error {
func (s *Session) CloseWithErrorChan(errCode uint32) chan error {
return s.close(&GoAwayError{Remote: false, ErrorCode: errCode}, true, errCode)
}

func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) error {
func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) chan error {
s.shutdownLock.Lock()
defer s.shutdownLock.Unlock()

errCh := make(chan error, 1)
if s.shutdown {
return nil
errCh <- nil
return errCh
}
s.shutdown = true
if s.shutdownErr == nil {
Expand All @@ -308,35 +309,43 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
close(s.shutdownCh)
s.stopKeepalive()

// Only send GoAway if we have an error code.
if sendGoAway && errCode != goAwayNormal {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
}

s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh

resetErr := shutdownErr
resetErr := s.shutdownErr
if _, ok := resetErr.(*GoAwayError); !ok {
resetErr = fmt.Errorf("%w: connection closed: %w", ErrStreamReset, shutdownErr)
}

s.streamLock.Lock()
defer s.streamLock.Unlock()
for id, stream := range s.streams {
stream.forceClose(resetErr)
delete(s.streams, id)
stream.memorySpan.Done()
}
return nil
s.streamLock.Unlock()

if sendGoAway {
go func() {
// wait for write loop to exit
// We need to write the current frame completely before sending a goaway.
// This will wait for at most s.config.ConnectionWriteTimeout
<-s.sendDoneCh
ga := s.goAway(errCode)
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
}
s.conn.SetWriteDeadline(time.Time{})
s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh
errCh <- nil
}()
return errCh
}

errCh <- nil
s.conn.Close()
<-s.sendDoneCh
<-s.recvDoneCh
return errCh
}

// GoAway can be used to prevent accepting further
Expand Down Expand Up @@ -748,12 +757,10 @@ func (s *Session) handleStreamMessage(hdr header) error {
return err
}
}

// Get the stream
s.streamLock.Lock()
stream := s.streams[id]
s.streamLock.Unlock()

// If we do not have a stream, likely we sent a RST and/or closed the stream for reading.
if stream == nil {
// Drain any data on the wire
Expand Down Expand Up @@ -850,7 +857,6 @@ func (s *Session) incomingStream(id uint32) error {
return err
}
stream := newStream(s, id, streamSYNReceived, initialStreamWindow, span)

s.streamLock.Lock()
defer s.streamLock.Unlock()

Expand Down
3 changes: 2 additions & 1 deletion session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,8 @@ func TestCloseWithError(t *testing.T) {
defer client.Close()
defer server.Close()

if err := server.CloseWithError(42); err != nil {
errCh := server.CloseWithErrorChan(42)
if err := <-errCh; err != nil {
t.Fatalf("err: %v", err)
}

Expand Down