diff --git a/examples/helloworld/greeter_client/main.go b/examples/helloworld/greeter_client/main.go index 1e02ab154da3..af5a510652b4 100644 --- a/examples/helloworld/greeter_client/main.go +++ b/examples/helloworld/greeter_client/main.go @@ -34,12 +34,14 @@ package main import ( + "crypto/tls" + "flag" "log" - "os" - pb "google.golang.org/grpc/examples/helloworld/helloworld" "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + pb "google.golang.org/grpc/examples/helloworld/helloworld" ) const ( @@ -47,9 +49,19 @@ const ( defaultName = "world" ) +var withInsecureTLS = flag.Bool("insecure_tls", false, "Use Insecure TLS; suitable for hitting the greeter_server in -use_http mode") + func main() { + flag.Parse() + // Set up a connection to the server. - conn, err := grpc.Dial(address, grpc.WithInsecure()) + var opts []grpc.DialOption + if *withInsecureTLS { + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}))) + } else { + opts = append(opts, grpc.WithInsecure()) + } + conn, err := grpc.Dial(address, opts...) if err != nil { log.Fatalf("did not connect: %v", err) } @@ -58,8 +70,8 @@ func main() { // Contact the server and print out its response. name := defaultName - if len(os.Args) > 1 { - name = os.Args[1] + if flag.NArg() > 0 { + name = flag.Arg(0) } r, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name}) if err != nil { diff --git a/examples/helloworld/greeter_server/main.go b/examples/helloworld/greeter_server/main.go index 66010a512060..53733fbea802 100644 --- a/examples/helloworld/greeter_server/main.go +++ b/examples/helloworld/greeter_server/main.go @@ -34,16 +34,21 @@ package main import ( + "flag" "log" "net" + "net/http" + "path/filepath" - pb "google.golang.org/grpc/examples/helloworld/helloworld" "golang.org/x/net/context" "google.golang.org/grpc" + pb "google.golang.org/grpc/examples/helloworld/helloworld" + "google.golang.org/grpc/testdata" ) -const ( - port = ":50051" +var ( + listen = flag.String("listen", "localhost:50051", "address to listen on") + httpMode = flag.Bool("use_http", false, "Use net/http integration mode; enables TLS and requires -insecure_tls on the client side") ) // server is used to implement helloworld.GreeterServer. @@ -55,11 +60,25 @@ func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloRe } func main() { - lis, err := net.Listen("tcp", port) - if err != nil { - log.Fatalf("failed to listen: %v", err) - } + flag.Parse() + s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) - s.Serve(lis) + + log.Printf("Running hello server on %s ...", *listen) + if *httpMode { + // Running a gRPC server when you need to integrate with an existing + // net/http server on the same port: (using HTTP routing as needed) + http.Handle("/", s) + log.Fatal(http.ListenAndServeTLS(*listen, file("server1.pem"), file("server1.key"), http.DefaultServeMux)) + } else { + // Running a gRPC server on its own port, without net/http integration: + lis, err := net.Listen("tcp", *listen) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + log.Fatal(s.Serve(lis)) + } } + +func file(base string) string { return filepath.Join(testdata.Dir(), base) } diff --git a/examples/route_guide/server/server.go b/examples/route_guide/server/server.go index 09b3942d191d..3c524f7e337a 100644 --- a/examples/route_guide/server/server.go +++ b/examples/route_guide/server/server.go @@ -43,8 +43,10 @@ import ( "fmt" "io" "io/ioutil" + "log" "math" "net" + "net/http" "time" "golang.org/x/net/context" @@ -59,7 +61,8 @@ import ( ) var ( - tls = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") + useTLS = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") + useHTTP = flag.Bool("http", false, "Use the ServeHTTP Transport; requires tls") certFile = flag.String("cert_file", "testdata/server1.pem", "The TLS cert file") keyFile = flag.String("key_file", "testdata/server1.key", "The TLS key file") jsonDBFile = flag.String("json_db_file", "testdata/route_guide_db.json", "A json file containing a list of features") @@ -221,12 +224,11 @@ func newServer() *routeGuideServer { func main() { flag.Parse() - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) - if err != nil { - grpclog.Fatalf("failed to listen: %v", err) + if *useHTTP && !*useTLS { + log.Fatalf("-http flag requires -tls") } var opts []grpc.ServerOption - if *tls { + if *useTLS { creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) if err != nil { grpclog.Fatalf("Failed to generate credentials %v", err) @@ -235,5 +237,15 @@ func main() { } grpcServer := grpc.NewServer(opts...) pb.RegisterRouteGuideServer(grpcServer, newServer()) - grpcServer.Serve(lis) + log.Printf("Listening on port %d ...", *port) + if *useHTTP { + http.Handle("/", grpcServer) + log.Fatal(http.ListenAndServeTLS(fmt.Sprintf(":%d", *port), *certFile, *keyFile, nil)) + } else { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + if err != nil { + grpclog.Fatalf("failed to listen: %v", err) + } + grpcServer.Serve(lis) + } } diff --git a/rpc_util.go b/rpc_util.go index e98ddbcdc5a7..fadf3394d6ca 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er case compressionNone: case compressionMade: if recvCompress == "" { - return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) + return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress) } if dc == nil || recvCompress != dc.Type() { return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) diff --git a/server.go b/server.go index 1c42b6eff278..577c031cc7a0 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,7 @@ import ( "fmt" "io" "net" + "net/http" "reflect" "runtime" "strings" @@ -46,6 +47,7 @@ import ( "time" "golang.org/x/net/context" + "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -82,10 +84,12 @@ type service struct { // Server is a gRPC server to serve RPC requests. type Server struct { - opts options - mu sync.Mutex + opts options + + mu sync.Mutex // guards following lis map[net.Listener]bool conns map[transport.ServerTransport]bool + hconns map[net.Conn]bool // handler conns; used only for http.Handler-based transprot m map[string]*service // service name -> service info events trace.EventLog } @@ -96,6 +100,7 @@ type options struct { cp Compressor dc Decompressor maxConcurrentStreams uint32 + useHandlerImpl bool // use http.Handler-based server } // A ServerOption sets options. @@ -216,9 +221,17 @@ var ( ErrServerStopped = errors.New("grpc: the server has been stopped") ) +func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + creds, ok := s.opts.creds.(credentials.TransportAuthenticator) + if !ok { + return rawConn, nil, nil + } + return creds.ServerHandshake(rawConn) +} + // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines -// read gRPC request and then call the registered handlers to reply to them. +// read gRPC requests and then call the registered handlers to reply to them. // Service returns when lis.Accept fails. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() @@ -235,39 +248,47 @@ func (s *Server) Serve(lis net.Listener) error { delete(s.lis, lis) s.mu.Unlock() }() + if s.opts.useHandlerImpl { + return s.serveUsingHandler(lis) + } for { - c, err := lis.Accept() + rawConn, err := lis.Accept() if err != nil { s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() return err } - var authInfo credentials.AuthInfo - if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { - var conn net.Conn - conn, authInfo, err = creds.ServerHandshake(c) - if err != nil { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err) - s.mu.Unlock() - grpclog.Println("grpc: Server.Serve failed to complete security handshake.") - continue - } - c = conn + conn, authInfo, err := s.useTransportAuthenticator(rawConn) + if err != nil { + s.handleFailedConnAuthentication(rawConn, err) + continue } s.mu.Lock() if s.conns == nil { s.mu.Unlock() - c.Close() + conn.Close() return nil } s.mu.Unlock() - go s.serveNewHTTP2Transport(c, authInfo) + go s.serveNewHTTP2Transport(conn, authInfo) } } +func (s *Server) handleFailedConnAuthentication(rawConn net.Conn, err error) { + s.mu.Lock() + s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + s.mu.Unlock() + grpclog.Println("grpc: Server.Serve failed to complete security handshake.") + rawConn.Close() +} + +// serveNewHTTP2Transport sets up a new http/2 transport (using the +// gRPC http2 server transport in transport/http2_server.go) and +// serves streams on it. +// This is run in its own goroutine (it does network I/O in +// transport.NewServerTransport). func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) if err != nil { @@ -299,6 +320,135 @@ func (s *Server) serveStreams(st transport.ServerTransport) { wg.Wait() } +var _ http.Handler = (*Server)(nil) + +// serveUsingHandler is the implementation of Serve(net.Listener) when +// TestingUseHandlerImpl has been configured. This lets the end2end +// tests exercise the ServeHTTP method as one of the environment types. +func (s *Server) serveUsingHandler(lis net.Listener) error { + hs := &http.Server{ + Handler: s, + ConnState: s.onConnStateChange, + } + if err := http2.ConfigureServer(hs, &http2.Server{ + MaxConcurrentStreams: s.opts.maxConcurrentStreams, + }); err != nil { + return err + } + hlis := &handlerListener{ + s: s, + Listener: lis, + acceptc: make(chan interface{}, 1), + closedc: make(chan struct{}), + } + go hlis.acceptLoop() + return hs.Serve(hlis) +} + +// onConnStateChange is the net/http.Server.ConnState state change +// hook used by the http.Handler-based transport, to track which +// inbound TCP or TLS connections are live. Note that these are not +// ServerTransports. Each received HTTP request (each ServeHTTP call) +// is a ServerTransport for exactly 1 stream. This on the other hand +// tracks the underlying connections. +func (s *Server) onConnStateChange(c net.Conn, state http.ConnState) { + if state != http.StateNew && state != http.StateClosed { + // Ignore transitions between idle and active. + return + } + s.mu.Lock() + defer s.mu.Unlock() + if s.hconns == nil { + s.hconns = make(map[net.Conn]bool) + } + if state == http.StateNew { + s.hconns[c] = true + } else { + delete(s.hconns, c) + } +} + +type handlerListener struct { + s *Server + net.Listener // embedded for Addr + acceptc chan interface{} // of conn or error + + closedc chan struct{} // closed on close + closeOnce sync.Once + closeErr error +} + +func (hl *handlerListener) Close() error { + hl.closeOnce.Do(hl.closeOnceFunc) + return hl.closeErr +} + +func (hl *handlerListener) closeOnceFunc() { + hl.closeErr = hl.Listener.Close() + close(hl.closedc) +} + +func (hl *handlerListener) Accept() (net.Conn, error) { + select { + case v := <-hl.acceptc: + if c, ok := v.(net.Conn); ok { + return c, nil + } + return nil, v.(error) + case <-hl.closedc: + return nil, errors.New("listener closed") + } +} + +// acceptLoop runs in its own goroutine and accepts conns and sets up +// TLS, feeding successful connections to Accept. +func (hl *handlerListener) acceptLoop() { + for { + rawConn, err := hl.Listener.Accept() + if err != nil { + select { + case hl.acceptc <- err: + case <-hl.closedc: + return + } + continue + } + go hl.authenticateConn(rawConn) + } +} + +// authenticateConn runs in a goroutine separate from the handlerListener's accept loop +// and sets up TLS (or whatever the TransportAuthenticator does) and sends successfully upgraded +// Conns along the channel for Accept to return. +func (hl *handlerListener) authenticateConn(rawConn net.Conn) { + // Discarding authInfo because it's just the *tls.Conn's + // ConnectionState, which we can recover later in ServeHTTP. + conn, _, err := hl.s.useTransportAuthenticator(rawConn) + if err != nil { + hl.s.handleFailedConnAuthentication(rawConn, err) + return + } + select { + case hl.acceptc <- conn: + case <-hl.closedc: + conn.Close() + } +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + st, err := transport.NewServerHandlerTransport(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !s.addConn(st) { + st.Close() + return + } + defer s.removeConn(st) + s.serveStreams(st) +} + // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. // If tracing is not enabled, it returns nil. func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { @@ -602,13 +752,18 @@ func (s *Server) Stop() { s.lis = nil cs := s.conns s.conns = nil + for c := range s.hconns { + go c.Close() + } s.mu.Unlock() + for lis := range listeners { lis.Close() } for c := range cs { c.Close() } + s.mu.Lock() if s.events != nil { s.events.Finish() @@ -618,7 +773,8 @@ func (s *Server) Stop() { } // TestingCloseConns closes all exiting transports but keeps s.lis accepting new -// connections. This is for test only now. +// connections. +// This is only for tests and is subject to removal. func (s *Server) TestingCloseConns() { s.mu.Lock() for c := range s.conns { @@ -628,6 +784,13 @@ func (s *Server) TestingCloseConns() { s.mu.Unlock() } +// TestingUseHandlerImpl enables the http.Handler-based server implementation. +// It must be called before Serve and requires TLS credentials. +// This is only for tests and is subject to removal. +func (s *Server) TestingUseHandlerImpl() { + s.opts.useHandlerImpl = true +} + // SendHeader sends header metadata. It may be called at most once from a unary // RPC handler. The ctx is the RPC handler's Context or one derived from it. func SendHeader(ctx context.Context, md metadata.MD) error { diff --git a/test/end2end_test.go b/test/end2end_test.go index d0e4ea2c73ea..89de613d8304 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -38,8 +38,10 @@ import ( "io" "math" "net" + "os" "reflect" "runtime" + "strconv" "sync" "syscall" "testing" @@ -58,10 +60,16 @@ import ( ) var ( + // For headers: testMetadata = metadata.MD{ "key1": []string{"value1"}, "key2": []string{"value2"}, } + // For trailers: + testTrailerMetadata = metadata.MD{ + "tkey1": []string{"trailerValue1"}, + "tkey2": []string{"trailerValue2"}, + } testAppUA = "myApp1/1.0 myApp2/0.9" ) @@ -74,7 +82,7 @@ func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.E // For testing purpose, returns an error if there is attached metadata other than // the user agent set by the client application. if _, ok := md["user-agent"]; !ok { - return nil, grpc.Errorf(codes.DataLoss, "got extra metadata") + return nil, grpc.Errorf(codes.DataLoss, "missing expected user-agent") } var str []string for _, entry := range md["user-agent"] { @@ -109,7 +117,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if err := grpc.SendHeader(ctx, md); err != nil { return nil, fmt.Errorf("grpc.SendHeader(%v, %v) = %v, want %v", ctx, md, err, nil) } - grpc.SetTrailer(ctx, md) + grpc.SetTrailer(ctx, testTrailerMetadata) } pr, ok := peer.FromContext(ctx) if !ok { @@ -315,19 +323,36 @@ func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { } type env struct { - network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) - security string // The security protocol such as TLS, SSH, etc. + name string + network string // The type of network such as tcp, unix, etc. + dialer func(addr string, timeout time.Duration) (net.Conn, error) + security string // The security protocol such as TLS, SSH, etc. + httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS } +var ( + tcpEnv = env{name: "TCP", network: "tcp"} + tlsEnv = env{name: "TLS", network: "tcp", security: "tls"} + unixSocketEnv = env{name: "Unix", network: "unix", dialer: unixDialer} + handlerEnv = env{name: "Handler", network: "tcp", security: "tls", httpHandler: true} +) + +// Environment hack to only test the http.Handler-based code paths. +var onlyHandler, _ = strconv.ParseBool(os.Getenv("GRPC_TEST_HANDLER_ONLY")) + func listTestEnv() []env { + if onlyHandler { + return []env{handlerEnv} + } if runtime.GOOS == "windows" { - return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}} + return []env{tcpEnv, tlsEnv} } - return []env{{"tcp", nil, ""}, {"tcp", nil, "tls"}, {"unix", unixDialer, ""}, {"unix", unixDialer, "tls"}} + return []env{tcpEnv} + return []env{tcpEnv, tlsEnv, handlerEnv, unixSocketEnv} } func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream uint32, cp grpc.Compressor, dc grpc.Decompressor, e env) (s *grpc.Server, addr string) { + t.Logf("Running test in %s environment...", e.name) sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream), grpc.RPCCompressor(cp), grpc.RPCDecompressor(dc)} la := ":0" switch e.network { @@ -347,6 +372,9 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u sopts = append(sopts, grpc.Creds(creds)) } s = grpc.NewServer(sopts...) + if e.httpHandler { + s.TestingUseHandlerImpl() + } if hs != nil { healthpb.RegisterHealthServer(s, hs) } @@ -586,8 +614,9 @@ func testFailedEmptyUnary(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) ctx := metadata.NewContext(context.Background(), testMetadata) - if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + wantErr := grpc.Errorf(codes.DataLoss, "missing expected user-agent") + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != wantErr { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) } } @@ -655,11 +684,17 @@ func testMetadataUnaryRPC(t *testing.T, e env) { if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } - if !reflect.DeepEqual(testMetadata, header) { + + // Ignore optional response headers that Servers may set: + if header != nil { + delete(header, "trailer") // RFC 2616 says server SHOULD (but optional) declare trailers + delete(header, "date") // the Date header is also optional + } + if !reflect.DeepEqual(header, testMetadata) { t.Fatalf("Received header metadata %v, want %v", header, testMetadata) } - if !reflect.DeepEqual(testMetadata, trailer) { - t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata) + if !reflect.DeepEqual(trailer, testTrailerMetadata) { + t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata) } } @@ -853,7 +888,7 @@ func testNoService(t *testing.T, e env) { tc := testpb.NewTestServiceClient(cc) defer tearDown(s, cc) // Make sure setting ack has been sent. - time.Sleep(2*time.Second) + time.Sleep(2 * time.Second) stream, err := tc.FullDuplexCall(context.Background()) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) diff --git a/testdata/dir.go b/testdata/dir.go new file mode 100644 index 000000000000..5adaca5db182 --- /dev/null +++ b/testdata/dir.go @@ -0,0 +1,70 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package testdata + +import ( + "log" + "os" + "path/filepath" +) + +// Dir returns the path to the grpc testdata directory or fails. +func Dir() string { + v, err := goPackagePath("google.golang.org/grpc/testdata") + if err != nil { + log.Fatalf("Error finding google.golang.org/grpc/testdata directory: %v", err) + } + return v +} + +func goPackagePath(pkg string) (path string, err error) { + gp := os.Getenv("GOPATH") + if gp == "" { + return path, os.ErrNotExist + } + for _, p := range filepath.SplitList(gp) { + dir := filepath.Join(p, "src", filepath.FromSlash(pkg)) + fi, err := os.Stat(dir) + if os.IsNotExist(err) { + continue + } + if err != nil { + return "", err + } + if !fi.IsDir() { + continue + } + return dir, nil + } + return path, os.ErrNotExist +} diff --git a/transport/handler_server.go b/transport/handler_server.go new file mode 100644 index 000000000000..8fbb3fa6b738 --- /dev/null +++ b/transport/handler_server.go @@ -0,0 +1,280 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// This file is the implementation of a gRPC server using HTTP/2 which +// uses the standard Go http2 Server implementation (via the +// http.Handler interface), rather than speaking low-level HTTP/2 +// frames itself. It is the implementation of *grpc.Server.ServeHTTP. + +package transport + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// NewServerHandlerTransport returns a ServerTransport handling gRPC +// from inside an http.Handler. It requires that the http Server +// supports HTTP/2. +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) { + if r.ProtoMajor != 2 { + return nil, errors.New("gRPC requires HTTP/2") + } + if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + return nil, errors.New("transport: invalid request content-type") + } + if _, ok := w.(http.Flusher); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + } + if _, ok := w.(http.CloseNotifier); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier") + } + + st := &serverHandlerTransport{ + w: w, + r: r, + closedCh: make(chan struct{}), + wroteStatus: make(chan struct{}, 1), + } + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := timeoutDecode(v) + if err != nil { + return nil, StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err) + } + st.timeoutSet = true + st.timeout = to + } + + var metakv []string + for k, vv := range r.Header { + k = strings.ToLower(k) + if isReservedHeader(k) { + continue + } + for _, v := range vv { + if k == "user-agent" { + // user-agent is special. Copying logic of http_util.go. + if i := strings.LastIndex(v, " "); i == -1 { + // There is no application user agent string being set + continue + } else { + v = v[:i] + } + } + metakv = append(metakv, k, v) + } + } + st.headerMD = metadata.Pairs(metakv...) + + return st, nil +} + +type serverHandlerTransport struct { + w http.ResponseWriter + r *http.Request + timeoutSet bool + timeout time.Duration + didCommonHeaders bool + + headerMD metadata.MD + + closeOnce sync.Once + closedCh chan struct{} // closed on Close + + wroteStatus chan struct{} +} + +func (ht *serverHandlerTransport) Close() error { + ht.closeOnce.Do(ht.closeCloseChanOnce) + return nil +} + +func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } + +func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.r.RemoteAddr) } + +type strAddr string + +func (a strAddr) Network() string { + if a != "" { + return "tcp" + } + return "" +} + +func (a strAddr) String() string { return string(a) } + +func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { + ht.writeCommonHeaders(s) + h := ht.w.Header() + h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) + if statusDesc != "" { + h.Set("Grpc-Message", statusDesc) + } + if md := s.Trailer(); len(md) > 0 { + for k, vv := range md { + for _, v := range vv { + // http2 ResponseWriter mechanism to + // send undeclared Trailers after the + // headers have been written (and thus + // the trailer field name isn't in the + // "trailers" Header): + h.Add("Trailer\x00"+k, v) + } + } + } + ht.wroteStatus <- struct{}{} + return nil +} + +// writeCommonHeaders sets common headers on the first call to Write or WriteHeader. +func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { + if ht.didCommonHeaders { + return + } + ht.didCommonHeaders = true + + h := ht.w.Header() + h["Date"] = nil // suppress Date to make tests happy; TODO: restore + h.Set("Content-Type", "application/grpc") + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + if s.sendCompress != "" { + h.Set("Grpc-Encoding", s.sendCompress) + } +} + +func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { + ht.writeCommonHeaders(s) + ht.w.Write(data) + if !opts.Delay { + ht.w.(http.Flusher).Flush() + } + return nil +} + +func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { + ht.writeCommonHeaders(s) + h := ht.w.Header() + for k, vv := range md { + for _, v := range vv { + h.Add(k, v) + } + } + ht.w.WriteHeader(200) + ht.w.(http.Flusher).Flush() + return nil +} + +func (ht *serverHandlerTransport) HandleStreams(fn func(*Stream)) { + // With this transport type there will be exactly 1 stream: this HTTP request. + + var ctx context.Context + var cancel context.CancelFunc + if ht.timeoutSet { + ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } + clientGone := ht.w.(http.CloseNotifier).CloseNotify() + go func() { + select { + case <-ht.closedCh: + case <-clientGone: + } + cancel() + }() + + r := ht.r + + s := &Stream{ + id: 0, // irrelevant + windowHandler: func(int) {}, // nothing + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: r.URL.Path, + recvCompress: r.Header.Get("grpc-encoding"), + } + pr := &peer.Peer{ + Addr: ht.RemoteAddr(), + } + if ht.r.TLS != nil { + pr.AuthInfo = credentials.TLSInfo{*ht.r.TLS} + } + ctx = metadata.NewContext(ctx, ht.headerMD) + ctx = peer.NewContext(ctx, pr) + s.ctx = newContextWithStream(ctx, s) + s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf} + + rpcOver := make(chan struct{}) + defer close(rpcOver) + defer r.Body.Close() // unblock goroutine's Read, if applicable + + go func() { + for { + buf := make([]byte, 1024) // TODO: minimize garbagef, optimize recvBuffer code/ownership + n, err := r.Body.Read(buf) + select { + case <-rpcOver: + println(fmt.Sprintf("HandleStream over but: ... Read = %v, %v", n, err)) + return + default: + } + if n > 0 { + s.buf.put(&recvMsg{data: buf[:n]}) + } + if err != nil { + s.buf.put(&recvMsg{err: err}) + break + } + } + }() + + fn(s) + select { + case <-ctx.Done(): + case <-ht.wroteStatus: + } +} diff --git a/transport/http_util.go b/transport/http_util.go index f9d9fdf0afdc..62710862ea0d 100644 --- a/transport/http_util.go +++ b/transport/http_util.go @@ -120,7 +120,7 @@ type headerFrame interface { // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. func isReservedHeader(hdr string) bool { - if hdr[0] == ':' { + if hdr != "" && hdr[0] == ':' { return true } switch hdr { @@ -130,6 +130,7 @@ func isReservedHeader(hdr string) bool { "grpc-message", "grpc-status", "grpc-timeout", + "trailer", "te": return true default: