diff --git a/rest/engine.go b/rest/engine.go index e57786caf205..c3c7e1f033f5 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "sort" "time" @@ -330,6 +331,31 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error { ng.conf.KeyFile, router, opts...) } +func (ng *engine) startWithListener(listener net.Listener, router httpx.Router, opts ...StartOption) error { + if err := ng.bindRoutes(router); err != nil { + return err + } + + // make sure user defined options overwrite default options + opts = append([]StartOption{ng.withTimeout()}, opts...) + + if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { + return internal.StartHttpWithListener(listener, router, opts...) + } + + // make sure user defined options overwrite default options + opts = append([]StartOption{ + func(svr *http.Server) { + if ng.tlsConfig != nil { + svr.TLSConfig = ng.tlsConfig + } + }, + }, opts...) + + return internal.StartHttpsWithListener(listener, ng.conf.CertFile, + ng.conf.KeyFile, router, opts...) +} + func (ng *engine) use(middleware Middleware) { ng.middlewares = append(ng.middlewares, middleware) } diff --git a/rest/engine_test.go b/rest/engine_test.go index 4f86d2173efd..12fc7c4fedd9 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "net/http/httptest" "os" @@ -429,6 +430,43 @@ func TestEngine_start(t *testing.T) { }) } +func TestEngine_startWithListener(t *testing.T) { + logx.Disable() + + t.Run("http", func(t *testing.T) { + ng := newEngine(RestConf{ + Host: "localhost", + Port: -1, + }) + address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port) + listener, err := net.Listen("tcp", address) + assert.Error(t, err) + if listener != nil { + assert.Error(t, ng.startWithListener(listener, router.NewRouter())) + } else { + assert.Error(t, ng.start(router.NewRouter())) + } + }) + + t.Run("https", func(t *testing.T) { + ng := newEngine(RestConf{ + Host: "localhost", + Port: -1, + CertFile: "foo", + KeyFile: "bar", + }) + ng.tlsConfig = &tls.Config{} + address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port) + listener, err := net.Listen("tcp", address) + assert.Error(t, err) + if listener != nil { + assert.Error(t, ng.startWithListener(listener, router.NewRouter())) + } else { + assert.Error(t, ng.start(router.NewRouter())) + } + }) +} + type mockedRouter struct { } diff --git a/rest/internal/starter.go b/rest/internal/starter.go index 174303342b7e..a2533bf57d71 100644 --- a/rest/internal/starter.go +++ b/rest/internal/starter.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "github.com/zeromicro/go-zero/core/logx" @@ -23,6 +24,13 @@ func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) }, opts...) } +// StartHttpWithListener starts a http server with listener. +func StartHttpWithListener(listener net.Listener, handler http.Handler, opts ...StartOption) error { + return startWithListener(listener, handler, func(svr *http.Server) error { + return svr.Serve(listener) + }, opts...) +} + // StartHttps starts a https server. func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler, opts ...StartOption) error { @@ -32,6 +40,15 @@ func StartHttps(host string, port int, certFile, keyFile string, handler http.Ha }, opts...) } +// StartHttpsWithListener starts a https server with listener. +func StartHttpsWithListener(listener net.Listener, certFile, keyFile string, handler http.Handler, + opts ...StartOption) error { + return startWithListener(listener, handler, func(svr *http.Server) error { + // certFile and keyFile are set in buildHttpsServer + return svr.ServeTLS(listener, certFile, keyFile) + }, opts...) +} + func start(host string, port int, handler http.Handler, run func(svr *http.Server) error, opts ...StartOption) (err error) { server := &http.Server{ @@ -59,3 +76,32 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve health.AddProbe(healthManager) return run(server) } + +func startWithListener(listener net.Listener, handler http.Handler, run func(svr *http.Server) error, + opts ...StartOption) (err error) { + + server := &http.Server{ + Addr: fmt.Sprintf("%s", listener.Addr().String()), + Handler: handler, + } + for _, opt := range opts { + opt(server) + } + + healthManager := health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, listener.Addr().String())) + waitForCalled := proc.AddShutdownListener(func() { + healthManager.MarkNotReady() + if e := server.Shutdown(context.Background()); e != nil { + logx.Error(e) + } + }) + defer func() { + if errors.Is(err, http.ErrServerClosed) { + waitForCalled() + } + }() + + healthManager.MarkReady() + health.AddProbe(healthManager) + return run(server) +} diff --git a/rest/internal/starter_test.go b/rest/internal/starter_test.go index a54c215f9f56..9e59ddfaa8f4 100644 --- a/rest/internal/starter_test.go +++ b/rest/internal/starter_test.go @@ -34,3 +34,21 @@ func TestStartHttps(t *testing.T) { assert.NotNil(t, err) proc.WrapUp() } + +func TestStartHttpWithListener(t *testing.T) { + svr := httptest.NewUnstartedServer(http.NotFoundHandler()) + err := StartHttpWithListener(svr.Listener, http.NotFoundHandler(), func(svr *http.Server) { + svr.IdleTimeout = 0 + }) + assert.NotNil(t, err) + proc.WrapUp() +} + +func TestStartHttpsWithListener(t *testing.T) { + svr := httptest.NewUnstartedServer(http.NotFoundHandler()) + err := StartHttpsWithListener(svr.Listener, "", "", http.NotFoundHandler(), func(svr *http.Server) { + svr.IdleTimeout = 0 + }) + assert.NotNil(t, err) + proc.WrapUp() +} diff --git a/rest/server.go b/rest/server.go index b1e5487bd8a5..c55419f1f95c 100644 --- a/rest/server.go +++ b/rest/server.go @@ -3,6 +3,7 @@ package rest import ( "crypto/tls" "errors" + "net" "net/http" "path" "time" @@ -121,6 +122,13 @@ func (s *Server) Start() { handleError(s.ngin.start(s.router)) } +// StartWithListener starts the Server with listener +// Graceful shutdown is enabled by default. +// Use proc.SetTimeToForceQuit to customize the graceful shutdown period. +func (s *Server) StartWithListener(listener net.Listener) { + handleError(s.ngin.startWithListener(listener, s.router)) +} + // StartWithOpts starts the Server. // Graceful shutdown is enabled by default. // Use proc.SetTimeToForceQuit to customize the graceful shutdown period. diff --git a/rest/server_test.go b/rest/server_test.go index 9a92d58f8203..3a298b48c826 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/fs" + "net" "net/http" "net/http/httptest" "os" @@ -124,6 +125,24 @@ Port: 0 }) svr.Stop() }() + + func() { + defer func() { + p := recover() + switch v := p.(type) { + case error: + assert.Equal(t, "foo", v.Error()) + default: + t.Fail() + } + }() + + address := fmt.Sprintf("%s:%d", cnf.Host, cnf.Port) + listener, err := net.Listen("tcp", address) + assert.Nil(t, err) + svr.StartWithListener(listener) + svr.Stop() + }() } } diff --git a/zrpc/internal/rpcpubserver.go b/zrpc/internal/rpcpubserver.go index 70b481323d92..4c795ad7c271 100644 --- a/zrpc/internal/rpcpubserver.go +++ b/zrpc/internal/rpcpubserver.go @@ -1,6 +1,7 @@ package internal import ( + "net" "os" "strings" @@ -53,6 +54,14 @@ func (s keepAliveServer) Start(fn RegisterFn) error { return s.Server.Start(fn) } +func (s keepAliveServer) StartWithListener(listener net.Listener, fn RegisterFn) error { + if err := s.registerEtcd(); err != nil { + return err + } + + return s.Server.StartWithListener(listener, fn) +} + func figureOutListenOn(listenOn string) string { fields := strings.Split(listenOn, ":") if len(fields) == 0 { diff --git a/zrpc/internal/rpcpubserver_test.go b/zrpc/internal/rpcpubserver_test.go index cc36e4653357..ec009820736c 100644 --- a/zrpc/internal/rpcpubserver_test.go +++ b/zrpc/internal/rpcpubserver_test.go @@ -18,6 +18,10 @@ func TestNewRpcPubServer(t *testing.T) { assert.NotPanics(t, func() { s.Start(nil) }) + + assert.NotPanics(t, func() { + s.StartWithListener(nil, nil) + }) } func TestFigureOutListenOn(t *testing.T) { diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index c1302f03ae52..7d94c4cc4b20 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -78,6 +78,36 @@ func (s *rpcServer) Start(register RegisterFn) error { return server.Serve(lis) } +func (s *rpcServer) StartWithListener(listener net.Listener, register RegisterFn) error { + + unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.unaryInterceptors...) + streamInterceptorOption := grpc.ChainStreamInterceptor(s.streamInterceptors...) + + options := append(s.options, unaryInterceptorOption, streamInterceptorOption) + server := grpc.NewServer(options...) + register(server) + + // register the health check service + if s.health != nil { + grpc_health_v1.RegisterHealthServer(server, s.health) + s.health.Resume() + } + s.healthManager.MarkReady() + health.AddProbe(s.healthManager) + + // we need to make sure all others are wrapped up, + // so we do graceful stop at shutdown phase instead of wrap up phase + waitForCalled := proc.AddShutdownListener(func() { + if s.health != nil { + s.health.Shutdown() + } + server.GracefulStop() + }) + defer waitForCalled() + + return server.Serve(listener) +} + // WithRpcHealth returns a func that sets rpc health switch to a Server. func WithRpcHealth(health bool) ServerOption { return func(options *rpcServerOptions) { diff --git a/zrpc/internal/rpcserver_test.go b/zrpc/internal/rpcserver_test.go index 696dae68713c..b415f15eba8a 100644 --- a/zrpc/internal/rpcserver_test.go +++ b/zrpc/internal/rpcserver_test.go @@ -1,6 +1,7 @@ package internal import ( + "net" "sync" "testing" "time" @@ -17,8 +18,8 @@ func TestRpcServer(t *testing.T) { var wg, wgDone sync.WaitGroup var grpcServer *grpc.Server var lock sync.Mutex - wg.Add(1) - wgDone.Add(1) + wg.Add(2) + wgDone.Add(2) go func() { err := server.Start(func(server *grpc.Server) { lock.Lock() @@ -31,6 +32,21 @@ func TestRpcServer(t *testing.T) { wgDone.Done() }() + go func() { + listener, err := net.Listen("tcp", "localhost:54322") + assert.Nil(t, err) + serverWithListener := NewRpcServer(listener.Addr().String(), WithRpcHealth(true)) + err = serverWithListener.StartWithListener(listener, func(server *grpc.Server) { + lock.Lock() + mock.RegisterDepositServiceServer(server, new(mock.DepositServer)) + grpcServer = server + lock.Unlock() + wg.Done() + }) + assert.Nil(t, err) + wgDone.Done() + }() + wg.Wait() time.Sleep(100 * time.Millisecond) diff --git a/zrpc/internal/server.go b/zrpc/internal/server.go index fc9eea0cbb50..40f4972deb72 100644 --- a/zrpc/internal/server.go +++ b/zrpc/internal/server.go @@ -1,6 +1,7 @@ package internal import ( + "net" "time" "google.golang.org/grpc" @@ -21,6 +22,7 @@ type ( AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) SetName(string) Start(register RegisterFn) error + StartWithListener(listener net.Listener, register RegisterFn) error } baseRpcServer struct { diff --git a/zrpc/server.go b/zrpc/server.go index 813fc358d298..d8463514331a 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -1,6 +1,7 @@ package zrpc import ( + "net" "time" "github.com/zeromicro/go-zero/core/load" @@ -92,6 +93,16 @@ func (rs *RpcServer) Start() { } } +// StartWithListener starts the RpcServer with listener. +// Graceful shutdown is enabled by default. +// Use proc.SetTimeToForceQuit to customize the graceful shutdown period. +func (rs *RpcServer) StartWithListener(listener net.Listener) { + if err := rs.server.StartWithListener(listener, rs.register); err != nil { + logx.Error(err) + panic(err) + } +} + // Stop stops the RpcServer. func (rs *RpcServer) Stop() { logx.Close() diff --git a/zrpc/server_test.go b/zrpc/server_test.go index e42e379a6cd1..03c344c34f1d 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -2,6 +2,7 @@ package zrpc import ( "context" + "net" "testing" "time" @@ -56,6 +57,48 @@ func TestServer(t *testing.T) { svr.Stop() } +func TestServer_StartWithListener(t *testing.T) { + DontLogContentForMethod("foo") + SetServerSlowThreshold(time.Second) + svr := MustNewServer(RpcServerConf{ + ServiceConf: service.ServiceConf{ + Log: logx.LogConf{ + ServiceName: "foo", + Mode: "console", + }, + }, + ListenOn: "localhost:8081", + Etcd: discov.EtcdConf{}, + Auth: false, + Redis: redis.RedisKeyConf{}, + StrictControl: false, + Timeout: 0, + CpuThreshold: 0, + Middlewares: ServerMiddlewaresConf{ + Trace: true, + Recover: true, + Stat: true, + Prometheus: true, + Breaker: true, + }, + MethodTimeouts: []MethodTimeoutConf{ + { + FullMethod: "/foo", + Timeout: time.Second, + }, + }, + }, func(server *grpc.Server) { + }) + svr.AddOptions(grpc.ConnectionTimeout(time.Hour)) + svr.AddUnaryInterceptors(serverinterceptors.UnaryRecoverInterceptor) + svr.AddStreamInterceptors(serverinterceptors.StreamRecoverInterceptor) + + listener, err := net.Listen("tcp", "localhost:8081") + assert.Nil(t, err) + go svr.StartWithListener(listener) + svr.Stop() +} + func TestServerError(t *testing.T) { _, err := NewServer(RpcServerConf{ ServiceConf: service.ServiceConf{ @@ -159,6 +202,10 @@ func (m *mockedServer) Start(_ internal.RegisterFn) error { return nil } +func (m *mockedServer) StartWithListener(_ net.Listener, _ internal.RegisterFn) error { + return nil +} + func Test_setupUnaryInterceptors(t *testing.T) { tests := []struct { name string