From db060385847cdce85ced6afdddfd5ac625aad5ab Mon Sep 17 00:00:00 2001 From: h5law <53987565+h5law@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:54:51 +0000 Subject: [PATCH] [Bug] Fix bug introduced by changing `BigEndian` to `LittleEndian` (#36) * chore: simplify error checking * chore: replace BigEndian encoding with LittleEndian in the SMST --- benchmarks/proof_sizes_test.go | 8 ++++---- docs/MerkleSumTrie.md | 2 +- fuzz_test.go | 4 +++- hasher.go | 6 +++--- proofs.go | 4 ++-- proofs_test.go | 16 ++++------------ smst.go | 6 +++--- smst_proofs_test.go | 12 ++++++------ smst_test.go | 2 +- smst_utils_test.go | 13 ++----------- smt_proofs_test.go | 4 ++-- utils.go | 4 ++-- 12 files changed, 33 insertions(+), 48 deletions(-) diff --git a/benchmarks/proof_sizes_test.go b/benchmarks/proof_sizes_test.go index d103abe..e020fad 100644 --- a/benchmarks/proof_sizes_test.go +++ b/benchmarks/proof_sizes_test.go @@ -43,7 +43,7 @@ func TestSMT_ProofSizes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for i := 0; i < tc.trieSize; i++ { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) + binary.BigEndian.PutUint64(b, uint64(i)) require.NoError(t, trie.Update(b, b)) } require.NoError(t, trie.Commit()) @@ -55,7 +55,7 @@ func TestSMT_ProofSizes(t *testing.T) { minCompact := uint64(0) for i := 0; i < tc.trieSize; i++ { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) + binary.BigEndian.PutUint64(b, uint64(i)) proof, err := trie.Prove(b) require.NoError(t, err) require.NotNil(t, proof) @@ -132,7 +132,7 @@ func TestSMST_ProofSizes(t *testing.T) { t.Run(tc.name, func(t *testing.T) { for i := 0; i < tc.trieSize; i++ { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) + binary.BigEndian.PutUint64(b, uint64(i)) require.NoError(t, trie.Update(b, b, uint64(i))) } require.NoError(t, trie.Commit()) @@ -144,7 +144,7 @@ func TestSMST_ProofSizes(t *testing.T) { minCompact := uint64(0) for i := 0; i < tc.trieSize; i++ { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) + binary.BigEndian.PutUint64(b, uint64(i)) proof, err := trie.Prove(b) require.NoError(t, err) require.NotNil(t, proof) diff --git a/docs/MerkleSumTrie.md b/docs/MerkleSumTrie.md index cecf46e..a72017c 100644 --- a/docs/MerkleSumTrie.md +++ b/docs/MerkleSumTrie.md @@ -61,7 +61,7 @@ The majority of the code relating to the SMST can be found in: The sum for any node is encoded in a byte array with a fixed size (`[8]byte`) this allows for the sum to fully represent a `uint64` value in binary form. The golang `encoding/binary` package is used to encode the sum with -`binary.LittleEndian.PutUint64(sumBz[:], sum)` into a byte array `sumBz`. +`binary.BigEndian.PutUint64(sumBz[:], sum)` into a byte array `sumBz`. In order for the SMST to include the sum into a leaf node the SMT the SMST initialises the SMT with the `WithValueHasher(nil)` option so that the SMT does diff --git a/fuzz_test.go b/fuzz_test.go index 91a2233..347abeb 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -72,7 +72,7 @@ func FuzzSMT_DetectUnexpectedFailures(f *testing.F) { require.Equal(t, originalRoot, newRoot, "root changed while getting") case Update: value := make([]byte, 32) - binary.LittleEndian.PutUint64(value, uint64(i)) + binary.BigEndian.PutUint64(value, uint64(i)) err := trie.Update(key(), value) require.NoErrorf(t, err, "unknown error occured while updating") newRoot := trie.Root() @@ -93,6 +93,8 @@ func FuzzSMT_DetectUnexpectedFailures(f *testing.F) { } newRoot := trie.Root() require.Equal(t, originalRoot, newRoot, "root changed while proving") + default: + panic("unknown operation") } newRoot := trie.Root() diff --git a/hasher.go b/hasher.go index 08b3b2c..1b9b2cd 100644 --- a/hasher.go +++ b/hasher.go @@ -167,12 +167,12 @@ func encodeSumInner(leftData []byte, rightData []byte) []byte { leftSumBz := leftData[len(leftData)-sumSize:] rightSumBz := rightData[len(rightData)-sumSize:] if !bytes.Equal(leftSumBz, defaultSum[:]) { - leftSum = binary.LittleEndian.Uint64(leftSumBz) + leftSum = binary.BigEndian.Uint64(leftSumBz) } if !bytes.Equal(rightSumBz, defaultSum[:]) { - rightSum = binary.LittleEndian.Uint64(rightSumBz) + rightSum = binary.BigEndian.Uint64(rightSumBz) } - binary.LittleEndian.PutUint64(sum[:], leftSum+rightSum) + binary.BigEndian.PutUint64(sum[:], leftSum+rightSum) value = append(value, sum[:]...) return value } diff --git a/proofs.go b/proofs.go index d7e2f64..32983a9 100644 --- a/proofs.go +++ b/proofs.go @@ -285,7 +285,7 @@ func VerifyProof(proof *SparseMerkleProof, root, key, value []byte, spec *TrieSp // VerifySumProof verifies a Merkle proof for a sum trie. func VerifySumProof(proof *SparseMerkleProof, root, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) { var sumBz [sumSize]byte - binary.LittleEndian.PutUint64(sumBz[:], sum) + binary.BigEndian.PutUint64(sumBz[:], sum) valueHash := spec.digestValue(value) valueHash = append(valueHash, sumBz[:]...) if bytes.Equal(value, defaultValue) && sum == 0 { @@ -318,7 +318,7 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, nil, 0, spec) } sumBz := proof.ClosestValueHash[len(proof.ClosestValueHash)-sumSize:] - sum := binary.LittleEndian.Uint64(sumBz) + sum := binary.BigEndian.Uint64(sumBz) valueHash := proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize] return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec) } diff --git a/proofs_test.go b/proofs_test.go index e1b9f03..af2157b 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -193,13 +193,9 @@ func randomiseSumProof(proof *SparseMerkleProof) *SparseMerkleProof { func checkCompactEquivalence(t *testing.T, proof *SparseMerkleProof, base *TrieSpec) { t.Helper() compactedProof, err := CompactProof(proof, base) - if err != nil { - t.Fatalf("failed to compact proof: %v", err) - } + require.NoErrorf(t, err, "failed to compact proof: %v", err) decompactedProof, err := DecompactProof(compactedProof, base) - if err != nil { - t.Fatalf("failed to decompact proof: %v", err) - } + require.NoErrorf(t, err, "failed to decompact proof: %v", err) require.Equal(t, proof, decompactedProof) } @@ -207,12 +203,8 @@ func checkCompactEquivalence(t *testing.T, proof *SparseMerkleProof, base *TrieS func checkClosestCompactEquivalence(t *testing.T, proof *SparseMerkleClosestProof, spec *TrieSpec) { t.Helper() compactedProof, err := CompactClosestProof(proof, spec) - if err != nil { - t.Fatalf("failed to compact proof: %v", err) - } + require.NoErrorf(t, err, "failed to compact proof: %v", err) decompactedProof, err := DecompactClosestProof(compactedProof, spec) - if err != nil { - t.Fatalf("failed to decompact proof: %v", err) - } + require.NoErrorf(t, err, "failed to decompact proof: %v", err) require.Equal(t, proof, decompactedProof) } diff --git a/smst.go b/smst.go index 0e90422..9f571cc 100644 --- a/smst.go +++ b/smst.go @@ -59,7 +59,7 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { } var weightBz [sumSize]byte copy(weightBz[:], valueHash[len(valueHash)-sumSize:]) - weight := binary.LittleEndian.Uint64(weightBz[:]) + weight := binary.BigEndian.Uint64(weightBz[:]) return valueHash[:len(valueHash)-sumSize], weight, nil } @@ -69,7 +69,7 @@ func (smst *SMST) Get(key []byte) ([]byte, uint64, error) { func (smst *SMST) Update(key, value []byte, weight uint64) error { valueHash := smst.digestValue(value) var weightBz [sumSize]byte - binary.LittleEndian.PutUint64(weightBz[:], weight) + binary.BigEndian.PutUint64(weightBz[:], weight) valueHash = append(valueHash, weightBz[:]...) return smst.SMT.Update(key, valueHash) } @@ -109,5 +109,5 @@ func (smst *SMST) Sum() uint64 { var sumBz [sumSize]byte digest := smst.Root() copy(sumBz[:], digest[len(digest)-sumSize:]) - return binary.LittleEndian.Uint64(sumBz[:]) + return binary.BigEndian.Uint64(sumBz[:]) } diff --git a/smst_proofs_test.go b/smst_proofs_test.go index b9c7625..84c7cf5 100644 --- a/smst_proofs_test.go +++ b/smst_proofs_test.go @@ -101,7 +101,7 @@ func TestSMST_Proof_Operations(t *testing.T) { // Try proving a default value for a non-default leaf. var sum [sumSize]byte - binary.LittleEndian.PutUint64(sum[:], 5) + binary.BigEndian.PutUint64(sum[:], 5) tval := base.digestValue([]byte("testValue")) tval = append(tval, sum[:]...) _, leafData := base.th.digestSumLeaf(base.ph.Path([]byte("testKey2")), tval) @@ -321,7 +321,7 @@ func TestSMST_ProveClosest(t *testing.T) { require.NotEqual(t, proof, &SparseMerkleClosestProof{}) closestPath := sha256.Sum256([]byte("testKey2")) closestValueHash := []byte("testValue2") - binary.LittleEndian.PutUint64(sumBz[:], 24) + binary.BigEndian.PutUint64(sumBz[:], 24) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path[:], @@ -347,7 +347,7 @@ func TestSMST_ProveClosest(t *testing.T) { require.NotEqual(t, proof, &SparseMerkleClosestProof{}) closestPath = sha256.Sum256([]byte("testKey4")) closestValueHash = []byte("testValue4") - binary.LittleEndian.PutUint64(sumBz[:], 30) + binary.BigEndian.PutUint64(sumBz[:], 30) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path2[:], @@ -418,7 +418,7 @@ func TestSMST_ProveClosest_OneNode(t *testing.T) { closestPath := sha256.Sum256([]byte("foo")) closestValueHash := []byte("bar") var sumBz [sumSize]byte - binary.LittleEndian.PutUint64(sumBz[:], 5) + binary.BigEndian.PutUint64(sumBz[:], 5) closestValueHash = append(closestValueHash, sumBz[:]...) require.Equal(t, proof, &SparseMerkleClosestProof{ Path: path[:], @@ -450,8 +450,8 @@ func TestSMST_ProveClosest_Proof(t *testing.T) { smst256 = NewSparseMerkleSumTrie(smn, sha256.New()) smst512 = NewSparseMerkleSumTrie(smn, sha512.New()) - // insert 100000 key-value-sum triples - for i := 0; i < 100000; i++ { + // insert 100 key-value-sum triples + for i := 0; i < 100; i++ { s := strconv.Itoa(i) require.NoError(t, smst256.Update([]byte(s), []byte(s), uint64(i))) require.NoError(t, smst512.Update([]byte(s), []byte(s), uint64(i))) diff --git a/smst_test.go b/smst_test.go index 5246883..ca40dfa 100644 --- a/smst_test.go +++ b/smst_test.go @@ -460,7 +460,7 @@ func TestSMST_TotalSum(t *testing.T) { // Check root hash contains the correct hex sum root1 := smst.Root() sumBz := root1[len(root1)-sumSize:] - rootSum := binary.LittleEndian.Uint64(sumBz) + rootSum := binary.BigEndian.Uint64(sumBz) require.NoError(t, err) // Calculate total sum of the trie diff --git a/smst_utils_test.go b/smst_utils_test.go index ca37941..da26d01 100644 --- a/smst_utils_test.go +++ b/smst_utils_test.go @@ -23,7 +23,7 @@ func (smst *SMSTWithStorage) Update(key, value []byte, sum uint64) error { } valueHash := smst.digestValue(value) var sumBz [sumSize]byte - binary.LittleEndian.PutUint64(sumBz[:], sum) + binary.BigEndian.PutUint64(sumBz[:], sum) value = append(value, sumBz[:]...) return smst.preimages.Set(valueHash, value) } @@ -54,7 +54,7 @@ func (smst *SMSTWithStorage) GetValueSum(key []byte) ([]byte, uint64, error) { } var sumBz [sumSize]byte copy(sumBz[:], value[len(value)-sumSize:]) - storedSum := binary.LittleEndian.Uint64(sumBz[:]) + storedSum := binary.BigEndian.Uint64(sumBz[:]) if storedSum != sum { return nil, 0, fmt.Errorf("sum mismatch for %s: got %d, expected %d", string(key), storedSum, sum) } @@ -66,12 +66,3 @@ func (smst *SMSTWithStorage) Has(key []byte) (bool, error) { val, sum, err := smst.GetValueSum(key) return !bytes.Equal(defaultValue, val) || sum != 0, err } - -// ProveSumCompact generates a compacted Merkle proof for a key against the current root. -func ProveSumCompact(key []byte, smst SparseMerkleSumTrie) (*SparseCompactMerkleProof, error) { - proof, err := smst.Prove(key) - if err != nil { - return nil, err - } - return CompactProof(proof, smst.Spec()) -} diff --git a/smt_proofs_test.go b/smt_proofs_test.go index 5b1236c..cefed7a 100644 --- a/smt_proofs_test.go +++ b/smt_proofs_test.go @@ -400,8 +400,8 @@ func TestSMT_ProveClosest_Proof(t *testing.T) { smt256 = NewSparseMerkleTrie(smn, sha256.New()) smt512 = NewSparseMerkleTrie(smn, sha512.New()) - // insert 100000 key-value-sum triples - for i := 0; i < 100000; i++ { + // insert 100 key-value-sum triples + for i := 0; i < 100; i++ { s := strconv.Itoa(i) require.NoError(t, smt256.Update([]byte(s), []byte(s))) require.NoError(t, smt512.Update([]byte(s), []byte(s))) diff --git a/utils.go b/utils.go index ea4466f..cc5bca8 100644 --- a/utils.go +++ b/utils.go @@ -103,7 +103,7 @@ func minBytes(i int) int { // intToBytes converts an int to a byte slice func intToBytes(i int) []byte { b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(i)) + binary.BigEndian.PutUint64(b, uint64(i)) d := minBytes(i) return b[8-d:] } @@ -113,7 +113,7 @@ func bytesToInt(bz []byte) int { b := make([]byte, 8) // allocate space for a 64-bit unsigned integer d := 8 - len(bz) // determine how much padding is necessary copy(b[d:], bz) // copy over the non-zero bytes - u := binary.LittleEndian.Uint64(b) + u := binary.BigEndian.Uint64(b) return int(u) }