From 0565d2841bb7d4527834175eceef6b445be5e9f5 Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Thu, 30 Jan 2025 20:57:08 +0800 Subject: [PATCH 1/3] fix cache close/get race --- internal/store.go | 33 ++++++++++++++++++++++++++------ internal/store_test.go | 43 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/internal/store.go b/internal/store.go index 002ff34..8f5be3a 100644 --- a/internal/store.go +++ b/internal/store.go @@ -51,6 +51,7 @@ type Shard[K comparable, V any] struct { vgroup *Group[K, V] // used in secondary cache counter uint mu *RBMutex + closed bool } func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] { @@ -67,6 +68,9 @@ func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] { } func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { + if s.closed { + return + } s.hashmap[key] = entry if s.dookeeper != nil { ds := 20 * len(s.hashmap) @@ -77,11 +81,17 @@ func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { } func (s *Shard[K, V]) get(key K) (entry *Entry[K, V], ok bool) { + if s.closed { + return nil, false + } entry, ok = s.hashmap[key] return } func (s *Shard[K, V]) delete(entry *Entry[K, V]) bool { + if s.closed { + return false + } var deleted bool exist, ok := s.hashmap[entry.key] if ok && exist == entry { @@ -389,6 +399,10 @@ func (s *Store[K, V]) setInternal(key K, value V, cost int64, expire int64, nvmC h, index := s.index(key) shard := s.shards[index] shard.mu.Lock() + if shard.closed { + shard.mu.Unlock() + return nil, nil, false + } return s.setShard(shard, h, key, value, cost, expire, nvmClean) @@ -738,18 +752,25 @@ func (s *Store[K, V]) Stats() Stats { return newStats(s.policy.hits.Value(), s.policy.misses.Value()) } +// Close waits for all current read and write operations to complete, +// then clears the hashmap and shuts down the maintenance goroutine. +// After the cache is closed, Get will always return (nil, false), +// and Set will have no effect. func (s *Store[K, V]) Close() { - for _, s := range s.shards { - tk := s.mu.RLock() - s.hashmap = nil - s.mu.RUnlock(tk) + s.mu.Lock() + s.closed = true + s.hashmap = map[K]*Entry[K, V]{} } + s.Wait() + for _, s := range s.shards { + s.mu.Unlock() + } + close(s.writeChan) s.policyMu.Lock() s.closed = true s.cancel() s.policyMu.Unlock() - close(s.writeChan) } func (s *Store[K, V]) getReadBufferIdx() int { @@ -859,7 +880,7 @@ func (s *Store[K, V]) processSecondary() { } } -// wait write chan, used in test +// Wait blocks until the write channel is drained. func (s *Store[K, V]) Wait() { s.writeChan <- WriteBufItem[K, V]{code: WAIT} <-s.waitChan diff --git a/internal/store_test.go b/internal/store_test.go index 4267fd7..6b56a97 100644 --- a/internal/store_test.go +++ b/internal/store_test.go @@ -2,6 +2,7 @@ package internal import ( "sync" + "sync/atomic" "testing" "time" @@ -201,3 +202,45 @@ func TestStore_SinkWritePolicyWeight(t *testing.T) { require.Equal(t, 8, int(store.policy.weightedSize)) } + +func TestStore_CloseRace(t *testing.T) { + store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) + + var wg sync.WaitGroup + var closed atomic.Bool + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + counter := i * 5 + countdown := -1 + defer wg.Done() + for { + // continue get/set 20 times after cache closed + if countdown == 0 { + return + } + if closed.Load() && countdown == -1 { + countdown = 20 + } + store.Get(counter) + store.Set(100, 100, 1, 0) + counter += i + if countdown > 0 { + countdown -= 1 + } + } + }(i) + } + wg.Add(1) + go func() { + defer wg.Done() + store.Close() + closed.Store(true) + }() + wg.Wait() + + _ = store.Set(100, 100, 1, 0) + v, ok := store.Get(100) + require.False(t, ok) + require.Equal(t, 0, v) +} From f6f64b2f99908b6ef51b791fcb12aca7674c118d Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Fri, 31 Jan 2025 09:39:40 +0800 Subject: [PATCH 2/3] loading cache return err if cache is closed --- internal/store.go | 6 ++++++ internal/store_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/internal/store.go b/internal/store.go index 8f5be3a..34cfdee 100644 --- a/internal/store.go +++ b/internal/store.go @@ -28,6 +28,7 @@ const ( var ( VersionMismatch = errors.New("version mismatch") + ErrCacheClosed = errors.New("cache is closed") RoundedParallelism int ShardCount int StripedBufferSize int @@ -756,6 +757,7 @@ func (s *Store[K, V]) Stats() Stats { // then clears the hashmap and shuts down the maintenance goroutine. // After the cache is closed, Get will always return (nil, false), // and Set will have no effect. +// For loading cache, Get will return ErrCacheClosed after closing. func (s *Store[K, V]) Close() { for _, s := range s.shards { s.mu.Lock() @@ -1102,6 +1104,10 @@ func (s *LoadingStore[K, V]) Get(ctx context.Context, key K) (V, error) { loaded, err, _ := shard.group.Do(key, func() (Loaded[V], error) { // load and store should be atomic shard.mu.Lock() + if shard.closed { + shard.mu.Unlock() + return Loaded[V]{}, ErrCacheClosed + } // first try get from secondary cache if s.secondaryCache != nil { diff --git a/internal/store_test.go b/internal/store_test.go index 6b56a97..acb5c80 100644 --- a/internal/store_test.go +++ b/internal/store_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "sync" "sync/atomic" "testing" @@ -244,3 +245,50 @@ func TestStore_CloseRace(t *testing.T) { require.False(t, ok) require.Equal(t, 0, v) } + +func TestStore_CloseRaceLoadingCache(t *testing.T) { + store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) + loadingStore := NewLoadingStore(store) + loadingStore.loader = func(ctx context.Context, key int) (Loaded[int], error) { + return Loaded[int]{Value: 100, Cost: 1}, nil + } + ctx := context.TODO() + + var wg sync.WaitGroup + var closed atomic.Bool + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + counter := i * 5 + countdown := -1 + defer wg.Done() + for { + // continue get/set 20 times after cache closed + if countdown == 0 { + return + } + if closed.Load() && countdown == -1 { + countdown = 20 + } + _, err := loadingStore.Get(ctx, counter) + if countdown > 0 { + require.Equal(t, ErrCacheClosed, err) + } + counter += i + if countdown > 0 { + countdown -= 1 + } + } + }(i) + } + wg.Add(1) + go func() { + defer wg.Done() + loadingStore.Close() + closed.Store(true) + }() + wg.Wait() + + _, err := loadingStore.Get(ctx, 100) + require.Equal(t, ErrCacheClosed, err) +} From 1e68a836b1520fa64cf6690c0c8d301bc972f01c Mon Sep 17 00:00:00 2001 From: Yiling-J Date: Fri, 31 Jan 2025 20:49:01 +0800 Subject: [PATCH 3/3] minor improvement --- internal/store.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/store.go b/internal/store.go index 34cfdee..3cc800e 100644 --- a/internal/store.go +++ b/internal/store.go @@ -759,10 +759,10 @@ func (s *Store[K, V]) Stats() Stats { // and Set will have no effect. // For loading cache, Get will return ErrCacheClosed after closing. func (s *Store[K, V]) Close() { - for _, s := range s.shards { - s.mu.Lock() - s.closed = true - s.hashmap = map[K]*Entry[K, V]{} + for _, shard := range s.shards { + shard.mu.Lock() + shard.closed = true + shard.hashmap = map[K]*Entry[K, V]{} } s.Wait() for _, s := range s.shards {