diff --git a/storage/watchable_store.go b/storage/watchable_store.go index ebad2a64de9f..c6f855fdebac 100644 --- a/storage/watchable_store.go +++ b/storage/watchable_store.go @@ -17,6 +17,7 @@ package storage import ( "fmt" "log" + "math" "sync" "time" @@ -241,6 +242,59 @@ func (s *watchableStore) syncWatchingsLoop() { } } +// RangeAllUnsynced ranges on all unsynced watchings +// and returns all key-value pairs and next revision. +func (s *watchableStore) RangeAllUnsynced() ([]storagepb.KeyValue, int64, error) { + totalLimit := 0 + minRev, maxRev := int64(math.MaxInt64), int64(math.MinInt64) + for w := range s.unsynced { + if w.cur > 0 && w.cur <= s.store.compactMainRev { + log.Printf("storage: %v", ErrCompacted) + delete(s.unsynced, w) + continue + } + if w.cur > s.store.currentRev.main { + log.Printf("storage: %v", ErrFutureRev) + delete(s.unsynced, w) + continue + } + totalLimit += cap(w.ch) - len(w.ch) + if minRev >= w.cur { + minRev = w.cur + } + if maxRev <= w.cur { + maxRev = w.cur + } + } + + min, max := newRevBytes(), newRevBytes() + revToBytes(revision{main: minRev}, min) + revToBytes(revision{main: maxRev, sub: maxRev}, max) + + s.store.mu.Lock() + defer s.store.mu.Unlock() + + tx := s.store.b.BatchTx() + + tx.Lock() + defer tx.Unlock() + + kvs := []storagepb.KeyValue{} + + _, vs := tx.UnsafeRange(keyBucketName, min, max, int64(totalLimit)) + + for _, vi := range vs { + var kv storagepb.KeyValue + if err := kv.Unmarshal(vi); err != nil { + return nil, 0, fmt.Errorf("storage: cannot unmarshal event: %v", err) + } + + kvs = append(kvs, kv) + } + + return kvs, s.store.currentRev.main + 1, nil +} + // syncWatchings syncs the watchings in the unsyncd map. func (s *watchableStore) syncWatchings() { _, curRev, _ := s.store.Range(nil, nil, 0, 0) diff --git a/storage/watchable_store_test.go b/storage/watchable_store_test.go index 9fe9600dc63c..8d84e0fcbfb1 100644 --- a/storage/watchable_store_test.go +++ b/storage/watchable_store_test.go @@ -15,6 +15,8 @@ package storage import ( + "bytes" + "fmt" "os" "testing" ) @@ -114,6 +116,56 @@ func TestCancelUnsynced(t *testing.T) { } } +// TestRangeAllUnsynced populates unsynced watchings to test if it +// correctly returns key-value pairs and nextRev. +func TestRangeAllUnsynced(t *testing.T) { + s := &watchableStore{ + store: newStore(tmpPath), + unsynced: make(map[*watching]struct{}), + synced: make(map[string]map[*watching]struct{}), + } + + defer func() { + s.store.Close() + os.Remove(tmpPath) + }() + + watcherN := 10 + + keys := make([][]byte, watcherN) + vals := make([][]byte, watcherN) + for i := 0; i < watcherN; i++ { + keys[i] = []byte(fmt.Sprintf("%d_Foo", i+1)) + vals[i] = []byte(fmt.Sprintf("%d_Bar", i+1)) + } + + for i := 1; i <= watcherN; i++ { + s.Put(keys[i-1], vals[i-1]) + w := s.NewWatcher() + // use non-0 to keep watchers in unsynced + w.Watch(keys[i-1], false, int64(i)) + } + + kvs, nextRev, err := s.RangeAllUnsynced() + if err != nil { + t.Error(err) + } + if len(kvs) != watcherN { + t.Errorf("len(kvs) = %d, want = %d", len(kvs), watcherN) + } + if nextRev != int64(watcherN+1) { + t.Errorf("nextRev = %d, want = %d", nextRev, watcherN) + } + for i, v := range kvs { + if !bytes.Equal(keys[i], v.Key) { + t.Errorf("v.Key = %s, want = %s", v.Key, keys[i]) + } + if !bytes.Equal(vals[i], v.Value) { + t.Errorf("v.Value = %s, want = %s", v.Value, vals[i]) + } + } +} + // TestSyncWatchings populates unsynced watching map and // tests syncWatchings method to see if it correctly sends // events to channel of unsynced watchings and moves these