Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix data race in Get/Set after Close is called #60

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions internal/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (

var (
VersionMismatch = errors.New("version mismatch")
ErrCacheClosed = errors.New("cache is closed")
RoundedParallelism int
ShardCount int
StripedBufferSize int
Expand All @@ -51,6 +52,7 @@ type Shard[K comparable, V any] struct {
vgroup *Group[K, V] // used in secondary cache
counter uint
mu *RBMutex
closed bool
Copy link

@miparnisari miparnisari Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be an atomic.Bool? since you can have multiple goroutines trying to read & write to it?

the downside is that it will make all operations slightly slower.

it could also be a bool but guarded with a RWMutex.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The closed field is protected by the shard mutex, and as you can see, I check it immediately before or after accessing the shard hashmap, so the mutex is guaranteed to already be held at that point. Meanwhile, the Close method also first acquires the shard mutex before updating the closed field, making it safe to use a basic bool type here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good if you documented which fields are guarded by mu. For example: https://github.com/openfga/openfga/blob/94dd78f03dc99e51f46e7f71deaaea878e8340c9/pkg/storage/memory/memory.go#L125-L128

This doesn't change anything at runtime but it's useful documentation.

}

func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] {
Expand All @@ -67,6 +69,9 @@ func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] {
}

func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) {
if s.closed {
return
}
s.hashmap[key] = entry
if s.dookeeper != nil {
ds := 20 * len(s.hashmap)
Expand All @@ -77,11 +82,17 @@ func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) {
}

func (s *Shard[K, V]) get(key K) (entry *Entry[K, V], ok bool) {
if s.closed {
return nil, false
}
entry, ok = s.hashmap[key]
return
}

func (s *Shard[K, V]) delete(entry *Entry[K, V]) bool {
if s.closed {
return false
}
var deleted bool
exist, ok := s.hashmap[entry.key]
if ok && exist == entry {
Expand Down Expand Up @@ -389,6 +400,10 @@ func (s *Store[K, V]) setInternal(key K, value V, cost int64, expire int64, nvmC
h, index := s.index(key)
shard := s.shards[index]
shard.mu.Lock()
if shard.closed {
shard.mu.Unlock()
return nil, nil, false
}

return s.setShard(shard, h, key, value, cost, expire, nvmClean)

Expand Down Expand Up @@ -738,18 +753,26 @@ func (s *Store[K, V]) Stats() Stats {
return newStats(s.policy.hits.Value(), s.policy.misses.Value())
}

// Close waits for all current read and write operations to complete,
// then clears the hashmap and shuts down the maintenance goroutine.
Comment on lines +756 to +757

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Close waits for all current read and write operations to complete,
// then clears the hashmap and shuts down the maintenance goroutine.
// Close clears the hashmap and shuts down the maintenance goroutine,
// then waits for all current read and write operations to complete.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer my original version here. For each shard, shard.mu.Lock() is called first, which ensures that the current read/write operations holding the mutex are completed before closing.

// After the cache is closed, Get will always return (nil, false),
// and Set will have no effect.
// For loading cache, Get will return ErrCacheClosed after closing.
func (s *Store[K, V]) Close() {

for _, shard := range s.shards {
shard.mu.Lock()
shard.closed = true
shard.hashmap = map[K]*Entry[K, V]{}
}
s.Wait()
for _, s := range s.shards {
tk := s.mu.RLock()
s.hashmap = nil
s.mu.RUnlock(tk)
s.mu.Unlock()
}
close(s.writeChan)
s.policyMu.Lock()
s.closed = true
s.cancel()
s.policyMu.Unlock()
close(s.writeChan)
}

func (s *Store[K, V]) getReadBufferIdx() int {
Expand Down Expand Up @@ -859,7 +882,7 @@ func (s *Store[K, V]) processSecondary() {
}
}

// wait write chan, used in test
// Wait blocks until the write channel is drained.
func (s *Store[K, V]) Wait() {
s.writeChan <- WriteBufItem[K, V]{code: WAIT}
<-s.waitChan
Expand Down Expand Up @@ -1081,6 +1104,10 @@ func (s *LoadingStore[K, V]) Get(ctx context.Context, key K) (V, error) {
loaded, err, _ := shard.group.Do(key, func() (Loaded[V], error) {
// load and store should be atomic
shard.mu.Lock()
if shard.closed {
shard.mu.Unlock()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked out the code locally and I see that the locking & unlocking code gets weird, in that sometimes a function locks but doesn't unlock (hopefully because it calls a function that does the unlock). You risk introducing deadlock scenarios with this.

My recommendation (not for this PR) is that you at least document each function:

  • if it unlocks a mutex
  • if it assumes that the mutex has already been locked by the caller.

I also recommend that you use this pattern:

mutex.Lock()
defer mutex.Unlock()

to prevent the issue completely (at the expense that you will be holding the mutex for a longer time).

return Loaded[V]{}, ErrCacheClosed
}

// first try get from secondary cache
if s.secondaryCache != nil {
Expand Down
91 changes: 91 additions & 0 deletions internal/store_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package internal

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -201,3 +203,92 @@ func TestStore_SinkWritePolicyWeight(t *testing.T) {
require.Equal(t, 8, int(store.policy.weightedSize))

}

func TestStore_CloseRace(t *testing.T) {
store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil)
Copy link

@miparnisari miparnisari Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but holy cow that is a lot of parameters 🤣 are all of those mandatory?! If they aren't, I suggest using the "functional options pattern" described here [scroll to option 3]. My team has been using that with quite a lot of success.


var wg sync.WaitGroup
var closed atomic.Bool
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
counter := i * 5
countdown := -1
defer wg.Done()
for {
// continue get/set 20 times after cache closed
if countdown == 0 {
return
}
if closed.Load() && countdown == -1 {
countdown = 20
}
store.Get(counter)
store.Set(100, 100, 1, 0)
counter += i
if countdown > 0 {
countdown -= 1
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
store.Close()
closed.Store(true)
}()
wg.Wait()

_ = store.Set(100, 100, 1, 0)
v, ok := store.Get(100)
require.False(t, ok)
require.Equal(t, 0, v)
}

func TestStore_CloseRaceLoadingCache(t *testing.T) {
store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil)
loadingStore := NewLoadingStore(store)
loadingStore.loader = func(ctx context.Context, key int) (Loaded[int], error) {
return Loaded[int]{Value: 100, Cost: 1}, nil
}
ctx := context.TODO()

var wg sync.WaitGroup
var closed atomic.Bool
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
counter := i * 5
countdown := -1
defer wg.Done()
for {
// continue get/set 20 times after cache closed
if countdown == 0 {
return
}
if closed.Load() && countdown == -1 {
countdown = 20
}
_, err := loadingStore.Get(ctx, counter)
if countdown > 0 {
require.Equal(t, ErrCacheClosed, err)
}
counter += i
if countdown > 0 {
countdown -= 1
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
loadingStore.Close()
closed.Store(true)
}()
wg.Wait()

_, err := loadingStore.Get(ctx, 100)
require.Equal(t, ErrCacheClosed, err)
}