From 1e2a9c2811231448d82ee47491363ff7a21128d6 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 4 Feb 2016 21:14:01 +0000 Subject: [PATCH] Add a ServeHTTP method to *grpc.Server This adds new http.Handler-based ServerTransport in the process, reusing the HTTP/2 server code in x/net/http2 or Go 1.6+. All end2end tests pass with this new ServerTransport. Fixes grpc/grpc-go#75 Also: Updates grpc/grpc-go#495 (lets user fix it with middleware in front) Updates grpc/grpc-go#468 (x/net/http2 validates) Updates grpc/grpc-go#147 (possible with x/net/http2) Updates grpc/grpc-go#104 (x/net/http2 does this) --- rpc_util.go | 2 +- server.go | 172 +++++++++++++++---- test/end2end_test.go | 192 +++++++++++++++++++--- test/race_test.go | 39 +++++ testdata/dir.go | 70 ++++++++ transport/handler_server.go | 319 ++++++++++++++++++++++++++++++++++++ transport/http_util.go | 3 +- 7 files changed, 741 insertions(+), 56 deletions(-) create mode 100644 test/race_test.go create mode 100644 testdata/dir.go create mode 100644 transport/handler_server.go diff --git a/rpc_util.go b/rpc_util.go index e98ddbcdc5a7..fadf3394d6ca 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er case compressionNone: case compressionMade: if recvCompress == "" { - return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) + return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress) } if dc == nil || recvCompress != dc.Type() { return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) diff --git a/server.go b/server.go index 1c42b6eff278..086ae0b19b14 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,7 @@ import ( "fmt" "io" "net" + "net/http" "reflect" "runtime" "strings" @@ -46,6 +47,7 @@ import ( "time" "golang.org/x/net/context" + "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -82,10 +84,11 @@ type service struct { // Server is a gRPC server to serve RPC requests. type Server struct { - opts options - mu sync.Mutex + opts options + + mu sync.Mutex // guards following lis map[net.Listener]bool - conns map[transport.ServerTransport]bool + conns map[io.Closer]bool m map[string]*service // service name -> service info events trace.EventLog } @@ -96,6 +99,7 @@ type options struct { cp Compressor dc Decompressor maxConcurrentStreams uint32 + useHandlerImpl bool // use http.Handler-based server } // A ServerOption sets options. @@ -149,7 +153,7 @@ func NewServer(opt ...ServerOption) *Server { s := &Server{ lis: make(map[net.Listener]bool), opts: opts, - conns: make(map[transport.ServerTransport]bool), + conns: make(map[io.Closer]bool), m: make(map[string]*service), } if EnableTracing { @@ -216,9 +220,17 @@ var ( ErrServerStopped = errors.New("grpc: the server has been stopped") ) +func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + creds, ok := s.opts.creds.(credentials.TransportAuthenticator) + if !ok { + return rawConn, nil, nil + } + return creds.ServerHandshake(rawConn) +} + // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines -// read gRPC request and then call the registered handlers to reply to them. +// read gRPC requests and then call the registered handlers to reply to them. // Service returns when lis.Accept fails. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() @@ -235,39 +247,54 @@ func (s *Server) Serve(lis net.Listener) error { delete(s.lis, lis) s.mu.Unlock() }() + listenerAddr := lis.Addr() for { - c, err := lis.Accept() + rawConn, err := lis.Accept() if err != nil { s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() return err } - var authInfo credentials.AuthInfo - if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { - var conn net.Conn - conn, authInfo, err = creds.ServerHandshake(c) - if err != nil { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err) - s.mu.Unlock() - grpclog.Println("grpc: Server.Serve failed to complete security handshake.") - continue - } - c = conn - } + // Start a new goroutine to deal with rawConn + // so we don't stall this Accept loop goroutine. + go s.handleRawConn(listenerAddr, rawConn) + } +} + +// handleRawConn is run in its own goroutine and handles a just-accepted +// connection that has not had any I/O performed on it yet. +func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) { + conn, authInfo, err := s.useTransportAuthenticator(rawConn) + if err != nil { s.mu.Lock() - if s.conns == nil { - s.mu.Unlock() - c.Close() - return nil - } + s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() + grpclog.Println("grpc: Server.Serve failed to complete security handshake.") + rawConn.Close() + return + } - go s.serveNewHTTP2Transport(c, authInfo) + s.mu.Lock() + if s.conns == nil { + s.mu.Unlock() + conn.Close() + return + } + s.mu.Unlock() + + if s.opts.useHandlerImpl { + s.serveUsingHandler(listenerAddr, conn) + } else { + s.serveNewHTTP2Transport(conn, authInfo) } } +// serveNewHTTP2Transport sets up a new http/2 transport (using the +// gRPC http2 server transport in transport/http2_server.go) and +// serves streams on it. +// This is run in its own goroutine (it does network I/O in +// transport.NewServerTransport). func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) if err != nil { @@ -299,6 +326,52 @@ func (s *Server) serveStreams(st transport.ServerTransport) { wg.Wait() } +var _ http.Handler = (*Server)(nil) + +// serveUsingHandler is the implementation of Serve(net.Listener) when +// TestingUseHandlerImpl has been configured. This lets the end2end +// tests exercise the ServeHTTP method as one of the environment types. +// +// conn is the *tls.Conn that's already been authenticated. +func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) { + if !s.addConn(conn) { + conn.Close() + return + } + defer s.removeConn(conn) + connDone := make(chan struct{}) + hs := &http.Server{ + Handler: s, + ConnState: func(c net.Conn, cs http.ConnState) { + if cs == http.StateClosed { + close(connDone) + } + }, + } + if err := http2.ConfigureServer(hs, &http2.Server{ + MaxConcurrentStreams: s.opts.maxConcurrentStreams, + }); err != nil { + grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err) + return + } + hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn}) + <-connDone +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + st, err := transport.NewServerHandlerTransport(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !s.addConn(st) { + st.Close() + return + } + defer s.removeConn(st) + s.serveStreams(st) +} + // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. // If tracing is not enabled, it returns nil. func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { @@ -317,21 +390,21 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea return trInfo } -func (s *Server) addConn(st transport.ServerTransport) bool { +func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() if s.conns == nil { return false } - s.conns[st] = true + s.conns[c] = true return true } -func (s *Server) removeConn(st transport.ServerTransport) { +func (s *Server) removeConn(c io.Closer) { s.mu.Lock() defer s.mu.Unlock() if s.conns != nil { - delete(s.conns, st) + delete(s.conns, c) } } @@ -603,12 +676,14 @@ func (s *Server) Stop() { cs := s.conns s.conns = nil s.mu.Unlock() + for lis := range listeners { lis.Close() } for c := range cs { c.Close() } + s.mu.Lock() if s.events != nil { s.events.Finish() @@ -618,16 +693,24 @@ func (s *Server) Stop() { } // TestingCloseConns closes all exiting transports but keeps s.lis accepting new -// connections. This is for test only now. +// connections. +// This is only for tests and is subject to removal. func (s *Server) TestingCloseConns() { s.mu.Lock() for c := range s.conns { c.Close() + delete(s.conns, c) } - s.conns = make(map[transport.ServerTransport]bool) s.mu.Unlock() } +// TestingUseHandlerImpl enables the http.Handler-based server implementation. +// It must be called before Serve and requires TLS credentials. +// This is only for tests and is subject to removal. +func (s *Server) TestingUseHandlerImpl() { + s.opts.useHandlerImpl = true +} + // SendHeader sends header metadata. It may be called at most once from a unary // RPC handler. The ctx is the RPC handler's Context or one derived from it. func SendHeader(ctx context.Context, md metadata.MD) error { @@ -658,3 +741,30 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { } return stream.SetTrailer(md) } + +// singleConnListener is a net.Listener that yields a single conn. +type singleConnListener struct { + mu sync.Mutex + addr net.Addr + conn net.Conn // nil if done +} + +func (ln *singleConnListener) Addr() net.Addr { return ln.addr } + +func (ln *singleConnListener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + ln.conn = nil + return nil +} + +func (ln *singleConnListener) Accept() (net.Conn, error) { + ln.mu.Lock() + defer ln.mu.Unlock() + c := ln.conn + if c == nil { + return nil, io.EOF + } + ln.conn = nil + return c, nil +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 4abe16592607..203d4e5b64d0 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -34,12 +34,15 @@ package grpc_test import ( + "flag" "fmt" "io" "math" "net" "reflect" "runtime" + "sort" + "strings" "sync" "syscall" "testing" @@ -58,13 +61,21 @@ import ( ) var ( + // For headers: testMetadata = metadata.MD{ "key1": []string{"value1"}, "key2": []string{"value2"}, } + // For trailers: + testTrailerMetadata = metadata.MD{ + "tkey1": []string{"trailerValue1"}, + "tkey2": []string{"trailerValue2"}, + } testAppUA = "myApp1/1.0 myApp2/0.9" ) +var raceMode bool // set by race_test.go in race mode + type testServer struct { security string // indicate the authentication protocol used by this server. } @@ -74,7 +85,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E // For testing purpose, returns an error if there is attached metadata other than // the user agent set by the client application. if _, ok := md["user-agent"]; !ok { - return nil, grpc.Errorf(codes.DataLoss, "got extra metadata") + return nil, grpc.Errorf(codes.DataLoss, "missing expected user-agent") } var str []string for _, entry := range md["user-agent"] { @@ -109,7 +120,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if err := grpc.SendHeader(ctx, md); err != nil { return nil, fmt.Errorf("grpc.SendHeader(%v, %v) = %v, want %v", ctx, md, err, nil) } - grpc.SetTrailer(ctx, md) + grpc.SetTrailer(ctx, testTrailerMetadata) } pr, ok := peer.FromContext(ctx) if !ok { @@ -267,6 +278,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ const tlsDir = "testdata/" func TestReconnectTimeout(t *testing.T) { + defer leakCheck(t)() lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("Failed to listen: %v", err) @@ -317,19 +329,41 @@ func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { } type env struct { - network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) - security string // The security protocol such as TLS, SSH, etc. + name string + network string // The type of network such as tcp, unix, etc. + dialer func(addr string, timeout time.Duration) (net.Conn, error) + security string // The security protocol such as TLS, SSH, etc. + httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS } +var ( + tcpClearEnv = env{name: "tcp-clear", network: "tcp"} + tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} + unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} + unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} + handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} + allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} +) + +var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.") + func listTestEnv() []env { + if *onlyEnv != "" { + for _, e := range allEnv { + if e.name == *onlyEnv { + return []env{e} + } + } + panic(fmt.Sprintf("invalid --only_env value %q", *onlyEnv)) + } if runtime.GOOS == "windows" { - return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}} + return []env{tcpClearEnv, tcpTLSEnv, handlerEnv} } - return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} + return allEnv } func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) { + t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)} la := ":0" switch e.network { @@ -349,6 +383,9 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u sopts = append(sopts, grpc.Creds(creds)) } s = grpc.NewServer(sopts...) + if e.httpHandler { + s.TestingUseHandlerImpl() + } if hs != nil { healthpb.RegisterHealthServer(s, hs) } @@ -392,6 +429,7 @@ func tearDown(s *grpc.Server, cc *grpc.ClientConn) { } func TestTimeoutOnDeadServer(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testTimeoutOnDeadServer(t, e) } @@ -434,8 +472,8 @@ func testTimeoutOnDeadServer(t *testing.T, e env) { cc.Close() } -func healthCheck(t time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { - ctx, _ := context.WithTimeout(context.Background(), t) +func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { + ctx, _ := context.WithTimeout(context.Background(), d) hc := healthpb.NewHealthClient(cc) req := &healthpb.HealthCheckRequest{ Service: serviceName, @@ -444,6 +482,7 @@ func healthCheck(t time.Duration, cc *grpc.ClientConn, serviceName string) (*hea } func TestHealthCheckOnSuccess(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testHealthCheckOnSuccess(t, e) } @@ -461,6 +500,7 @@ func testHealthCheckOnSuccess(t *testing.T, e env) { } func TestHealthCheckOnFailure(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testHealthCheckOnFailure(t, e) } @@ -478,6 +518,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) { } func TestHealthCheckOff(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testHealthCheckOff(t, e) } @@ -487,12 +528,13 @@ func testHealthCheckOff(t *testing.T, e env) { s, addr := serverSetUp(t, true, nil, math.MaxUint32, nil, nil, e) cc := clientSetUp(t, addr, nil, nil, "", e) defer tearDown(s, cc) - if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") { + if _, err := healthCheck(5*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.Health") { t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented) } } func TestHealthCheckServingStatus(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testHealthCheckServingStatus(t, e) } @@ -533,6 +575,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) { } func TestEmptyUnaryWithUserAgent(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testEmptyUnaryWithUserAgent(t, e) } @@ -577,6 +620,7 @@ func testEmptyUnaryWithUserAgent(t *testing.T, e env) { } func TestFailedEmptyUnary(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testFailedEmptyUnary(t, e) } @@ -588,12 +632,14 @@ func testFailedEmptyUnary(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent") + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != wantErr { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } } func TestLargeUnary(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testLargeUnary(t, e) } @@ -629,6 +675,7 @@ func testLargeUnary(t *testing.T, e env) { } func TestMetadataUnaryRPC(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testMetadataUnaryRPC(t, e) } @@ -657,11 +704,17 @@ func testMetadataUnaryRPC(t *testing.T, e env) { if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } - if !reflect.DeepEqual(testMetadata, header) { + + // Ignore optional response headers that Servers may set: + if header != nil { + delete(header, "trailer") // RFC 2616 says server SHOULD (but optional) declare trailers + delete(header, "date") // the Date header is also optional + } + if !reflect.DeepEqual(header, testMetadata) { t.Fatalf("Received header metadata %v, want %v", header, testMetadata) } - if !reflect.DeepEqual(testMetadata, trailer) { - t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata) + if !reflect.DeepEqual(trailer, testTrailerMetadata) { + t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata) } } @@ -695,6 +748,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup } func TestRetry(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testRetry(t, e) } @@ -709,9 +763,24 @@ func testRetry(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) var wg sync.WaitGroup + + numRPC := 1000 + rpcSpacing := 2 * time.Millisecond + if raceMode { + // The race detector has a limit on how many goroutines it can track. + // This test is near the upper limit, and goes over the limit + // depending on the environment (the http.Handler environment uses + // more goroutines) + t.Logf("Shortening test in race mode.") + numRPC /= 2 + rpcSpacing *= 2 + } + wg.Add(1) go func() { - time.Sleep(1 * time.Second) + // Halfway through starting RPCs, kill all connections: + time.Sleep(time.Duration(numRPC/2) * rpcSpacing) + // The server shuts down the network connection to make a // transport error which will be detected by the client side // code. @@ -719,8 +788,8 @@ func testRetry(t *testing.T, e env) { wg.Done() }() // All these RPCs should succeed eventually. - for i := 0; i < 1000; i++ { - time.Sleep(2 * time.Millisecond) + for i := 0; i < numRPC; i++ { + time.Sleep(rpcSpacing) wg.Add(1) go performOneRPC(t, tc, &wg) } @@ -728,6 +797,7 @@ func testRetry(t *testing.T, e env) { } func TestRPCTimeout(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testRPCTimeout(t, e) } @@ -762,6 +832,7 @@ func testRPCTimeout(t *testing.T, e env) { } func TestCancel(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testCancel(t, e) } @@ -794,6 +865,7 @@ func testCancel(t *testing.T, e env) { } func TestCancelNoIO(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testCancelNoIO(t, e) } @@ -847,6 +919,7 @@ var ( ) func TestNoService(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testNoService(t, e) } @@ -858,8 +931,10 @@ func testNoService(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) // Make sure setting ack has been sent. - time.Sleep(2 * time.Second) - stream, err := tc.FullDuplexCall(context.Background()) + time.Sleep(20 * time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } @@ -869,6 +944,7 @@ func testNoService(t *testing.T, e env) { } func TestPingPong(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testPingPong(t, e) } @@ -927,6 +1003,7 @@ func testPingPong(t *testing.T, e env) { } func TestMetadataStreamingRPC(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testMetadataStreamingRPC(t, e) } @@ -994,6 +1071,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) { } func TestServerStreaming(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testServerStreaming(t, e) } @@ -1047,6 +1125,7 @@ func testServerStreaming(t *testing.T, e env) { } func TestFailedServerStreaming(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testFailedServerStreaming(t, e) } @@ -1078,6 +1157,7 @@ func testFailedServerStreaming(t *testing.T, e env) { } func TestClientStreaming(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testClientStreaming(t, e) } @@ -1118,6 +1198,7 @@ func testClientStreaming(t *testing.T, e env) { } func TestExceedMaxStreamsLimit(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testExceedMaxStreamsLimit(t, e) } @@ -1129,13 +1210,16 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { cc := clientSetUp(t, addr, nil, nil, "", e) tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) - _, err := tc.StreamingInputCall(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := tc.StreamingInputCall(ctx) if err != nil { t.Fatalf("%v.StreamingInputCall(_) = _, %v, want _, ", tc, err) } // Loop until receiving the new max stream setting from the server. for { - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() _, err := tc.StreamingInputCall(ctx) if err == nil { time.Sleep(time.Second) @@ -1149,6 +1233,7 @@ func testExceedMaxStreamsLimit(t *testing.T, e env) { } func TestCompressServerHasNoSupport(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testCompressServerHasNoSupport(t, e) } @@ -1202,6 +1287,7 @@ func testCompressServerHasNoSupport(t *testing.T, e env) { } func TestCompressOK(t *testing.T) { + defer leakCheck(t)() for _, e := range listTestEnv() { testCompressOK(t, e) } @@ -1228,10 +1314,12 @@ func testCompressOK(t *testing.T, e env) { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) } // Streaming RPC - stream, err := tc.FullDuplexCall(context.Background()) + ctx, cancel := context.WithCancel(context.Background()) + stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) } + defer cancel() respParam := []*testpb.ResponseParameters{ { Size: proto.Int32(31415), @@ -1253,3 +1341,61 @@ func testCompressOK(t *testing.T, e env) { t.Fatalf("%v.Recv() = %v, want ", stream, err) } } + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if strings.HasPrefix(stack, "testing.RunTests") { + continue + } + + if stack == "" || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "testing.Main(") || + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "interestingGoroutines") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, g) + } + sort.Strings(gs) + return +} + +func leakCheck(t testing.TB) func() { + orig := map[string]bool{} + for _, g := range interestingGoroutines() { + orig[g] = true + } + return func() { + t0 := time.Now() + for { + var leaked []string + for _, g := range interestingGoroutines() { + if !orig[g] { + leaked = append(leaked, g) + } + } + if len(leaked) == 0 { + return + } + if time.Now().Before(t0.Add(5 * time.Second)) { + time.Sleep(50 * time.Millisecond) + continue + } + for _, g := range leaked { + t.Errorf("Leaked goroutine: %v", g) + } + return + } + } +} diff --git a/test/race_test.go b/test/race_test.go new file mode 100644 index 000000000000..b3a7056c66ba --- /dev/null +++ b/test/race_test.go @@ -0,0 +1,39 @@ +// +build race + +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc_test + +func init() { + raceMode = true +} diff --git a/testdata/dir.go b/testdata/dir.go new file mode 100644 index 000000000000..5adaca5db182 --- /dev/null +++ b/testdata/dir.go @@ -0,0 +1,70 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package testdata + +import ( + "log" + "os" + "path/filepath" +) + +// Dir returns the path to the grpc testdata directory or fails. +func Dir() string { + v, err := goPackagePath("google.golang.org/grpc/testdata") + if err != nil { + log.Fatalf("Error finding google.golang.org/grpc/testdata directory: %v", err) + } + return v +} + +func goPackagePath(pkg string) (path string, err error) { + gp := os.Getenv("GOPATH") + if gp == "" { + return path, os.ErrNotExist + } + for _, p := range filepath.SplitList(gp) { + dir := filepath.Join(p, "src", filepath.FromSlash(pkg)) + fi, err := os.Stat(dir) + if os.IsNotExist(err) { + continue + } + if err != nil { + return "", err + } + if !fi.IsDir() { + continue + } + return dir, nil + } + return path, os.ErrNotExist +} diff --git a/transport/handler_server.go b/transport/handler_server.go new file mode 100644 index 000000000000..71f3a5821008 --- /dev/null +++ b/transport/handler_server.go @@ -0,0 +1,319 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// This file is the implementation of a gRPC server using HTTP/2 which +// uses the standard Go http2 Server implementation (via the +// http.Handler interface), rather than speaking low-level HTTP/2 +// frames itself. It is the implementation of *grpc.Server.ServeHTTP. + +package transport + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/http2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// NewServerHandlerTransport returns a ServerTransport handling gRPC +// from inside an http.Handler. It requires that the http Server +// supports HTTP/2. +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) { + if r.ProtoMajor != 2 { + return nil, errors.New("gRPC requires HTTP/2") + } + if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + return nil, errors.New("transport: invalid request content-type") + } + if _, ok := w.(http.Flusher); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + } + if _, ok := w.(http.CloseNotifier); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier") + } + + st := &serverHandlerTransport{ + w: w, + r: r, + closedCh: make(chan struct{}), + wroteStatus: make(chan struct{}), + } + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := timeoutDecode(v) + if err != nil { + return nil, StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err) + } + st.timeoutSet = true + st.timeout = to + } + + var metakv []string + for k, vv := range r.Header { + k = strings.ToLower(k) + if isReservedHeader(k) { + continue + } + for _, v := range vv { + if k == "user-agent" { + // user-agent is special. Copying logic of http_util.go. + if i := strings.LastIndex(v, " "); i == -1 { + // There is no application user agent string being set + continue + } else { + v = v[:i] + } + } + metakv = append(metakv, k, v) + + } + } + st.headerMD = metadata.Pairs(metakv...) + + return st, nil +} + +// serverHandlerTransport is an implementation of ServerTransport +// which replies to exactly one gRPC request (exactly one HTTP request), +// using the net/http.Handler interface. This http.Handler is guranteed +// at this point to be speaking over HTTP/2, so it's able to speak valid +// gRPC. +type serverHandlerTransport struct { + w http.ResponseWriter + r *http.Request + timeoutSet bool + timeout time.Duration + didCommonHeaders bool + + headerMD metadata.MD + + closeOnce sync.Once + closedCh chan struct{} // closed on Close + + wroteStatus chan struct{} // closed on WriteStatus + statusCode codes.Code // WriteStatus code; set before wroteStatus +} + +func (ht *serverHandlerTransport) Close() error { + ht.closeOnce.Do(ht.closeCloseChanOnce) + return nil +} + +func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } + +func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.r.RemoteAddr) } + +type strAddr string + +func (a strAddr) Network() string { + if a != "" { + return "tcp" + } + return "" +} + +func (a strAddr) String() string { return string(a) } + +func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { + ht.writeCommonHeaders(s) + h := ht.w.Header() + h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) + if statusDesc != "" { + h.Set("Grpc-Message", statusDesc) + } + if md := s.Trailer(); len(md) > 0 { + for k, vv := range md { + for _, v := range vv { + // http2 ResponseWriter mechanism to + // send undeclared Trailers after the + // headers have possibly been written. + h.Add(http2.TrailerPrefix+k, v) + } + } + } + ht.statusCode = statusCode + close(ht.wroteStatus) + return nil +} + +// writeCommonHeaders sets common headers on the first call to Write or WriteHeader. +func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { + if ht.didCommonHeaders { + return + } + ht.didCommonHeaders = true + + h := ht.w.Header() + h["Date"] = nil // suppress Date to make tests happy; TODO: restore + h.Set("Content-Type", "application/grpc") + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + if s.sendCompress != "" { + h.Set("Grpc-Encoding", s.sendCompress) + } +} + +func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { + ht.writeCommonHeaders(s) + ht.w.Write(data) + if !opts.Delay { + ht.w.(http.Flusher).Flush() + } + return nil +} + +func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { + ht.writeCommonHeaders(s) + h := ht.w.Header() + for k, vv := range md { + for _, v := range vv { + h.Add(k, v) + } + } + ht.w.WriteHeader(200) + ht.w.(http.Flusher).Flush() + return nil +} + +func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) { + // With this transport type there will be exactly 1 stream: this HTTP request. + + var ctx context.Context + var cancel context.CancelFunc + if ht.timeoutSet { + ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } + clientGone := ht.w.(http.CloseNotifier).CloseNotify() + go func() { + select { + case <-ht.closedCh: + case <-clientGone: + } + cancel() + }() + + r := ht.r + + s := &Stream{ + id: 0, // irrelevant + windowHandler: func(int) {}, // nothing + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: r.URL.Path, + recvCompress: r.Header.Get("grpc-encoding"), + } + pr := &peer.Peer{ + Addr: ht.RemoteAddr(), + } + if ht.r.TLS != nil { + pr.AuthInfo = credentials.TLSInfo{*ht.r.TLS} + } + ctx = metadata.NewContext(ctx, ht.headerMD) + ctx = peer.NewContext(ctx, pr) + s.ctx = newContextWithStream(ctx, s) + s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf} + + rpcOver := make(chan struct{}) + + readerDone := make(chan struct{}) + go func() { + defer close(readerDone) + for { + buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership + n, err := r.Body.Read(buf) + select { + case <-rpcOver: + return + default: + } + if n > 0 { + s.buf.put(&recvMsg{data: buf[:n]}) + } + if err != nil { + s.buf.put(&recvMsg{err: err}) + break + } + } + }() + + runStream(s) + + var bodyTimer *time.Timer + + // Wait for either the RPC to be over, or for + // the RPC to have gotten its status written. + select { + case <-ctx.Done(): + case <-ht.wroteStatus: + switch ht.statusCode { + case codes.Unimplemented, codes.InvalidArgument: + // For these two error codes, sometimes the client + // closes the body, and sometimes it doesn't. We can't + // just always close it here, because then if the + // client does still happen to be writing, they get a + // RST_STREAM the client doesn't understand. + // + // TODO(bradfitz): fix the tests and/or client code to + // be consistent. I expect this to happen with the + // switch to the x/net/http2 Transport, actually. + // + // But this hack works in the meantime, forcing the + // body closed if it doesn't close itself soon. In + // practice this only needs to be about 10 + // milliseconds, but we'll be conservative. + // Plus this only impacts callers already doing buggy + // stuff, so the time doesn't really matter. + bodyTimer = time.AfterFunc(250*time.Millisecond, func() { + r.Body.Close() + }) + } + } + close(rpcOver) + + <-readerDone + if bodyTimer != nil { + bodyTimer.Stop() + } +} diff --git a/transport/http_util.go b/transport/http_util.go index f9d9fdf0afdc..62710862ea0d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -120,7 +120,7 @@ type headerFrame interface { // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. func isReservedHeader(hdr string) bool { - if hdr[0] == ':' { + if hdr != "" && hdr[0] == ':' { return true } switch hdr { @@ -130,6 +130,7 @@ func isReservedHeader(hdr string) bool { "grpc-message", "grpc-status", "grpc-timeout", + "trailer", "te": return true default: