Skip to content

Commit

Permalink
Add a ServeHTTP method to *grpc.Server
Browse files Browse the repository at this point in the history
This adds new http.Handler-based ServerTransport in the process,
reusing the HTTP/2 server code in x/net/http2 or Go 1.6+.

All end2end tests pass with this new ServerTransport.

Fixes grpc#75

Also:
Updates grpc#495 (lets user fix it with middleware in front)
Updates grpc#468 (x/net/http2 validates)
Updates grpc#147 (possible with x/net/http2)
Updates grpc#104 (x/net/http2 does this)
  • Loading branch information
bradfitz committed Feb 11, 2016
1 parent 3c4302b commit 81d512c
Show file tree
Hide file tree
Showing 5 changed files with 900 additions and 39 deletions.
2 changes: 1 addition & 1 deletion rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
case compressionNone:
case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
return transport.StreamErrorf(codes.InvalidArgument, "grpc: invalid grpc-encoding %q with compression enabled", recvCompress)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
Expand Down
179 changes: 148 additions & 31 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ import (
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"time"

"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/trace"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -82,10 +84,11 @@ type service struct {

// Server is a gRPC server to serve RPC requests.
type Server struct {
opts options
mu sync.Mutex
opts options

mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[transport.ServerTransport]bool
conns map[io.Closer]bool
m map[string]*service // service name -> service info
events trace.EventLog
}
Expand All @@ -96,6 +99,7 @@ type options struct {
cp Compressor
dc Decompressor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}

// A ServerOption sets options.
Expand Down Expand Up @@ -149,7 +153,7 @@ func NewServer(opt ...ServerOption) *Server {
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
conns: make(map[transport.ServerTransport]bool),
conns: make(map[io.Closer]bool),
m: make(map[string]*service),
}
if EnableTracing {
Expand Down Expand Up @@ -216,9 +220,17 @@ var (
ErrServerStopped = errors.New("grpc: the server has been stopped")
)

func (s *Server) useTransportAuthenticator(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
creds, ok := s.opts.creds.(credentials.TransportAuthenticator)
if !ok {
return rawConn, nil, nil
}
return creds.ServerHandshake(rawConn)
}

// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines
// read gRPC request and then call the registered handlers to reply to them.
// read gRPC requests and then call the registered handlers to reply to them.
// Service returns when lis.Accept fails.
func (s *Server) Serve(lis net.Listener) error {
s.mu.Lock()
Expand All @@ -235,39 +247,54 @@ func (s *Server) Serve(lis net.Listener) error {
delete(s.lis, lis)
s.mu.Unlock()
}()
listenerAddr := lis.Addr()
for {
c, err := lis.Accept()
rawConn, err := lis.Accept()
if err != nil {
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
return err
}
var authInfo credentials.AuthInfo
if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
var conn net.Conn
conn, authInfo, err = creds.ServerHandshake(c)
if err != nil {
s.mu.Lock()
s.errorf("ServerHandshake(%q) failed: %v", c.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
continue
}
c = conn
}
// Start a new goroutine to deal with rawConn
// so we don't stall this Accept loop goroutine.
go s.handleRawConn(listenerAddr, rawConn)
}
}

// handleRawConn is run in its own goroutine and handles a just-accepted
// connection that has not had any I/O performed on it yet.
func (s *Server) handleRawConn(listenerAddr net.Addr, rawConn net.Conn) {
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
if err != nil {
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
c.Close()
return nil
}
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
rawConn.Close()
return
}

go s.serveNewHTTP2Transport(c, authInfo)
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
conn.Close()
return
}
s.mu.Unlock()

if s.opts.useHandlerImpl {
s.serveUsingHandler(listenerAddr, conn)
} else {
s.serveNewHTTP2Transport(conn, authInfo)
}
}

// serveNewHTTP2Transport sets up a new http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go) and
// serves streams on it.
// This is run in its own goroutine (it does network I/O in
// transport.NewServerTransport).
func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) {
st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
if err != nil {
Expand Down Expand Up @@ -299,6 +326,59 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
wg.Wait()
}

var _ http.Handler = (*Server)(nil)

// serveUsingHandler is called from handleRawConn when s is configured
// to handle requests via the http.Handler interface. It sets up a
// net/http.Server to handle the just-accepted conn. The http.Server
// is configured to route all incoming requests (all HTTP/2 streams)
// to ServeHTTP, which creates a new ServerTransport for each stream.
// serveUsingHandler blocks until conn closes.
//
// This codepath is only used when Server.TestingUseHandlerImpl has
// been configured. This lets the end2end tests exercise the ServeHTTP
// method as one of the environment types.
//
// conn is the *tls.Conn that's already been authenticated.
func (s *Server) serveUsingHandler(listenerAddr net.Addr, conn net.Conn) {
if !s.addConn(conn) {
conn.Close()
return
}
defer s.removeConn(conn)
connDone := make(chan struct{})
hs := &http.Server{
Handler: s,
ConnState: func(c net.Conn, cs http.ConnState) {
if cs == http.StateClosed {
close(connDone)
}
},
}
if err := http2.ConfigureServer(hs, &http2.Server{
MaxConcurrentStreams: s.opts.maxConcurrentStreams,
}); err != nil {
grpclog.Fatalf("grpc: http2.ConfigureServer: %v", err)
return
}
hs.Serve(&singleConnListener{addr: listenerAddr, conn: conn})
<-connDone
}

func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !s.addConn(st) {
st.Close()
return
}
defer s.removeConn(st)
s.serveStreams(st)
}

// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
Expand All @@ -317,21 +397,21 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
return trInfo
}

func (s *Server) addConn(st transport.ServerTransport) bool {
func (s *Server) addConn(c io.Closer) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns == nil {
return false
}
s.conns[st] = true
s.conns[c] = true
return true
}

func (s *Server) removeConn(st transport.ServerTransport) {
func (s *Server) removeConn(c io.Closer) {
s.mu.Lock()
defer s.mu.Unlock()
if s.conns != nil {
delete(s.conns, st)
delete(s.conns, c)
}
}

Expand Down Expand Up @@ -606,12 +686,14 @@ func (s *Server) Stop() {
cs := s.conns
s.conns = nil
s.mu.Unlock()

for lis := range listeners {
lis.Close()
}
for c := range cs {
c.Close()
}

s.mu.Lock()
if s.events != nil {
s.events.Finish()
Expand All @@ -621,16 +703,24 @@ func (s *Server) Stop() {
}

// TestingCloseConns closes all exiting transports but keeps s.lis accepting new
// connections. This is for test only now.
// connections.
// This is only for tests and is subject to removal.
func (s *Server) TestingCloseConns() {
s.mu.Lock()
for c := range s.conns {
c.Close()
delete(s.conns, c)
}
s.conns = make(map[transport.ServerTransport]bool)
s.mu.Unlock()
}

// TestingUseHandlerImpl enables the http.Handler-based server implementation.
// It must be called before Serve and requires TLS credentials.
// This is only for tests and is subject to removal.
func (s *Server) TestingUseHandlerImpl() {
s.opts.useHandlerImpl = true
}

// SendHeader sends header metadata. It may be called at most once from a unary
// RPC handler. The ctx is the RPC handler's Context or one derived from it.
func SendHeader(ctx context.Context, md metadata.MD) error {
Expand Down Expand Up @@ -661,3 +751,30 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
}
return stream.SetTrailer(md)
}

// singleConnListener is a net.Listener that yields a single conn.
type singleConnListener struct {
mu sync.Mutex
addr net.Addr
conn net.Conn // nil if done
}

func (ln *singleConnListener) Addr() net.Addr { return ln.addr }

func (ln *singleConnListener) Close() error {
ln.mu.Lock()
defer ln.mu.Unlock()
ln.conn = nil
return nil
}

func (ln *singleConnListener) Accept() (net.Conn, error) {
ln.mu.Lock()
defer ln.mu.Unlock()
c := ln.conn
if c == nil {
return nil, io.EOF
}
ln.conn = nil
return c, nil
}
21 changes: 14 additions & 7 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,11 @@ func unixDialer(addr string, timeout time.Duration) (net.Conn, error) {
}

type env struct {
name string
network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc.
name string
network string // The type of network such as tcp, unix, etc.
dialer func(addr string, timeout time.Duration) (net.Conn, error)
security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
}

func (e env) runnable() bool {
Expand All @@ -347,10 +348,11 @@ var (
tcpTLSEnv = env{name: "tcp-tls", network: "tcp", security: "tls"}
unixClearEnv = env{name: "unix-clear", network: "unix", dialer: unixDialer}
unixTLSEnv = env{name: "unix-tls", network: "unix", dialer: unixDialer, security: "tls"}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true}
allEnv = []env{tcpClearEnv, tcpTLSEnv, unixClearEnv, unixTLSEnv, handlerEnv}
)

var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', or 'unix-tls' to only run the tests for that environment. Empty means all.")
var onlyEnv = flag.String("only_env", "", "If non-empty, one of 'tcp-clear', 'tcp-tls', 'unix-clear', 'unix-tls', or 'handler-tls' to only run the tests for that environment. Empty means all.")

func listTestEnv() (envs []env) {
if *onlyEnv != "" {
Expand Down Expand Up @@ -393,6 +395,9 @@ func serverSetUp(t *testing.T, servON bool, hs *health.HealthServer, maxStream u
sopts = append(sopts, grpc.Creds(creds))
}
s = grpc.NewServer(sopts...)
if e.httpHandler {
s.TestingUseHandlerImpl()
}
if hs != nil {
healthpb.RegisterHealthServer(s, hs)
}
Expand Down Expand Up @@ -720,7 +725,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
t.Fatalf("Received header metadata %v, want %v", header, testMetadata)
}
if !reflect.DeepEqual(trailer, testTrailerMetadata) {
t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata)
t.Fatalf("Received trailer metadata %v, want %v", trailer, testTrailerMetadata)
}
}

Expand Down Expand Up @@ -1030,11 +1035,13 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
if e.security == "tls" {
delete(headerMD, "transport_security_type")
}
delete(headerMD, "trailer") // ignore if present
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#1 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
// test the cached value.
headerMD, err = stream.Header()
delete(headerMD, "trailer") // ignore if present
if err != nil || !reflect.DeepEqual(testMetadata, headerMD) {
t.Errorf("#2 %v.Header() = %v, %v, want %v, <nil>", stream, headerMD, err, testMetadata)
}
Expand Down
Loading

0 comments on commit 81d512c

Please sign in to comment.