From a2500dc4bc835f81bf65cdfe3fcee43345aa2aa2 Mon Sep 17 00:00:00 2001 From: Matteo Merli Date: Tue, 24 Sep 2024 11:51:27 -0700 Subject: [PATCH] Implementation of disabling notifications on server side (#529) --- common/error_codes.go | 46 ++++++++-------- server/follower_controller.go | 20 ++++--- server/follower_controller_test.go | 4 +- server/kv/db.go | 88 +++++++++++++++++++++++++----- server/kv/db_notifications_test.go | 23 ++++++++ server/kv/db_test.go | 10 ++-- server/leader_controller.go | 11 +++- server/leader_controller_test.go | 33 ++++++++++- 8 files changed, 180 insertions(+), 55 deletions(-) diff --git a/common/error_codes.go b/common/error_codes.go index e3078edb..a8af6d20 100644 --- a/common/error_codes.go +++ b/common/error_codes.go @@ -20,29 +20,31 @@ import ( ) const ( - CodeNotInitialized codes.Code = 100 - CodeInvalidTerm codes.Code = 101 - CodeInvalidStatus codes.Code = 102 - CodeCancelled codes.Code = 103 - CodeAlreadyClosed codes.Code = 104 - CodeLeaderAlreadyConnected codes.Code = 105 - CodeNodeIsNotLeader codes.Code = 106 - CodeNodeIsNotFollower codes.Code = 107 - CodeInvalidSession codes.Code = 108 - CodeInvalidSessionTimeout codes.Code = 109 - CodeNamespaceNotFound codes.Code = 110 + CodeNotInitialized codes.Code = 100 + CodeInvalidTerm codes.Code = 101 + CodeInvalidStatus codes.Code = 102 + CodeCancelled codes.Code = 103 + CodeAlreadyClosed codes.Code = 104 + CodeLeaderAlreadyConnected codes.Code = 105 + CodeNodeIsNotLeader codes.Code = 106 + CodeNodeIsNotFollower codes.Code = 107 + CodeInvalidSession codes.Code = 108 + CodeInvalidSessionTimeout codes.Code = 109 + CodeNamespaceNotFound codes.Code = 110 + CodeNotificationsNotEnabled codes.Code = 111 ) var ( - ErrorNotInitialized = status.Error(CodeNotInitialized, "oxia: server not initialized yet") - ErrorCancelled = status.Error(CodeCancelled, "oxia: operation was cancelled") - ErrorInvalidTerm = status.Error(CodeInvalidTerm, "oxia: invalid term") - ErrorInvalidStatus = status.Error(CodeInvalidStatus, "oxia: invalid status") - ErrorLeaderAlreadyConnected = status.Error(CodeLeaderAlreadyConnected, "oxia: leader is already connected") - ErrorAlreadyClosed = status.Error(CodeAlreadyClosed, "oxia: node is shutting down") - ErrorNodeIsNotLeader = status.Error(CodeNodeIsNotLeader, "oxia: node is not leader for shard") - ErrorNodeIsNotFollower = status.Error(CodeNodeIsNotFollower, "oxia: node is not follower for shard") - ErrorInvalidSession = status.Error(CodeInvalidSession, "oxia: session not found") - ErrorInvalidSessionTimeout = status.Error(CodeInvalidSessionTimeout, "oxia: invalid session timeout") - ErrorNamespaceNotFound = status.Error(CodeNamespaceNotFound, "oxia: namespace not found") + ErrorNotInitialized = status.Error(CodeNotInitialized, "oxia: server not initialized yet") + ErrorCancelled = status.Error(CodeCancelled, "oxia: operation was cancelled") + ErrorInvalidTerm = status.Error(CodeInvalidTerm, "oxia: invalid term") + ErrorInvalidStatus = status.Error(CodeInvalidStatus, "oxia: invalid status") + ErrorLeaderAlreadyConnected = status.Error(CodeLeaderAlreadyConnected, "oxia: leader is already connected") + ErrorAlreadyClosed = status.Error(CodeAlreadyClosed, "oxia: node is shutting down") + ErrorNodeIsNotLeader = status.Error(CodeNodeIsNotLeader, "oxia: node is not leader for shard") + ErrorNodeIsNotFollower = status.Error(CodeNodeIsNotFollower, "oxia: node is not follower for shard") + ErrorInvalidSession = status.Error(CodeInvalidSession, "oxia: session not found") + ErrorInvalidSessionTimeout = status.Error(CodeInvalidSessionTimeout, "oxia: invalid session timeout") + ErrorNamespaceNotFound = status.Error(CodeNamespaceNotFound, "oxia: namespace not found") + ErrorNotificationsNotEnabled = status.Error(CodeNotificationsNotEnabled, "oxia: notifications not enabled on namespace") ) diff --git a/server/follower_controller.go b/server/follower_controller.go index e0501e49..4a25a420 100644 --- a/server/follower_controller.go +++ b/server/follower_controller.go @@ -90,10 +90,11 @@ type followerController struct { // Offset of the last entry appended and not fully synced yet on the wal lastAppendedOffset int64 - status proto.ServingStatus - wal wal.Wal - kvFactory kv.Factory - db kv.DB + status proto.ServingStatus + wal wal.Wal + kvFactory kv.Factory + db kv.DB + termOptions kv.TermOptions ctx context.Context cancel context.CancelFunc @@ -134,7 +135,7 @@ func NewFollowerController(config Config, namespace string, shardId int64, wf wa return nil, err } - if fc.term, err = fc.db.ReadTerm(); err != nil { + if fc.term, fc.termOptions, err = fc.db.ReadTerm(); err != nil { return nil, err } @@ -142,6 +143,8 @@ func NewFollowerController(config Config, namespace string, shardId int64, wf wa fc.status = proto.ServingStatus_FENCED } + fc.db.EnableNotifications(fc.termOptions.NotificationsEnabled) + commitOffset, err := fc.db.ReadCommitOffset() if err != nil { return nil, err @@ -281,10 +284,13 @@ func (fc *followerController) NewTerm(req *proto.NewTermRequest) (*proto.NewTerm } } - if err := fc.db.UpdateTerm(req.Term); err != nil { + fc.termOptions = kv.ToDbOption(req.Options) + if err := fc.db.UpdateTerm(req.Term, fc.termOptions); err != nil { return nil, err } + fc.db.EnableNotifications(fc.termOptions.NotificationsEnabled) + fc.term = req.Term fc.setLogger() fc.status = proto.ServingStatus_FENCED @@ -693,7 +699,7 @@ func (fc *followerController) handleSnapshot(stream proto.OxiaLogReplication_Sen } // The new term must be persisted, to avoid rolling it back - if err = newDb.UpdateTerm(fc.term); err != nil { + if err = newDb.UpdateTerm(fc.term, fc.termOptions); err != nil { fc.closeStreamNoMutex(errors.Wrap(err, "Failed to update term in db")) } diff --git a/server/follower_controller_test.go b/server/follower_controller_test.go index f8db3524..c93bf4d2 100644 --- a/server/follower_controller_test.go +++ b/server/follower_controller_test.go @@ -220,7 +220,7 @@ func TestFollower_RestoreCommitOffset(t *testing.T) { }}}, 9, 0, kv.NoOpCallback) assert.NoError(t, err) - assert.NoError(t, db.UpdateTerm(6)) + assert.NoError(t, db.UpdateTerm(6, kv.TermOptions{})) assert.NoError(t, db.Close()) fc, err := NewFollowerController(Config{}, common.DefaultNamespace, shardId, walFactory, kvFactory) @@ -466,7 +466,7 @@ func TestFollowerController_RejectEntriesWithDifferentTerm(t *testing.T) { db, err := kv.NewDB(common.DefaultNamespace, shardId, kvFactory, 1*time.Hour, common.SystemClock) assert.NoError(t, err) // Force a new term in the DB before opening - assert.NoError(t, db.UpdateTerm(5)) + assert.NoError(t, db.UpdateTerm(5, kv.TermOptions{})) assert.NoError(t, db.Close()) walFactory := wal.NewWalFactory(&wal.FactoryOptions{BaseWalDir: t.TempDir()}) diff --git a/server/kv/db.go b/server/kv/db.go index 34b32769..012dbada 100644 --- a/server/kv/db.go +++ b/server/kv/db.go @@ -16,6 +16,7 @@ package kv import ( "context" + "encoding/json" "fmt" "io" "log/slog" @@ -38,12 +39,14 @@ var ( ErrMissingPartitionKey = errors.New("oxia: sequential key operation requires partition key") ErrMissingSequenceDeltas = errors.New("oxia: sequential key operation missing some sequence deltas") ErrSequenceDeltaIsZero = errors.New("oxia: sequential key operation requires first delta do be > 0") + ErrNotificationsDisabled = errors.New("oxia: notifications disabled") ) const ( commitOffsetKey = common.InternalKeyPrefix + "commit-offset" commitLastVersionIdKey = common.InternalKeyPrefix + "last-version-id" termKey = common.InternalKeyPrefix + "term" + termOptionsKey = termKey + "-options" ) type UpdateOperationCallback interface { @@ -59,9 +62,15 @@ type RangeScanIterator interface { Next() bool } +type TermOptions struct { + NotificationsEnabled bool +} + type DB interface { io.Closer + EnableNotifications(enable bool) + ProcessWrite(b *proto.WriteRequest, commitOffset int64, timestamp uint64, updateOperationCallback UpdateOperationCallback) (*proto.WriteResponse, error) Get(request *proto.GetRequest) (*proto.GetResponse, error) List(request *proto.ListRequest) (KeyIterator, error) @@ -70,8 +79,8 @@ type DB interface { ReadNextNotifications(ctx context.Context, startOffset int64) ([]*proto.NotificationBatch, error) - UpdateTerm(newTerm int64) error - ReadTerm() (term int64, err error) + UpdateTerm(newTerm int64, options TermOptions) error + ReadTerm() (term int64, options TermOptions, err error) Snapshot() (Snapshot, error) @@ -87,8 +96,9 @@ func NewDB(namespace string, shardId int64, factory Factory, notificationRetenti labels := metrics.LabelsForShard(namespace, shardId) db := &db{ - kv: kv, - shardId: shardId, + kv: kv, + shardId: shardId, + notificationsEnabled: true, log: slog.With( slog.String("component", "db"), slog.String("namespace", namespace), @@ -136,6 +146,7 @@ type db struct { versionIdTracker atomic.Int64 notificationsTracker *notificationsTracker log *slog.Logger + notificationsEnabled bool putCounter metrics.Counter deleteCounter metrics.Counter @@ -153,6 +164,10 @@ func (d *db) Snapshot() (Snapshot, error) { return d.kv.Snapshot() } +func (d *db) EnableNotifications(enabled bool) { + d.notificationsEnabled = enabled +} + func (d *db) Close() error { return multierr.Combine( d.notificationsTracker.Close(), @@ -173,7 +188,10 @@ func now() uint64 { func (d *db) applyWriteRequest(b *proto.WriteRequest, batch WriteBatch, commitOffset int64, timestamp uint64, updateOperationCallback UpdateOperationCallback) (*notifications, *proto.WriteResponse, error) { res := &proto.WriteResponse{} - notifications := newNotifications(d.shardId, commitOffset, timestamp) + var notifications *notifications + if d.notificationsEnabled { + notifications = newNotifications(d.shardId, commitOffset, timestamp) + } d.putCounter.Add(len(b.Puts)) for _, putReq := range b.Puts { @@ -225,16 +243,20 @@ func (d *db) ProcessWrite(b *proto.WriteRequest, commitOffset int64, timestamp u return nil, err } - // Add the notifications to the batch as well - if err := d.addNotifications(batch, notifications); err != nil { - return nil, err + if notifications != nil { + // Add the notifications to the batch as well + if err := d.addNotifications(batch, notifications); err != nil { + return nil, err + } } if err := batch.Commit(); err != nil { return nil, err } - d.notificationsTracker.UpdatedCommitOffset(commitOffset) + if notifications != nil { + d.notificationsTracker.UpdatedCommitOffset(commitOffset) + } if err := batch.Close(); err != nil { return nil, err @@ -376,7 +398,7 @@ func (d *db) readASCIILong(key string) (int64, error) { return res, nil } -func (d *db) UpdateTerm(newTerm int64) error { +func (d *db) UpdateTerm(newTerm int64, options TermOptions) error { batch := d.kv.NewWriteBatch() if _, err := d.applyPut(batch, nil, &proto.PutRequest{ @@ -386,6 +408,17 @@ func (d *db) UpdateTerm(newTerm int64) error { return err } + serOptions, err := json.Marshal(options) + if err != nil { + return err + } + if _, err := d.applyPut(batch, nil, &proto.PutRequest{ + Key: termOptionsKey, + Value: serOptions, + }, now(), NoOpCallback, true); err != nil { + return err + } + if err := batch.Commit(); err != nil { return err } @@ -399,23 +432,36 @@ func (d *db) UpdateTerm(newTerm int64) error { return d.kv.Flush() } -func (d *db) ReadTerm() (term int64, err error) { +func (d *db) ReadTerm() (term int64, options TermOptions, err error) { getReq := &proto.GetRequest{ Key: termKey, IncludeValue: true, } gr, err := applyGet(d.kv, getReq) if err != nil { - return wal.InvalidTerm, err + return wal.InvalidTerm, TermOptions{}, err } if gr.Status == proto.Status_KEY_NOT_FOUND { - return wal.InvalidTerm, nil + return wal.InvalidTerm, TermOptions{}, nil } if _, err = fmt.Sscanf(string(gr.Value), "%d", &term); err != nil { - return wal.InvalidTerm, err + return wal.InvalidTerm, TermOptions{}, err + } + + if gr, err = applyGet(d.kv, &proto.GetRequest{Key: termOptionsKey, IncludeValue: true}); err != nil { + return wal.InvalidTerm, TermOptions{}, err + } + + if gr.Status == proto.Status_KEY_NOT_FOUND { + options = TermOptions{} + } else { + if err := json.Unmarshal(gr.Value, &options); err != nil { + return wal.InvalidTerm, TermOptions{}, err + } } - return term, nil + + return term, options, nil } func (d *db) applyPut(batch WriteBatch, notifications *notifications, putReq *proto.PutRequest, timestamp uint64, updateOperationCallback UpdateOperationCallback, internal bool) (*proto.PutResponse, error) { //nolint:revive @@ -690,6 +736,9 @@ func deserialize(value []byte, se *proto.StorageEntry) error { } func (d *db) ReadNextNotifications(ctx context.Context, startOffset int64) ([]*proto.NotificationBatch, error) { + if !d.notificationsEnabled { + return nil, ErrNotificationsDisabled + } return d.notificationsTracker.ReadNextNotifications(ctx, startOffset) } @@ -704,3 +753,12 @@ func (*noopCallback) OnDelete(WriteBatch, string) error { } var NoOpCallback UpdateOperationCallback = &noopCallback{} + +func ToDbOption(opt *proto.NewTermOptions) TermOptions { + to := TermOptions{NotificationsEnabled: true} + if opt != nil { + to.NotificationsEnabled = opt.EnableNotifications + } + + return to +} diff --git a/server/kv/db_notifications_test.go b/server/kv/db_notifications_test.go index 94747ee8..77879267 100644 --- a/server/kv/db_notifications_test.go +++ b/server/kv/db_notifications_test.go @@ -204,3 +204,26 @@ func TestDB_NotificationsCancelWait(t *testing.T) { assert.NoError(t, db.Close()) assert.NoError(t, factory.Close()) } + +func TestDB_NotificationsDisabled(t *testing.T) { + factory, err := NewPebbleKVFactory(testKVOptions) + assert.NoError(t, err) + db, err := NewDB(common.DefaultNamespace, 1, factory, 1*time.Hour, common.SystemClock) + assert.NoError(t, err) + + db.EnableNotifications(false) + t0 := now() + _, _ = db.ProcessWrite(&proto.WriteRequest{ + Puts: []*proto.PutRequest{{ + Key: "a", + Value: []byte("0"), + }}, + }, 0, t0, NoOpCallback) + + notifications, err := db.ReadNextNotifications(context.Background(), 0) + assert.Error(t, ErrNotificationsDisabled, err) + assert.Nil(t, notifications) + + assert.NoError(t, db.Close()) + assert.NoError(t, factory.Close()) +} diff --git a/server/kv/db_test.go b/server/kv/db_test.go index 37bc042f..56b85447 100644 --- a/server/kv/db_test.go +++ b/server/kv/db_test.go @@ -450,16 +450,18 @@ func TestDb_UpdateTerm(t *testing.T) { db, err := NewDB(common.DefaultNamespace, 1, factory, 0, common.SystemClock) assert.NoError(t, err) - term, err := db.ReadTerm() + term, options, err := db.ReadTerm() assert.NoError(t, err) assert.Equal(t, wal.InvalidOffset, term) + assert.Equal(t, TermOptions{}, options) - err = db.UpdateTerm(1) + err = db.UpdateTerm(1, TermOptions{NotificationsEnabled: true}) assert.NoError(t, err) - term, err = db.ReadTerm() + term, options, err = db.ReadTerm() assert.NoError(t, err) assert.EqualValues(t, 1, term) + assert.Equal(t, TermOptions{NotificationsEnabled: true}, options) assert.NoError(t, db.Close()) @@ -467,7 +469,7 @@ func TestDb_UpdateTerm(t *testing.T) { db, err = NewDB(common.DefaultNamespace, 1, factory, 0, common.SystemClock) assert.NoError(t, err) - term, err = db.ReadTerm() + term, _, err = db.ReadTerm() assert.NoError(t, err) assert.Equal(t, wal.InvalidOffset, term) diff --git a/server/leader_controller.go b/server/leader_controller.go index d881578b..dc33d29d 100644 --- a/server/leader_controller.go +++ b/server/leader_controller.go @@ -94,6 +94,7 @@ type leaderController struct { cancel context.CancelFunc wal wal.Wal db kv.DB + termOptions kv.TermOptions rpcClient ReplicationRpcProvider sessionManager SessionManager log *slog.Logger @@ -154,7 +155,7 @@ func NewLeaderController(config Config, namespace string, shardId int64, rpcClie return nil, err } - if lc.term, err = lc.db.ReadTerm(); err != nil { + if lc.term, lc.termOptions, err = lc.db.ReadTerm(); err != nil { return nil, err } @@ -162,6 +163,7 @@ func NewLeaderController(config Config, namespace string, shardId int64, rpcClie lc.status = proto.ServingStatus_FENCED } + lc.db.EnableNotifications(lc.termOptions.NotificationsEnabled) lc.setLogger() lc.log.Info("Created leader controller") return lc, nil @@ -224,10 +226,12 @@ func (lc *leaderController) NewTerm(req *proto.NewTermRequest) (*proto.NewTermRe return nil, common.ErrorInvalidStatus } - if err := lc.db.UpdateTerm(req.Term); err != nil { + lc.termOptions = kv.ToDbOption(req.Options) + if err := lc.db.UpdateTerm(req.Term, lc.termOptions); err != nil { return nil, err } + lc.db.EnableNotifications(lc.termOptions.NotificationsEnabled) lc.term = req.Term lc.setLogger() lc.status = proto.ServingStatus_FENCED @@ -966,6 +970,9 @@ func (lc *leaderController) appendToWalStreamRequest(request *proto.WriteRequest // //// func (lc *leaderController) GetNotifications(req *proto.NotificationsRequest, stream proto.OxiaClient_GetNotificationsServer) error { + if !lc.termOptions.NotificationsEnabled { + return common.ErrorNotificationsNotEnabled + } return startNotificationDispatcher(lc, req, stream) } diff --git a/server/leader_controller_test.go b/server/leader_controller_test.go index a73c96e3..33ee94fb 100644 --- a/server/leader_controller_test.go +++ b/server/leader_controller_test.go @@ -400,7 +400,7 @@ func TestLeaderController_FenceTerm(t *testing.T) { db, err := kv.NewDB(common.DefaultNamespace, shard, kvFactory, 1*time.Hour, common.SystemClock) assert.NoError(t, err) // Force a new term in the DB before opening - assert.NoError(t, db.UpdateTerm(5)) + assert.NoError(t, db.UpdateTerm(5, kv.TermOptions{})) assert.NoError(t, db.Close()) lc, err := NewLeaderController(Config{}, common.DefaultNamespace, shard, newMockRpcClient(), walFactory, kvFactory) @@ -447,7 +447,7 @@ func TestLeaderController_BecomeLeaderTerm(t *testing.T) { db, err := kv.NewDB(common.DefaultNamespace, shard, kvFactory, 1*time.Hour, common.SystemClock) assert.NoError(t, err) // Force a new term in the DB before opening - assert.NoError(t, db.UpdateTerm(5)) + assert.NoError(t, db.UpdateTerm(5, kv.TermOptions{})) assert.NoError(t, db.Close()) lc, err := NewLeaderController(Config{}, common.DefaultNamespace, shard, newMockRpcClient(), walFactory, kvFactory) @@ -589,7 +589,7 @@ func TestLeaderController_AddFollower_Truncate(t *testing.T) { assert.NoError(t, err) } - assert.NoError(t, db.UpdateTerm(5)) + assert.NoError(t, db.UpdateTerm(5, kv.TermOptions{})) assert.NoError(t, db.Close()) assert.NoError(t, walObject.Close()) @@ -1262,3 +1262,30 @@ func TestLeaderController_WriteStream(t *testing.T) { assert.NoError(t, kvFactory.Close()) assert.NoError(t, walFactory.Close()) } + +func TestLeaderController_NotificationsDisabled(t *testing.T) { + var shard int64 = 1 + + kvFactory, _ := kv.NewPebbleKVFactory(testKVOptions) + walFactory := newTestWalFactory(t) + + lc, _ := NewLeaderController(Config{}, common.DefaultNamespace, shard, newMockRpcClient(), walFactory, kvFactory) + _, _ = lc.NewTerm(&proto.NewTermRequest{Shard: shard, Term: 1, Options: &proto.NewTermOptions{EnableNotifications: false}}) + _, _ = lc.BecomeLeader(context.Background(), &proto.BecomeLeaderRequest{ + Shard: shard, + Term: 1, + ReplicationFactor: 1, + FollowerMaps: nil, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + stream := newMockGetNotificationsServer(ctx) + + err := lc.GetNotifications(&proto.NotificationsRequest{Shard: shard}, stream) + assert.ErrorIs(t, err, common.ErrorNotificationsNotEnabled) + + assert.NoError(t, lc.Close()) + assert.NoError(t, kvFactory.Close()) + assert.NoError(t, walFactory.Close()) +}