Skip to content

Commit

Permalink
Make sure all in-flight streams close when ClientConn.Close() is call…
Browse files Browse the repository at this point in the history
…ed. (#1136)

* Make sure all in-flight streams close when ClientConn.Close() is called.

* added test
  • Loading branch information
MakMukhi authored Apr 21, 2017
1 parent 6d0e6b0 commit 2d949be
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
3 changes: 3 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
select {
case <-t.Error():
// Incur transport error, simply exit.
case <-cc.ctx.Done():
cs.finish(ErrClientConnClosing)
cs.closeTransportStream(ErrClientConnClosing)
case <-s.Done():
// TODO: The trace of the RPC is terminated here when there is no pending
// I/O, which is probably not the optimal solution.
Expand Down
35 changes: 35 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,41 @@ func testConcurrentServerStopAndGoAway(t *testing.T, e env) {
awaitNewConnLogOutput()
}

func TestClientConnCloseAfterGoAwayWithActiveStream(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
if e.name == "handler-tls" {
continue
}
testClientConnCloseAfterGoAwayWithActiveStream(t, e)
}
}

func testClientConnCloseAfterGoAwayWithActiveStream(t *testing.T, e env) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)

if _, err := tc.FullDuplexCall(context.Background()); err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want _, <nil>", tc, err)
}
done := make(chan struct{})
go func() {
te.srv.GracefulStop()
close(done)
}()
time.Sleep(time.Second)
cc.Close()
timeout := time.NewTimer(time.Second)
select {
case <-done:
case <-timeout.C:
t.Fatalf("Test timed-out.")
}
}

func TestFailFast(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
Expand Down

0 comments on commit 2d949be

Please sign in to comment.