Skip to content

Commit

Permalink
[Code Health] refactor: SMST#Root(), #Sum(), & #Count() (#51)
Browse files Browse the repository at this point in the history
Signed-off-by: Bryan White <[email protected]>
  • Loading branch information
bryanchriswhite authored Jul 17, 2024
1 parent 6c22c94 commit 3de80fe
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 112 deletions.
89 changes: 65 additions & 24 deletions root.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@ import (
"fmt"
)

const (
// These are intentionally exposed to allow for for testing and custom
// implementations of downstream applications.
SmtRootSizeBytes = 32
SmstRootSizeBytes = SmtRootSizeBytes + sumSizeBytes + countSizeBytes
)

// MustSum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will panic.
func (r MerkleRoot) MustSum() uint64 {
sum, err := r.Sum()
func (root MerkleSumRoot) MustSum() uint64 {
sum, err := root.Sum()
if err != nil {
panic(err)
}
Expand All @@ -27,28 +20,76 @@ func (r MerkleRoot) MustSum() uint64 {
// Sum returns the uint64 sum of the merkle root, it checks the length of the
// merkle root and if it is no the same as the size of the SMST's expected
// root hash it will return an error.
func (r MerkleRoot) Sum() (uint64, error) {
if len(r)%SmtRootSizeBytes == 0 {
return 0, fmt.Errorf("root#sum: not a merkle sum trie")
func (root MerkleSumRoot) Sum() (uint64, error) {
if err := root.validateBasic(); err != nil {
return 0, err
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return root.sum(), nil
}

// MustCount returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree. It panics if the root length
// is invalid.
func (root MerkleSumRoot) MustCount() uint64 {
count, err := root.Count()
if err != nil {
panic(err)
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], []byte(r)[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:]), nil
return count
}

// Count returns the uint64 count of the merkle root, a cryptographically secure
// count of the number of non-empty leafs in the tree.
func (r MerkleRoot) Count() uint64 {
if len(r)%SmtRootSizeBytes == 0 {
panic("root#sum: not a merkle sum trie")
// count of the number of non-empty leafs in the tree. It returns an error if the
// root length is invalid.
func (root MerkleSumRoot) Count() (uint64, error) {
if err := root.validateBasic(); err != nil {
return 0, err
}

_, firstCountByteIdx := getFirstMetaByteIdx([]byte(r))
return root.count(), nil
}

// DigestSize returns the length of the digest portion of the root.
func (root MerkleSumRoot) DigestSize() int {
return len(root) - countSizeBytes - sumSizeBytes
}

// HasDigestSize returns true if the root digest size is the same as
// that of the size of the given hasher.
func (root MerkleSumRoot) HasDigestSize(size int) bool {
return root.DigestSize() == size
}

var countBz [countSizeBytes]byte
copy(countBz[:], []byte(r)[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
// validateBasic returns an error if the root digest size is not a power of two.
func (root MerkleSumRoot) validateBasic() error {
if !isPowerOfTwo(root.DigestSize()) {
return fmt.Errorf("MerkleSumRoot#validateBasic: invalid root length")
}

return nil
}

// sum returns the sum of the node stored in the root.
func (root MerkleSumRoot) sum() uint64 {
firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(root)

return binary.BigEndian.Uint64(root[firstSumByteIdx:firstCountByteIdx])
}

// count returns the count of the node stored in the root.
func (root MerkleSumRoot) count() uint64 {
_, firstCountByteIdx := getFirstMetaByteIdx(root)

return binary.BigEndian.Uint64(root[firstCountByteIdx:])
}

// isPowerOfTwo function returns true if the input n is a power of 2
func isPowerOfTwo(n int) bool {
// A power of 2 has only one bit set in its binary representation
if n <= 0 {
return false
}
return (n & (n - 1)) == 0
}
100 changes: 59 additions & 41 deletions root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,82 @@ import (
"github.com/pokt-network/smt/kvstore/simplemap"
)

func TestMerkleRoot_TrieTypes(t *testing.T) {
func TestMerkleSumRoot_SumAndCountSuccess(t *testing.T) {
tests := []struct {
desc string
sumTree bool
hasher hash.Hash
expectedPanic string
desc string
hasher hash.Hash
}{
{
desc: "successfully: gets sum of sha256 hasher SMST",
sumTree: true,
hasher: sha256.New(),
expectedPanic: "",
desc: "sha256 hasher",
hasher: sha256.New(),
},
{
desc: "successfully: gets sum of sha512 hasher SMST",
sumTree: true,
hasher: sha512.New(),
expectedPanic: "",
desc: "sha512 hasher",
hasher: sha512.New(),
},
}

nodeStore := simplemap.NewSimpleMap()
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, nodeStore.ClearAll())
})
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}

sum, sumErr := trie.Sum()
require.NoError(t, sumErr)

count, countErr := trie.Count()
require.NoError(t, countErr)

require.EqualValues(t, uint64(45), sum)
require.EqualValues(t, uint64(10), count)
})
}
}

func TestMekleRoot_SumAndCountError(t *testing.T) {
tests := []struct {
desc string
hasher hash.Hash
}{
{
desc: "failure: panics for sha256 hasher SMT",
sumTree: false,
hasher: sha256.New(),
expectedPanic: "roo#sum: not a merkle sum trie",
desc: "sha256 hasher",
hasher: sha256.New(),
},
{
desc: "failure: panics for sha512 hasher SMT",
sumTree: false,
hasher: sha512.New(),
expectedPanic: "roo#sum: not a merkle sum trie",
desc: "sha512 hasher",
hasher: sha512.New(),
},
}

nodeStore := simplemap.NewSimpleMap()
for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Cleanup(func() {
require.NoError(t, nodeStore.ClearAll())
})
if tt.sumTree {
trie := smt.NewSparseMerkleSumTrie(nodeStore, tt.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}
require.NotNil(t, trie.Sum())
require.EqualValues(t, 45, trie.Sum())
require.EqualValues(t, 10, trie.Count())

return
}
trie := smt.NewSparseMerkleTrie(nodeStore, tt.hasher)
for i := 0; i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))))
}
if panicStr := recover(); panicStr != nil {
require.Equal(t, tt.expectedPanic, panicStr)
trie := smt.NewSparseMerkleSumTrie(nodeStore, test.hasher)
for i := uint64(0); i < 10; i++ {
require.NoError(t, trie.Update([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i)), i))
}

root := trie.Root()

// Mangle the root bytes.
root = root[:len(root)-1]

sum, sumErr := root.Sum()
require.Error(t, sumErr)
require.Equal(t, uint64(0), sum)

count, countErr := root.Count()
require.Error(t, countErr)
require.Equal(t, uint64(0), count)
})
}
}
50 changes: 30 additions & 20 deletions smst.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package smt
import (
"bytes"
"encoding/binary"
"fmt"
"hash"

"github.com/pokt-network/smt/kvstore"
Expand Down Expand Up @@ -170,39 +171,48 @@ func (smst *SMST) Commit() error {
}

// Root returns the root hash of the trie with the total sum bytes appended
func (smst *SMST) Root() MerkleRoot {
return smst.SMT.Root() // [digest]+[binary sum]
func (smst *SMST) Root() MerkleSumRoot {
return MerkleSumRoot(smst.SMT.Root()) // [digest]+[binary sum]+[binary count]
}

// Sum returns the sum of the entire trie stored in the root.
// MustSum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will panic.
func (smst *SMST) Sum() uint64 {
rootDigest := []byte(smst.Root())
func (smst *SMST) MustSum() uint64 {
sum, err := smst.Sum()
if err != nil {
panic(err)
}
return sum
}

// Sum returns the sum of the entire trie stored in the root.
// If the tree is not a sum tree, it will return an error.
func (smst *SMST) Sum() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

firstSumByteIdx, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)
return smst.Root().Sum()
}

var sumBz [sumSizeBytes]byte
copy(sumBz[:], rootDigest[firstSumByteIdx:firstCountByteIdx])
return binary.BigEndian.Uint64(sumBz[:])
// MustCount returns the number of non-empty nodes in the entire trie stored in the root.
// If the tree is not a sum tree, it will panic.
func (smst *SMST) MustCount() uint64 {
count, err := smst.Count()
if err != nil {
panic(err)
}
return count
}

// Count returns the number of non-empty nodes in the entire trie stored in the root.
func (smst *SMST) Count() uint64 {
rootDigest := []byte(smst.Root())

// If the tree is not a sum tree, it will return an error.
func (smst *SMST) Count() (uint64, error) {
if !smst.Spec().sumTrie {
panic("SMST: not a merkle sum trie")
return 0, fmt.Errorf("SMST: not a merkle sum trie")
}

_, firstCountByteIdx := getFirstMetaByteIdx(rootDigest)

var countBz [countSizeBytes]byte
copy(countBz[:], rootDigest[firstCountByteIdx:])
return binary.BigEndian.Uint64(countBz[:])
return smst.Root().Count()
}

// getFirstMetaByteIdx returns the index of the first count byte and the first sum byte
Expand All @@ -211,5 +221,5 @@ func (smst *SMST) Count() uint64 {
func getFirstMetaByteIdx(data []byte) (firstSumByteIdx, firstCountByteIdx int) {
firstCountByteIdx = len(data) - countSizeBytes
firstSumByteIdx = firstCountByteIdx - sumSizeBytes
return
return firstSumByteIdx, firstCountByteIdx
}
6 changes: 3 additions & 3 deletions smst_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestExampleSMST(t *testing.T) {
_ = trie.Commit()

// Calculate the total sum of the trie
_ = trie.Sum() // 20
_ = trie.MustSum() // 20

// Generate a Merkle proof for "foo"
proof1, _ := trie.Prove([]byte("foo"))
Expand All @@ -52,8 +52,8 @@ func TestExampleSMST(t *testing.T) {
require.False(t, valid_false1)

// Verify the total sum of the trie
require.EqualValues(t, 20, trie.Sum())
require.EqualValues(t, 20, trie.MustSum())

// Verify the number of non-empty leafs in the trie
require.EqualValues(t, 3, trie.Count())
require.EqualValues(t, 3, trie.MustCount())
}
Loading

0 comments on commit 3de80fe

Please sign in to comment.