diff --git a/serverConn.go b/serverConn.go index 34b8ef9..5f9954e 100644 --- a/serverConn.go +++ b/serverConn.go @@ -264,7 +264,7 @@ func (sc *serverConn) handleStreams() { streamPool.Put(strm) if sc.debug { - sc.logger.Printf("Stream destroyed %d\n", strmID) + sc.logger.Printf("Stream destroyed %d. Open streams: %d\n", strmID, openStreams) } } @@ -272,6 +272,8 @@ loop: for { select { case <-sc.maxRequestTimer.C: + reqTimerArmed = false + deleteUntil := 0 for _, strm := range strms { // the request is due if the startedAt time + maxRequestTime is in the past @@ -308,6 +310,10 @@ loop: when := strm.startedAt.Add(sc.maxRequestTime).Sub(time.Now()) // if the time is negative or zero it triggers imm sc.maxRequestTimer.Reset(when) + + if sc.debug { + sc.logger.Printf("Next request will timeout in %f seconds\n", when.Seconds()) + } } } case fr, ok := <-sc.reader: @@ -380,9 +386,17 @@ loop: sc.createStream(sc.c, fr.Type(), strm) + if sc.debug { + sc.logger.Printf("Stream %d created. Open streams: %d\n", strm.ID(), openStreams) + } + if !reqTimerArmed && sc.maxRequestTime > 0 { reqTimerArmed = true sc.maxRequestTimer.Reset(sc.maxRequestTime) + + if sc.debug { + sc.logger.Printf("Next request will timeout in %f seconds\n", sc.maxRequestTime.Seconds()) + } } } diff --git a/server_test.go b/server_test.go index 4e0cce7..8b69dd2 100644 --- a/server_test.go +++ b/server_test.go @@ -39,7 +39,7 @@ func getConn(s *Server) (*Conn, net.Listener, error) { return nc, ln, nc.doHandshake() } -func makeHeaders(id uint32, enc *HPACK, endStream bool, hs map[string]string) *FrameHeader { +func makeHeaders(id uint32, enc *HPACK, endHeaders, endStream bool, hs map[string]string) *FrameHeader { fr := AcquireFrameHeader() fr.SetStream(id) @@ -56,7 +56,7 @@ func makeHeaders(id uint32, enc *HPACK, endStream bool, hs map[string]string) *F h.SetPadding(false) h.SetEndStream(endStream) - h.SetEndHeaders(true) + h.SetEndHeaders(endHeaders) return fr } @@ -89,21 +89,21 @@ func testIssue52(t *testing.T) { msg := []byte("Hello world, how are you doing?") - h1 := makeHeaders(3, c.enc, false, map[string]string{ + h1 := makeHeaders(3, c.enc, true, false, map[string]string{ string(StringAuthority): "localhost", string(StringMethod): "POST", string(StringPath): "/hello/world", string(StringScheme): "https", "Content-Length": strconv.Itoa(len(msg)), }) - h2 := makeHeaders(9, c.enc, false, map[string]string{ + h2 := makeHeaders(9, c.enc, true, false, map[string]string{ string(StringAuthority): "localhost", string(StringMethod): "POST", string(StringPath): "/hello/world", string(StringScheme): "https", "Content-Length": strconv.Itoa(len(msg)), }) - h3 := makeHeaders(7, c.enc, true, map[string]string{ + h3 := makeHeaders(7, c.enc, true, true, map[string]string{ string(StringAuthority): "localhost", string(StringMethod): "GET", string(StringPath): "/hello/world", @@ -153,3 +153,78 @@ func testIssue52(t *testing.T) { t.Fatalf("expected EOF, got %s", err) } } + +func TestIssue27(t *testing.T) { + s := &Server{ + s: &fasthttp.Server{ + Handler: func(ctx *fasthttp.RequestCtx) { + io.WriteString(ctx, "Hello world") + }, + ReadTimeout: time.Second * 1, + }, + cnf: ServerConfig{ + Debug: false, + }, + } + + c, ln, err := getConn(s) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer ln.Close() + + msg := []byte("Hello world, how are you doing?") + + h1 := makeHeaders(3, c.enc, true, false, map[string]string{ + string(StringAuthority): "localhost", + string(StringMethod): "POST", + string(StringPath): "/hello/world", + string(StringScheme): "https", + "Content-Length": strconv.Itoa(len(msg)), + }) + h2 := makeHeaders(5, c.enc, true, false, map[string]string{ + string(StringAuthority): "localhost", + string(StringMethod): "POST", + string(StringPath): "/hello/world", + string(StringScheme): "https", + "Content-Length": strconv.Itoa(len(msg)), + }) + h3 := makeHeaders(7, c.enc, false, false, map[string]string{ + string(StringAuthority): "localhost", + string(StringMethod): "GET", + string(StringPath): "/hello/world", + string(StringScheme): "https", + "Content-Length": strconv.Itoa(len(msg)), + }) + + c.writeFrame(h1) + c.writeFrame(h2) + + time.Sleep(time.Second) + c.writeFrame(h3) + + id := uint32(3) + + for i := 0; i < 3; i++ { + fr, err := c.readNext() + if err != nil { + t.Fatal(err) + } + + if fr.Stream() != id { + t.Fatalf("Expecting update on stream %d, got %d", id, fr.Stream()) + } + + if fr.Type() != FrameResetStream { + t.Fatalf("Expecting Reset, got %s", fr.Type()) + } + + rst := fr.Body().(*RstStream) + if rst.Code() != StreamCanceled { + t.Fatalf("Expecting StreamCanceled, got %s", rst.Code()) + } + + id += 2 + } +}