diff --git a/trie/branchNode.go b/trie/branchNode.go index 11306eeb30..b918f8278c 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -309,24 +309,28 @@ func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStor return child.tryGet(key, currentDepth+1, db) } -func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) { +func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) { if len(key) == 0 { - return nil, nil, ErrValueTooShort + return nil, ErrValueTooShort } childPos := key[firstByte] if childPosOutOfRange(childPos) { - return nil, nil, ErrChildPosOutOfRange + return nil, ErrChildPosOutOfRange } key = key[1:] - _, err := bn.resolveIfCollapsed(childPos, db) + if len(bn.EncodedChildren[childPos]) == 0 { + return nil, ErrNodeNotFound + } + childNode, encodedNode, err := getNodeFromDBAndDecode(bn.EncodedChildren[childPos], db, bn.marsh, bn.hasher) if err != nil { - return nil, nil, err + return nil, err } - if bn.children[childPos] == nil { - return nil, nil, ErrNodeNotFound - } - return bn.children[childPos], key, nil + return &nodeData{ + currentNode: childNode, + encodedNode: encodedNode, + hexKey: key, + }, nil } func (bn *branchNode) insert( diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index d643ca8846..f0454053e8 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -346,13 +346,17 @@ func TestBranchNode_getNext(t *testing.T) { nextNode, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher) childPos := byte(2) key := append([]byte{childPos}, []byte("dog")...) - - n, key, err := bn.getNext(key, nil) + db := testscommon.NewMemDbMock() + bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) + data, err := bn.getNext(key, db) + assert.NotNil(t, data) h1, _ := encodeNodeAndGetHash(nextNode) - h2, _ := encodeNodeAndGetHash(n) + h2, _ := encodeNodeAndGetHash(data.currentNode) + nextNodeBytes, _ := nextNode.getEncodedNode() + assert.Equal(t, nextNodeBytes, data.encodedNode) assert.Equal(t, h1, h2) - assert.Equal(t, []byte("dog"), key) + assert.Equal(t, []byte("dog"), data.hexKey) assert.Nil(t, err) } @@ -362,9 +366,8 @@ func TestBranchNode_getNextWrongKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, key, err := bn.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) + data, err := bn.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrChildPosOutOfRange, err) } @@ -375,9 +378,8 @@ func TestBranchNode_getNextNilChild(t *testing.T) { nilChildPos := byte(4) key := append([]byte{nilChildPos}, []byte("dog")...) - n, key, err := bn.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) + data, err := bn.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } @@ -458,8 +460,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) bnHash := bn.getHash() - ln, _, _ := bn.getNext(key, db) - lnHash := ln.getHash() + nd, _ := bn.getNext(key, db) + lnHash := nd.currentNode.getHash() expectedHashes := [][]byte{lnHash, bnHash} goRoutinesManager := getTestGoroutinesManager() @@ -586,8 +588,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) { bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) bnHash := bn.getHash() - ln, _, _ := bn.getNext(lnKey, db) - lnHash := ln.getHash() + nd, _ := bn.getNext(lnKey, db) + lnHash := nd.currentNode.getHash() expectedHashes := [][]byte{lnHash, bnHash} goRoutinesManager := getTestGoroutinesManager() diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 51131ac489..acf5ca9c26 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -246,22 +246,26 @@ func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieS return child.tryGet(key, currentDepth+1, db) } -func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) { +func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) { keyTooShort := len(key) < len(en.Key) if keyTooShort { - return nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)]) if keysDontMatch { - return nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } - childNode, err := en.resolveIfCollapsed(db) + child, encodedChild, err := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher) if err != nil { - return nil, nil, err + return nil, err } key = key[len(en.Key):] - return childNode, key, nil + return &nodeData{ + currentNode: child, + encodedNode: encodedChild, + hexKey: key, + }, nil } func (en *extensionNode) insert( diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 233b43c85f..f67c99a825 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -275,7 +275,8 @@ func TestExtensionNode_getNext(t *testing.T) { t.Parallel() en, _ := getEnAndCollapsedEn() - nextNode, _ := getBnAndCollapsedBn(en.marsh, en.hasher) + db := testscommon.NewMemDbMock() + en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) enKey := []byte{100} bnKey := []byte{2} @@ -283,9 +284,12 @@ func TestExtensionNode_getNext(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - n, newKey, err := en.getNext(key, nil) - assert.Equal(t, nextNode, n) - assert.Equal(t, key[1:], newKey) + data, err := en.getNext(key, db) + child, childBytes, _ := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher) + assert.NotNil(t, data) + assert.Equal(t, childBytes, data.encodedNode) + assert.Equal(t, child, data.currentNode) + assert.Equal(t, key[1:], data.hexKey) assert.Nil(t, err) } @@ -297,9 +301,8 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) { lnKey := []byte("dog") key := append(bnKey, lnKey...) - n, key, err := en.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) + data, err := en.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } @@ -352,8 +355,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) enHash := en.getHash() - bn, _, _ := en.getNext(enKey, db) - bnHash := bn.getHash() + nd, _ := en.getNext(enKey, db) + bnHash := nd.currentNode.getHash() expectedHashes := [][]byte{bnHash, enHash} goRoutinesManager := getTestGoroutinesManager() @@ -461,9 +464,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) { en.setHash(getTestGoroutinesManager()) en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) - bn, key, _ := en.getNext(key, db) - ln, _, _ := bn.getNext(key, db) - expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()} + bnData, _ := en.getNext(key, db) + lnData, _ := bnData.currentNode.getNext(bnData.hexKey, db) + expectedHashes := [][]byte{lnData.currentNode.getHash(), bnData.currentNode.getHash(), en.getHash()} data := []core.TrieData{{Key: lnPathKey}} goRoutinesManager := getTestGoroutinesManager() diff --git a/trie/interface.go b/trie/interface.go index 6321894a22..2b291f074a 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -12,6 +12,12 @@ import ( vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) +type nodeData struct { + currentNode node + encodedNode []byte + hexKey []byte +} + type baseTrieNode interface { getHash() []byte setGivenHash([]byte) @@ -28,7 +34,7 @@ type node interface { setHash(goRoutinesManager common.TrieGoroutinesManager) getEncodedNode() ([]byte, error) tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) - getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) + getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) node delete(data []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) (bool, node) reduceNode(pos int, db common.TrieStorageInteractor) (node, bool, error) diff --git a/trie/leafNode.go b/trie/leafNode.go index e14662e386..60cb79b1a1 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -144,11 +144,11 @@ func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorage return nil, currentDepth, nil } -func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, error) { +func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (*nodeData, error) { if bytes.Equal(key, ln.Key) { - return nil, nil, nil + return nil, nil } - return nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } func (ln *leafNode) insert( diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index afef263e49..ab505f154e 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -165,9 +165,8 @@ func TestLeafNode_getNext(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, key, err := ln.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) + data, err := ln.getNext(key, nil) + assert.Nil(t, data) assert.Nil(t, err) } @@ -177,9 +176,8 @@ func TestLeafNode_getNextWrongKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := append([]byte{2}, []byte("dog")...) - n, key, err := ln.getNext(wrongKey, nil) - assert.Nil(t, n) - assert.Nil(t, key) + data, err := ln.getNext(wrongKey, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 61f6d35be9..161010752d 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -637,35 +637,33 @@ func logMapWithTrace(message string, paramName string, hashes common.ModifiedHas // GetProof computes a Merkle proof for the node that is present at the given key func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) { - //TODO refactor this function to avoid encoding the node after it is retrieved from the DB. - // The encoded node is actually the value from db, thus we can use the retrieved value directly - if len(key) == 0 || bytes.Equal(rootHash, common.EmptyTrieHash) { + if common.IsEmptyTrie(rootHash) { return nil, nil, ErrNilNode } - rootNode, _, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher) + rootNode, encodedNode, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher) if err != nil { return nil, nil, fmt.Errorf("trie get proof error: %w", err) } var proof [][]byte - hexKey := keyBytesToHex(key) - currentNode := rootNode + var errGet error + + data := &nodeData{ + currentNode: rootNode, + encodedNode: encodedNode, + hexKey: keyBytesToHex(key), + } for { - encodedNode, errGet := currentNode.getEncodedNode() - if errGet != nil { - return nil, nil, errGet - } - proof = append(proof, encodedNode) - value := currentNode.getValue() + proof = append(proof, data.encodedNode) + value := data.currentNode.getValue() - currentNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage) + data, errGet = data.currentNode.getNext(data.hexKey, tr.trieStorage) if errGet != nil { return nil, nil, errGet } - - if currentNode == nil { + if data == nil { return proof, value, nil } } diff --git a/trie/rootManager.go b/trie/rootManager.go index 843c0879cf..148b9a523b 100644 --- a/trie/rootManager.go +++ b/trie/rootManager.go @@ -6,8 +6,6 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" ) -// TODO: add unit tests - type rootManager struct { root node oldHashes [][]byte diff --git a/trie/rootManager_test.go b/trie/rootManager_test.go new file mode 100644 index 0000000000..d2757cd62e --- /dev/null +++ b/trie/rootManager_test.go @@ -0,0 +1,130 @@ +package trie + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRootManager(t *testing.T) { + t.Parallel() + + rm := NewRootManager() + assert.Nil(t, rm.root) + assert.Empty(t, rm.oldHashes) + assert.Empty(t, rm.oldRootHash) +} + +func TestRootManager_GetRootNode(t *testing.T) { + t.Parallel() + + bn := &branchNode{ + baseNode: &baseNode{ + hash: []byte{1, 2, 3}, + }, + } + rm := NewRootManager() + rm.root = bn + assert.Equal(t, bn, rm.GetRootNode()) +} + +func TestRootManager_SetNewRootNode(t *testing.T) { + t.Parallel() + + bn := &branchNode{ + baseNode: &baseNode{ + hash: []byte{1, 2, 3}, + }, + } + rm := NewRootManager() + rm.SetNewRootNode(bn) + assert.Equal(t, bn, rm.root) +} + +func TestRootManager_SetDataForRootChange(t *testing.T) { + t.Parallel() + + bn := &branchNode{ + baseNode: &baseNode{ + hash: []byte{1, 2, 3}, + }, + } + oldRootHash := []byte{4, 5, 6} + oldHashes := [][]byte{{7, 8, 9}, {10, 11, 12}} + rm := NewRootManager() + + rm.SetDataForRootChange(bn, oldRootHash, oldHashes) + assert.Equal(t, bn, rm.root) + assert.Equal(t, oldRootHash, rm.oldRootHash) + assert.Equal(t, oldHashes, rm.oldHashes) + + var newHash []byte + rm.SetDataForRootChange(bn, newHash, oldHashes) + assert.Equal(t, bn, rm.root) + assert.Equal(t, oldRootHash, rm.oldRootHash) + assert.Equal(t, append(oldHashes, oldHashes...), rm.oldHashes) +} + +func TestRootManager_ResetCollectedHashes(t *testing.T) { + t.Parallel() + + oldRootHash := []byte{4, 5, 6} + oldHashes := [][]byte{{7, 8, 9}, {10, 11, 12}} + rm := NewRootManager() + rm.oldRootHash = oldRootHash + rm.oldHashes = oldHashes + + rm.ResetCollectedHashes() + assert.Empty(t, rm.oldRootHash) + assert.Empty(t, rm.oldHashes) +} + +func TestRootManager_GetOldHashes(t *testing.T) { + t.Parallel() + + oldHashes := [][]byte{{7, 8, 9}, {10, 11, 12}} + rm := NewRootManager() + rm.oldHashes = oldHashes + assert.Equal(t, oldHashes, rm.GetOldHashes()) +} + +func TestRootManager_GetOldRootHash(t *testing.T) { + t.Parallel() + + oldRootHash := []byte{4, 5, 6} + rm := NewRootManager() + rm.oldRootHash = oldRootHash + assert.Equal(t, oldRootHash, rm.GetOldRootHash()) +} + +func TestRootManager_Concurrency(t *testing.T) { + t.Parallel() + + numOperations := 1000 + numMethods := 6 + rm := NewRootManager() + wg := sync.WaitGroup{} + wg.Add(numOperations) + + for i := 0; i < numOperations; i++ { + go func(index int) { + defer wg.Done() + switch index % numMethods { + case 0: + rm.GetRootNode() + case 1: + rm.SetNewRootNode(&extensionNode{baseNode: &baseNode{dirty: true}}) + case 2: + rm.SetDataForRootChange(&branchNode{baseNode: &baseNode{dirty: true}}, []byte{1, 2, 3}, [][]byte{{4, 5, 6}}) + case 3: + rm.ResetCollectedHashes() + case 4: + rm.GetOldHashes() + case 5: + rm.GetOldRootHash() + } + }(i) + } + wg.Wait() +}