-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix data race in Get/Set after Close is called #60
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -28,6 +28,7 @@ const ( | |||||||||
|
||||||||||
var ( | ||||||||||
VersionMismatch = errors.New("version mismatch") | ||||||||||
ErrCacheClosed = errors.New("cache is closed") | ||||||||||
RoundedParallelism int | ||||||||||
ShardCount int | ||||||||||
StripedBufferSize int | ||||||||||
|
@@ -51,6 +52,7 @@ type Shard[K comparable, V any] struct { | |||||||||
vgroup *Group[K, V] // used in secondary cache | ||||||||||
counter uint | ||||||||||
mu *RBMutex | ||||||||||
closed bool | ||||||||||
} | ||||||||||
|
||||||||||
func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] { | ||||||||||
|
@@ -67,6 +69,9 @@ func NewShard[K comparable, V any](doorkeeper bool) *Shard[K, V] { | |||||||||
} | ||||||||||
|
||||||||||
func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { | ||||||||||
if s.closed { | ||||||||||
return | ||||||||||
} | ||||||||||
s.hashmap[key] = entry | ||||||||||
if s.dookeeper != nil { | ||||||||||
ds := 20 * len(s.hashmap) | ||||||||||
|
@@ -77,11 +82,17 @@ func (s *Shard[K, V]) set(key K, entry *Entry[K, V]) { | |||||||||
} | ||||||||||
|
||||||||||
func (s *Shard[K, V]) get(key K) (entry *Entry[K, V], ok bool) { | ||||||||||
if s.closed { | ||||||||||
return nil, false | ||||||||||
} | ||||||||||
entry, ok = s.hashmap[key] | ||||||||||
return | ||||||||||
} | ||||||||||
|
||||||||||
func (s *Shard[K, V]) delete(entry *Entry[K, V]) bool { | ||||||||||
if s.closed { | ||||||||||
return false | ||||||||||
} | ||||||||||
var deleted bool | ||||||||||
exist, ok := s.hashmap[entry.key] | ||||||||||
if ok && exist == entry { | ||||||||||
|
@@ -389,6 +400,10 @@ func (s *Store[K, V]) setInternal(key K, value V, cost int64, expire int64, nvmC | |||||||||
h, index := s.index(key) | ||||||||||
shard := s.shards[index] | ||||||||||
shard.mu.Lock() | ||||||||||
if shard.closed { | ||||||||||
shard.mu.Unlock() | ||||||||||
return nil, nil, false | ||||||||||
} | ||||||||||
|
||||||||||
return s.setShard(shard, h, key, value, cost, expire, nvmClean) | ||||||||||
|
||||||||||
|
@@ -738,18 +753,26 @@ func (s *Store[K, V]) Stats() Stats { | |||||||||
return newStats(s.policy.hits.Value(), s.policy.misses.Value()) | ||||||||||
} | ||||||||||
|
||||||||||
// Close waits for all current read and write operations to complete, | ||||||||||
// then clears the hashmap and shuts down the maintenance goroutine. | ||||||||||
Comment on lines
+756
to
+757
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer my original version here. For each shard, |
||||||||||
// After the cache is closed, Get will always return (nil, false), | ||||||||||
// and Set will have no effect. | ||||||||||
// For loading cache, Get will return ErrCacheClosed after closing. | ||||||||||
func (s *Store[K, V]) Close() { | ||||||||||
|
||||||||||
for _, shard := range s.shards { | ||||||||||
shard.mu.Lock() | ||||||||||
shard.closed = true | ||||||||||
shard.hashmap = map[K]*Entry[K, V]{} | ||||||||||
} | ||||||||||
s.Wait() | ||||||||||
for _, s := range s.shards { | ||||||||||
tk := s.mu.RLock() | ||||||||||
s.hashmap = nil | ||||||||||
s.mu.RUnlock(tk) | ||||||||||
s.mu.Unlock() | ||||||||||
} | ||||||||||
close(s.writeChan) | ||||||||||
s.policyMu.Lock() | ||||||||||
s.closed = true | ||||||||||
s.cancel() | ||||||||||
s.policyMu.Unlock() | ||||||||||
close(s.writeChan) | ||||||||||
} | ||||||||||
|
||||||||||
func (s *Store[K, V]) getReadBufferIdx() int { | ||||||||||
|
@@ -859,7 +882,7 @@ func (s *Store[K, V]) processSecondary() { | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
// wait write chan, used in test | ||||||||||
// Wait blocks until the write channel is drained. | ||||||||||
func (s *Store[K, V]) Wait() { | ||||||||||
s.writeChan <- WriteBufItem[K, V]{code: WAIT} | ||||||||||
<-s.waitChan | ||||||||||
|
@@ -1081,6 +1104,10 @@ func (s *LoadingStore[K, V]) Get(ctx context.Context, key K) (V, error) { | |||||||||
loaded, err, _ := shard.group.Do(key, func() (Loaded[V], error) { | ||||||||||
// load and store should be atomic | ||||||||||
shard.mu.Lock() | ||||||||||
if shard.closed { | ||||||||||
shard.mu.Unlock() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked out the code locally and I see that the locking & unlocking code gets weird, in that sometimes a function locks but doesn't unlock (hopefully because it calls a function that does the unlock). You risk introducing deadlock scenarios with this. My recommendation (not for this PR) is that you at least document each function:
I also recommend that you use this pattern:
to prevent the issue completely (at the expense that you will be holding the mutex for a longer time). |
||||||||||
return Loaded[V]{}, ErrCacheClosed | ||||||||||
} | ||||||||||
|
||||||||||
// first try get from secondary cache | ||||||||||
if s.secondaryCache != nil { | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
package internal | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
|
@@ -201,3 +203,92 @@ func TestStore_SinkWritePolicyWeight(t *testing.T) { | |
require.Equal(t, 8, int(store.policy.weightedSize)) | ||
|
||
} | ||
|
||
func TestStore_CloseRace(t *testing.T) { | ||
store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not for this PR, but holy cow that is a lot of parameters 🤣 are all of those mandatory?! If they aren't, I suggest using the "functional options pattern" described here [scroll to option 3]. My team has been using that with quite a lot of success. |
||
|
||
var wg sync.WaitGroup | ||
var closed atomic.Bool | ||
for i := 0; i < 10; i++ { | ||
wg.Add(1) | ||
go func(i int) { | ||
counter := i * 5 | ||
countdown := -1 | ||
defer wg.Done() | ||
for { | ||
// continue get/set 20 times after cache closed | ||
if countdown == 0 { | ||
return | ||
} | ||
if closed.Load() && countdown == -1 { | ||
countdown = 20 | ||
} | ||
store.Get(counter) | ||
store.Set(100, 100, 1, 0) | ||
counter += i | ||
if countdown > 0 { | ||
countdown -= 1 | ||
} | ||
} | ||
}(i) | ||
} | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
store.Close() | ||
closed.Store(true) | ||
}() | ||
wg.Wait() | ||
|
||
_ = store.Set(100, 100, 1, 0) | ||
v, ok := store.Get(100) | ||
require.False(t, ok) | ||
require.Equal(t, 0, v) | ||
} | ||
|
||
func TestStore_CloseRaceLoadingCache(t *testing.T) { | ||
store := NewStore[int, int](1000, false, true, nil, nil, nil, 0, 0, nil) | ||
loadingStore := NewLoadingStore(store) | ||
loadingStore.loader = func(ctx context.Context, key int) (Loaded[int], error) { | ||
return Loaded[int]{Value: 100, Cost: 1}, nil | ||
} | ||
ctx := context.TODO() | ||
|
||
var wg sync.WaitGroup | ||
var closed atomic.Bool | ||
for i := 0; i < 10; i++ { | ||
wg.Add(1) | ||
go func(i int) { | ||
counter := i * 5 | ||
countdown := -1 | ||
defer wg.Done() | ||
for { | ||
// continue get/set 20 times after cache closed | ||
if countdown == 0 { | ||
return | ||
} | ||
if closed.Load() && countdown == -1 { | ||
countdown = 20 | ||
} | ||
_, err := loadingStore.Get(ctx, counter) | ||
if countdown > 0 { | ||
require.Equal(t, ErrCacheClosed, err) | ||
} | ||
counter += i | ||
if countdown > 0 { | ||
countdown -= 1 | ||
} | ||
} | ||
}(i) | ||
} | ||
wg.Add(1) | ||
go func() { | ||
defer wg.Done() | ||
loadingStore.Close() | ||
closed.Store(true) | ||
}() | ||
wg.Wait() | ||
|
||
_, err := loadingStore.Get(ctx, 100) | ||
require.Equal(t, ErrCacheClosed, err) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be an
atomic.Bool
? since you can have multiple goroutines trying to read & write to it?the downside is that it will make all operations slightly slower.
it could also be a
bool
but guarded with a RWMutex.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
closed
field is protected by the shard mutex, and as you can see, I check it immediately before or after accessing the shard hashmap, so the mutex is guaranteed to already be held at that point. Meanwhile, theClose
method also first acquires the shard mutex before updating theclosed
field, making it safe to use a basic bool type here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be good if you documented which fields are guarded by
mu
. For example: https://github.com/openfga/openfga/blob/94dd78f03dc99e51f46e7f71deaaea878e8340c9/pkg/storage/memory/memory.go#L125-L128This doesn't change anything at runtime but it's useful documentation.