Skip to content

Commit

Permalink
Fix lifecycle race condition and prevent recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
xorkevin committed May 18, 2023
1 parent c4a4585 commit c263597
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 59 deletions.
38 changes: 22 additions & 16 deletions cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,32 @@ caching, object storage, emailing, message queues and more.`,
}
rootCmd.PersistentFlags().StringVar(&c.configFile, "config", "", fmt.Sprintf("config file (default is $XDG_CONFIG_HOME/%s/{%s|%s}.json for server and client respectively)", c.opts.Appname, c.opts.DefaultFile, c.opts.ClientDefault))

serveCmd := &cobra.Command{
Use: "serve",
Short: "starts the http server and runs all services",
Long: `Starts the http server and runs all services
if c.s != nil {
serveCmd := &cobra.Command{
Use: "serve",
Short: "starts the http server and runs all services",
Long: `Starts the http server and runs all services
The server first runs all init procedures for all services before starting.`,
Run: c.serve,
DisableAutoGenTag: true,
Run: c.serve,
DisableAutoGenTag: true,
}
rootCmd.AddCommand(serveCmd)
}
rootCmd.AddCommand(serveCmd)

setupCmd := &cobra.Command{
Use: "setup",
Short: "runs the setup procedures for all services",
Long: `Runs the setup procedures for all services
if c.c != nil {
setupCmd := &cobra.Command{
Use: "setup",
Short: "runs the setup procedures for all services",
Long: `Runs the setup procedures for all services
Calls the server setup endpoint.`,
Run: c.setup,
DisableAutoGenTag: true,
Run: c.setup,
DisableAutoGenTag: true,
}
setupCmd.PersistentFlags().StringVar(&c.cmdFlags.setupSecret, "secret", "", "setup secret")
rootCmd.AddCommand(setupCmd)
}
setupCmd.PersistentFlags().StringVar(&c.cmdFlags.setupSecret, "secret", "", "setup secret")
rootCmd.AddCommand(setupCmd)

docCmd := &cobra.Command{
Use: "doc",
Expand Down Expand Up @@ -106,7 +110,9 @@ Calls the server setup endpoint.`,
}
docCmd.AddCommand(docMdCmd)

c.addTrees(c.c.GetCmds(), rootCmd)
if c.c != nil {
c.addTrees(c.c.GetCmds(), rootCmd)
}

c.cmd = rootCmd
}
Expand Down
7 changes: 4 additions & 3 deletions cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,18 @@ func TestCmd(t *testing.T) {
assert := require.New(t)

var out bytes.Buffer
stderr := klog.NewSyncWriter(&out)

client := NewClient(Opts{
Appname: "govtest",
ClientPrefix: "govc",
ConfigReader: strings.NewReader(""),
ConfigReader: strings.NewReader("{}"),
LogWriter: io.Discard,
TermConfig: &TermConfig{
StdinFd: int(os.Stdin.Fd()),
Stdin: strings.NewReader("test input content"),
Stdout: io.Discard,
Stderr: klog.NewSyncWriter(&out),
Stderr: stderr,
Fsys: fstest.MapFS{},
WFsys: writefstest.MapFS{},
Exit: func(code int) {},
Expand All @@ -146,7 +147,7 @@ func TestCmd(t *testing.T) {
StdinFd: int(os.Stdin.Fd()),
Stdin: strings.NewReader(""),
Stdout: io.Discard,
Stderr: io.Discard,
Stderr: stderr,
Exit: func(code int) {},
},
}, nil, client)
Expand Down
17 changes: 3 additions & 14 deletions govinjector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ func TestInjector(t *testing.T) {

assert := require.New(t)

tabReplacer := strings.NewReplacer("\t", " ")

server := New(Opts{
Appname: "govtest",
Version: Version{
Expand All @@ -28,18 +26,9 @@ func TestInjector(t *testing.T) {
Description: "test gov server",
EnvPrefix: "gov",
ClientPrefix: "govc",
ConfigReader: strings.NewReader(tabReplacer.Replace(`
http:
addr: ':8080'
basepath: /api
setupsecret: setupsecret
`)),
VaultReader: strings.NewReader(tabReplacer.Replace(`
data:
setupsecret:
secret: setupsecret
`)),
LogWriter: io.Discard,
ConfigReader: strings.NewReader("{}"),
VaultReader: strings.NewReader("{}"),
LogWriter: io.Discard,
})

pathA := server.Injector()
Expand Down
2 changes: 1 addition & 1 deletion service/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ type (
}
)

func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.Manager[sqldbClient]) (*sqldbClient, error) {
func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.State[sqldbClient]) (*sqldbClient, error) {
var auth pgAuth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/events/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ type (
}
)

func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.Manager[kafkaClient]) (*kafkaClient, error) {
func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.State[kafkaClient]) (*kafkaClient, error) {
var secret secretAuth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/events/nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ type (
}
)

func (s *NatsService) handleGetClient(ctx context.Context, m *lifecycle.Manager[natsClient]) (*natsClient, error) {
func (s *NatsService) handleGetClient(ctx context.Context, m *lifecycle.State[natsClient]) (*natsClient, error) {
var auth natsauth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/kvstore/kvstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ type (
}
)

func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.Manager[kvstoreClient]) (*kvstoreClient, error) {
func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.State[kvstoreClient]) (*kvstoreClient, error) {
var auth redisauth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/mail/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (s *Service) handlePing(ctx context.Context, m *lifecycle.Manager[mailSecre
m.Stop(ctx)
}

func (s *Service) handleGetSecrets(ctx context.Context, m *lifecycle.Manager[mailSecrets]) (*mailSecrets, error) {
func (s *Service) handleGetSecrets(ctx context.Context, m *lifecycle.State[mailSecrets]) (*mailSecrets, error) {
currentSecrets := m.Load(ctx)

var auth secretAuth
Expand Down
2 changes: 1 addition & 1 deletion service/objstore/objstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ type (
}
)

func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.Manager[objstoreClient]) (*objstoreClient, error) {
func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.State[objstoreClient]) (*objstoreClient, error) {
var auth minioauth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/pubsub/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ type (
}
)

func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.Manager[pubsubClient]) (*pubsubClient, error) {
func (s *Service) handleGetClient(ctx context.Context, m *lifecycle.State[pubsubClient]) (*pubsubClient, error) {
var auth natsauth
{
client := m.Load(ctx)
Expand Down
2 changes: 1 addition & 1 deletion service/user/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ type (
}
)

func (s *Service) getSecrets(ctx context.Context, m *lifecycle.Manager[tokenSigner]) (*tokenSigner, error) {
func (s *Service) getSecrets(ctx context.Context, m *lifecycle.State[tokenSigner]) (*tokenSigner, error) {
currentSigner := m.Load(ctx)
var tokenSecrets secretToken
if err := s.config.GetSecret(ctx, "tokensecret", s.keyrefresh, &tokenSecrets); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion service/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ func (s *Service) handlePing(ctx context.Context, m *lifecycle.Manager[otpCipher
m.Stop(ctx)
}

func (s *Service) handleGetCipher(ctx context.Context, m *lifecycle.Manager[otpCipher]) (*otpCipher, error) {
func (s *Service) handleGetCipher(ctx context.Context, m *lifecycle.State[otpCipher]) (*otpCipher, error) {
currentCipher := m.Load(ctx)
var otpsecrets secretOTP
if err := s.config.GetSecret(ctx, "otpkey", s.otprefresh, &otpsecrets); err != nil {
Expand Down
83 changes: 66 additions & 17 deletions util/lifecycle/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,44 @@ package lifecycle

import (
"context"
"errors"
"sync/atomic"
"time"

"xorkevin.dev/governor/util/ksync"
"xorkevin.dev/klog"
)

// ErrClosed is returned when the lifecycle has been closed
var ErrClosed errClosed

type (
errClosed struct{}
)

func (e errClosed) Error() string {
return "Lifecycle closed"
}

type (
// Manager is an interface to interact with a lifecycle
Manager[T any] struct {
l *Lifecycle[T]
}

// State is an interface to interact with lifecycle state
State[T any] struct {
l *Lifecycle[T]
}

// Lifecycle manages the lifecycle of connecting to external services
Lifecycle[T any] struct {
aclient atomic.Pointer[T]
sf ksync.SingleFlight[T]
doctx context.Context
constructorfn func(ctx context.Context, m *Manager[T]) (*T, error)
cancelctx func()
closed bool
constructorfn func(ctx context.Context, m *State[T]) (*T, error)
stopfn func(ctx context.Context, client *T)
heartbeatfn func(ctx context.Context, m *Manager[T])
hbinterval time.Duration
Expand All @@ -30,15 +49,18 @@ type (
// New creates a new [*Lifecycle]
func New[T any](
doctx context.Context,
constructorfn func(ctx context.Context, m *Manager[T]) (*T, error),
constructorfn func(ctx context.Context, s *State[T]) (*T, error),
stopfn func(ctx context.Context, client *T),
heartbeatfn func(ctx context.Context, m *Manager[T]),
hbinterval time.Duration,
) *Lifecycle[T] {
doctx, cancel := context.WithCancel(doctx)
return &Lifecycle[T]{
aclient: atomic.Pointer[T]{},
sf: ksync.SingleFlight[T]{},
doctx: doctx,
cancelctx: cancel,
closed: false,
constructorfn: constructorfn,
stopfn: stopfn,
heartbeatfn: heartbeatfn,
Expand All @@ -51,22 +73,45 @@ func (l *Lifecycle[T]) Load(ctx context.Context) *T {
return l.aclient.Load()
}

func (l *Lifecycle[T]) constructWithManager(ctx context.Context) (*T, error) {
m := &Manager[T]{
l: l,
func (l *Lifecycle[T]) constructWithState(ctx context.Context) (*T, error) {
if l.closed {
return nil, ErrClosed
}
return l.constructorfn(ctx, m)

return l.constructorfn(ctx, &State[T]{
l: l,
})
}

// Construct constructs an instance
func (l *Lifecycle[T]) Construct(ctx context.Context) (*T, error) {
return l.sf.Do(l.doctx, ctx, l.constructWithManager)
return l.sf.Do(l.doctx, ctx, l.constructWithState)
}

func (l *Lifecycle[T]) stop(ctx context.Context) {
l.stopfn(klog.ExtendCtx(context.Background(), ctx), l.aclient.Swap(nil))
}

func (l *Lifecycle[T]) closeConstruction(ctx context.Context) (*T, error) {
l.closed = true
return nil, ErrClosed
}

func (l *Lifecycle[T]) waitUntilClosed() {
for {
if _, err := l.sf.Do(context.Background(), context.Background(), l.closeConstruction); errors.Is(err, ErrClosed) {
return
}
}
}

// Heartbeat calls the heartbeat function at an interval and calls the stop
// function when the context is closed.
func (l *Lifecycle[T]) Heartbeat(ctx context.Context, wg *ksync.WaitGroup) {
defer wg.Done()
defer l.stop(ctx)
defer l.waitUntilClosed()
defer l.cancelctx()
ticker := time.NewTicker(l.hbinterval)
defer ticker.Stop()
m := &Manager[T]{
Expand All @@ -75,30 +120,34 @@ func (l *Lifecycle[T]) Heartbeat(ctx context.Context, wg *ksync.WaitGroup) {
for {
select {
case <-ctx.Done():
l.stopfn(klog.ExtendCtx(context.Background(), ctx), l.aclient.Swap(nil))
return
case <-ticker.C:
l.heartbeatfn(ctx, m)
}
}
}

// Construct constructs an instance
func (m *Manager[T]) Construct(ctx context.Context) (*T, error) {
return m.l.Construct(ctx)
}

// Stop stops and removes an instance
func (m *Manager[T]) Stop(ctx context.Context) {
m.l.stopfn(klog.ExtendCtx(context.Background(), ctx), m.l.aclient.Swap(nil))
func (m *State[T]) Stop(ctx context.Context) {
m.l.stop(ctx)
}

// Load returns the cached instance
func (m *Manager[T]) Load(ctx context.Context) *T {
func (m *State[T]) Load(ctx context.Context) *T {
return m.l.Load(ctx)
}

// Store stores the cached instance
func (m *Manager[T]) Store(client *T) {
func (m *State[T]) Store(client *T) {
m.l.aclient.Store(client)
}

// Construct constructs an instance
func (m *Manager[T]) Construct(ctx context.Context) (*T, error) {
return m.l.Construct(ctx)
}

// Stop stops and removes an instance
func (m *Manager[T]) Stop(ctx context.Context) {
m.l.stop(ctx)
}

0 comments on commit c263597

Please sign in to comment.