Skip to content

Commit

Permalink
Fix(writeBatch): Avoid deadlock in commit callback (hypermodeinc#1529)
Browse files Browse the repository at this point in the history
Signed-off-by: thomassong <[email protected]>
  • Loading branch information
mYmNeo committed Feb 13, 2023
1 parent 090fa70 commit 13a108f
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 15 deletions.
3 changes: 2 additions & 1 deletion backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
"encoding/binary"
"io"

"github.com/golang/protobuf/proto"

"github.com/dgraph-io/badger/pb"
"github.com/dgraph-io/badger/y"
"github.com/golang/protobuf/proto"
)

// flushThreshold determines when a buffer will be flushed. When performing a
Expand Down
6 changes: 4 additions & 2 deletions backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ import (
"testing"
"time"

"github.com/dgraph-io/badger/pb"
"github.com/stretchr/testify/require"

"github.com/dgraph-io/badger/pb"
)

func TestBackupRestore1(t *testing.T) {
Expand Down Expand Up @@ -102,6 +103,7 @@ func TestBackupRestore1(t *testing.T) {
if err != nil {
return err
}
t.Logf("Got entry: %v\n", item.Version())
require.Equal(t, entries[count].key, item.Key())
require.Equal(t, entries[count].val, val)
require.Equal(t, entries[count].version, item.Version())
Expand All @@ -112,7 +114,7 @@ func TestBackupRestore1(t *testing.T) {
return nil
})
require.NoError(t, err)
require.Equal(t, db.orc.nextTs(), uint64(3))
require.Equal(t, 3, int(db.orc.nextTs()))
}

func TestBackupRestore2(t *testing.T) {
Expand Down
12 changes: 10 additions & 2 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ func (db *DB) NewWriteBatch() *WriteBatch {
return db.newWriteBatch(false)
}

func (db *DB) NewManagedWriteBatch() *WriteBatch {
if !db.opt.managedTxns {
panic("cannot use NewManagedWriteBatch with managedDB=false. Use NewWriteBatch instead")
}

wb := db.newWriteBatch(true)
return wb
}

func (db *DB) newWriteBatch(isManaged bool) *WriteBatch {
return &WriteBatch{
db: db,
Expand Down Expand Up @@ -79,15 +88,14 @@ func (wb *WriteBatch) Cancel() {
wb.txn.Discard()
}

// The caller of this callback must hold the lock.
func (wb *WriteBatch) callback(err error) {
// sync.WaitGroup is thread-safe, so it doesn't need to be run inside wb.Lock.
defer wb.throttle.Done(err)
if err == nil {
return
}

wb.Lock()
defer wb.Unlock()
if wb.err != nil {
return
}
Expand Down
16 changes: 16 additions & 0 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package badger

import (
"fmt"
"io/ioutil"
"testing"
"time"

Expand Down Expand Up @@ -88,3 +89,18 @@ func TestFlushPanic(t *testing.T) {
})
})
}

func TestBatchErrDeadlock(t *testing.T) {
dir, err := ioutil.TempDir("", "badger-test")
require.NoError(t, err)
defer removeDir(dir)

opt := DefaultOptions(dir)
db, err := OpenManaged(opt)
require.NoError(t, err)

wb := db.NewManagedWriteBatch()
require.NoError(t, wb.SetEntry(&Entry{Key: []byte("foo")}))
require.Error(t, wb.Flush())
require.NoError(t, db.Close())
}
26 changes: 17 additions & 9 deletions txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,14 @@ func (txn *Txn) commitAndSend() (func() error, error) {
return ret, nil
}

func (txn *Txn) commitPrecheck() {
func (txn *Txn) commitPrecheck() error {
if txn.commitTs == 0 && txn.db.opt.managedTxns {
panic("Commit cannot be called with managedDB=true. Use CommitAt.")
return errors.New("CommitTs cannot be zero. Please use commitAt instead")
}
if txn.discarded {
panic("Trying to commit a discarded txn")
return errors.New("Trying to commit a discarded txn")
}
return nil
}

// Commit commits the transaction, following these steps:
Expand All @@ -556,13 +557,16 @@ func (txn *Txn) commitPrecheck() {
// If error is nil, the transaction is successfully committed. In case of a non-nil error, the LSM
// tree won't be updated, so there's no need for any rollback.
func (txn *Txn) Commit() error {
txn.commitPrecheck() // Precheck before discarding txn.
defer txn.Discard()

if len(txn.writes) == 0 {
return nil // Nothing to do.
}

// Precheck before discarding txn.
if err := txn.commitPrecheck(); err != nil {
return err
}
defer txn.Discard()

txnCb, err := txn.commitAndSend()
if err != nil {
return err
Expand Down Expand Up @@ -601,9 +605,6 @@ func runTxnCallback(cb *txnCb) {
// so it is safe to increment sync.WaitGroup before calling CommitWith, and
// decrementing it in the callback; to block until all callbacks are run.
func (txn *Txn) CommitWith(cb func(error)) {
txn.commitPrecheck() // Precheck before discarding txn.
defer txn.Discard()

if cb == nil {
panic("Nil callback provided to CommitWith")
}
Expand All @@ -616,6 +617,13 @@ func (txn *Txn) CommitWith(cb func(error)) {
return
}

// Precheck before discarding txn.
if err := txn.commitPrecheck(); err != nil {
cb(err)
return
}
defer txn.Discard()

commitCb, err := txn.commitAndSend()
if err != nil {
go runTxnCallback(&txnCb{user: cb, err: err})
Expand Down
2 changes: 1 addition & 1 deletion txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ func TestManagedDB(t *testing.T) {
for i := 0; i <= 3; i++ {
require.NoError(t, txn.SetEntry(NewEntry(key(i), val(i))))
}
require.Panics(t, func() { txn.Commit() })
require.Error(t, txn.Commit())
require.NoError(t, txn.CommitAt(3, nil))

// Read data at t=2.
Expand Down

0 comments on commit 13a108f

Please sign in to comment.