diff --git a/rpc_util.go b/rpc_util.go index e98ddbcdc5a7..fadf3394d6ca 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er case compressionNone: case compressionMade: if recvCompress == "" { - return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf) + return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress) } if dc == nil || recvCompress != dc.Type() { return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress) diff --git a/server.go b/server.go index bcd6196025a5..eb56b3443d75 100644 --- a/server.go +++ b/server.go @@ -39,6 +39,7 @@ import ( "fmt" "io" "net" + "net/http" "reflect" "runtime" "strings" @@ -46,6 +47,7 @@ import ( "time" "golang.org/x/net/context" + "golang.org/x/net/http2" "golang.org/x/net/trace" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -82,10 +84,11 @@ type service struct { // Server is a gRPC server to serve RPC requests. type Server struct { - opts options - mu sync.Mutex + opts options + + mu sync.Mutex // guards following lis map[net.Listener]bool - conns map[transport.ServerTransport]bool + conns map[io.Closer]bool m map[string]*service // service name -> service info events trace.EventLog } @@ -96,6 +99,7 @@ type options struct { cp Compressor dc Decompressor maxConcurrentStreams uint32 + useHandlerImpl bool // use http.Handler-based server } // A ServerOption sets options. @@ -149,7 +153,7 @@ func NewServer(opt ...ServerOption) *Server { s := &Server{ lis: make(map[net.Listener]bool), opts: opts, - conns: make(map[transport.ServerTransport]bool), + conns: make(map[io.Closer]bool), m: make(map[string]*service), } if EnableTracing { @@ -216,9 +220,17 @@ var ( ErrServerStopped = errors.New("grpc: the server has been stopped") ) +func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + creds, ok := s.opts.creds.(credentials.TransportAuthenticator) + if !ok { + return rawConn, nil, nil + } + return creds.ServerHandshake(rawConn) +} + // Serve accepts incoming connections on the listener lis, creating a new // ServerTransport and service goroutine for each. The service goroutines -// read gRPC request and then call the registered handlers to reply to them. +// read gRPC requests and then call the registered handlers to reply to them. // Service returns when lis.Accept fails. func (s *Server) Serve(lis net.Listener) error { s.mu.Lock() @@ -235,39 +247,54 @@ func (s *Server) Serve(lis net.Listener) error { delete(s.lis, lis) s.mu.Unlock() }() + listenerAddr := lis.Addr() for { - c, err := lis.Accept() + rawConn, err := lis.Accept() if err != nil { s.mu.Lock() s.printf("done serving; Accept = %v", err) s.mu.Unlock() return err } - var authInfo credentials.AuthInfo - if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { - var conn net.Conn - conn, authInfo, err = creds.ServerHandshake(c) - if err != nil { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err) - s.mu.Unlock() - grpclog.Println("grpc: Server.Serve failed to complete security handshake.") - continue - } - c = conn - } + // Start a new goroutine to deal with rawConn + // so we don't stall this Accept loop goroutine. + go s.handleRawConn(listenerAddr, rawConn) + } +} + +// handleRawConn is run in its own goroutine and handles a just-accepted +// connection that has not had any I/O performed on it yet. +func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) { + conn, authInfo, err := s.useTransportAuthenticator(rawConn) + if err != nil { s.mu.Lock() - if s.conns == nil { - s.mu.Unlock() - c.Close() - return nil - } + s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) s.mu.Unlock() + grpclog.Println("grpc: Server.Serve failed to complete security handshake.") + rawConn.Close() + return + } - go s.serveNewHTTP2Transport(c, authInfo) + s.mu.Lock() + if s.conns == nil { + s.mu.Unlock() + conn.Close() + return + } + s.mu.Unlock() + + if s.opts.useHandlerImpl { + s.serveUsingHandler(listenerAddr, conn) + } else { + s.serveNewHTTP2Transport(conn, authInfo) } } +// serveNewHTTP2Transport sets up a new http/2 transport (using the +// gRPC http2 server transport in transport/http2_server.go) and +// serves streams on it. +// This is run in its own goroutine (it does network I/O in +// transport.NewServerTransport). func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) if err != nil { @@ -299,6 +326,59 @@ func (s *Server) serveStreams(st transport.ServerTransport) { wg.Wait() } +var _ http.Handler = (*Server)(nil) + +// serveUsingHandler is called from handleRawConn when s is configured +// to handle requests via the http.Handler interface. It sets up a +// net/http.Server to handle the just-accepted conn. The http.Server +// is configured to route all incoming requests (all HTTP/2 streams) +// to ServeHTTP, which creates a new ServerTransport for each stream. +// serveUsingHandler blocks until conn closes. +// +// This codepath is only used when Server.TestingUseHandlerImpl has +// been configured. This lets the end2end tests exercise the ServeHTTP +// method as one of the environment types. +// +// conn is the *tls.Conn that's already been authenticated. +func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) { + if !s.addConn(conn) { + conn.Close() + return + } + defer s.removeConn(conn) + connDone := make(chan struct{}) + hs := &http.Server{ + Handler: s, + ConnState: func(c net.Conn, cs http.ConnState) { + if cs == http.StateClosed { + close(connDone) + } + }, + } + if err := http2.ConfigureServer(hs, &http2.Server{ + MaxConcurrentStreams: s.opts.maxConcurrentStreams, + }); err != nil { + grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err) + return + } + hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn}) + <-connDone +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + st, err := transport.NewServerHandlerTransport(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !s.addConn(st) { + st.Close() + return + } + defer s.removeConn(st) + s.serveStreams(st) +} + // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. // If tracing is not enabled, it returns nil. func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { @@ -317,21 +397,21 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea return trInfo } -func (s *Server) addConn(st transport.ServerTransport) bool { +func (s *Server) addConn(c io.Closer) bool { s.mu.Lock() defer s.mu.Unlock() if s.conns == nil { return false } - s.conns[st] = true + s.conns[c] = true return true } -func (s *Server) removeConn(st transport.ServerTransport) { +func (s *Server) removeConn(c io.Closer) { s.mu.Lock() defer s.mu.Unlock() if s.conns != nil { - delete(s.conns, st) + delete(s.conns, c) } } @@ -606,12 +686,14 @@ func (s *Server) Stop() { cs := s.conns s.conns = nil s.mu.Unlock() + for lis := range listeners { lis.Close() } for c := range cs { c.Close() } + s.mu.Lock() if s.events != nil { s.events.Finish() @@ -621,16 +703,24 @@ func (s *Server) Stop() { } // TestingCloseConns closes all exiting transports but keeps s.lis accepting new -// connections. This is for test only now. +// connections. +// This is only for tests and is subject to removal. func (s *Server) TestingCloseConns() { s.mu.Lock() for c := range s.conns { c.Close() + delete(s.conns, c) } - s.conns = make(map[transport.ServerTransport]bool) s.mu.Unlock() } +// TestingUseHandlerImpl enables the http.Handler-based server implementation. +// It must be called before Serve and requires TLS credentials. +// This is only for tests and is subject to removal. +func (s *Server) TestingUseHandlerImpl() { + s.opts.useHandlerImpl = true +} + // SendHeader sends header metadata. It may be called at most once from a unary // RPC handler. The ctx is the RPC handler's Context or one derived from it. func SendHeader(ctx context.Context, md metadata.MD) error { @@ -661,3 +751,30 @@ func SetTrailer(ctx context.Context, md metadata.MD) error { } return stream.SetTrailer(md) } + +// singleConnListener is a net.Listener that yields a single conn. +type singleConnListener struct { + mu sync.Mutex + addr net.Addr + conn net.Conn // nil if done +} + +func (ln *singleConnListener) Addr() net.Addr { return ln.addr } + +func (ln *singleConnListener) Close() error { + ln.mu.Lock() + defer ln.mu.Unlock() + ln.conn = nil + return nil +} + +func (ln *singleConnListener) Accept() (net.Conn, error) { + ln.mu.Lock() + defer ln.mu.Unlock() + c := ln.conn + if c == nil { + return nil, io.EOF + } + ln.conn = nil + return c, nil +} diff --git a/test/end2end_test.go b/test/end2end_test.go index 946df32034ca..0bb72891e698 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -329,10 +329,11 @@ func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { } type env struct { - name string - network string // The type of network such as tcp, unix, etc. - dialer func(addr string, timeout time.Duration) (net.Conn, error) - security string // The security protocol such as TLS, SSH, etc. + name string + network string // The type of network such as tcp, unix, etc. + dialer func(addr string, timeout time.Duration) (net.Conn, error) + security string // The security protocol such as TLS, SSH, etc. + httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS } func (e env) runnable() bool { @@ -347,10 +348,11 @@ var ( tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"} unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer} unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"} - allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv} + handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true} + allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv} ) -var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', or 'unix-tls' to only run the tests for that environment. Empty means all.") +var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.") func listTestEnv() (envs []env) { if *onlyEnv != "" { @@ -393,6 +395,9 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u sopts = append(sopts, grpc.Creds(creds)) } s = grpc.NewServer(sopts...) + if e.httpHandler { + s.TestingUseHandlerImpl() + } if hs != nil { healthpb.RegisterHealthServer(s, hs) } @@ -720,7 +725,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) { t.Fatalf("Received header metadata %v, want %v", header, testMetadata) } if !reflect.DeepEqual(trailer, testTrailerMetadata) { - t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata) + t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata) } } @@ -1030,11 +1035,13 @@ func testMetadataStreamingRPC(t *testing.T, e env) { if e.security == "tls" { delete(headerMD, "transport_security_type") } + delete(headerMD, "trailer") // ignore if present if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#1 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } // test the cached value. headerMD, err = stream.Header() + delete(headerMD, "trailer") // ignore if present if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { t.Errorf("#2 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) } diff --git a/transport/handler_server.go b/transport/handler_server.go new file mode 100644 index 000000000000..5d1bffee2861 --- /dev/null +++ b/transport/handler_server.go @@ -0,0 +1,329 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// This file is the implementation of a gRPC server using HTTP/2 which +// uses the standard Go http2 Server implementation (via the +// http.Handler interface), rather than speaking low-level HTTP/2 +// frames itself. It is the implementation of *grpc.Server.ServeHTTP. + +package transport + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/http2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" +) + +// NewServerHandlerTransport returns a ServerTransport handling gRPC +// from inside an http.Handler. It requires that the http Server +// supports HTTP/2. +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTransport, error) { + if r.ProtoMajor != 2 { + return nil, errors.New("gRPC requires HTTP/2") + } + if r.Method != "POST" { + return nil, errors.New("invalid gRPC request method") + } + if !strings.Contains(r.Header.Get("Content-Type"), "application/grpc") { + return nil, errors.New("invalid gRPC request content-type") + } + if _, ok := w.(http.Flusher); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher") + } + if _, ok := w.(http.CloseNotifier); !ok { + return nil, errors.New("gRPC requires a ResponseWriter supporting http.CloseNotifier") + } + + st := &serverHandlerTransport{ + rw: w, + req: r, + closedCh: make(chan struct{}), + wroteStatus: make(chan struct{}), + } + + if v := r.Header.Get("grpc-timeout"); v != "" { + to, err := timeoutDecode(v) + if err != nil { + return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err) + } + st.timeoutSet = true + st.timeout = to + } + + var metakv []string + for k, vv := range r.Header { + k = strings.ToLower(k) + if isReservedHeader(k) { + continue + } + for _, v := range vv { + if k == "user-agent" { + // user-agent is special. Copying logic of http_util.go. + if i := strings.LastIndex(v, " "); i == -1 { + // There is no application user agent string being set + continue + } else { + v = v[:i] + } + } + metakv = append(metakv, k, v) + + } + } + st.headerMD = metadata.Pairs(metakv...) + + return st, nil +} + +// serverHandlerTransport is an implementation of ServerTransport +// which replies to exactly one gRPC request (exactly one HTTP request), +// using the net/http.Handler interface. This http.Handler is guranteed +// at this point to be speaking over HTTP/2, so it's able to speak valid +// gRPC. +type serverHandlerTransport struct { + rw http.ResponseWriter + req *http.Request + timeoutSet bool + timeout time.Duration + didCommonHeaders bool + + headerMD metadata.MD + + closeOnce sync.Once + closedCh chan struct{} // closed on Close + + wroteStatus chan struct{} // closed on WriteStatus +} + +func (ht *serverHandlerTransport) Close() error { + ht.closeOnce.Do(ht.closeCloseChanOnce) + return nil +} + +func (ht *serverHandlerTransport) closeCloseChanOnce() { close(ht.closedCh) } + +func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } + +// strAddr is a net.Addr backed by either a TCP "ip:port" string, or +// the empty string if unknown. +type strAddr string + +func (a strAddr) Network() string { + if a != "" { + // Per the documentation on net/http.Request.RemoteAddr, if this is + // set, it's set to the IP:port of the peer (hence, TCP): + // https://golang.org/pkg/net/http/#Request + // + // If we want to support Unix sockets later, we can + // add our own grpc-specific convention within the + // grpc codebase to set RemoteAddr to a different + // format, or probably better: we can attach it to the + // context and use that from serverHandlerTransport.RemoteAddr. + return "tcp" + } + return "" +} + +func (a strAddr) String() string { return string(a) } + +func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { + ht.writeCommonHeaders(s) + + // And flush, in case no header or body has been sent yet. + // This forces a separation of headers and trailers if this is the + // first call (for example, in end2end tests's TestNoService). + ht.rw.(http.Flusher).Flush() + + h := ht.rw.Header() + h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) + if statusDesc != "" { + h.Set("Grpc-Message", statusDesc) + } + if md := s.Trailer(); len(md) > 0 { + for k, vv := range md { + for _, v := range vv { + // http2 ResponseWriter mechanism to + // send undeclared Trailers after the + // headers have possibly been written. + h.Add(http2.TrailerPrefix+k, v) + } + } + } + close(ht.wroteStatus) + return nil +} + +// writeCommonHeaders sets common headers on the first write +// call (Write, WriteHeader, or WriteStatus). +func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { + if ht.didCommonHeaders { + return + } + ht.didCommonHeaders = true + + h := ht.rw.Header() + h["Date"] = nil // suppress Date to make tests happy; TODO: restore + h.Set("Content-Type", "application/grpc") + + // Predeclare trailers we'll set later in WriteStatus (after the body). + // This is a SHOULD in the HTTP RFC, and the way you add (known) + // Trailers per the net/http.ResponseWriter contract. + // See https://golang.org/pkg/net/http/#ResponseWriter + // and https://golang.org/pkg/net/http/#example_ResponseWriter_trailers + h.Add("Trailer", "Grpc-Status") + h.Add("Trailer", "Grpc-Message") + + if s.sendCompress != "" { + h.Set("Grpc-Encoding", s.sendCompress) + } +} + +func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error { + ht.writeCommonHeaders(s) + ht.rw.Write(data) + if !opts.Delay { + ht.rw.(http.Flusher).Flush() + } + return nil +} + +func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { + ht.writeCommonHeaders(s) + h := ht.rw.Header() + for k, vv := range md { + for _, v := range vv { + h.Add(k, v) + } + } + ht.rw.WriteHeader(200) + ht.rw.(http.Flusher).Flush() + return nil +} + +func (ht *serverHandlerTransport) HandleStreams(runStream func(*Stream)) { + // With this transport type there will be exactly 1 stream: this HTTP request. + + var ctx context.Context + var cancel context.CancelFunc + if ht.timeoutSet { + ctx, cancel = context.WithTimeout(context.Background(), ht.timeout) + } else { + ctx, cancel = context.WithCancel(context.Background()) + } + + // clientGone receives a single value if peer is gone, either + // because the underlying connection is dead or because the + // peer sends an http2 RST_STREAM. + clientGone := ht.rw.(http.CloseNotifier).CloseNotify() + go func() { + select { + case <-ht.closedCh: + case <-clientGone: + } + cancel() + }() + + req := ht.req + + s := &Stream{ + id: 0, // irrelevant + windowHandler: func(int) {}, // nothing + cancel: cancel, + buf: newRecvBuffer(), + st: ht, + method: req.URL.Path, + recvCompress: req.Header.Get("grpc-encoding"), + } + pr := &peer.Peer{ + Addr: ht.RemoteAddr(), + } + if req.TLS != nil { + pr.AuthInfo = credentials.TLSInfo{*req.TLS} + } + ctx = metadata.NewContext(ctx, ht.headerMD) + ctx = peer.NewContext(ctx, pr) + s.ctx = newContextWithStream(ctx, s) + s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf} + + // requestOver is closed when either the request's context is done + // or the status has been written via WriteStatus. + requestOver := make(chan struct{}) + + // readerDone is closed when the Body.Read-ing goroutine exits. + readerDone := make(chan struct{}) + go func() { + defer close(readerDone) + for { + buf := make([]byte, 1024) // TODO: minimize garbage, optimize recvBuffer code/ownership + n, err := req.Body.Read(buf) + select { + case <-requestOver: + return + default: + } + if n > 0 { + s.buf.put(&recvMsg{data: buf[:n]}) + } + if err != nil { + s.buf.put(&recvMsg{err: err}) + break + } + } + }() + + // runStream is provided by the *grpc.Server.serveStreams. + // It starts a goroutine handling s and exits immediately. + runStream(s) + + // Wait for the stream to be done. It is considered done when + // either its context is done, or we've written its status. + select { + case <-ctx.Done(): + case <-ht.wroteStatus: + } + close(requestOver) + + // Wait for reading goroutine to finish. + req.Body.Close() + <-readerDone +} diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go new file mode 100644 index 000000000000..faa95af964de --- /dev/null +++ b/transport/handler_server_test.go @@ -0,0 +1,386 @@ +/* + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) { + type testCase struct { + name string + req *http.Request + wantErr string + modrw func(http.ResponseWriter) http.ResponseWriter + check func(*serverHandlerTransport, *testCase) error + } + tests := []testCase{ + { + name: "http/1.1", + req: &http.Request{ + ProtoMajor: 1, + ProtoMinor: 1, + }, + wantErr: "gRPC requires HTTP/2", + }, + { + name: "bad method", + req: &http.Request{ + ProtoMajor: 2, + Method: "GET", + Header: http.Header{}, + RequestURI: "/", + }, + wantErr: "invalid gRPC request method", + }, + { + name: "bad content type", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/foo"}, + }, + RequestURI: "/service/foo.bar", + }, + wantErr: "invalid gRPC request content-type", + }, + { + name: "not flusher", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/grpc"}, + }, + RequestURI: "/service/foo.bar", + }, + modrw: func(w http.ResponseWriter) http.ResponseWriter { + // Return w without its Flush method + type onlyCloseNotifier interface { + http.ResponseWriter + http.CloseNotifier + } + return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)} + }, + wantErr: "gRPC requires a ResponseWriter supporting http.Flusher", + }, + { + name: "not closenotifier", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/grpc"}, + }, + RequestURI: "/service/foo.bar", + }, + modrw: func(w http.ResponseWriter) http.ResponseWriter { + // Return w without its CloseNotify method + type onlyFlusher interface { + http.ResponseWriter + http.Flusher + } + return struct{ onlyFlusher }{w.(onlyFlusher)} + }, + wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier", + }, + { + name: "valid", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/grpc"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + }, + check: func(t *serverHandlerTransport, tt *testCase) error { + if t.req != tt.req { + return fmt.Errorf("t.req = %p; want %p", t.req, tt.req) + } + if t.rw == nil { + return errors.New("t.rw = nil; want non-nil") + } + return nil + }, + }, + { + name: "with timeout", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": []string{"application/grpc"}, + "Grpc-Timeout": {"200m"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + }, + check: func(t *serverHandlerTransport, tt *testCase) error { + if !t.timeoutSet { + return errors.New("timeout not set") + } + if want := 200 * time.Millisecond; t.timeout != want { + return fmt.Errorf("timeout = %v; want %v", t.timeout, want) + } + return nil + }, + }, + { + name: "with bad timeout", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": []string{"application/grpc"}, + "Grpc-Timeout": {"tomorrow"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + }, + wantErr: `stream error: code = 13 desc = "malformed time-out: transport: timeout unit is not recognized: \"tomorrow\""`, + }, + { + name: "with metadata", + req: &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": []string{"application/grpc"}, + "meta-foo": {"foo-val"}, + "meta-bar": {"bar-val1", "bar-val2"}, + "user-agent": {"x/y a/b"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + }, + check: func(ht *serverHandlerTransport, tt *testCase) error { + want := metadata.MD{ + "meta-bar": {"bar-val1", "bar-val2"}, + "user-agent": {"x/y"}, + "meta-foo": {"foo-val"}, + } + if !reflect.DeepEqual(ht.headerMD, want) { + return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want) + } + return nil + }, + }, + } + + for _, tt := range tests { + rw := newTestHandlerResponseWriter() + if tt.modrw != nil { + rw = tt.modrw(rw) + } + got, gotErr := NewServerHandlerTransport(rw, tt.req) + if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) { + t.Errorf("%s: error = %v; want %q", tt.name, gotErr, tt.wantErr) + continue + } + if gotErr != nil { + continue + } + if tt.check != nil { + if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil { + t.Errorf("%s: %v", tt.name, err) + } + } + } +} + +type testHandlerResponseWriter struct { + *httptest.ResponseRecorder + closeNotify chan bool +} + +func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify } +func (w testHandlerResponseWriter) Flush() {} + +func newTestHandlerResponseWriter() http.ResponseWriter { + return testHandlerResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + closeNotify: make(chan bool, 1), + } +} + +type handleStreamTest struct { + t *testing.T + bodyw *io.PipeWriter + req *http.Request + rw testHandlerResponseWriter + ht *serverHandlerTransport +} + +func newHandleStreamTest(t *testing.T) *handleStreamTest { + bodyr, bodyw := io.Pipe() + req := &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/grpc"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + Body: bodyr, + } + rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) + ht, err := NewServerHandlerTransport(rw, req) + if err != nil { + t.Fatal(err) + } + return &handleStreamTest{ + t: t, + bodyw: bodyw, + ht: ht.(*serverHandlerTransport), + rw: rw, + } +} + +func TestHandlerTransport_HandleStreams(t *testing.T) { + st := newHandleStreamTest(t) + st.ht.HandleStreams(func(s *Stream) { + if want := "/service/foo.bar"; s.method != want { + t.Errorf("stream method = %q; want %q", s.method, want) + } + st.bodyw.Close() // no body + st.ht.WriteStatus(s, codes.OK, "") + }) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Grpc-Status": {"0"}, + } + if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { + t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader) + } +} + +// Tests that codes.Unimplemented will close the body, per comment in handler_server.go. +func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) { + handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented") +} + +// Tests that codes.InvalidArgument will close the body, per comment in handler_server.go. +func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { + handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg") +} + +func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { + st := newHandleStreamTest(t) + st.ht.HandleStreams(func(s *Stream) { + st.ht.WriteStatus(s, statusCode, msg) + }) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, + "Grpc-Message": {msg}, + } + if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { + t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) + } +} + +func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { + bodyr, bodyw := io.Pipe() + req := &http.Request{ + ProtoMajor: 2, + Method: "POST", + Header: http.Header{ + "Content-Type": {"application/grpc"}, + "Grpc-Timeout": {"200m"}, + }, + URL: &url.URL{ + Path: "/service/foo.bar", + }, + RequestURI: "/service/foo.bar", + Body: bodyr, + } + rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) + ht, err := NewServerHandlerTransport(rw, req) + if err != nil { + t.Fatal(err) + } + ht.HandleStreams(func(s *Stream) { + defer bodyw.Close() + select { + case <-s.ctx.Done(): + case <-time.After(5 * time.Second): + t.Errorf("timeout waiting for ctx.Done") + return + } + err := s.ctx.Err() + if err != context.DeadlineExceeded { + t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded) + return + } + ht.WriteStatus(s, codes.DeadlineExceeded, "too slow") + }) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message"}, + "Grpc-Status": {"4"}, + "Grpc-Message": {"too slow"}, + } + if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { + t.Errorf("Header+Trailer Map: %#v; want %#v", rw.HeaderMap, wantHeader) + } +}