diff --git a/internal/services/configstore/action/action.go b/internal/services/configstore/action/action.go index 11a28cdfa..bf3eb6d1b 100644 --- a/internal/services/configstore/action/action.go +++ b/internal/services/configstore/action/action.go @@ -15,6 +15,8 @@ package action import ( + "sync" + "github.com/rs/zerolog" "github.com/sorintlab/errors" @@ -26,10 +28,11 @@ import ( ) type ActionHandler struct { - log zerolog.Logger - d *db.DB - lf lock.LockFactory - maintenanceMode bool + log zerolog.Logger + d *db.DB + lf lock.LockFactory + maintenanceMode bool + maintenanceModeMutex sync.Mutex } func NewActionHandler(log zerolog.Logger, d *db.DB, lf lock.LockFactory) *ActionHandler { @@ -41,10 +44,6 @@ func NewActionHandler(log zerolog.Logger, d *db.DB, lf lock.LockFactory) *Action } } -func (h *ActionHandler) SetMaintenanceMode(maintenanceMode bool) { - h.maintenanceMode = maintenanceMode -} - func (h *ActionHandler) ResolveObjectID(tx *sql.Tx, objectKind types.ObjectKind, ref string) (string, error) { switch objectKind { case types.ObjectKindProjectGroup: diff --git a/internal/services/configstore/action/maintenance.go b/internal/services/configstore/action/maintenance.go index 122f1a2ba..8f7910f8a 100644 --- a/internal/services/configstore/action/maintenance.go +++ b/internal/services/configstore/action/maintenance.go @@ -39,6 +39,20 @@ var ( maintenanceTableDDL = fmt.Sprintf("create table if not exists %s (enabled boolean not null, time timestamptz not null)", maintenanceTableName) ) +func (h *ActionHandler) IsMaintenanceMode() bool { + h.maintenanceModeMutex.Lock() + defer h.maintenanceModeMutex.Unlock() + + return h.maintenanceMode +} + +func (h *ActionHandler) SetMaintenanceMode(maintenanceMode bool) { + h.maintenanceModeMutex.Lock() + defer h.maintenanceModeMutex.Unlock() + + h.maintenanceMode = maintenanceMode +} + func isMaintenanceEnabled(d *db.DB, tx *sql.Tx) (bool, error) { var enabled *bool sb := sq.Select("enabled").From(maintenanceTableName) @@ -76,7 +90,7 @@ func (h *ActionHandler) IsMaintenanceEnabled(ctx context.Context) (bool, error) return enabled, nil } -func (h *ActionHandler) MaintenanceMode(ctx context.Context, enable bool) error { +func (h *ActionHandler) SetMaintenanceEnabled(ctx context.Context, enable bool) error { err := h.d.Do(ctx, func(tx *sql.Tx) error { if _, err := tx.Exec(maintenanceTableDDL); err != nil { return errors.Wrapf(err, "failed to create %s table", maintenanceTableName) diff --git a/internal/services/configstore/api/maintenance.go b/internal/services/configstore/api/maintenance.go index 7e66ad383..417a46d63 100644 --- a/internal/services/configstore/api/maintenance.go +++ b/internal/services/configstore/api/maintenance.go @@ -25,13 +25,13 @@ import ( ) type MaintenanceStatusHandler struct { - log zerolog.Logger - ah *action.ActionHandler - maintenanceRouter bool + log zerolog.Logger + ah *action.ActionHandler + currentMaintenanceMode bool } -func NewMaintenanceStatusHandler(log zerolog.Logger, ah *action.ActionHandler, maintenanceRouter bool) *MaintenanceStatusHandler { - return &MaintenanceStatusHandler{log: log, ah: ah, maintenanceRouter: maintenanceRouter} +func NewMaintenanceStatusHandler(log zerolog.Logger, ah *action.ActionHandler, currentMaintenanceMode bool) *MaintenanceStatusHandler { + return &MaintenanceStatusHandler{log: log, ah: ah, currentMaintenanceMode: currentMaintenanceMode} } func (h *MaintenanceStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -44,7 +44,7 @@ func (h *MaintenanceStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ return } - resp := csapitypes.MaintenanceStatusResponse{RequestedStatus: requestedStatus, CurrentStatus: h.maintenanceRouter} + resp := csapitypes.MaintenanceStatusResponse{RequestedStatus: requestedStatus, CurrentStatus: h.currentMaintenanceMode} if err := util.HTTPResponse(w, http.StatusOK, resp); err != nil { h.log.Err(err).Send() } @@ -70,7 +70,7 @@ func (h *MaintenanceModeHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques enable = false } - err := h.ah.MaintenanceMode(ctx, enable) + err := h.ah.SetMaintenanceEnabled(ctx, enable) if err != nil { h.log.Err(err).Send() util.HTTPError(w, err) diff --git a/internal/services/configstore/configstore.go b/internal/services/configstore/configstore.go index 11699a36a..729cb5681 100644 --- a/internal/services/configstore/configstore.go +++ b/internal/services/configstore/configstore.go @@ -40,14 +40,14 @@ import ( "agola.io/agola/internal/util" ) -func (s *Configstore) maintenanceModeWatcherLoop(ctx context.Context, runCtxCancel context.CancelFunc, maintenanceModeEnabled bool) { - s.log.Info().Msgf("maintenance mode watcher: maintenance mode enabled: %t", maintenanceModeEnabled) +func (s *Configstore) maintenanceModeWatcherLoop(ctx context.Context, runCtxCancel context.CancelFunc, maintenanceMode bool) { + s.log.Info().Msgf("maintenance mode watcher: maintenance mode enabled: %t", maintenanceMode) for { s.log.Debug().Msgf("maintenanceModeWatcherLoop") // at first watch restart from previous processed revision - if err := s.maintenanceModeWatcher(ctx, runCtxCancel, maintenanceModeEnabled); err != nil { + if err := s.maintenanceModeWatcher(ctx, runCtxCancel, maintenanceMode); err != nil { s.log.Err(err).Msgf("maintenance mode watcher error") } @@ -60,13 +60,13 @@ func (s *Configstore) maintenanceModeWatcherLoop(ctx context.Context, runCtxCanc } } -func (s *Configstore) maintenanceModeWatcher(ctx context.Context, runCtxCancel context.CancelFunc, maintenanceModeEnabled bool) error { +func (s *Configstore) maintenanceModeWatcher(ctx context.Context, runCtxCancel context.CancelFunc, maintenanceMode bool) error { maintenanceEnabled, err := s.ah.IsMaintenanceEnabled(ctx) if err != nil { return errors.WithStack(err) } - if maintenanceEnabled != maintenanceModeEnabled { + if maintenanceEnabled != maintenanceMode { s.log.Info().Msgf("maintenance mode changed to %t", maintenanceEnabled) runCtxCancel() } diff --git a/internal/services/configstore/configstore_test.go b/internal/services/configstore/configstore_test.go index facc33b6d..eb36def43 100644 --- a/internal/services/configstore/configstore_test.go +++ b/internal/services/configstore/configstore_test.go @@ -173,8 +173,6 @@ func TestExportImport(t *testing.T) { t.Logf("starting cs") go func() { _ = cs.Run(ctx) }() - time.Sleep(1 * time.Second) - var expectedRemoteSourcesCount int var expectedUsersCount int var expectedOrgsCount int @@ -188,18 +186,7 @@ func TestExportImport(t *testing.T) { } expectedRemoteSourcesCount++ - for i := 0; i < 10; i++ { - if _, err := cs.ah.CreateUser(ctx, &action.CreateUserRequest{UserName: fmt.Sprintf("user%d", i)}); err != nil { - t.Fatalf("unexpected err: %v", err) - } - expectedUsersCount++ - expectedProjectGroupsCount++ - } - - time.Sleep(5 * time.Second) - - // Do some more changes - for i := 10; i < 20; i++ { + for i := 0; i < 20; i++ { if _, err := cs.ah.CreateUser(ctx, &action.CreateUserRequest{UserName: fmt.Sprintf("user%d", i)}); err != nil { t.Fatalf("unexpected err: %v", err) } @@ -314,21 +301,33 @@ func TestExportImport(t *testing.T) { t.Fatalf("unexpected err: %v", err) } - if err := cs.ah.MaintenanceMode(ctx, true); err != nil { + if err := cs.ah.SetMaintenanceEnabled(ctx, true); err != nil { t.Fatalf("unexpected err: %v", err) } - time.Sleep(5 * time.Second) + _ = testutil.Wait(30*time.Second, func() (bool, error) { + if !cs.ah.IsMaintenanceMode() { + return false, nil + } + + return true, nil + }) if err := cs.ah.Import(ctx, &export); err != nil { t.Fatalf("unexpected err: %v", err) } - if err := cs.ah.MaintenanceMode(ctx, false); err != nil { + if err := cs.ah.SetMaintenanceEnabled(ctx, false); err != nil { t.Fatalf("unexpected err: %v", err) } - time.Sleep(5 * time.Second) + _ = testutil.Wait(30*time.Second, func() (bool, error) { + if cs.ah.IsMaintenanceMode() { + return false, nil + } + + return true, nil + }) newRemoteSources, err := getRemoteSources(ctx, cs) if err != nil {