diff --git a/nmt.go b/nmt.go index 2683957..9f52d03 100644 --- a/nmt.go +++ b/nmt.go @@ -482,6 +482,11 @@ func (n *NamespacedMerkleTree) MaxNamespace() (namespace.ID, error) { // encompasses the leaves within the range of [start, end). // Any errors returned by this method are irrecoverable and indicate an illegal state of the tree (n). func (n *NamespacedMerkleTree) computeRoot(start, end int) ([]byte, error) { + // in computeRoot, start may be equal to end which indicates an empty tree hence empty root. + // Due to this, we need to perform custom range check instead of using validateRange() in which start=end is considered invalid. + if start < 0 || start > end || end > len(n.leaves) { + return nil, fmt.Errorf("failed to compute root [%d, %d): %w", start, end, ErrInvalidRange) + } switch end - start { case 0: rootHash := n.treeHasher.EmptyRoot() diff --git a/nmt_test.go b/nmt_test.go index 0fc7677..1acb0b2 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -1019,17 +1019,23 @@ func Test_Root_Error(t *testing.T) { // Test_computeRoot_Error tests that the computeRoot method returns an error when the underlying tree is in an invalid state, such as when the leaves are not ordered by namespace ID or when a leaf is corrupt. func Test_computeRoot_Error(t *testing.T) { + nIDSize := 2 + nIDList := []byte{1, 2, 3, 4, 5, 6, 7, 8} + // create an NMT with 8 sequentially namespaced leaves, numbered from 1 to 8. - treeWithCorruptLeafHash := exampleNMT(2, 1, 2, 3, 4, 5, 6, 7, 8) + treeWithCorruptLeafHash := exampleNMT(nIDSize, nIDList...) // corrupt a leaf hash treeWithCorruptLeafHash.leafHashes[4] = treeWithCorruptLeafHash.leafHashes[4][:treeWithCorruptLeafHash.NamespaceSize()-1] // create an NMT with 8 sequentially namespaced leaves, numbered from 1 to 8. - treeWithUnorderedLeaves := exampleNMT(2, 1, 2, 3, 4, 5, 6, 7, 8) + treeWithUnorderedLeaves := exampleNMT(nIDSize, nIDList...) // swap the positions of the 4th and 5th leaves swap(treeWithUnorderedLeaves.leaves, 4, 5) swap(treeWithUnorderedLeaves.leafHashes, 4, 5) + // create an NMT with 8 sequentially namespaced leaves, numbered from 1 to 8. + validTree := exampleNMT(nIDSize, nIDList...) + tests := []struct { name string tree *NamespacedMerkleTree @@ -1037,12 +1043,15 @@ func Test_computeRoot_Error(t *testing.T) { wantErr bool errType error }{ - {"corrupt leaf hash: the entire tree", treeWithCorruptLeafHash, 0, 7, true, ErrInvalidNodeLen}, - {"corrupt leaf: from the corrupt node until the end of the tree", treeWithCorruptLeafHash, 4, 7, true, ErrInvalidNodeLen}, - {"corrupt leaf: the corrupt node and the node to its left", treeWithCorruptLeafHash, 3, 5, true, ErrInvalidNodeLen}, - {"unordered leaves: the entire tree", treeWithUnorderedLeaves, 0, 7, true, ErrUnorderedSiblings}, - {"unordered leaves: the unordered portion", treeWithUnorderedLeaves, 4, 6, true, ErrUnorderedSiblings}, - {"unordered leaves: a portion of the tree containing the unordered leaves", treeWithUnorderedLeaves, 3, 7, true, ErrUnorderedSiblings}, + {"invalid tree with corrupt leaf hash. Query: the entire tree", treeWithCorruptLeafHash, 0, 7, true, ErrInvalidNodeLen}, + {"invalid tree with corrupt leaf. Query: from the corrupt node until the end of the tree", treeWithCorruptLeafHash, 4, 7, true, ErrInvalidNodeLen}, + {"invalid tree with corrupt leaf. Query: the corrupt node and the node to its left", treeWithCorruptLeafHash, 3, 5, true, ErrInvalidNodeLen}, + {"invalid tree with unordered leaves. Query: the entire tree", treeWithUnorderedLeaves, 0, 7, true, ErrUnorderedSiblings}, + {"invalid tree with unordered leaves. Query: the unordered portion", treeWithUnorderedLeaves, 4, 6, true, ErrUnorderedSiblings}, + {"invalid tree with unordered leaves. Query: a portion of the tree containing the unordered leaves", treeWithUnorderedLeaves, 3, 7, true, ErrUnorderedSiblings}, + {"valid tree. Query: start < 0", validTree, -1, 1, true, ErrInvalidRange}, + {"valid tree. Query: start > end", validTree, 3, 1, true, ErrInvalidRange}, + {"valid tree. Query: end > total number of leaves", validTree, 3, len(validTree.leaves) + 1, true, ErrInvalidRange}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {