diff --git a/etcdserver/api/v3rpc/watch.go b/etcdserver/api/v3rpc/watch.go index 208c453ffa56..f326d710e4cd 100644 --- a/etcdserver/api/v3rpc/watch.go +++ b/etcdserver/api/v3rpc/watch.go @@ -194,23 +194,29 @@ func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { close(sws.ctrlStream) case <-stream.Context().Done(): - err = stream.Context().Err() - if err == context.Canceled { - err = rpctypes.ErrGRPCWatchCanceled - // Try to determine a more specific cancellation error. Use the stream context - // and local leader state to detect leader loss. If the reason is inconclusive, - // assume a client cancellation. - if md, hasMetadata := metadata.FromIncomingContext(stream.Context()); hasMetadata { - if rl := md[rpctypes.MetadataRequireLeaderKey]; len(rl) > 0 && rl[0] == rpctypes.MetadataHasLeader { - if sws.sg.Leader() == types.ID(raft.None) { - err = rpctypes.ErrGRPCNoLeader - } + err = getClientStreamClosureError(stream.Context(), sws.sg) + } + + sws.close() + return err +} + +// getClientStreamClosureError tries to determine a more specific client stream +// cancellation error. If the context error appears to be a cancellation, use +// the stream context and local leader state to detect leader loss. If the +// reason is inconclusive, assume a client cancellation. +func getClientStreamClosureError(ctx context.Context, statusGetter etcdserver.RaftStatusGetter) error { + err := ctx.Err() + if err == context.Canceled { + err = rpctypes.ErrGRPCWatchCanceled + if md, hasMetadata := metadata.FromIncomingContext(ctx); hasMetadata { + if rl := md[rpctypes.MetadataRequireLeaderKey]; len(rl) > 0 && rl[0] == rpctypes.MetadataHasLeader { + if statusGetter.Leader() == types.ID(raft.None) { + err = rpctypes.ErrGRPCNoLeader } } } } - - sws.close() return err } diff --git a/etcdserver/api/v3rpc/watch_test.go b/etcdserver/api/v3rpc/watch_test.go index f507f5eabd88..ce7b8f99bf3e 100644 --- a/etcdserver/api/v3rpc/watch_test.go +++ b/etcdserver/api/v3rpc/watch_test.go @@ -16,11 +16,19 @@ package v3rpc import ( "bytes" + "context" + "fmt" "math" "testing" + "google.golang.org/grpc/metadata" + + "go.etcd.io/etcd/v3/etcdserver" + "go.etcd.io/etcd/v3/etcdserver/api/v3rpc/rpctypes" pb "go.etcd.io/etcd/v3/etcdserver/etcdserverpb" "go.etcd.io/etcd/v3/mvcc/mvccpb" + "go.etcd.io/etcd/v3/pkg/types" + "go.etcd.io/etcd/v3/raft" ) func TestSendFragment(t *testing.T) { @@ -93,3 +101,75 @@ func createResponse(dataSize, events int) (resp *pb.WatchResponse) { } return resp } + +func TestWatchCloseErrorDetection(t *testing.T) { + unknownError := fmt.Errorf("unknown") + + tests := map[string]struct{ + ctx context.Context + leaderRequired bool + hasLeader bool + expect error + }{ + "no error": { + ctx: context.TODO(), leaderRequired: false, hasLeader: true, expect: nil, + }, + "generic cancellation": { + ctx: cancelledContext(), leaderRequired: false, hasLeader: true, expect: rpctypes.ErrGRPCWatchCanceled, + }, + "leader required and present": { + ctx: cancelledContext(), leaderRequired: true, hasLeader: true, expect: rpctypes.ErrGRPCWatchCanceled, + }, + "leader required but missing": { + ctx: cancelledContext(), leaderRequired: true, hasLeader: false, expect: rpctypes.ErrGRPCNoLeader, + }, + "unknown error": { + ctx: errorContext(unknownError), leaderRequired: false, hasLeader: true, expect: unknownError, + }, + } + + for name, test := range tests { + if test.leaderRequired { + md := metadata.New(map[string]string{rpctypes.MetadataRequireLeaderKey: rpctypes.MetadataHasLeader}) + test.ctx = metadata.NewIncomingContext(test.ctx, md) + } + + rsg := fakeRaftStatusGetter{leader: 1} + if !test.hasLeader { + rsg.leader = types.ID(raft.None) + } + + actual := getClientStreamClosureError(test.ctx, rsg) + if test.expect != actual { + t.Errorf("test %q expected %v, got %v", name, test.expect, actual) + } + } +} + +func cancelledContext() context.Context { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + return ctx +} + +func errorContext(err error) context.Context { + return genericErrorContext{err: err} +} + +type genericErrorContext struct { + context.Context + err error +} + +func (c genericErrorContext) Err() error { + return c.err +} + +type fakeRaftStatusGetter struct { + etcdserver.RaftStatusGetter + leader types.ID +} + +func (f fakeRaftStatusGetter) Leader() types.ID { + return f.leader +}