Skip to content

Commit

Permalink
Shutting down HTTP and gRPC API servers when the runtime shuts down (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
pkedy authored Oct 26, 2021
1 parent 328e8df commit 5e8d727
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 50 deletions.
43 changes: 27 additions & 16 deletions pkg/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package grpc
import (
"crypto/tls"
"fmt"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -39,6 +40,7 @@ const (

// Server is an interface for the dapr gRPC server.
type Server interface {
io.Closer
StartNonBlocking() error
}

Expand All @@ -48,8 +50,7 @@ type server struct {
tracingSpec config.TracingSpec
metricSpec config.MetricSpec
authenticator auth.Authenticator
listeners []net.Listener
srv *grpc_go.Server
servers []*grpc_go.Server
renewMutex *sync.Mutex
signedCert *auth.SignedCertificate
tlsCert tls.Certificate
Expand Down Expand Up @@ -127,30 +128,40 @@ func (s *server) StartNonBlocking() error {
if len(listeners) == 0 {
return errors.Errorf("could not listen on any endpoint")
}
s.listeners = listeners

server, err := s.getGRPCServer()
if err != nil {
return err
}
s.srv = server
for _, listener := range listeners {
// server is created in a loop because each instance
// has a handle on the underlying listener.
server, err := s.getGRPCServer()
if err != nil {
return err
}
s.servers = append(s.servers, server)

if s.kind == internalServer {
internalv1pb.RegisterServiceInvocationServer(server, s.api)
} else if s.kind == apiServer {
runtimev1pb.RegisterDaprServer(server, s.api)
}
if s.kind == internalServer {
internalv1pb.RegisterServiceInvocationServer(server, s.api)
} else if s.kind == apiServer {
runtimev1pb.RegisterDaprServer(server, s.api)
}

for _, listener := range listeners {
go func(l net.Listener) {
go func(server *grpc_go.Server, l net.Listener) {
if err := server.Serve(l); err != nil {
s.logger.Fatalf("gRPC serve error: %v", err)
}
}(listener)
}(server, listener)
}
return nil
}

func (s *server) Close() error {
for _, server := range s.servers {
// This calls `Close()` on the underlying listener.
server.GracefulStop()
}

return nil
}

func (s *server) generateWorkloadCert() error {
s.logger.Info("sending workload csr request to sentry")
signedCert, err := s.authenticator.CreateSignedWorkloadCert(s.config.AppID, s.config.NameSpace, s.config.TrustDomain)
Expand Down
15 changes: 15 additions & 0 deletions pkg/grpc/server_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package grpc

import (
"fmt"
"sync"
"testing"
"time"

"github.com/phayes/freeport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dapr/kit/logger"

"github.com/dapr/dapr/pkg/config"
dapr_testing "github.com/dapr/dapr/pkg/testing"
)

func TestCertRenewal(t *testing.T) {
Expand Down Expand Up @@ -85,3 +89,14 @@ func TestGetMiddlewareOptions(t *testing.T) {
assert.Equal(t, 1, len(serverOption))
})
}

func TestClose(t *testing.T) {
port, err := freeport.GetFreePort()
require.NoError(t, err)
serverConfig := NewServerConfig("test", "127.0.0.1", port, []string{"127.0.0.1"}, "test", "test", 4, "", 4)
a := &api{}
server := NewAPIServer(a, serverConfig, config.TracingSpec{}, config.MetricSpec{}, config.APISpec{}, nil)
require.NoError(t, server.StartNonBlocking())
dapr_testing.WaitForListeningAddress(t, 5*time.Second, fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, server.Close())
}
12 changes: 8 additions & 4 deletions pkg/health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,13 @@ func TestResponses(t *testing.T) {
server := httptest.NewServer(&testServer{
statusCode: 200,
})
defer server.Close()

ch := StartEndpointHealthCheck(server.URL, WithInterval(time.Second*1), WithFailureThreshold(1))
for {
healthy := <-ch
assert.True(t, healthy)
server.Close()

return
}
})
Expand All @@ -119,12 +120,13 @@ func TestResponses(t *testing.T) {
server := httptest.NewServer(&testServer{
statusCode: 201,
})
defer server.Close()

ch := StartEndpointHealthCheck(server.URL, WithInterval(time.Second*1), WithFailureThreshold(1), WithSuccessStatusCode(201))
for {
healthy := <-ch
assert.True(t, healthy)
server.Close()

return
}
})
Expand All @@ -133,12 +135,13 @@ func TestResponses(t *testing.T) {
server := httptest.NewServer(&testServer{
statusCode: 500,
})
defer server.Close()

ch := StartEndpointHealthCheck(server.URL, WithInterval(time.Second*1), WithFailureThreshold(1))
for {
healthy := <-ch
assert.False(t, healthy)
server.Close()

return
}
})
Expand All @@ -148,6 +151,7 @@ func TestResponses(t *testing.T) {
statusCode: 500,
}
server := httptest.NewServer(test)
defer server.Close()

ch := StartEndpointHealthCheck(server.URL, WithInterval(time.Second*1), WithFailureThreshold(1))
count := 0
Expand All @@ -159,7 +163,7 @@ func TestResponses(t *testing.T) {
test.statusCode = 200
} else {
assert.True(t, healthy)
server.Close()

return
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/http/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ func TestAPIToken(t *testing.T) {
token := "1234"

os.Setenv("DAPR_API_TOKEN", token)
defer os.Clearenv()
defer os.Unsetenv("DAPR_API_TOKEN")

fakeHeaderMetadata := map[string][]string{
"Accept-Encoding": {"gzip"},
Expand Down
57 changes: 45 additions & 12 deletions pkg/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package http

import (
"fmt"
"io"
"net"
"net/http"
"net/url"
Expand All @@ -15,6 +16,7 @@ import (

cors "github.com/AdhityaRamadhanus/fasthttpcors"
routing "github.com/fasthttp/router"
"github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/valyala/fasthttp"
"github.com/valyala/fasthttp/pprofhandler"
Expand All @@ -34,6 +36,7 @@ const protocol = "http"

// Server is an interface for the Dapr HTTP server.
type Server interface {
io.Closer
StartNonBlocking() error
}

Expand All @@ -44,7 +47,7 @@ type server struct {
pipeline http_middleware.Pipeline
api API
apiSpec config.APISpec
listeners []net.Listener
servers []*fasthttp.Server
profilingListeners []net.Listener
}

Expand All @@ -71,13 +74,6 @@ func (s *server) StartNonBlocking() error {
handler = s.useMetrics(handler)
handler = s.useTracing(handler)

customServer := &fasthttp.Server{
Handler: handler,
MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024,
ReadBufferSize: s.config.ReadBufferSize * 1024,
StreamRequestBody: s.config.StreamRequestBody,
}

var listeners []net.Listener
var profilingListeners []net.Listener
if s.config.UnixDomainSocket != "" {
Expand All @@ -101,10 +97,21 @@ func (s *server) StartNonBlocking() error {
return errors.Errorf("could not listen on any endpoint")
}

s.listeners = listeners
for _, listener := range listeners {
// customServer is created in a loop because each instance
// has a handle on the underlying listener.
customServer := &fasthttp.Server{
Handler: handler,
MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024,
ReadBufferSize: s.config.ReadBufferSize * 1024,
StreamRequestBody: s.config.StreamRequestBody,
}
s.servers = append(s.servers, customServer)

go func(l net.Listener) {
log.Fatal(customServer.Serve(l))
if err := customServer.Serve(l); err != nil {
log.Fatal(err)
}
}(listener)
}

Expand All @@ -117,9 +124,12 @@ func (s *server) StartNonBlocking() error {
Handler: publicHandler,
MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024,
}
s.servers = append(s.servers, healthServer)

go func() {
log.Fatal(healthServer.ListenAndServe(fmt.Sprintf(":%d", *s.config.PublicPort)))
if err := healthServer.ListenAndServe(fmt.Sprintf(":%d", *s.config.PublicPort)); err != nil {
log.Fatal(err)
}
}()
}

Expand All @@ -140,15 +150,38 @@ func (s *server) StartNonBlocking() error {

s.profilingListeners = profilingListeners
for _, listener := range profilingListeners {
// profServer is created in a loop because each instance
// has a handle on the underlying listener.
profServer := &fasthttp.Server{
Handler: pprofhandler.PprofHandler,
MaxRequestBodySize: s.config.MaxRequestBodySize * 1024 * 1024,
}
s.servers = append(s.servers, profServer)

go func(l net.Listener) {
log.Fatal(fasthttp.Serve(l, pprofhandler.PprofHandler))
if err := profServer.Serve(l); err != nil {
log.Fatal(err)
}
}(listener)
}
}

return nil
}

func (s *server) Close() error {
var merr error

for _, ln := range s.servers {
// This calls `Close()` on the underlying listener.
if err := ln.Shutdown(); err != nil {
merr = multierror.Append(merr, err)
}
}

return merr
}

func (s *server) useTracing(next fasthttp.RequestHandler) fasthttp.RequestHandler {
if diag_utils.IsTracingEnabled(s.tracingSpec.SamplingRate) {
log.Infof("enabled tracing http middleware")
Expand Down
16 changes: 16 additions & 0 deletions pkg/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@ import (
"runtime"
"strings"
"testing"
"time"

"github.com/fasthttp/router"
"github.com/phayes/freeport"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"

"github.com/dapr/dapr/pkg/config"
"github.com/dapr/dapr/pkg/cors"
http_middleware "github.com/dapr/dapr/pkg/middleware/http"
dapr_testing "github.com/dapr/dapr/pkg/testing"
)

type mockHost struct {
Expand Down Expand Up @@ -637,3 +642,14 @@ func TestAliasRoute(t *testing.T) {
assert.Equal(t, 1, len(routes[router.MethodWild]))
})
}

func TestClose(t *testing.T) {
port, err := freeport.GetFreePort()
require.NoError(t, err)
serverConfig := NewServerConfig("test", "127.0.0.1", port, []string{"127.0.0.1"}, nil, 0, "", false, 4, "", 4, false)
a := &api{}
server := NewServer(a, serverConfig, config.TracingSpec{}, config.MetricSpec{}, http_middleware.Pipeline{}, config.APISpec{})
require.NoError(t, server.StartNonBlocking())
dapr_testing.WaitForListeningAddress(t, 5*time.Second, fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, server.Close())
}
Loading

0 comments on commit 5e8d727

Please sign in to comment.