diff --git a/buffer.go b/buffer.go index d05b199..f8cf30a 100644 --- a/buffer.go +++ b/buffer.go @@ -42,6 +42,10 @@ func (s *bufferedReader) Read(p []byte) (int, error) { bn := copy(p, s.buffer.Bytes()[s.bufferRead:s.bufferSize]) s.bufferRead += bn return bn, s.lastErr + } else if !s.sniffing && s.buffer.Cap() != 0 { + // We don't need the buffer anymore. + // Reset it to release the internal slice. + s.buffer = bytes.Buffer{} } // If there is nothing more to return in the sniffed buffer, read from the diff --git a/cmux_test.go b/cmux_test.go index 2279ded..17945c8 100644 --- a/cmux_test.go +++ b/cmux_test.go @@ -15,6 +15,7 @@ package cmux import ( + "bytes" "errors" "fmt" "io" @@ -32,6 +33,7 @@ import ( "time" "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" ) const ( @@ -186,8 +188,8 @@ func runTestRPCClient(t *testing.T, addr net.Addr) { } const ( - handleHttp1Close = 1 - handleHttp1Request = 2 + handleHTTP1Close = 1 + handleHTTP1Request = 2 handleAnyClose = 3 handleAnyRequest = 4 ) @@ -208,11 +210,11 @@ func TestTimeout(t *testing.T) { go func() { con, err := http1.Accept() if err != nil { - result <- handleHttp1Close + result <- handleHTTP1Close } else { _, _ = con.Write([]byte("http1")) _ = con.Close() - result <- handleHttp1Request + result <- handleHTTP1Request } }() go func() { @@ -258,7 +260,7 @@ func TestTimeout(t *testing.T) { if a := <-result; a != handleAnyRequest { t.Fatal("testTimeout failed: any rule did not match") } - if a := <-result; a != handleHttp1Close { + if a := <-result; a != handleHTTP1Close { t.Fatal("testTimeout failed: no close an http rule") } } @@ -394,6 +396,72 @@ func TestHTTP2(t *testing.T) { } } +func TestHTTP2MatchHeaderField(t *testing.T) { + defer leakCheck(t)() + errCh := make(chan error) + defer func() { + select { + case err := <-errCh: + t.Fatal(err) + default: + } + }() + name := "name" + value := "value" + writer, reader := net.Pipe() + go func() { + if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + if err := enc.WriteField(hpack.HeaderField{Name: name, Value: value}); err != nil { + t.Fatal(err) + } + framer := http2.NewFramer(writer, nil) + err := framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: 1, + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }) + if err != nil { + t.Fatal(err) + } + if err := writer.Close(); err != nil { + t.Fatal(err) + } + }() + + l := newChanListener() + l.connCh <- reader + muxl := New(l) + // Register a bogus matcher that only reads one byte. + muxl.Match(func(r io.Reader) bool { + var b [1]byte + _, _ = r.Read(b[:]) + return false + }) + // Create a matcher that cannot match the response. + muxl.Match(HTTP2HeaderField(name, "another"+value)) + // Then match with the expected field. + h2l := muxl.Match(HTTP2HeaderField(name, value)) + go safeServe(errCh, muxl) + muxedConn, err := h2l.Accept() + close(l.connCh) + if err != nil { + t.Fatal(err) + } + var b [len(http2.ClientPreface)]byte + // We have the sniffed buffer first... + if _, err := muxedConn.Read(b[:]); err == io.EOF { + t.Fatal(err) + } + if string(b[:]) != http2.ClientPreface { + t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) + } +} + func TestHTTPGoRPC(t *testing.T) { defer leakCheck(t)() errCh := make(chan error) diff --git a/matchers.go b/matchers.go index 2e7428f..485ede8 100644 --- a/matchers.go +++ b/matchers.go @@ -144,10 +144,14 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool return false } + done := false framer := http2.NewFramer(w, r) hdec := hpack.NewDecoder(uint32(4<<10), func(hf hpack.HeaderField) { - if hf.Name == name && hf.Value == value { - matched = true + if hf.Name == name { + done = true + if hf.Value == value { + matched = true + } } }) for { @@ -161,17 +165,20 @@ func matchHTTP2Field(w io.Writer, r io.Reader, name, value string) (matched bool if err := framer.WriteSettings(); err != nil { return false } - case *http2.HeadersFrame: + case *http2.ContinuationFrame: if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false } - if matched { - return true - } - - if f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 { + done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 + case *http2.HeadersFrame: + if _, err := hdec.Write(f.HeaderBlockFragment()); err != nil { return false } + done = done || f.FrameHeader.Flags&http2.FlagHeadersEndHeaders != 0 + } + + if done { + return matched } } }