diff --git a/multiplex.go b/multiplex.go index 2108b00..0bc697e 100644 --- a/multiplex.go +++ b/multiplex.go @@ -34,6 +34,11 @@ var ErrTwoInitiators = errors.New("two initiators") // In this case, we close the connection to be safe. var ErrInvalidState = errors.New("received an unexpected message from the peer") +var ( + NewStreamTimeout = time.Minute + ResetStreamTimeout = 2 * time.Minute +) + // +1 for initiator const ( newStreamTag = 0 @@ -140,6 +145,14 @@ func (mp *Multiplex) IsClosed() bool { } func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) error { + buf := pool.Get(len(data) + 20) + defer pool.Put(buf) + + n := 0 + n += binary.PutUvarint(buf[n:], header) + n += binary.PutUvarint(buf[n:], uint64(len(data))) + n += copy(buf[n:], data) + select { case tkn := <-mp.wrTkn: defer func() { mp.wrTkn <- tkn }() @@ -147,25 +160,22 @@ func (mp *Multiplex) sendMsg(ctx context.Context, header uint64, data []byte) er return ctx.Err() } + if mp.isShutdown() { + return ErrShutdown + } + dl, hasDl := ctx.Deadline() if hasDl { if err := mp.con.SetWriteDeadline(dl); err != nil { return err } } - buf := pool.Get(len(data) + 20) - defer pool.Put(buf) - - n := 0 - n += binary.PutUvarint(buf[n:], header) - n += binary.PutUvarint(buf[n:], uint64(len(data))) - n += copy(buf[n:], data) written, err := mp.con.Write(buf[:n]) - if err != nil && written > 0 { - // Bail. We've written partial message and can't do anything + if err != nil && (written > 0 || isFatalNetworkError(err)) { + // Bail. We've written partial message or it's a fatal error and can't do anything // about this. - mp.con.Close() + mp.closeNoWait() return err } @@ -214,7 +224,10 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { mp.channels[s.id] = s mp.chLock.Unlock() - err := mp.sendMsg(context.Background(), header, []byte(name)) + ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout) + defer cancel() + + err := mp.sendMsg(ctx, header, []byte(name)) if err != nil { return nil, err } @@ -316,20 +329,16 @@ func (mp *Multiplex) handleIncoming() { } msch.clLock.Lock() - // Honestly, this check should never be true... It means we've leaked. - // However, this is an error on *our* side so we shouldn't just bail. isClosed := msch.isClosed() - if isClosed && msch.closedRemote { - msch.clLock.Unlock() - log.Errorf("leaked a completely closed stream") - continue - } if !msch.closedRemote { close(msch.reset) + msch.closedRemote = true + } + + if !isClosed { + msch.doCloseLocal() } - msch.closedRemote = true - msch.doCloseLocal() msch.clLock.Unlock() @@ -370,7 +379,7 @@ func (mp *Multiplex) handleIncoming() { // This is a perfectly valid case when we reset // and forget about the stream. log.Debugf("message for non-existant stream, dropping data: %d", ch) - go mp.sendMsg(context.Background(), ch.header(resetTag), nil) + go mp.sendResetMsg(ch.header(resetTag), false) continue } @@ -382,7 +391,7 @@ func (mp *Multiplex) handleIncoming() { pool.Put(b) log.Warningf("Received data from remote after stream was closed by them. (len = %d)", len(b)) - go mp.sendMsg(context.Background(), msch.id.header(resetTag), nil) + go mp.sendResetMsg(msch.id.header(resetTag), false) continue } @@ -415,6 +424,30 @@ func (mp *Multiplex) handleIncoming() { } } +func (mp *Multiplex) isShutdown() bool { + select { + case <-mp.shutdown: + return true + default: + return false + } +} + +func (mp *Multiplex) sendResetMsg(header uint64, hard bool) { + ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) + defer cancel() + + err := mp.sendMsg(ctx, header, nil) + if err != nil && !mp.isShutdown() { + if hard { + log.Warningf("error sending reset message: %s; killing connection", err.Error()) + mp.Close() + } else { + log.Debugf("error sending reset message: %s", err.Error()) + } + } +} + func (mp *Multiplex) readNextHeader() (uint64, uint64, error) { h, err := binary.ReadUvarint(mp.buf) if err != nil { @@ -452,3 +485,11 @@ func (mp *Multiplex) readNext() ([]byte, error) { return buf[:n], nil } + +func isFatalNetworkError(err error) bool { + nerr, ok := err.(net.Error) + if ok { + return !(nerr.Timeout() || nerr.Temporary()) + } + return false +} diff --git a/stream.go b/stream.go index dfd70d7..034e74a 100644 --- a/stream.go +++ b/stream.go @@ -180,7 +180,10 @@ func (s *Stream) isClosed() bool { } func (s *Stream) Close() error { - err := s.mp.sendMsg(context.Background(), s.id.header(closeTag), nil) + ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) + defer cancel() + + err := s.mp.sendMsg(ctx, s.id.header(closeTag), nil) if s.isClosed() { return nil @@ -198,6 +201,11 @@ func (s *Stream) Close() error { s.mp.chLock.Unlock() } + if err != nil && !s.mp.isShutdown() { + log.Warningf("Error closing stream: %s; killing connection", err.Error()) + s.mp.Close() + } + return err } @@ -222,7 +230,7 @@ func (s *Stream) Reset() error { s.doCloseLocal() s.closedRemote = true - go s.mp.sendMsg(context.Background(), s.id.header(resetTag), nil) + go s.mp.sendResetMsg(s.id.header(resetTag), true) s.clLock.Unlock()