diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 38fc4a8f3702..29c8909248a0 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -44,6 +44,7 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -384,7 +385,8 @@ func TestGRPCLB(t *testing.T) { creds := serverNameCheckCreds{ expected: besn, } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) @@ -421,7 +423,8 @@ func TestDropRequest(t *testing.T) { creds := serverNameCheckCreds{ expected: besn, } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) @@ -471,7 +474,8 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { creds := serverNameCheckCreds{ expected: besn, } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) @@ -479,7 +483,8 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { t.Fatalf("Failed to dial to the backend %v", err) } testC := testpb.NewTestServiceClient(cc) - ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded) } @@ -521,7 +526,8 @@ func TestServerExpiration(t *testing.T) { creds := serverNameCheckCreds{ expected: besn, } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) @@ -578,7 +584,8 @@ func TestBalancerDisconnects(t *testing.T) { creds := serverNameCheckCreds{ expected: besn, } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() resolver := &testNameResolver{ addrs: lbAddrs[:2], } @@ -646,11 +653,14 @@ func (failPreRPCCred) RequireTransportSecurity() bool { return false } -func TestGRPCLBStatsUnary(t *testing.T) { - var ( - countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting. - countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting. - ) +func checkStats(stats *lbpb.ClientStats, expected *lbpb.ClientStats) error { + if !proto.Equal(stats, expected) { + return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected) + } + return nil +} + +func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbpb.ClientStats { tss, cleanup, err := newLoadBalancer(3) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) @@ -658,157 +668,247 @@ func TestGRPCLBStatsUnary(t *testing.T) { defer cleanup() tss.ls.sls = []*lbpb.ServerList{{ Servers: []*lbpb.Server{{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, - }, { - IpAddress: tss.beIPs[1], - Port: int32(tss.bePorts[1]), - LoadBalanceToken: lbToken, - DropForRateLimiting: true, - }, { IpAddress: tss.beIPs[2], Port: int32(tss.bePorts[2]), LoadBalanceToken: lbToken, - DropForLoadBalancing: false, + DropForLoadBalancing: dropForLoadBalancing, + DropForRateLimiting: dropForRateLimiting, }}, }} tss.ls.intervals = []time.Duration{0} tss.ls.statsDura = 100 * time.Millisecond - creds := serverNameCheckCreds{ - expected: besn, - } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + creds := serverNameCheckCreds{expected: besn} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ addrs: []string{tss.lbAddr}, })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{})) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } - testC := testpb.NewTestServiceClient(cc) - // The first non-failfast RPC succeeds, all connections are up. - if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { - t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) - } - for i := 0; i < countNormalRPC-1; i++ { - testC.EmptyCall(context.Background(), &testpb.Empty{}) - } - for i := 0; i < countFailedToSend; i++ { - grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) - } - cc.Close() + defer cc.Close() + runRPCs(cc) time.Sleep(1 * time.Second) tss.ls.mu.Lock() - if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) { - t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend) - } - if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) { - t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend) - } - if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 { - t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend) - } - if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 { - t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend) - } - if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 { - t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend) - } - if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 { - t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC) - } + stats := tss.ls.stats tss.ls.mu.Unlock() + return stats } -func TestGRPCLBStatsStreaming(t *testing.T) { - var ( - countNormalRPC = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting. - countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting. - ) - tss, cleanup, err := newLoadBalancer(3) - if err != nil { - t.Fatalf("failed to create new load balancer: %v", err) - } - defer cleanup() - tss.ls.sls = []*lbpb.ServerList{{ - Servers: []*lbpb.Server{{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: true, - }, { - IpAddress: tss.beIPs[1], - Port: int32(tss.bePorts[1]), - LoadBalanceToken: lbToken, - DropForRateLimiting: true, - }, { - IpAddress: tss.beIPs[2], - Port: int32(tss.bePorts[2]), - LoadBalanceToken: lbToken, - DropForLoadBalancing: false, - }}, - }} - tss.ls.intervals = []time.Duration{0} - tss.ls.statsDura = 100 * time.Millisecond - creds := serverNameCheckCreds{ - expected: besn, - } - ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) - cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{ - addrs: []string{tss.lbAddr}, - })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{})) - if err != nil { - t.Fatalf("Failed to dial to the backend %v", err) - } - testC := testpb.NewTestServiceClient(cc) - // The first non-failfast RPC succeeds, all connections are up. - var stream testpb.TestService_FullDuplexCallClient - stream, err = testC.FullDuplexCall(context.Background(), grpc.FailFast(false)) - if err != nil { - t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, ", testC, err) - } - for { - if _, err = stream.Recv(); err == io.EOF { - break +const countRPC = 40 + +func TestGRPCLBStatsUnarySuccess(t *testing.T) { + stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + // The first non-failfast RPC succeeds, all connections are up. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) } + for i := 0; i < countRPC-1; i++ { + testC.EmptyCall(context.Background(), &testpb.Empty{}) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC), + NumCallsFinished: int64(countRPC), + NumCallsFinishedKnownReceived: int64(countRPC), + }); err != nil { + t.Fatal(err) } - for i := 0; i < countNormalRPC-1; i++ { - stream, err = testC.FullDuplexCall(context.Background()) - if err == nil { - // Wait for stream to end if err is nil. - for { - if _, err = stream.Recv(); err == io.EOF { +} + +func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) { + c := 0 + stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + for { + c++ + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if strings.Contains(err.Error(), "drops requests") { break } } } + for i := 0; i < countRPC; i++ { + testC.EmptyCall(context.Background(), &testpb.Empty{}) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC + c), + NumCallsFinished: int64(countRPC + c), + NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1), + NumCallsFinishedWithClientFailedToSend: int64(c - 1), + }); err != nil { + t.Fatal(err) } - for i := 0; i < countFailedToSend; i++ { - grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") - } - cc.Close() +} - time.Sleep(1 * time.Second) - tss.ls.mu.Lock() - if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) { - t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend) +func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) { + c := 0 + stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + for { + c++ + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if strings.Contains(err.Error(), "drops requests") { + break + } + } + } + for i := 0; i < countRPC; i++ { + testC.EmptyCall(context.Background(), &testpb.Empty{}) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC + c), + NumCallsFinished: int64(countRPC + c), + NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1), + NumCallsFinishedWithClientFailedToSend: int64(c - 1), + }); err != nil { + t.Fatal(err) } - if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) { - t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend) +} + +func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { + stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + // The first non-failfast RPC succeeds, all connections are up. + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + for i := 0; i < countRPC-1; i++ { + grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC), + NumCallsFinished: int64(countRPC), + NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1), + NumCallsFinishedKnownReceived: 1, + }); err != nil { + t.Fatal(err) } - if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 { - t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend) +} + +func TestGRPCLBStatsStreamingSuccess(t *testing.T) { + stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + // The first non-failfast RPC succeeds, all connections are up. + stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, ", testC, err) + } + for { + if _, err = stream.Recv(); err == io.EOF { + break + } + } + for i := 0; i < countRPC-1; i++ { + stream, err = testC.FullDuplexCall(context.Background()) + if err == nil { + // Wait for stream to end if err is nil. + for { + if _, err = stream.Recv(); err == io.EOF { + break + } + } + } + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC), + NumCallsFinished: int64(countRPC), + NumCallsFinishedKnownReceived: int64(countRPC), + }); err != nil { + t.Fatal(err) } - if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 { - t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend) +} + +func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) { + c := 0 + stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + for { + c++ + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if strings.Contains(err.Error(), "drops requests") { + break + } + } + } + for i := 0; i < countRPC; i++ { + testC.FullDuplexCall(context.Background()) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC + c), + NumCallsFinished: int64(countRPC + c), + NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1), + NumCallsFinishedWithClientFailedToSend: int64(c - 1), + }); err != nil { + t.Fatal(err) } - if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 { - t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend) +} + +func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) { + c := 0 + stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + for { + c++ + if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { + if strings.Contains(err.Error(), "drops requests") { + break + } + } + } + for i := 0; i < countRPC; i++ { + testC.FullDuplexCall(context.Background()) + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC + c), + NumCallsFinished: int64(countRPC + c), + NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1), + NumCallsFinishedWithClientFailedToSend: int64(c - 1), + }); err != nil { + t.Fatal(err) } - if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 { - t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC) +} + +func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { + stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) { + testC := testpb.NewTestServiceClient(cc) + // The first non-failfast RPC succeeds, all connections are up. + stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false)) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, ", testC, err) + } + for { + if _, err = stream.Recv(); err == io.EOF { + break + } + } + for i := 0; i < countRPC-1; i++ { + grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend") + } + }) + + if err := checkStats(&stats, &lbpb.ClientStats{ + NumCallsStarted: int64(countRPC), + NumCallsFinished: int64(countRPC), + NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1), + NumCallsFinishedKnownReceived: 1, + }); err != nil { + t.Fatal(err) } - tss.ls.mu.Unlock() }