From 241801b7f64410a02610bb1c4fa797ee4fc94f1f Mon Sep 17 00:00:00 2001
From: Simba <1531315@qq.com>
Date: Wed, 19 Jan 2022 21:05:11 +0800
Subject: [PATCH] Gracefully exit the program when the lease expired (#2655)

This PR can let trillian proactively Listen "LeaseKeepAliveResponse" channel returned by KeepAlive in ETCD client. When automatic renewal interruption is detected, Exit the program by canceling the context.

Fixes #2654,#2249

Co-authored-by: Simba Peng <1531315@qq.com>
Co-authored-by: Martin Hutchinson <mhutchinson@gmail.com>
---
 cmd/internal/serverutil/main.go | 123 +++++++++++++++++++++++++++-----
 cmd/trillian_log_server/main.go |  10 ++-
 cmd/trillian_log_signer/main.go |   2 +-
 3 files changed, 114 insertions(+), 21 deletions(-)

diff --git a/cmd/internal/serverutil/main.go b/cmd/internal/serverutil/main.go
index 2f92952e5a..03abe10dd4 100644
--- a/cmd/internal/serverutil/main.go
+++ b/cmd/internal/serverutil/main.go
@@ -17,6 +17,7 @@ package serverutil
 
 import (
 	"context"
+	"errors"
 	"fmt"
 	"net"
 	"net/http"
@@ -28,10 +29,10 @@ import (
 	"github.com/google/trillian/monitoring"
 	"github.com/google/trillian/server/admin"
 	"github.com/google/trillian/server/interceptor"
-	"github.com/google/trillian/util"
 	"github.com/google/trillian/util/clock"
 	"github.com/prometheus/client_golang/prometheus/promhttp"
 	"go.etcd.io/etcd/client/v3/naming/endpoints"
+	"golang.org/x/sync/errgroup"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/reflection"
@@ -126,25 +127,54 @@ func (m *Main) Run(ctx context.Context) error {
 	trillian.RegisterTrillianAdminServer(srv, admin.New(m.Registry, m.AllowedTreeTypes))
 	reflection.Register(srv)
 
+	g, ctx := errgroup.WithContext(ctx)
+
 	if endpoint := m.HTTPEndpoint; endpoint != "" {
 		http.Handle("/metrics", promhttp.Handler())
 		http.HandleFunc("/healthz", m.healthz)
 
-		go func() {
+		s := &http.Server{
+			Addr: endpoint,
+		}
+
+		run := func() error {
 			glog.Infof("HTTP server starting on %v", endpoint)
 
 			var err error
 			// Let http.ListenAndServeTLS handle the error case when only one of the flags is set.
 			if m.TLSCertFile != "" || m.TLSKeyFile != "" {
-				err = http.ListenAndServeTLS(endpoint, m.TLSCertFile, m.TLSKeyFile, nil)
+				err = s.ListenAndServeTLS(m.TLSCertFile, m.TLSKeyFile)
 			} else {
-				err = http.ListenAndServe(endpoint, nil)
+				err = s.ListenAndServe()
 			}
 
 			if err != nil {
-				glog.Errorf("HTTP server stopped: %v", err)
+				if errors.Is(err, http.ErrServerClosed) {
+					return nil
+				}
+
+				err = fmt.Errorf("HTTP server stopped: %v", err)
+			}
+
+			return err
+		}
+
+		shutdown := func() {
+			glog.Infof("Stopping HTTP server...")
+			glog.Flush()
+
+			// 15 second exit time limit
+			ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+			defer cancel()
+
+			if err := s.Shutdown(ctx); err != nil {
+				glog.Errorf("Failed to http server shutdown: %v", err)
 			}
-		}()
+		}
+
+		g.Go(func() error {
+			return srvRun(ctx, run, shutdown)
+		})
 	}
 
 	glog.Infof("RPC server starting on %v", m.RPCEndpoint)
@@ -152,10 +182,9 @@ func (m *Main) Run(ctx context.Context) error {
 	if err != nil {
 		return err
 	}
-	go util.AwaitSignal(ctx, srv.Stop)
 
 	if m.TreeGCEnabled {
-		go func() {
+		g.Go(func() error {
 			glog.Info("Deleted tree GC started")
 			gc := admin.NewDeletedTreeGC(
 				m.Registry.AdminStorage,
@@ -163,20 +192,36 @@ func (m *Main) Run(ctx context.Context) error {
 				m.TreeDeleteMinInterval,
 				m.Registry.MetricFactory)
 			gc.Run(ctx)
-		}()
+			return nil
+		})
 	}
 
-	if err := srv.Serve(lis); err != nil {
-		glog.Errorf("RPC server terminated: %v", err)
+	run := func() error {
+		if err := srv.Serve(lis); err != nil {
+			return fmt.Errorf("RPC server terminated: %v", err)
+		}
+
+		return nil
 	}
 
-	glog.Infof("Stopping server, about to exit")
-	glog.Flush()
+	shutdown := func() {
+		glog.Infof("Stopping RPC server...")
+		glog.Flush()
+
+		srv.GracefulStop()
+	}
+
+	g.Go(func() error {
+		return srvRun(ctx, run, shutdown)
+	})
+
+	// wait for all jobs to exit gracefully
+	err = g.Wait()
 
 	// Give things a few seconds to tidy up
 	time.Sleep(time.Second * 5)
 
-	return nil
+	return err
 }
 
 // newGRPCServer starts a new Trillian gRPC server.
@@ -207,10 +252,11 @@ func (m *Main) newGRPCServer() (*grpc.Server, error) {
 	return s, nil
 }
 
-// AnnounceSelf announces this binary's presence to etcd.  Returns a function that
+// AnnounceSelf announces this binary's presence to etcd. This calls the cancel
+// function if the keepalive lease with etcd expires.  Returns a function that
 // should be called on process exit.
 // AnnounceSelf does nothing if client is nil.
-func AnnounceSelf(ctx context.Context, client *clientv3.Client, etcdService, endpoint string) func() {
+func AnnounceSelf(ctx context.Context, client *clientv3.Client, etcdService, endpoint string, cancel func()) func() {
 	if client == nil {
 		return func() {}
 	}
@@ -220,7 +266,12 @@ func AnnounceSelf(ctx context.Context, client *clientv3.Client, etcdService, end
 	if err != nil {
 		glog.Exitf("Failed to get lease from etcd: %v", err)
 	}
-	client.KeepAlive(ctx, leaseRsp.ID)
+
+	keepAliveRspCh, err := client.KeepAlive(ctx, leaseRsp.ID)
+	if err != nil {
+		glog.Exitf("Failed to keep lease alive from etcd: %v", err)
+	}
+	go listenKeepAliveRsp(ctx, keepAliveRspCh, cancel)
 
 	em, err := endpoints.NewManager(client, etcdService)
 	if err != nil {
@@ -238,3 +289,41 @@ func AnnounceSelf(ctx context.Context, client *clientv3.Client, etcdService, end
 		client.Revoke(ctx, leaseRsp.ID)
 	}
 }
+
+// listenKeepAliveRsp listens to `keepAliveRspCh` channel, and calls the cancel function
+// to notify the lease expired.
+func listenKeepAliveRsp(ctx context.Context, keepAliveRspCh <-chan *clientv3.LeaseKeepAliveResponse, cancel func()) {
+	for {
+		select {
+		case <-ctx.Done():
+			glog.Infof("listenKeepAliveRsp canceled: %v", ctx.Err())
+			return
+		case _, ok := <-keepAliveRspCh:
+			if !ok {
+				glog.Errorf("listenKeepAliveRsp canceled: unexpected lease expired")
+				cancel()
+				return
+			}
+		}
+	}
+}
+
+// srvRun run the server and call `shutdown` when the context has been cancelled
+func srvRun(ctx context.Context, run func() error, shutdown func()) error {
+	exit := make(chan struct{})
+	var err error
+	go func() {
+		defer close(exit)
+		err = run()
+	}()
+
+	select {
+	case <-ctx.Done():
+		shutdown()
+		// wait for run to return
+		<-exit
+	case <-exit:
+	}
+
+	return err
+}
diff --git a/cmd/trillian_log_server/main.go b/cmd/trillian_log_server/main.go
index ec75c70297..c8c7978933 100644
--- a/cmd/trillian_log_server/main.go
+++ b/cmd/trillian_log_server/main.go
@@ -40,6 +40,7 @@ import (
 	"github.com/google/trillian/quota/etcd/quotapb"
 	"github.com/google/trillian/server"
 	"github.com/google/trillian/storage"
+	"github.com/google/trillian/util"
 	"github.com/google/trillian/util/clock"
 	clientv3 "go.etcd.io/etcd/client/v3"
 	"google.golang.org/grpc"
@@ -91,7 +92,9 @@ func main() {
 		}
 	}
 
-	ctx := context.Background()
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+	go util.AwaitSignal(ctx, cancel)
 
 	var options []grpc.ServerOption
 	mf := prometheus.MetricFactory{}
@@ -124,10 +127,11 @@ func main() {
 	}
 
 	// Announce our endpoints to etcd if so configured.
-	unannounce := serverutil.AnnounceSelf(ctx, client, *etcdService, *rpcEndpoint)
+	unannounce := serverutil.AnnounceSelf(ctx, client, *etcdService, *rpcEndpoint, cancel)
 	defer unannounce()
+
 	if *httpEndpoint != "" {
-		unannounceHTTP := serverutil.AnnounceSelf(ctx, client, *etcdHTTPService, *httpEndpoint)
+		unannounceHTTP := serverutil.AnnounceSelf(ctx, client, *etcdHTTPService, *httpEndpoint, cancel)
 		defer unannounceHTTP()
 	}
 
diff --git a/cmd/trillian_log_signer/main.go b/cmd/trillian_log_signer/main.go
index 976752ae29..f2f8c54168 100644
--- a/cmd/trillian_log_signer/main.go
+++ b/cmd/trillian_log_signer/main.go
@@ -150,7 +150,7 @@ func main() {
 	// Start HTTP server (optional)
 	if *httpEndpoint != "" {
 		// Announce our endpoint to etcd if so configured.
-		unannounceHTTP := serverutil.AnnounceSelf(ctx, client, *etcdHTTPService, *httpEndpoint)
+		unannounceHTTP := serverutil.AnnounceSelf(ctx, client, *etcdHTTPService, *httpEndpoint, cancel)
 		defer unannounceHTTP()
 	}