Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add StartWithListener method to api and rpc server #4465

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions rest/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sort"
"time"
Expand Down Expand Up @@ -330,6 +331,31 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
ng.conf.KeyFile, router, opts...)
}

func (ng *engine) startWithListener(listener net.Listener, router httpx.Router, opts ...StartOption) error {
if err := ng.bindRoutes(router); err != nil {
return err
}

// make sure user defined options overwrite default options
opts = append([]StartOption{ng.withTimeout()}, opts...)

if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
return internal.StartHttpWithListener(listener, router, opts...)
}

// make sure user defined options overwrite default options
opts = append([]StartOption{
func(svr *http.Server) {
if ng.tlsConfig != nil {
svr.TLSConfig = ng.tlsConfig
}
},
}, opts...)

return internal.StartHttpsWithListener(listener, ng.conf.CertFile,
ng.conf.KeyFile, router, opts...)
}

func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware)
}
Expand Down
38 changes: 38 additions & 0 deletions rest/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -429,6 +430,43 @@ func TestEngine_start(t *testing.T) {
})
}

func TestEngine_startWithListener(t *testing.T) {
logx.Disable()

t.Run("http", func(t *testing.T) {
ng := newEngine(RestConf{
Host: "localhost",
Port: -1,
})
address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)
listener, err := net.Listen("tcp", address)
assert.Error(t, err)
if listener != nil {
assert.Error(t, ng.startWithListener(listener, router.NewRouter()))
} else {
assert.Error(t, ng.start(router.NewRouter()))
}
})

t.Run("https", func(t *testing.T) {
ng := newEngine(RestConf{
Host: "localhost",
Port: -1,
CertFile: "foo",
KeyFile: "bar",
})
ng.tlsConfig = &tls.Config{}
address := fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port)
listener, err := net.Listen("tcp", address)
assert.Error(t, err)
if listener != nil {
assert.Error(t, ng.startWithListener(listener, router.NewRouter()))
} else {
assert.Error(t, ng.start(router.NewRouter()))
}
})
}

type mockedRouter struct {
}

Expand Down
46 changes: 46 additions & 0 deletions rest/internal/starter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"

"github.com/zeromicro/go-zero/core/logx"
Expand All @@ -23,6 +24,13 @@ func StartHttp(host string, port int, handler http.Handler, opts ...StartOption)
}, opts...)
}

// StartHttpWithListener starts a http server with listener.
func StartHttpWithListener(listener net.Listener, handler http.Handler, opts ...StartOption) error {
return startWithListener(listener, handler, func(svr *http.Server) error {
return svr.Serve(listener)
}, opts...)
}

// StartHttps starts a https server.
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
opts ...StartOption) error {
Expand All @@ -32,6 +40,15 @@ func StartHttps(host string, port int, certFile, keyFile string, handler http.Ha
}, opts...)
}

// StartHttpsWithListener starts a https server with listener.
func StartHttpsWithListener(listener net.Listener, certFile, keyFile string, handler http.Handler,
opts ...StartOption) error {
return startWithListener(listener, handler, func(svr *http.Server) error {
// certFile and keyFile are set in buildHttpsServer
return svr.ServeTLS(listener, certFile, keyFile)
}, opts...)
}

func start(host string, port int, handler http.Handler, run func(svr *http.Server) error,
opts ...StartOption) (err error) {
server := &http.Server{
Expand Down Expand Up @@ -59,3 +76,32 @@ func start(host string, port int, handler http.Handler, run func(svr *http.Serve
health.AddProbe(healthManager)
return run(server)
}

func startWithListener(listener net.Listener, handler http.Handler, run func(svr *http.Server) error,
opts ...StartOption) (err error) {

server := &http.Server{
Addr: fmt.Sprintf("%s", listener.Addr().String()),
Handler: handler,
}
for _, opt := range opts {
opt(server)
}

healthManager := health.NewHealthManager(fmt.Sprintf("%s-%s", probeNamePrefix, listener.Addr().String()))
waitForCalled := proc.AddShutdownListener(func() {
healthManager.MarkNotReady()
if e := server.Shutdown(context.Background()); e != nil {
logx.Error(e)
}
})
defer func() {
if errors.Is(err, http.ErrServerClosed) {
waitForCalled()
}
}()

healthManager.MarkReady()
health.AddProbe(healthManager)
return run(server)
}
18 changes: 18 additions & 0 deletions rest/internal/starter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,21 @@ func TestStartHttps(t *testing.T) {
assert.NotNil(t, err)
proc.WrapUp()
}

func TestStartHttpWithListener(t *testing.T) {
svr := httptest.NewUnstartedServer(http.NotFoundHandler())
err := StartHttpWithListener(svr.Listener, http.NotFoundHandler(), func(svr *http.Server) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
proc.WrapUp()
}

func TestStartHttpsWithListener(t *testing.T) {
svr := httptest.NewUnstartedServer(http.NotFoundHandler())
err := StartHttpsWithListener(svr.Listener, "", "", http.NotFoundHandler(), func(svr *http.Server) {
svr.IdleTimeout = 0
})
assert.NotNil(t, err)
proc.WrapUp()
}
8 changes: 8 additions & 0 deletions rest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rest
import (
"crypto/tls"
"errors"
"net"
"net/http"
"path"
"time"
Expand Down Expand Up @@ -121,6 +122,13 @@ func (s *Server) Start() {
handleError(s.ngin.start(s.router))
}

// StartWithListener starts the Server with listener
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (s *Server) StartWithListener(listener net.Listener) {
handleError(s.ngin.startWithListener(listener, s.router))
}

// StartWithOpts starts the Server.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
Expand Down
19 changes: 19 additions & 0 deletions rest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -124,6 +125,24 @@ Port: 0
})
svr.Stop()
}()

func() {
defer func() {
p := recover()
switch v := p.(type) {
case error:
assert.Equal(t, "foo", v.Error())
default:
t.Fail()
}
}()

address := fmt.Sprintf("%s:%d", cnf.Host, cnf.Port)
listener, err := net.Listen("tcp", address)
assert.Nil(t, err)
svr.StartWithListener(listener)
svr.Stop()
}()
}
}

Expand Down
9 changes: 9 additions & 0 deletions zrpc/internal/rpcpubserver.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"net"
"os"
"strings"

Expand Down Expand Up @@ -53,6 +54,14 @@ func (s keepAliveServer) Start(fn RegisterFn) error {
return s.Server.Start(fn)
}

func (s keepAliveServer) StartWithListener(listener net.Listener, fn RegisterFn) error {
if err := s.registerEtcd(); err != nil {
return err
}

return s.Server.StartWithListener(listener, fn)
}

func figureOutListenOn(listenOn string) string {
fields := strings.Split(listenOn, ":")
if len(fields) == 0 {
Expand Down
4 changes: 4 additions & 0 deletions zrpc/internal/rpcpubserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ func TestNewRpcPubServer(t *testing.T) {
assert.NotPanics(t, func() {
s.Start(nil)
})

assert.NotPanics(t, func() {
s.StartWithListener(nil, nil)
})
}

func TestFigureOutListenOn(t *testing.T) {
Expand Down
30 changes: 30 additions & 0 deletions zrpc/internal/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ func (s *rpcServer) Start(register RegisterFn) error {
return server.Serve(lis)
}

func (s *rpcServer) StartWithListener(listener net.Listener, register RegisterFn) error {

unaryInterceptorOption := grpc.ChainUnaryInterceptor(s.unaryInterceptors...)
streamInterceptorOption := grpc.ChainStreamInterceptor(s.streamInterceptors...)

options := append(s.options, unaryInterceptorOption, streamInterceptorOption)
server := grpc.NewServer(options...)
register(server)

// register the health check service
if s.health != nil {
grpc_health_v1.RegisterHealthServer(server, s.health)
s.health.Resume()
}
s.healthManager.MarkReady()
health.AddProbe(s.healthManager)

// we need to make sure all others are wrapped up,
// so we do graceful stop at shutdown phase instead of wrap up phase
waitForCalled := proc.AddShutdownListener(func() {
if s.health != nil {
s.health.Shutdown()
}
server.GracefulStop()
})
defer waitForCalled()

return server.Serve(listener)
}

// WithRpcHealth returns a func that sets rpc health switch to a Server.
func WithRpcHealth(health bool) ServerOption {
return func(options *rpcServerOptions) {
Expand Down
20 changes: 18 additions & 2 deletions zrpc/internal/rpcserver_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"net"
"sync"
"testing"
"time"
Expand All @@ -17,8 +18,8 @@ func TestRpcServer(t *testing.T) {
var wg, wgDone sync.WaitGroup
var grpcServer *grpc.Server
var lock sync.Mutex
wg.Add(1)
wgDone.Add(1)
wg.Add(2)
wgDone.Add(2)
go func() {
err := server.Start(func(server *grpc.Server) {
lock.Lock()
Expand All @@ -31,6 +32,21 @@ func TestRpcServer(t *testing.T) {
wgDone.Done()
}()

go func() {
listener, err := net.Listen("tcp", "localhost:54322")
assert.Nil(t, err)
serverWithListener := NewRpcServer(listener.Addr().String(), WithRpcHealth(true))
err = serverWithListener.StartWithListener(listener, func(server *grpc.Server) {
lock.Lock()
mock.RegisterDepositServiceServer(server, new(mock.DepositServer))
grpcServer = server
lock.Unlock()
wg.Done()
})
assert.Nil(t, err)
wgDone.Done()
}()

wg.Wait()
time.Sleep(100 * time.Millisecond)

Expand Down
2 changes: 2 additions & 0 deletions zrpc/internal/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package internal

import (
"net"
"time"

"google.golang.org/grpc"
Expand All @@ -21,6 +22,7 @@ type (
AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor)
SetName(string)
Start(register RegisterFn) error
StartWithListener(listener net.Listener, register RegisterFn) error
}

baseRpcServer struct {
Expand Down
11 changes: 11 additions & 0 deletions zrpc/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zrpc

import (
"net"
"time"

"github.com/zeromicro/go-zero/core/load"
Expand Down Expand Up @@ -92,6 +93,16 @@ func (rs *RpcServer) Start() {
}
}

// StartWithListener starts the RpcServer with listener.
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (rs *RpcServer) StartWithListener(listener net.Listener) {
if err := rs.server.StartWithListener(listener, rs.register); err != nil {
logx.Error(err)
panic(err)
}
}

// Stop stops the RpcServer.
func (rs *RpcServer) Stop() {
logx.Close()
Expand Down
Loading