diff --git a/cases.md b/cases.md new file mode 100644 index 0000000..e69de29 diff --git a/lazyClient.go b/lazyClient.go index f8d9001..a4e1815 100644 --- a/lazyClient.go +++ b/lazyClient.go @@ -6,16 +6,9 @@ import ( "sync" ) -// Multistream represents in essense a ReadWriteCloser, or a single -// communication wire which supports multiple streams on it. Each -// stream is identified by a protocol tag. -type Multistream interface { - io.ReadWriteCloser -} - // NewMSSelect returns a new Multistream which is able to perform // protocol selection with a MultistreamMuxer. -func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream { +func NewMSSelect(c io.ReadWriteCloser, proto string) LazyConn { return &lazyClientConn{ protos: []string{ProtocolID, proto}, con: c, @@ -25,7 +18,7 @@ func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream { // NewMultistream returns a multistream for the given protocol. This will not // perform any protocol selection. If you are using a MultistreamMuxer, use // NewMSSelect. -func NewMultistream(c io.ReadWriteCloser, proto string) Multistream { +func NewMultistream(c io.ReadWriteCloser, proto string) LazyConn { return &lazyClientConn{ protos: []string{proto}, con: c, @@ -139,5 +132,24 @@ func (l *lazyClientConn) Write(b []byte) (int, error) { // Close closes the underlying io.ReadWriteCloser func (l *lazyClientConn) Close() error { - return l.con.Close() + // We must flush the handshake on a "nice" close. + // Otherwise, if the other side is actually waiting for our close (i.e., + // reading until EOF), they may get an error even though we received the + // request. + flushErr := l.Flush() + // But we close anyways because close should always close. + closeErr := l.con.Close() + if flushErr != nil { + return flushErr + } + return closeErr +} + +// Flush sends the handshake. +func (l *lazyClientConn) Flush() error { + l.whandshakeOnce.Do(func() { + go l.rhandshakeOnce.Do(l.doReadHandshake) + l.doWriteHandshake() + }) + return l.werr } diff --git a/lazyServer.go b/lazyServer.go index d7501fa..02b0310 100644 --- a/lazyServer.go +++ b/lazyServer.go @@ -33,5 +33,17 @@ func (l *lazyServerConn) Read(b []byte) (int, error) { } func (l *lazyServerConn) Close() error { - return l.con.Close() + // We must flush the handshake on a "nice" close. + flushErr := l.Flush() + closeErr := l.con.Close() + if flushErr != nil { + return flushErr + } + return closeErr +} + +// Flush sends the handshake. +func (l *lazyServerConn) Flush() error { + l.waitForHandshake.Do(func() { panic("didn't initiate handshake") }) + return l.werr } diff --git a/multistream.go b/multistream.go index 8a671e5..7cfeff3 100644 --- a/multistream.go +++ b/multistream.go @@ -7,6 +7,7 @@ import ( "bufio" "bytes" "errors" + "io" "sync" @@ -51,6 +52,13 @@ func NewMultistreamMuxer() *MultistreamMuxer { return new(MultistreamMuxer) } +// LazyConn is the connection type returned by the lazy negotiation functions. +type LazyConn interface { + io.ReadWriteCloser + // Flush flushes the lazy negotiation, if any. + Flush() error +} + func writeUvarint(w io.Writer, i uint64) error { varintbuf := make([]byte, 16) n := varint.PutUvarint(varintbuf, i) @@ -201,7 +209,7 @@ func (msm *MultistreamMuxer) findHandler(proto string) *Handler { // a multistream, the protocol used, the handler and an error. It is lazy // because the write-handshake is performed on a subroutine, allowing this // to return before that handshake is completed. -func (msm *MultistreamMuxer) NegotiateLazy(rwc io.ReadWriteCloser) (io.ReadWriteCloser, string, HandlerFunc, error) { +func (msm *MultistreamMuxer) NegotiateLazy(rwc io.ReadWriteCloser) (LazyConn, string, HandlerFunc, error) { pval := make(chan string, 1) writeErr := make(chan error, 1) defer close(pval) diff --git a/multistream_test.go b/multistream_test.go index 3bd2b84..e61c7b0 100644 --- a/multistream_test.go +++ b/multistream_test.go @@ -109,7 +109,7 @@ func TestProtocolNegotiationLazy(t *testing.T) { mux.AddHandler("/b", nil) mux.AddHandler("/c", nil) - var ac Multistream + var ac LazyConn done := make(chan struct{}) go func() { m, selected, _, err := mux.NegotiateLazy(a)