From 0c0754fc6a94c171733db0db93509c1877a0afd5 Mon Sep 17 00:00:00 2001
From: Gabriel Aszalos <gabriel.aszalos@datadoghq.com>
Date: Mon, 5 Mar 2018 15:36:52 +0100
Subject: [PATCH] cmd/trace-agent: use http.Server.Shutdown and context.Context

---
 Makefile                         |  1 +
 cmd/trace-agent/agent.go         | 16 ++++++++++------
 cmd/trace-agent/agent_test.go    | 20 ++++++++++++--------
 cmd/trace-agent/listener.go      | 25 +++++++------------------
 cmd/trace-agent/main.go          |  9 +++++----
 cmd/trace-agent/main_nix.go      |  7 ++++---
 cmd/trace-agent/main_windows.go  | 13 +++++++------
 cmd/trace-agent/receiver.go      | 18 +++++++++++-------
 cmd/trace-agent/receiver_test.go |  2 +-
 9 files changed, 58 insertions(+), 53 deletions(-)

diff --git a/Makefile b/Makefile
index 53abc2ca4..9ad14d207 100644
--- a/Makefile
+++ b/Makefile
@@ -23,6 +23,7 @@ install:
 
 ci:
 	# task used by CI
+	GOOS=windows go build ./cmd/trace-agent # ensure windows builds
 	go get -u github.com/golang/lint/golint/...
 	golint ./cmd/trace-agent ./filters ./fixtures ./info ./quantile ./quantizer ./sampler ./statsd ./watchdog ./writer
 	go test ./...
diff --git a/cmd/trace-agent/agent.go b/cmd/trace-agent/agent.go
index f299efbf0..d60d50e86 100644
--- a/cmd/trace-agent/agent.go
+++ b/cmd/trace-agent/agent.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"sync/atomic"
 	"time"
 
@@ -57,13 +58,14 @@ type Agent struct {
 	dynConf *config.DynamicConfig
 
 	// Used to synchronize on a clean exit
-	exit chan struct{}
+	ctx context.Context
 
 	die func(format string, args ...interface{})
 }
 
-// NewAgent returns a new Agent object, ready to be started
-func NewAgent(conf *config.AgentConfig, exit chan struct{}) *Agent {
+// NewAgent returns a new Agent object, ready to be started. It takes a context
+// which may be cancelled in order to gracefully stop the agent.
+func NewAgent(ctx context.Context, conf *config.AgentConfig) *Agent {
 	dynConf := config.NewDynamicConfig()
 
 	// inter-component channels
@@ -107,7 +109,7 @@ func NewAgent(conf *config.AgentConfig, exit chan struct{}) *Agent {
 		sampledTraceChan:   sampledTraceChan,
 		conf:               conf,
 		dynConf:            dynConf,
-		exit:               exit,
+		ctx:                ctx,
 		die:                die,
 	}
 }
@@ -140,9 +142,11 @@ func (a *Agent) Run() {
 			a.Process(t)
 		case <-watchdogTicker.C:
 			a.watchdog()
-		case <-a.exit:
+		case <-a.ctx.Done():
 			log.Info("exiting")
-			close(a.Receiver.exit)
+			if err := a.Receiver.Stop(); err != nil {
+				log.Error(err)
+			}
 			a.Concentrator.Stop()
 			a.TraceWriter.Stop()
 			a.StatsWriter.Stop()
diff --git a/cmd/trace-agent/agent_test.go b/cmd/trace-agent/agent_test.go
index c2d9908bf..ef095c1d5 100644
--- a/cmd/trace-agent/agent_test.go
+++ b/cmd/trace-agent/agent_test.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"net/http"
 	"runtime"
@@ -31,14 +32,15 @@ func TestWatchdog(t *testing.T) {
 	defaultMux := http.DefaultServeMux
 	http.DefaultServeMux = http.NewServeMux()
 
-	exit := make(chan struct{})
-	agent := NewAgent(conf, exit)
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+	agent := NewAgent(ctx, conf)
 
 	defer func() {
-		close(agent.exit)
+		cancelFunc()
 		// We need to manually close the receiver as the Run() func
 		// should have been broken and interrupted by the watchdog panic
-		close(agent.Receiver.exit)
+		agent.Receiver.Stop()
 		// we need to wait more than on second (time for StoppableListener.Accept
 		// to acknowledge the connection has been closed)
 		time.Sleep(2 * time.Second)
@@ -129,8 +131,9 @@ func BenchmarkAgentTraceProcessingWithWorstCaseFiltering(b *testing.B) {
 }
 
 func runTraceProcessingBenchmark(b *testing.B, c *config.AgentConfig) {
-	exit := make(chan struct{})
-	agent := NewAgent(c, exit)
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+	agent := NewAgent(ctx, c)
 	log.UseLogger(log.Disabled)
 
 	b.ResetTimer()
@@ -143,8 +146,9 @@ func runTraceProcessingBenchmark(b *testing.B, c *config.AgentConfig) {
 func BenchmarkWatchdog(b *testing.B) {
 	conf := config.NewDefaultAgentConfig()
 	conf.APIKey = "apikey_2"
-	exit := make(chan struct{})
-	agent := NewAgent(conf, exit)
+	ctx, cancelFunc := context.WithCancel(context.Background())
+	defer cancelFunc()
+	agent := NewAgent(ctx, conf)
 
 	b.ResetTimer()
 	b.ReportAllocs()
diff --git a/cmd/trace-agent/listener.go b/cmd/trace-agent/listener.go
index 0986c2a1b..df50c0a4f 100644
--- a/cmd/trace-agent/listener.go
+++ b/cmd/trace-agent/listener.go
@@ -9,28 +9,27 @@ import (
 	log "github.com/cihub/seelog"
 )
 
-// StoppableListener wraps a regular TCPListener with an exit channel so we can exit cleanly from the Serve() loop of our HTTP server
-type StoppableListener struct {
-	exit      chan struct{}
+// RateLimitedListener wraps a regular TCPListener with an exit channel so we can exit cleanly from the Serve() loop of our HTTP server
+type RateLimitedListener struct {
 	connLease int32 // How many connections are available for this listener before rate-limiting kicks in
 	*net.TCPListener
 }
 
-// NewStoppableListener returns a new wrapped listener, which is non-initialized
-func NewStoppableListener(l net.Listener, exit chan struct{}, conns int) (*StoppableListener, error) {
+// NewRateLimitedListener returns a new wrapped listener, which is non-initialized
+func NewRateLimitedListener(l net.Listener, conns int) (*RateLimitedListener, error) {
 	tcpL, ok := l.(*net.TCPListener)
 
 	if !ok {
 		return nil, errors.New("cannot wrap listener")
 	}
 
-	sl := &StoppableListener{exit: exit, connLease: int32(conns), TCPListener: tcpL}
+	sl := &RateLimitedListener{connLease: int32(conns), TCPListener: tcpL}
 
 	return sl, nil
 }
 
 // Refresh periodically refreshes the connection lease, and thus cancels any rate limits in place
-func (sl *StoppableListener) Refresh(conns int) {
+func (sl *RateLimitedListener) Refresh(conns int) {
 	for range time.Tick(30 * time.Second) {
 		atomic.StoreInt32(&sl.connLease, int32(conns))
 		log.Debugf("Refreshed the connection lease: %d conns available", conns)
@@ -51,7 +50,7 @@ func (e *RateLimitedError) Temporary() bool { return true }
 func (e *RateLimitedError) Timeout() bool { return false }
 
 // Accept reimplements the regular Accept but adds a check on the exit channel and returns if needed
-func (sl *StoppableListener) Accept() (net.Conn, error) {
+func (sl *RateLimitedListener) Accept() (net.Conn, error) {
 	if atomic.LoadInt32(&sl.connLease) <= 0 {
 		// we've reached our cap for this lease period, reject the request
 		return nil, &RateLimitedError{}
@@ -63,16 +62,6 @@ func (sl *StoppableListener) Accept() (net.Conn, error) {
 
 		newConn, err := sl.TCPListener.Accept()
 
-		//Check for the channel being closed
-		select {
-		case <-sl.exit:
-			log.Debug("stopping listener")
-			sl.TCPListener.Close()
-			return nil, errors.New("listener stopped")
-		default:
-			//If the channel is still open, continue as normal
-		}
-
 		if err != nil {
 			netErr, ok := err.(net.Error)
 
diff --git a/cmd/trace-agent/main.go b/cmd/trace-agent/main.go
index 18828d4e9..cc0aa7e44 100644
--- a/cmd/trace-agent/main.go
+++ b/cmd/trace-agent/main.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"fmt"
 	"math/rand"
 	"os"
@@ -24,14 +25,14 @@ import (
 )
 
 // handleSignal closes a channel to exit cleanly from routines
-func handleSignal(exit chan struct{}) {
+func handleSignal(onCancel func()) {
 	sigChan := make(chan os.Signal, 10)
 	signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
 	for signo := range sigChan {
 		switch signo {
 		case syscall.SIGINT, syscall.SIGTERM:
 			log.Infof("received signal %d (%v)", signo, signo)
-			close(exit)
+			onCancel()
 			return
 		default:
 			log.Warnf("unhandled signal %d (%v)", signo, signo)
@@ -71,7 +72,7 @@ to your datadog.conf file.
 Exiting.`
 
 // runAgent is the entrypoint of our code
-func runAgent(exit chan struct{}) {
+func runAgent(ctx context.Context) {
 	// configure a default logger before anything so we can observe initialization
 	if opts.info || opts.version {
 		log.UseLogger(log.Disabled)
@@ -205,7 +206,7 @@ func runAgent(exit chan struct{}) {
 	// Seed rand
 	rand.Seed(time.Now().UTC().UnixNano())
 
-	agent := NewAgent(agentConf, exit)
+	agent := NewAgent(ctx, agentConf)
 
 	log.Infof("trace-agent running on host %s", agentConf.HostName)
 	agent.Run()
diff --git a/cmd/trace-agent/main_nix.go b/cmd/trace-agent/main_nix.go
index 6321b3802..ce9aa3aa5 100644
--- a/cmd/trace-agent/main_nix.go
+++ b/cmd/trace-agent/main_nix.go
@@ -3,6 +3,7 @@
 package main
 
 import (
+	"context"
 	"flag"
 	_ "net/http/pprof"
 
@@ -28,13 +29,13 @@ func init() {
 
 // main is the main application entry point
 func main() {
-	exit := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
 
 	// Handle stops properly
 	go func() {
 		defer watchdog.LogOnPanic()
-		handleSignal(exit)
+		handleSignal(cancelFunc)
 	}()
 
-	runAgent(exit)
+	runAgent(ctx)
 }
diff --git a/cmd/trace-agent/main_windows.go b/cmd/trace-agent/main_windows.go
index 4b74c58cf..2f9a514dd 100644
--- a/cmd/trace-agent/main_windows.go
+++ b/cmd/trace-agent/main_windows.go
@@ -3,6 +3,7 @@
 package main
 
 import (
+	"context"
 	"flag"
 	"fmt"
 	"os"
@@ -60,7 +61,7 @@ func (m *myservice) Execute(args []string, r <-chan svc.ChangeRequest, changes c
 	changes <- svc.Status{State: svc.StartPending}
 	changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
 
-	exit := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
 
 	go func() {
 		for {
@@ -75,7 +76,7 @@ func (m *myservice) Execute(args []string, r <-chan svc.ChangeRequest, changes c
 				case svc.Stop, svc.Shutdown:
 					elog.Info(0x40000006, ServiceName)
 					changes <- svc.Status{State: svc.StopPending}
-					close(exit)
+					cancelFunc()
 					return
 				default:
 					elog.Warning(0xc000000A, string(c.Cmd))
@@ -84,7 +85,7 @@ func (m *myservice) Execute(args []string, r <-chan svc.ChangeRequest, changes c
 		}
 	}()
 	elog.Info(0x40000003, ServiceName)
-	runAgent(exit)
+	runAgent(ctx)
 
 	changes <- svc.Status{State: svc.Stopped}
 	return
@@ -173,15 +174,15 @@ func main() {
 
 	// if we are an interactive session, then just invoke the agent on the command line.
 
-	exit := make(chan struct{})
+	ctx, cancelFunc := context.WithCancel(context.Background())
 	// Handle stops properly
 	go func() {
 		defer watchdog.LogOnPanic()
-		handleSignal(exit)
+		handleSignal(cancelFunc)
 	}()
 
 	// Invoke the Agent
-	runAgent(exit)
+	runAgent(ctx)
 }
 
 func startService() error {
diff --git a/cmd/trace-agent/receiver.go b/cmd/trace-agent/receiver.go
index b33992bbc..89346c60a 100644
--- a/cmd/trace-agent/receiver.go
+++ b/cmd/trace-agent/receiver.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -58,12 +59,11 @@ type HTTPReceiver struct {
 	services chan model.ServicesMetadata
 	conf     *config.AgentConfig
 	dynConf  *config.DynamicConfig
+	server   *http.Server
 
 	stats      *info.ReceiverStats
 	preSampler *sampler.PreSampler
 
-	exit chan struct{}
-
 	maxRequestBodyLength int64
 	debug                bool
 }
@@ -78,7 +78,6 @@ func NewHTTPReceiver(
 		dynConf:    dynConf,
 		stats:      info.NewReceiverStats(),
 		preSampler: sampler.NewPreSampler(conf.PreSampleRate),
-		exit:       make(chan struct{}),
 
 		traces:   traces,
 		services: services,
@@ -128,8 +127,7 @@ func (r *HTTPReceiver) Listen(addr, logExtra string) error {
 		return fmt.Errorf("cannot listen on %s: %v", addr, err)
 	}
 
-	stoppableListener, err := NewStoppableListener(listener, r.exit,
-		r.conf.ConnectionLimit)
+	stoppableListener, err := NewRateLimitedListener(listener, r.conf.ConnectionLimit)
 	if err != nil {
 		return fmt.Errorf("cannot create stoppable listener: %v", err)
 	}
@@ -139,7 +137,7 @@ func (r *HTTPReceiver) Listen(addr, logExtra string) error {
 		timeout = time.Duration(r.conf.ReceiverTimeout) * time.Second
 	}
 
-	server := http.Server{
+	r.server = &http.Server{
 		ReadTimeout:  time.Second * time.Duration(timeout),
 		WriteTimeout: time.Second * time.Duration(timeout),
 	}
@@ -152,12 +150,18 @@ func (r *HTTPReceiver) Listen(addr, logExtra string) error {
 	}()
 	go func() {
 		defer watchdog.LogOnPanic()
-		server.Serve(stoppableListener)
+		r.server.Serve(stoppableListener)
 	}()
 
 	return nil
 }
 
+func (r *HTTPReceiver) Stop() error {
+	expiry := time.Now().Add(20 * time.Second) // give it 20 seconds
+	ctx, _ := context.WithDeadline(context.Background(), expiry)
+	return r.server.Shutdown(ctx)
+}
+
 func (r *HTTPReceiver) httpHandle(fn http.HandlerFunc) http.HandlerFunc {
 	return func(w http.ResponseWriter, req *http.Request) {
 		req.Body = model.NewLimitedReader(req.Body, r.maxRequestBodyLength)
diff --git a/cmd/trace-agent/receiver_test.go b/cmd/trace-agent/receiver_test.go
index 16461bd8c..0dda7f2fb 100644
--- a/cmd/trace-agent/receiver_test.go
+++ b/cmd/trace-agent/receiver_test.go
@@ -63,7 +63,7 @@ func TestReceiverRequestBodyLength(t *testing.T) {
 	go receiver.Run()
 
 	defer func() {
-		close(receiver.exit)
+		receiver.Stop()
 		// we need to wait more than on second (time for StoppableListener.Accept
 		// to acknowledge the connection has been closed)
 		time.Sleep(2 * time.Second)