Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor GetProof func and add some missing unit tests #6734

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions trie/branchNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,24 +309,28 @@ func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStor
return child.tryGet(key, currentDepth+1, db)
}

func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) {
func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
if len(key) == 0 {
return nil, nil, ErrValueTooShort
return nil, ErrValueTooShort
}
childPos := key[firstByte]
if childPosOutOfRange(childPos) {
return nil, nil, ErrChildPosOutOfRange
return nil, ErrChildPosOutOfRange
}
key = key[1:]
_, err := bn.resolveIfCollapsed(childPos, db)
if len(bn.EncodedChildren[childPos]) == 0 {
return nil, ErrNodeNotFound
}
childNode, encodedNode, err := getNodeFromDBAndDecode(bn.EncodedChildren[childPos], db, bn.marsh, bn.hasher)
if err != nil {
return nil, nil, err
return nil, err
}

if bn.children[childPos] == nil {
return nil, nil, ErrNodeNotFound
}
return bn.children[childPos], key, nil
return &nodeData{
currentNode: childNode,
encodedNode: encodedNode,
hexKey: key,
}, nil
}

func (bn *branchNode) insert(
Expand Down
30 changes: 16 additions & 14 deletions trie/branchNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,17 @@ func TestBranchNode_getNext(t *testing.T) {
nextNode, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher)
childPos := byte(2)
key := append([]byte{childPos}, []byte("dog")...)

n, key, err := bn.getNext(key, nil)
db := testscommon.NewMemDbMock()
bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
data, err := bn.getNext(key, db)
assert.NotNil(t, data)

h1, _ := encodeNodeAndGetHash(nextNode)
h2, _ := encodeNodeAndGetHash(n)
h2, _ := encodeNodeAndGetHash(data.currentNode)
nextNodeBytes, _ := nextNode.getEncodedNode()
assert.Equal(t, nextNodeBytes, data.encodedNode)
assert.Equal(t, h1, h2)
assert.Equal(t, []byte("dog"), key)
assert.Equal(t, []byte("dog"), data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -362,9 +366,8 @@ func TestBranchNode_getNextWrongKey(t *testing.T) {
bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrChildPosOutOfRange, err)
}

Expand All @@ -375,9 +378,8 @@ func TestBranchNode_getNextNilChild(t *testing.T) {
nilChildPos := byte(4)
key := append([]byte{nilChildPos}, []byte("dog")...)

n, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -458,8 +460,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _ := bn.getNext(key, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(key, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -586,8 +588,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _ := bn.getNext(lnKey, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(lnKey, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
16 changes: 10 additions & 6 deletions trie/extensionNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,26 @@ func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieS
return child.tryGet(key, currentDepth+1, db)
}

func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) {
func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
keyTooShort := len(key) < len(en.Key)
if keyTooShort {
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)])
if keysDontMatch {
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
childNode, err := en.resolveIfCollapsed(db)
child, encodedChild, err := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
if err != nil {
return nil, nil, err
return nil, err
}

key = key[len(en.Key):]
return childNode, key, nil
return &nodeData{
currentNode: child,
encodedNode: encodedChild,
hexKey: key,
}, nil
}

func (en *extensionNode) insert(
Expand Down
27 changes: 15 additions & 12 deletions trie/extensionNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,21 @@ func TestExtensionNode_getNext(t *testing.T) {
t.Parallel()

en, _ := getEnAndCollapsedEn()
nextNode, _ := getBnAndCollapsedBn(en.marsh, en.hasher)
db := testscommon.NewMemDbMock()
en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)

enKey := []byte{100}
bnKey := []byte{2}
lnKey := []byte("dog")
key := append(enKey, bnKey...)
key = append(key, lnKey...)

n, newKey, err := en.getNext(key, nil)
assert.Equal(t, nextNode, n)
assert.Equal(t, key[1:], newKey)
data, err := en.getNext(key, db)
child, childBytes, _ := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
assert.NotNil(t, data)
assert.Equal(t, childBytes, data.encodedNode)
assert.Equal(t, child, data.currentNode)
assert.Equal(t, key[1:], data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -297,9 +301,8 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) {
lnKey := []byte("dog")
key := append(bnKey, lnKey...)

n, key, err := en.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := en.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -352,8 +355,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) {

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
enHash := en.getHash()
bn, _, _ := en.getNext(enKey, db)
bnHash := bn.getHash()
nd, _ := en.getNext(enKey, db)
bnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{bnHash, enHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -461,9 +464,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) {
en.setHash(getTestGoroutinesManager())

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bn, key, _ := en.getNext(key, db)
ln, _, _ := bn.getNext(key, db)
expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()}
bnData, _ := en.getNext(key, db)
lnData, _ := bnData.currentNode.getNext(bnData.hexKey, db)
expectedHashes := [][]byte{lnData.currentNode.getHash(), bnData.currentNode.getHash(), en.getHash()}
data := []core.TrieData{{Key: lnPathKey}}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
8 changes: 7 additions & 1 deletion trie/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
vmcommon "github.com/multiversx/mx-chain-vm-common-go"
)

type nodeData struct {
currentNode node
encodedNode []byte
hexKey []byte
}

type baseTrieNode interface {
getHash() []byte
setGivenHash([]byte)
Expand All @@ -28,7 +34,7 @@ type node interface {
setHash(goRoutinesManager common.TrieGoroutinesManager)
getEncodedNode() ([]byte, error)
tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error)
getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error)
getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error)
insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) node
delete(data []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) (bool, node)
reduceNode(pos int, db common.TrieStorageInteractor) (node, bool, error)
Expand Down
6 changes: 3 additions & 3 deletions trie/leafNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorage
return nil, currentDepth, nil
}

func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, error) {
func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (*nodeData, error) {
if bytes.Equal(key, ln.Key) {
return nil, nil, nil
return nil, nil
}
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}

func (ln *leafNode) insert(
Expand Down
10 changes: 4 additions & 6 deletions trie/leafNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ func TestLeafNode_getNext(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, key, err := ln.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := ln.getNext(key, nil)
assert.Nil(t, data)
assert.Nil(t, err)
}

Expand All @@ -177,9 +176,8 @@ func TestLeafNode_getNextWrongKey(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
wrongKey := append([]byte{2}, []byte("dog")...)

n, key, err := ln.getNext(wrongKey, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := ln.getNext(wrongKey, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down
28 changes: 13 additions & 15 deletions trie/patriciaMerkleTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,35 +637,33 @@ func logMapWithTrace(message string, paramName string, hashes common.ModifiedHas

// GetProof computes a Merkle proof for the node that is present at the given key
func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) {
//TODO refactor this function to avoid encoding the node after it is retrieved from the DB.
// The encoded node is actually the value from db, thus we can use the retrieved value directly
if len(key) == 0 || bytes.Equal(rootHash, common.EmptyTrieHash) {
if common.IsEmptyTrie(rootHash) {
return nil, nil, ErrNilNode
}

rootNode, _, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher)
rootNode, encodedNode, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher)
if err != nil {
return nil, nil, fmt.Errorf("trie get proof error: %w", err)
}

var proof [][]byte
hexKey := keyBytesToHex(key)
currentNode := rootNode
var errGet error

data := &nodeData{
currentNode: rootNode,
encodedNode: encodedNode,
hexKey: keyBytesToHex(key),
}

for {
encodedNode, errGet := currentNode.getEncodedNode()
if errGet != nil {
return nil, nil, errGet
}
proof = append(proof, encodedNode)
value := currentNode.getValue()
proof = append(proof, data.encodedNode)
value := data.currentNode.getValue()

currentNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage)
data, errGet = data.currentNode.getNext(data.hexKey, tr.trieStorage)
if errGet != nil {
return nil, nil, errGet
}

if currentNode == nil {
if data == nil {
return proof, value, nil
}
}
Expand Down
2 changes: 0 additions & 2 deletions trie/rootManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"github.com/multiversx/mx-chain-core-go/core/check"
)

// TODO: add unit tests

type rootManager struct {
root node
oldHashes [][]byte
Expand Down
Loading