diff --git a/mdb_store.go b/mdb_store.go index 45b4704..4bcbddc 100644 --- a/mdb_store.go +++ b/mdb_store.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" - "github.com/bmatsuo/lmdb-go/exp/lmdbscan" "github.com/bmatsuo/lmdb-go/lmdb" "github.com/hashicorp/raft" ) @@ -91,7 +90,7 @@ func (m *MDBStore) initialize() error { // Create all the tables return m.env.Update(func(txn *lmdb.Txn) (err error) { - m.dbLogs, err = txn.OpenDBI(dbLogs, lmdb.Create) + m.dbLogs, err = txn.OpenDBI(dbLogs, lmdb.Create|lmdb.IntegerKey) if err != nil { return err } @@ -109,10 +108,12 @@ func (m *MDBStore) Close() error { return nil } +// FirstIndex returns the first index func (m *MDBStore) FirstIndex() (uint64, error) { return m.getIndex(nil, nil, lmdb.First) } +// LastIndex returns the last index func (m *MDBStore) LastIndex() (uint64, error) { return m.getIndex(nil, nil, lmdb.Last) } @@ -132,7 +133,11 @@ func (m *MDBStore) getIndex(k, v []byte, op uint) (uint64, error) { if lmdb.IsNotFound(err) { return nil } else if err == nil { - k64 = bytesToUint64(k) + kp, ok := lmdb.UintptrValue(k) + if !ok { + panic("key size") + } + k64 = uint64(kp) } return err @@ -140,13 +145,12 @@ func (m *MDBStore) getIndex(k, v []byte, op uint) (uint64, error) { return k64, err } -// Gets a log entry at a given index +// GetLog gets a log entry at a given index func (m *MDBStore) GetLog(index uint64, logOut *raft.Log) error { - key := uint64ToBytes(index) return m.env.View(func(txn *lmdb.Txn) error { txn.RawRead = true - val, err := txn.Get(m.dbLogs, key) + val, err := txn.GetValue(m.dbLogs, lmdb.Uintptr(uintptr(index))) if lmdb.IsNotFound(err) { return raft.ErrLogNotFound } else if err != nil { @@ -159,25 +163,26 @@ func (m *MDBStore) GetLog(index uint64, logOut *raft.Log) error { } -// Stores a log entry +// StoreLog stores a log entry func (m *MDBStore) StoreLog(log *raft.Log) error { return m.StoreLogs([]*raft.Log{log}) } -// Stores multiple log entries +// StoreLogs stores multiple log entries func (m *MDBStore) StoreLogs(logs []*raft.Log) error { // Start write txn return m.env.Update(func(txn *lmdb.Txn) error { for _, log := range logs { // Convert to an on-disk format - key := uint64ToBytes(log.Index) val, err := encodeMsgPack(log) if err != nil { return err } // Write to the table - if err := txn.Put(m.dbLogs, key, val.Bytes(), 0); err != nil { + k := lmdb.Uintptr(uintptr(log.Index)) + v := lmdb.Bytes(val.Bytes()) + if err := txn.PutValue(m.dbLogs, k, v, 0); err != nil { return err } } @@ -186,23 +191,35 @@ func (m *MDBStore) StoreLogs(logs []*raft.Log) error { }) } -// Deletes a range of log entries. The range is inclusive. +// DeleteRange deletes a range of log entries. The range is inclusive. func (m *MDBStore) DeleteRange(minIdx, maxIdx uint64) error { // Start write txn return m.env.Update(func(txn *lmdb.Txn) (err error) { txn.RawRead = true - s := lmdbscan.New(txn, m.dbLogs) - defer s.Close() - s.Set(uint64ToBytes(minIdx), nil, lmdb.SetKey) - for s.Scan() { - if maxIdx < bytesToUint64(s.Key()) { - break + cur, err := txn.OpenCursor(m.dbLogs) + if err != nil { + return err + } + start := lmdb.Uintptr(uintptr(minIdx)) + for k, _, err := cur.GetValue(start, lmdb.Bytes(nil), lmdb.SetKey); ; k, _, err = cur.GetValue(lmdb.Bytes(nil), lmdb.Bytes(nil), lmdb.Next) { + if err != nil { + if lmdb.IsNotFound(err) { + return nil + } + return err + } + + idx, ok := lmdb.UintptrValue(k) + if !ok { + panic("key size") + } + if maxIdx < uint64(idx) { + return nil } - if err := s.Cursor().Del(0); err != nil { + if err := cur.Del(0); err != nil { return err } } - return s.Err() }) }