-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathserver.go
112 lines (94 loc) · 2.19 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// Copyright(C) 2022 github.com/hidu All Rights Reserved.
// Author: hidu <[email protected]>
// Date: 2022/5/7
package hydra
import (
"context"
"log"
"net"
"os"
"os/signal"
"syscall"
"time"
"golang.org/x/sync/errgroup"
)
// Server 具备优雅关闭的 server
type Server interface {
// Serve 启动 server
Serve(net.Listener) error
// Shutdown 优雅关闭
Shutdown(context.Context) error
}
var _ Server = (*fixedListenerServer)(nil)
type fixedListenerServer struct {
Server
listener net.Listener
}
func (s *fixedListenerServer) Serve(_ net.Listener) error {
return s.Server.Serve(s.listener)
}
func (s *fixedListenerServer) Shutdown(ctx context.Context) error {
defer s.listener.Close()
return s.Server.Shutdown(ctx)
}
// Logger 日志接口定义
type Logger interface {
Println(v ...any)
Printf(format string, v ...any)
}
// Starter 优雅关闭 server
type Starter struct {
Server Server
Listener net.Listener
Timeout time.Duration
Signals []os.Signal
Logger Logger
}
func (sd *Starter) RunGrace() error {
ch := make(chan os.Signal, 1)
signal.Notify(ch, sd.getSignals()...)
g := &errgroup.Group{}
g.Go(func() error {
return sd.Server.Serve(sd.Listener)
})
g.Go(func() error {
return sd.shutdown(sd.Server, ch)
})
return g.Wait()
}
func (sd *Starter) shutdown(ser Server, ch <-chan os.Signal) error {
sig := <-ch
sd.getLogger().Println("[Starter] receive signal ", sig, ", start Shutdown")
ctx, cancel := context.WithTimeout(context.Background(), sd.getTimeout())
defer cancel()
start := time.Now()
err := ser.Shutdown(ctx)
sd.getLogger().Println("[Starter] Shutdown finish, err=", err, ", cost=", time.Since(start))
return err
}
func (sd *Starter) getSignals() []os.Signal {
if len(sd.Signals) == 0 {
return []os.Signal{os.Interrupt, syscall.SIGTERM}
}
return sd.Signals
}
func (sd *Starter) getTimeout() time.Duration {
if sd.Timeout > 0 {
return sd.Timeout
}
return time.Minute
}
func (sd *Starter) getLogger() Logger {
if sd.Logger != nil {
return sd.Logger
}
return log.Default()
}
func RunGrace(ser Server, l net.Listener, timeout time.Duration) error {
gs := &Starter{
Timeout: timeout,
Listener: l,
Server: ser,
}
return gs.RunGrace()
}