diff --git a/CHANGELOG-PENDING.md b/CHANGELOG-PENDING.md index 7bc32c5..9b4419f 100644 --- a/CHANGELOG-PENDING.md +++ b/CHANGELOG-PENDING.md @@ -7,6 +7,7 @@ Month, DD, YYYY ### BREAKING CHANGES - [go package] (Link to PR) Description @username +- [smt](https://github.com/celestiaorg/smt/pull/64) Adds support for raw key instead of hashed key [@SweeXordious](https://github.com/SweeXordious) ### FEATURES diff --git a/bench_test.go b/bench_test.go index fa5c960..2e8ad87 100644 --- a/bench_test.go +++ b/bench_test.go @@ -2,29 +2,40 @@ package smt import ( "crypto/sha256" + "fmt" "strconv" "testing" ) func BenchmarkSparseMerkleTree_Update(b *testing.B) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(9) + smt := NewSparseMerkleTree(smn, smv, hasher) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - s := strconv.Itoa(i) - _, _ = smt.Update([]byte(s), []byte(s)) + s := fmt.Sprintf("%09d", i) + _, err := smt.Update([]byte(s), []byte(s)) + if err != nil { + b.Error(err) + } } } func BenchmarkSparseMerkleTree_Delete(b *testing.B) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(9) + smt := NewSparseMerkleTree(smn, smv, hasher) for i := 0; i < 100000; i++ { - s := strconv.Itoa(i) - _, _ = smt.Update([]byte(s), []byte(s)) + s := fmt.Sprintf("%09d", i) + _, err := smt.Update([]byte(s), []byte(s)) + if err != nil { + b.Error(err) + } } b.ResetTimer() diff --git a/bulk_test.go b/bulk_test.go index 9442c11..1d44f1d 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -22,8 +22,11 @@ func TestSparseMerkleTree(t *testing.T) { // Test all tree operations in bulk, with specified ratio probabilities of insert, update and delete. func bulkOperations(t *testing.T, operations int, insert int, update int, delete int) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + keyLen := 16 + rand.Intn(32) + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(keyLen) + smt := NewSparseMerkleTree(smn, smv, hasher) max := insert + update + delete kv := make(map[string]string) @@ -31,7 +34,6 @@ func bulkOperations(t *testing.T, operations int, insert int, update int, delete for i := 0; i < operations; i++ { n := rand.Intn(max) if n < insert { // Insert - keyLen := 16 + rand.Intn(32) key := make([]byte, keyLen) rand.Read(key) @@ -93,14 +95,14 @@ func bulkCheckAll(t *testing.T, smt *SparseMerkleTree, kv *map[string]string) { if err != nil { t.Errorf("error: %v", err) } - if !VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.th.hasher) { + if !VerifyProof(proof, smt.Root(), []byte(k), []byte(v), smt.th.hasher, smt.values.GetKeySize()) { t.Error("Merkle proof failed to verify") } compactProof, err := smt.ProveCompact([]byte(k)) if err != nil { t.Errorf("error: %v", err) } - if !VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.th.hasher) { + if !VerifyCompactProof(compactProof, smt.Root(), []byte(k), []byte(v), smt.th.hasher, smt.values.GetKeySize()) { t.Error("Merkle proof failed to verify") } @@ -114,12 +116,12 @@ func bulkCheckAll(t *testing.T, smt *SparseMerkleTree, kv *map[string]string) { if v2 == "" { continue } - commonPrefix := countCommonPrefix(smt.th.path([]byte(k)), smt.th.path([]byte(k2))) + commonPrefix := countCommonPrefix([]byte(k), []byte(k2)) if commonPrefix != smt.depth() && commonPrefix > largestCommonPrefix { largestCommonPrefix = commonPrefix } } - sideNodes, _, _, _, err := smt.sideNodesForRoot(smt.th.path([]byte(k)), smt.Root(), false) + sideNodes, _, _, _, err := smt.sideNodesForRoot([]byte(k), smt.Root(), false) if err != nil { t.Errorf("error: %v", err) } diff --git a/deepsubtree.go b/deepsubtree.go index 38e5f63..8055ac3 100644 --- a/deepsubtree.go +++ b/deepsubtree.go @@ -28,13 +28,16 @@ func NewDeepSparseMerkleSubTree(nodes, values MapStore, hasher hash.Hash, root [ // If the leaf may be updated (e.g. during a state transition fraud proof), // an updatable proof should be used. See SparseMerkleTree.ProveUpdatable. func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []byte, value []byte) error { - result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher) + if len(key) != dsmst.values.GetKeySize() { + return ErrWrongKeySize + } + result, updates := verifyProofWithUpdates(proof, dsmst.Root(), key, value, dsmst.th.hasher, dsmst.values.GetKeySize()) if !result { return ErrBadProof } if !bytes.Equal(value, defaultValue) { // Membership proof. - if err := dsmst.values.Set(dsmst.th.path(key), value); err != nil { + if err := dsmst.values.Set(key, value); err != nil { return err } } @@ -64,6 +67,9 @@ func (dsmst *DeepSparseMerkleSubTree) AddBranch(proof SparseMerkleProof, key []b // Use if a key was _not_ previously added with AddBranch, otherwise use Get. // Errors if the key cannot be reached by descending. func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) { + if len(key) != smt.values.GetKeySize() { + return nil, ErrWrongKeySize + } // Get tree's root root := smt.Root() @@ -72,7 +78,6 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) { return defaultValue, nil } - path := smt.th.path(key) currentHash := root for i := 0; i < smt.depth(); i++ { currentData, err := smt.nodes.Get(currentHash) @@ -80,13 +85,13 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) { return nil, err } else if smt.th.isLeaf(currentData) { // We've reached the end. Is this the actual leaf? - p, _ := smt.th.parseLeaf(currentData) - if !bytes.Equal(path, p) { + p, _ := smt.th.parseLeaf(currentData, smt.values.GetKeySize()) + if !bytes.Equal(key, p) { // Nope. Therefore the key is actually empty. return defaultValue, nil } // Otherwise, yes. Return the value. - value, err := smt.values.Get(path) + value, err := smt.values.Get(key) if err != nil { return nil, err } @@ -94,7 +99,7 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) { } leftNode, rightNode := smt.th.parseNode(currentData) - if getBitAtFromMSB(path, i) == right { + if getBitAtFromMSB(key, i) == right { currentHash = rightNode } else { currentHash = leftNode @@ -109,7 +114,7 @@ func (smt *SparseMerkleTree) GetDescend(key []byte) ([]byte, error) { // The following lines of code should only be reached if the path is 256 // nodes high, which should be very unlikely if the underlying hash function // is collision-resistant. - value, err := smt.values.Get(path) + value, err := smt.values.Get(key) if err != nil { return nil, err } diff --git a/deepsubtree_test.go b/deepsubtree_test.go index b3a6235..9c5ac4f 100644 --- a/deepsubtree_test.go +++ b/deepsubtree_test.go @@ -7,14 +7,59 @@ import ( "testing" ) +func TestDeepSubTreeKeySizeChecks(t *testing.T) { + hasher := sha256.New() + keySize := len([]byte("testKey1")) + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(keySize) + smt := NewSparseMerkleTree(smn, smv, hasher) + + _, err := smt.Update([]byte("testKey1"), []byte("testValue1")) + if err != nil { + t.Errorf("couldn't update smt. exception: %v", err) + } + + proof, err := smt.Prove([]byte("testKey1")) + if err != nil { + t.Errorf("couldn't prove existing key. Actual exception: %v", err) + } + + smn, _ = NewSimpleMap(hasher.Size()) + smv, _ = NewSimpleMap(keySize) + dsmst := NewDeepSparseMerkleSubTree(smn, smv, hasher, smt.Root()) + + err = dsmst.AddBranch(proof, randomBytes(keySize+1), []byte("testValue1")) + if err != ErrWrongKeySize { + t.Errorf("should have complained of `keySize + 1` when adding branch. Actual exception: %v", err) + } + + err = dsmst.AddBranch(proof, randomBytes(keySize-1), []byte("testValue1")) + if err != ErrWrongKeySize { + t.Errorf("should have complained of `keySize - 1` when adding branch. Actual exception: %v", err) + } + + _, err = dsmst.GetDescend(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have complained of `keySize + 1` when getting descend. Actual exception: %v", err) + } + + _, err = dsmst.GetDescend(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have complained of `keySize - 1` when getting descend. Actual exception: %v", err) + } +} + func TestDeepSparseMerkleSubTreeBasic(t *testing.T) { - smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(len([]byte("testKey1"))) + smt := NewSparseMerkleTree(smn, smv, hasher) - smt.Update([]byte("testKey1"), []byte("testValue1")) - smt.Update([]byte("testKey2"), []byte("testValue2")) - smt.Update([]byte("testKey3"), []byte("testValue3")) - smt.Update([]byte("testKey4"), []byte("testValue4")) - smt.Update([]byte("testKey6"), []byte("testValue6")) + _, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + _, _ = smt.Update([]byte("testKey2"), []byte("testValue2")) + _, _ = smt.Update([]byte("testKey3"), []byte("testValue3")) + _, _ = smt.Update([]byte("testKey4"), []byte("testValue4")) + _, _ = smt.Update([]byte("testKey6"), []byte("testValue6")) originalRoot := make([]byte, len(smt.Root())) copy(originalRoot, smt.Root()) @@ -23,7 +68,9 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) { proof2, _ := smt.ProveUpdatable([]byte("testKey2")) proof5, _ := smt.ProveUpdatable([]byte("testKey5")) - dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root()) + smn, _ = NewSimpleMap(hasher.Size()) + smv, _ = NewSimpleMap(len([]byte("testKey1"))) + dsmst := NewDeepSparseMerkleSubTree(smn, smv, hasher, smt.Root()) err := dsmst.AddBranch(proof1, []byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when adding branch to deep subtree: %v", err) @@ -141,17 +188,22 @@ func TestDeepSparseMerkleSubTreeBasic(t *testing.T) { } func TestDeepSparseMerkleSubTreeBadInput(t *testing.T) { - smt := NewSparseMerkleTree(NewSimpleMap(), NewSimpleMap(), sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(len([]byte("testKey1"))) + smt := NewSparseMerkleTree(smn, smv, hasher) - smt.Update([]byte("testKey1"), []byte("testValue1")) - smt.Update([]byte("testKey2"), []byte("testValue2")) - smt.Update([]byte("testKey3"), []byte("testValue3")) - smt.Update([]byte("testKey4"), []byte("testValue4")) + _, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + _, _ = smt.Update([]byte("testKey2"), []byte("testValue2")) + _, _ = smt.Update([]byte("testKey3"), []byte("testValue3")) + _, _ = smt.Update([]byte("testKey4"), []byte("testValue4")) badProof, _ := smt.Prove([]byte("testKey1")) badProof.SideNodes[0][0] = byte(0) - dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), sha256.New(), smt.Root()) + smn, _ = NewSimpleMap(hasher.Size()) + smv, _ = NewSimpleMap(len([]byte("testKey1"))) + dsmst := NewDeepSparseMerkleSubTree(smn, smv, hasher, smt.Root()) err := dsmst.AddBranch(badProof, []byte("testKey1"), []byte("testValue1")) if !errors.Is(err, ErrBadProof) { t.Error("did not return ErrBadProof for bad proof input") diff --git a/fuzz/delete/fuzz.go b/fuzz/delete/fuzz.go index d0c6650..e199366 100644 --- a/fuzz/delete/fuzz.go +++ b/fuzz/delete/fuzz.go @@ -3,7 +3,6 @@ package delete import ( "bytes" "crypto/sha256" - "github.com/celestiaorg/smt" ) @@ -17,11 +16,14 @@ func Fuzz(data []byte) int { return -1 } - smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap() - tree := smt.NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + keySize := 10 + smn, _ := smt.NewSimpleMap(hasher.Size()) + smv, _ := smt.NewSimpleMap(keySize) + tree := smt.NewSparseMerkleTree(smn, smv, hasher) for i := 0; i < len(splits)-1; i += 2 { key, value := splits[i], splits[i+1] - tree.Update(key, value) + _, _ = tree.Update(key, value) } deleteKey := splits[len(splits)-1] diff --git a/fuzz/fuzz.go b/fuzz/fuzz.go index 0ad1cb1..1d91dc5 100644 --- a/fuzz/fuzz.go +++ b/fuzz/fuzz.go @@ -9,18 +9,22 @@ import ( "github.com/celestiaorg/smt" ) +// Fuzz FIXME func Fuzz(input []byte) int { if len(input) < 100 { return 0 } - smn, smv := smt.NewSimpleMap(), smt.NewSimpleMap() - tree := smt.NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + keySize := 10 + smn, _ := smt.NewSimpleMap(hasher.Size()) + smv, _ := smt.NewSimpleMap(keySize) + tree := smt.NewSparseMerkleTree(smn, smv, hasher) r := bytes.NewReader(input) var keys [][]byte key := func() []byte { if readByte(r) < math.MaxUint8/2 { k := make([]byte, readByte(r)/2) - r.Read(k) + _, _ = r.Read(k) keys = append(keys, k) return k } @@ -37,17 +41,17 @@ func Fuzz(input []byte) int { op := op(int(b) % int(Noop)) switch op { case Get: - tree.Get(key()) + _, _ = tree.Get(key()) case Update: value := make([]byte, 32) binary.BigEndian.PutUint64(value, uint64(i)) - tree.Update(key(), value) + _, _ = tree.Update(key(), value) case Delete: - tree.Delete(key()) + _, _ = tree.Delete(key()) case Prove: - tree.Prove(key()) + _, _ = tree.Prove(key()) case Has: - tree.Has(key()) + _, _ = tree.Has(key()) } } return 1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8522afa --- /dev/null +++ b/go.sum @@ -0,0 +1,43 @@ +github.com/Julusian/godocdown v0.0.0-20170816220326-6d19f8ff2df8/go.mod h1:INZr5t32rG59/5xeltqoCJoNY7e5x/3xoY9WSWVWg74= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dvyukov/go-fuzz v0.0.0-20210914135545-4980593459a1 h1:YQOLTC8zvFaNSEuMexG0i7pY26bOksnQFsSJfGclo54= +github.com/dvyukov/go-fuzz v0.0.0-20210914135545-4980593459a1/go.mod h1:11Gm+ccJnvAhCNLlf5+cS9KjtbaD5I5zaZpFMsTHWTw= +github.com/dvyukov/go-fuzz-corpus v0.0.0-20190920191254-c42c1b2914c7 h1:dECcCtiFwubOBYviX4zj9ggx9ucw2lLEGlU3O+Q7tuk= +github.com/dvyukov/go-fuzz-corpus v0.0.0-20190920191254-c42c1b2914c7/go.mod h1:aSQy4yPbyjNrXv8ANgj9BQacfss3bbmXTon1mFcwc1k= +github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE03qmvRTNfbw= +github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/robertkrimen/godocdown v0.0.0-20130622164427-0bfa04905481/go.mod h1:C9WhFzY47SzYBIvzFqSvHIR6ROgDo4TtdTuRaOMjF/s= +github.com/stephens2424/writerset v1.0.2 h1:znRLgU6g8RS5euYRcy004XeE4W+Tu44kALzy7ghPif8= +github.com/stephens2424/writerset v1.0.2/go.mod h1:aS2JhsMn6eA7e82oNmW4rfsgAOp9COBTTl8mzkwADnc= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.5.1 h1:OJxoQ/rynoF0dcCdI7cLPktw/hR2cueqYfjm43oqK38= +golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.8 h1:P1HhGGuLW4aAclzjtmJdf0mJOjVUZUzOTqkAkWL+l6w= +golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/helpers_test.go b/helpers_test.go new file mode 100644 index 0000000..2a5a383 --- /dev/null +++ b/helpers_test.go @@ -0,0 +1,11 @@ +package smt + +import ( + "math/rand" +) + +func randomBytes(length int) []byte { + b := make([]byte, length) + rand.Read(b) + return b +} diff --git a/mapstore.go b/mapstore.go index 5fe381b..2b6c3b7 100644 --- a/mapstore.go +++ b/mapstore.go @@ -1,6 +1,7 @@ package smt import ( + "errors" "fmt" ) @@ -9,6 +10,7 @@ type MapStore interface { Get(key []byte) ([]byte, error) // Get gets the value for a key. Set(key []byte, value []byte) error // Set updates the value for a key. Delete(key []byte) error // Delete deletes a key. + GetKeySize() int // Gets the key size for the map store. } // InvalidKeyError is thrown when a key that does not exist is being accessed. @@ -20,34 +22,59 @@ func (e *InvalidKeyError) Error() string { return fmt.Sprintf("invalid key: %x", e.Key) } +// ErrWrongKeySize is returned when a key has a different size than the key size. +var ErrWrongKeySize = errors.New("wrong key size") + +// ErrUnsupportedKeySize is returned when the map store is initialized by a key smaller than 1. +var ErrUnsupportedKeySize = errors.New("key size should be greater or equal to 1") + // SimpleMap is a simple in-memory map. type SimpleMap struct { - m map[string][]byte + m map[string][]byte + keySize int } // NewSimpleMap creates a new empty SimpleMap. -func NewSimpleMap() *SimpleMap { - return &SimpleMap{ - m: make(map[string][]byte), +func NewSimpleMap(keySize int) (*SimpleMap, error) { + if keySize < 1 { + return nil, ErrUnsupportedKeySize } + return &SimpleMap{ + m: make(map[string][]byte), + keySize: keySize, + }, nil } // Get gets the value for a key. func (sm *SimpleMap) Get(key []byte) ([]byte, error) { + if err := sm.checkKeySize(key); err != nil { + return nil, err + } if value, ok := sm.m[string(key)]; ok { return value, nil } return nil, &InvalidKeyError{Key: key} } +// GetKeySize gets the key size of the map store. +func (sm *SimpleMap) GetKeySize() int { + return sm.keySize +} + // Set updates the value for a key. func (sm *SimpleMap) Set(key []byte, value []byte) error { + if err := sm.checkKeySize(key); err != nil { + return err + } sm.m[string(key)] = value return nil } // Delete deletes a key. func (sm *SimpleMap) Delete(key []byte) error { + if err := sm.checkKeySize(key); err != nil { + return err + } _, ok := sm.m[string(key)] if ok { delete(sm.m, string(key)) @@ -55,3 +82,10 @@ func (sm *SimpleMap) Delete(key []byte) error { } return &InvalidKeyError{Key: key} } + +func (sm *SimpleMap) checkKeySize(key []byte) error { + if len(key) != sm.keySize { + return ErrWrongKeySize + } + return nil +} diff --git a/mapstore_test.go b/mapstore_test.go index ea945cc..5eada0b 100644 --- a/mapstore_test.go +++ b/mapstore_test.go @@ -2,48 +2,92 @@ package smt import ( "bytes" - "crypto/sha256" "testing" ) func TestSimpleMap(t *testing.T) { - sm := NewSimpleMap() - h := sha256.New() - var value []byte - var err error - - h.Write([]byte("test")) + sm, _ := NewSimpleMap(len([]byte("test1"))) // Tests for Get. - _, err = sm.Get(h.Sum(nil)) + _, err := sm.Get([]byte("test1")) if err == nil { t.Error("did not return an error when getting a non-existent key") } // Tests for Put. - err = sm.Set(h.Sum(nil), []byte("hello")) + err = sm.Set([]byte("test1"), []byte("hello")) if err != nil { - t.Error("updating a key returned an error") + t.Errorf("updating a key returned an error : %v", err) } - value, err = sm.Get(h.Sum(nil)) + value, err := sm.Get([]byte("test1")) if err != nil { - t.Error("getting a key returned an error") + t.Errorf("getting a key returned an error : %v", err) } if !bytes.Equal(value, []byte("hello")) { t.Error("failed to update key") } // Tests for Del. - err = sm.Delete(h.Sum(nil)) + err = sm.Delete([]byte("test1")) if err != nil { - t.Error("deleting a key returned an error") + t.Errorf("deleting a key returned an error : %v", err) } - _, err = sm.Get(h.Sum(nil)) + _, err = sm.Get([]byte("test1")) if err == nil { t.Error("failed to delete key") } - err = sm.Delete([]byte("nonexistent")) + err = sm.Delete([]byte("test2")) if err == nil { t.Error("deleting a key did not return an error on a non-existent key") } } + +func TestSimpleMapKeySize(t *testing.T) { + _, err := NewSimpleMap(0) + if err != ErrUnsupportedKeySize { + t.Errorf("didn't throw ErrUnsupportedKeySize when initializing with 0 as key size : %v", err) + } + + _, err = NewSimpleMap(1) + if err != nil { + t.Errorf("shouldn't throw an exception when initializing with 1 as key size : %v", err) + } + + sm, err := NewSimpleMap(len([]byte("test1"))) + if err != nil { + t.Errorf("shouldn't throw an exception when initializing with 5 as key size : %v", err) + } + + // Tests for setting wrong key size. + err = sm.Set([]byte("test11"), []byte("hello")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when setting a bigger key size : %v", err) + } + + err = sm.Set([]byte("test"), []byte("hello")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when setting a smaller key size : %v", err) + } + + // Tests for getting wrong key size. + _, err = sm.Get([]byte("test11")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when getting a bigger key size : %v", err) + } + + _, err = sm.Get([]byte("test")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when getting a smaller key size : %v", err) + } + + // Tests for getting wrong key size. + err = sm.Delete([]byte("test11")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when deleting a bigger key size : %v", err) + } + + err = sm.Delete([]byte("test")) + if err != ErrWrongKeySize { + t.Errorf("didn't throw ErrWrongKeySize when deleting a smaller key size : %v", err) + } +} diff --git a/proofs.go b/proofs.go index 372ca8c..b0b77bc 100644 --- a/proofs.go +++ b/proofs.go @@ -21,16 +21,16 @@ type SparseMerkleProof struct { SiblingData []byte } -func (proof *SparseMerkleProof) sanityCheck(th *treeHasher) bool { +func (proof *SparseMerkleProof) sanityCheck(th *treeHasher, keySize int) bool { // Do a basic sanity check on the proof, so that a malicious proof cannot // cause the verifier to fatally exit (e.g. due to an index out-of-range // error) or cause a CPU DoS attack. // Check that the number of supplied sidenodes does not exceed the maximum possible. - if len(proof.SideNodes) > th.pathSize()*8 || + if len(proof.SideNodes) > keySize*8 || // Check that leaf data for non-membership proofs is the correct size. - (proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) != len(leafPrefix)+th.pathSize()+th.hasher.Size()) { + (proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) != len(leafPrefix)+keySize+th.hasher.Size()) { return false } @@ -74,7 +74,7 @@ type SparseCompactMerkleProof struct { SiblingData []byte } -func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher) bool { +func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher, keySize int) bool { // Do a basic sanity check on the proof on the fields of the proof specific to // the compact proof only. // @@ -82,7 +82,7 @@ func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher) bool { // de-compacted proof should be executed. // Compact proofs: check that NumSideNodes is within the right range. - if proof.NumSideNodes < 0 || proof.NumSideNodes > th.pathSize()*8 || + if proof.NumSideNodes < 0 || proof.NumSideNodes > keySize*8 || // Compact proofs: check that the length of the bit mask is as expected // according to NumSideNodes. @@ -98,16 +98,18 @@ func (proof *SparseCompactMerkleProof) sanityCheck(th *treeHasher) bool { } // VerifyProof verifies a Merkle proof. -func VerifyProof(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) bool { - result, _ := verifyProofWithUpdates(proof, root, key, value, hasher) +func VerifyProof(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash, keySize int) bool { + if len(key) != keySize { + return false + } + result, _ := verifyProofWithUpdates(proof, root, key, value, hasher, keySize) return result } -func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) (bool, [][][]byte) { +func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash, keySize int) (bool, [][][]byte) { th := newTreeHasher(hasher) - path := th.path(key) - if !proof.sanityCheck(th) { + if !proof.sanityCheck(th, keySize) { return false, nil } @@ -119,8 +121,8 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va if proof.NonMembershipLeafData == nil { // Leaf is a placeholder value. currentHash = th.placeholder() } else { // Leaf is an unrelated leaf. - actualPath, valueHash := th.parseLeaf(proof.NonMembershipLeafData) - if bytes.Equal(actualPath, path) { + actualPath, valueHash := th.parseLeaf(proof.NonMembershipLeafData, keySize) + if bytes.Equal(actualPath, key) { // This is not an unrelated leaf; non-membership proof failed. return false, nil } @@ -132,7 +134,7 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va } } else { // Membership proof. valueHash := th.digest(value) - currentHash, currentData = th.digestLeaf(path, valueHash) + currentHash, currentData = th.digestLeaf(key, valueHash) update := make([][]byte, 2) update[0], update[1] = currentHash, currentData updates = append(updates, update) @@ -140,10 +142,12 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va // Recompute root. for i := 0; i < len(proof.SideNodes); i++ { - node := make([]byte, th.pathSize()) - copy(node, proof.SideNodes[i]) + node := make([]byte, hasher.Size()) + if copy(node, proof.SideNodes[i]) != len(proof.SideNodes[i]) { + return false, nil + } - if getBitAtFromMSB(path, len(proof.SideNodes)-1-i) == right { + if getBitAtFromMSB(key, len(proof.SideNodes)-1-i) == right { currentHash, currentData = th.digestNode(node, currentHash) } else { currentHash, currentData = th.digestNode(currentHash, node) @@ -158,19 +162,22 @@ func verifyProofWithUpdates(proof SparseMerkleProof, root []byte, key []byte, va } // VerifyCompactProof verifies a compacted Merkle proof. -func VerifyCompactProof(proof SparseCompactMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash) bool { - decompactedProof, err := DecompactProof(proof, hasher) +func VerifyCompactProof(proof SparseCompactMerkleProof, root []byte, key []byte, value []byte, hasher hash.Hash, keySize int) bool { + if len(key) != keySize { + return false + } + decompactedProof, err := DecompactProof(proof, hasher, keySize) if err != nil { return false } - return VerifyProof(decompactedProof, root, key, value, hasher) + return VerifyProof(decompactedProof, root, key, value, hasher, keySize) } // CompactProof compacts a proof, to reduce its size. -func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkleProof, error) { +func CompactProof(proof SparseMerkleProof, hasher hash.Hash, keySize int) (SparseCompactMerkleProof, error) { th := newTreeHasher(hasher) - if !proof.sanityCheck(th) { + if !proof.sanityCheck(th, keySize) { return SparseCompactMerkleProof{}, ErrBadProof } @@ -196,10 +203,10 @@ func CompactProof(proof SparseMerkleProof, hasher hash.Hash) (SparseCompactMerkl } // DecompactProof decompacts a proof, so that it can be used for VerifyProof. -func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash) (SparseMerkleProof, error) { +func DecompactProof(proof SparseCompactMerkleProof, hasher hash.Hash, keySize int) (SparseMerkleProof, error) { th := newTreeHasher(hasher) - if !proof.sanityCheck(th) { + if !proof.sanityCheck(th, keySize) { return SparseMerkleProof{}, ErrBadProof } diff --git a/proofs_test.go b/proofs_test.go index bdfd000..b1f8100 100644 --- a/proofs_test.go +++ b/proofs_test.go @@ -8,6 +8,64 @@ import ( "testing" ) +func TestProofsKeySizeChecks(t *testing.T) { + hasher := sha256.New() + keySize := len([]byte("testKey1")) + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(keySize) + smt := NewSparseMerkleTree(smn, smv, hasher) + + _, err := smt.Update([]byte("testKey1"), []byte("testValue1")) + if err != nil { + t.Errorf("couldn't update smt. exception: %v", err) + } + + _, err = smt.Update([]byte("testKey2"), []byte("testValue2")) + if err != nil { + t.Errorf("couldn't update smt. exception: %v", err) + } + + proof, err := smt.Prove([]byte("testKey1")) + if err != nil { + t.Errorf("couldn't prove existing key. Actual exception: %v", err) + } + + proved := VerifyProof(proof, smt.Root(), randomBytes(keySize+1), []byte("testValue1"), hasher, smt.values.GetKeySize()) + if proved { + t.Errorf("shouldn't have been able to verify prove a `keySize + 1`.") + } + + proved = VerifyProof(proof, smt.Root(), randomBytes(keySize-1), []byte("testValue1"), hasher, smt.values.GetKeySize()) + if proved { + t.Errorf("shouldn't have been able to verify prove a `keySize - 1`.") + } + + _, err = smt.ProveCompact(randomBytes(keySize + 1)) + if err == nil { + t.Errorf("shouldn't have been able to prove compact for a `keySize + 1`.") + } + + _, err = smt.ProveCompact(randomBytes(keySize - 1)) + if err == nil { + t.Errorf("shouldn't have been able to prove compact for a `keySize - 1`.") + } + + compactProof, err := smt.ProveCompact([]byte("testKey1")) + if err != nil { + t.Errorf("couldn't prove compact existing key: %v", err) + } + + proved = VerifyCompactProof(compactProof, smt.Root(), randomBytes(keySize+1), []byte("testValue1"), sha256.New(), smt.values.GetKeySize()) + if proved { + t.Errorf("shouldn't have been able to verify compact proof for a `keySize + 1`.") + } + + proved = VerifyCompactProof(compactProof, smt.Root(), randomBytes(keySize-1), []byte("testValue1"), sha256.New(), smt.values.GetKeySize()) + if proved { + t.Errorf("shouldn't have been able to verify compact proof for a `keySize - 1`.") + } +} + // Test base case Merkle proof operations. func TestProofsBasic(t *testing.T) { var smn, smv *SimpleMap @@ -17,105 +75,107 @@ func TestProofsBasic(t *testing.T) { var root []byte var err error - smn, smv = NewSimpleMap(), NewSimpleMap() - smt = NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ = NewSimpleMap(hasher.Size()) + smv, _ = NewSimpleMap(len([]byte("testKey1"))) + smt = NewSparseMerkleTree(smn, smv, hasher) // Generate and verify a proof on an empty key. proof, err = smt.Prove([]byte("testKey3")) - checkCompactEquivalence(t, proof, smt.th.hasher) + checkCompactEquivalence(t, proof, smt.th.hasher, smt.values.GetKeySize()) if err != nil { t.Error("error returned when trying to prove inclusion on empty key") } - result = VerifyProof(proof, bytes.Repeat([]byte{0}, smt.th.hasher.Size()), []byte("testKey3"), defaultValue, smt.th.hasher) + result = VerifyProof(proof, bytes.Repeat([]byte{0}, smt.th.hasher.Size()), []byte("testKey3"), defaultValue, smt.th.hasher, smt.values.GetKeySize()) if !result { t.Error("valid proof on empty key failed to verify") } - result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } // Add a key, generate and verify a Merkle proof. - root, _ = smt.Update([]byte("testKey"), []byte("testValue")) - proof, err = smt.Prove([]byte("testKey")) - checkCompactEquivalence(t, proof, smt.th.hasher) + root, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + proof, err = smt.Prove([]byte("testKey1")) + checkCompactEquivalence(t, proof, smt.th.hasher, smt.values.GetKeySize()) if err != nil { t.Error("error returned when trying to prove inclusion") } - result = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if !result { t.Error("valid proof failed to verify") } - result = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("badValue"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } // Add a key, generate and verify both Merkle proofs. - root, _ = smt.Update([]byte("testKey2"), []byte("testValue")) - proof, err = smt.Prove([]byte("testKey")) - checkCompactEquivalence(t, proof, smt.th.hasher) + root, _ = smt.Update([]byte("testKey2"), []byte("testValue1")) + proof, err = smt.Prove([]byte("testKey1")) + checkCompactEquivalence(t, proof, smt.th.hasher, smt.values.GetKeySize()) if err != nil { t.Error("error returned when trying to prove inclusion") } - result = VerifyProof(proof, root, []byte("testKey"), []byte("testValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if !result { t.Error("valid proof failed to verify") } - result = VerifyProof(proof, root, []byte("testKey"), []byte("badValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("badValue"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - result = VerifyProof(randomiseProof(proof), root, []byte("testKey"), []byte("testValue"), smt.th.hasher) + result = VerifyProof(randomiseProof(proof), root, []byte("testKey1"), []byte("testKey1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } proof, err = smt.Prove([]byte("testKey2")) - checkCompactEquivalence(t, proof, smt.th.hasher) + checkCompactEquivalence(t, proof, smt.th.hasher, smt.values.GetKeySize()) if err != nil { t.Error("error returned when trying to prove inclusion") } - result = VerifyProof(proof, root, []byte("testKey2"), []byte("testValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey2"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if !result { t.Error("valid proof failed to verify") } - result = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey2"), []byte("badValue"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - result = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue"), smt.th.hasher) + result = VerifyProof(randomiseProof(proof), root, []byte("testKey2"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } // Try proving a default value for a non-default leaf. th := newTreeHasher(smt.th.hasher) - _, leafData := th.digestLeaf(th.path([]byte("testKey2")), th.digest([]byte("testValue"))) + _, leafData := th.digestLeaf([]byte("testKey2"), th.digest([]byte("testValue1"))) proof = SparseMerkleProof{ SideNodes: proof.SideNodes, NonMembershipLeafData: leafData, } - result = VerifyProof(proof, root, []byte("testKey2"), defaultValue, smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey2"), defaultValue, smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } // Generate and verify a proof on an empty key. proof, err = smt.Prove([]byte("testKey3")) - checkCompactEquivalence(t, proof, smt.th.hasher) + checkCompactEquivalence(t, proof, smt.th.hasher, smt.values.GetKeySize()) if err != nil { t.Error("error returned when trying to prove inclusion on empty key") } - result = VerifyProof(proof, root, []byte("testKey3"), defaultValue, smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey3"), defaultValue, smt.th.hasher, smt.values.GetKeySize()) if !result { t.Error("valid proof on empty key failed to verify") } - result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey3"), []byte("badValue"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - result = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, smt.th.hasher) + result = VerifyProof(randomiseProof(proof), root, []byte("testKey3"), defaultValue, smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } @@ -123,30 +183,32 @@ func TestProofsBasic(t *testing.T) { // Test sanity check cases for non-compact proofs. func TestProofsSanityCheck(t *testing.T) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(len([]byte("testKey1"))) + smt := NewSparseMerkleTree(smn, smv, hasher) th := &smt.th - smt.Update([]byte("testKey1"), []byte("testValue1")) - smt.Update([]byte("testKey2"), []byte("testValue2")) - smt.Update([]byte("testKey3"), []byte("testValue3")) + _, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + _, _ = smt.Update([]byte("testKey2"), []byte("testValue2")) + _, _ = smt.Update([]byte("testKey3"), []byte("testValue3")) root, _ := smt.Update([]byte("testKey4"), []byte("testValue4")) // Case: invalid number of sidenodes. proof, _ := smt.Prove([]byte("testKey1")) - sideNodes := make([][]byte, smt.th.pathSize()*8+1) + sideNodes := make([][]byte, len([]byte("testKey1"))*8+1) for i := range sideNodes { sideNodes[i] = proof.SideNodes[0] } proof.SideNodes = sideNodes - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result := VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result := VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - _, err := CompactProof(proof, smt.th.hasher) + _, err := CompactProof(proof, smt.th.hasher, smt.values.GetKeySize()) if err == nil { t.Error("did not return error when compacting a malformed proof") } @@ -154,14 +216,14 @@ func TestProofsSanityCheck(t *testing.T) { // Case: incorrect size for NonMembershipLeafData. proof, _ = smt.Prove([]byte("testKey1")) proof.NonMembershipLeafData = make([]byte, 1) - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - _, err = CompactProof(proof, smt.th.hasher) + _, err = CompactProof(proof, smt.th.hasher, smt.values.GetKeySize()) if err == nil { t.Error("did not return error when compacting a malformed proof") } @@ -169,14 +231,14 @@ func TestProofsSanityCheck(t *testing.T) { // Case: unexpected sidenode size. proof, _ = smt.Prove([]byte("testKey1")) proof.SideNodes[0] = make([]byte, 1) - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - _, err = CompactProof(proof, smt.th.hasher) + _, err = CompactProof(proof, smt.th.hasher, smt.values.GetKeySize()) if err == nil { t.Error("did not return error when compacting a malformed proof") } @@ -184,14 +246,14 @@ func TestProofsSanityCheck(t *testing.T) { // Case: incorrect non-nil sibling data proof, _ = smt.ProveUpdatable([]byte("testKey1")) proof.SiblingData = smt.th.digest(proof.SiblingData) - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result = VerifyProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } - _, err = CompactProof(proof, smt.th.hasher) + _, err = CompactProof(proof, smt.th.hasher, smt.values.GetKeySize()) if err == nil { t.Error("did not return error when compacting a malformed proof") } @@ -199,26 +261,28 @@ func TestProofsSanityCheck(t *testing.T) { // Test sanity check cases for compact proofs. func TestCompactProofsSanityCheck(t *testing.T) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(len([]byte("testKey1"))) + smt := NewSparseMerkleTree(smn, smv, hasher) th := &smt.th - smt.Update([]byte("testKey1"), []byte("testValue1")) - smt.Update([]byte("testKey2"), []byte("testValue2")) - smt.Update([]byte("testKey3"), []byte("testValue3")) + _, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + _, _ = smt.Update([]byte("testKey2"), []byte("testValue2")) + _, _ = smt.Update([]byte("testKey3"), []byte("testValue3")) root, _ := smt.Update([]byte("testKey4"), []byte("testValue4")) // Case (compact proofs): NumSideNodes out of range. proof, _ := smt.ProveCompact([]byte("testKey1")) proof.NumSideNodes = -1 - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - proof.NumSideNodes = th.pathSize()*8 + 1 - if proof.sanityCheck(th) { + proof.NumSideNodes = len([]byte("testKey1"))*8 + 1 + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result := VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result := VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } @@ -226,10 +290,10 @@ func TestCompactProofsSanityCheck(t *testing.T) { // Case (compact proofs): unexpected bit mask length. proof, _ = smt.ProveCompact([]byte("testKey1")) proof.NumSideNodes = 10 - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result = VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result = VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } @@ -237,10 +301,10 @@ func TestCompactProofsSanityCheck(t *testing.T) { // Case (compact proofs): unexpected number of sidenodes for number of side nodes. proof, _ = smt.ProveCompact([]byte("testKey1")) proof.SideNodes = append(proof.SideNodes, proof.SideNodes...) - if proof.sanityCheck(th) { + if proof.sanityCheck(th, smt.values.GetKeySize()) { t.Error("sanity check incorrectly passed") } - result = VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher) + result = VerifyCompactProof(proof, root, []byte("testKey1"), []byte("testValue1"), smt.th.hasher, smt.values.GetKeySize()) if result { t.Error("invalid proof verification returned true") } @@ -259,12 +323,12 @@ func randomiseProof(proof SparseMerkleProof) SparseMerkleProof { } // Check that a non-compact proof is equivalent to the proof returned when it is compacted and de-compacted. -func checkCompactEquivalence(t *testing.T, proof SparseMerkleProof, hasher hash.Hash) { - compactedProof, err := CompactProof(proof, hasher) +func checkCompactEquivalence(t *testing.T, proof SparseMerkleProof, hasher hash.Hash, keySize int) { + compactedProof, err := CompactProof(proof, hasher, keySize) if err != nil { t.Errorf("failed to compact proof %v", err) } - decompactedProof, err := DecompactProof(compactedProof, hasher) + decompactedProof, err := DecompactProof(compactedProof, hasher, keySize) if err != nil { t.Errorf("failed to decompact proof %v", err) } diff --git a/smt.go b/smt.go index 5dcf5b1..aebef57 100644 --- a/smt.go +++ b/smt.go @@ -61,7 +61,7 @@ func (smt *SparseMerkleTree) SetRoot(root []byte) { } func (smt *SparseMerkleTree) depth() int { - return smt.th.pathSize() * 8 + return smt.values.GetKeySize() * 8 } // Get gets the value of a key from the tree. @@ -74,8 +74,7 @@ func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) { return defaultValue, nil } - path := smt.th.path(key) - value, err := smt.values.Get(path) + value, err := smt.values.Get(key) if err != nil { var invalidKeyError *InvalidKeyError @@ -83,10 +82,9 @@ func (smt *SparseMerkleTree) Get(key []byte) ([]byte, error) { if errors.As(err, &invalidKeyError) { // If key isn't found, return default value return defaultValue, nil - } else { - // Otherwise percolate up any other error - return nil, err } + // Otherwise, percolate up any other error + return nil, err } return value, nil } @@ -115,8 +113,7 @@ func (smt *SparseMerkleTree) Delete(key []byte) ([]byte, error) { // UpdateForRoot sets a new value for a key in the tree at a specific root, and returns the new root. func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte) ([]byte, error) { - path := smt.th.path(key) - sideNodes, pathNodes, oldLeafData, _, err := smt.sideNodesForRoot(path, root, false) + sideNodes, pathNodes, oldLeafData, _, err := smt.sideNodesForRoot(key, root, false) if err != nil { return nil, err } @@ -124,18 +121,17 @@ func (smt *SparseMerkleTree) UpdateForRoot(key []byte, value []byte, root []byte var newRoot []byte if bytes.Equal(value, defaultValue) { // Delete operation. - newRoot, err = smt.deleteWithSideNodes(path, sideNodes, pathNodes, oldLeafData) + newRoot, err = smt.deleteWithSideNodes(key, sideNodes, pathNodes, oldLeafData) if errors.Is(err, errKeyAlreadyEmpty) { // This key is already empty; return the old root. return root, nil } - if err := smt.values.Delete(path); err != nil { + if err := smt.values.Delete(key); err != nil { return nil, err } - } else { // Insert or update operation. - newRoot, err = smt.updateWithSideNodes(path, value, sideNodes, pathNodes, oldLeafData) + newRoot, err = smt.updateWithSideNodes(key, value, sideNodes, pathNodes, oldLeafData) } return newRoot, err } @@ -145,13 +141,13 @@ func (smt *SparseMerkleTree) DeleteForRoot(key, root []byte) ([]byte, error) { return smt.UpdateForRoot(key, defaultValue, root) } -func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) { +func (smt *SparseMerkleTree) deleteWithSideNodes(key []byte, sideNodes [][]byte, pathNodes [][]byte, oldLeafData []byte) ([]byte, error) { if bytes.Equal(pathNodes[0], smt.th.placeholder()) { // This key is already empty as it is a placeholder; return an error. return nil, errKeyAlreadyEmpty } - actualPath, _ := smt.th.parseLeaf(oldLeafData) - if !bytes.Equal(path, actualPath) { + actualKey, _ := smt.th.parseLeaf(oldLeafData, smt.values.GetKeySize()) + if !bytes.Equal(key, actualKey) { // This key is already empty as a different key was found its place; return an error. return nil, errKeyAlreadyEmpty } @@ -193,7 +189,7 @@ func (smt *SparseMerkleTree) deleteWithSideNodes(path []byte, sideNodes [][]byte nonPlaceholderReached = true } - if getBitAtFromMSB(path, len(sideNodes)-1-i) == right { + if getBitAtFromMSB(key, len(sideNodes)-1-i) == right { currentHash, currentData = smt.th.digestNode(sideNode, currentData) } else { currentHash, currentData = smt.th.digestNode(currentData, sideNode) @@ -231,7 +227,7 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side commonPrefixCount = smt.depth() } else { var actualPath []byte - actualPath, oldValueHash = smt.th.parseLeaf(oldLeafData) + actualPath, oldValueHash = smt.th.parseLeaf(oldLeafData, smt.values.GetKeySize()) commonPrefixCount = countCommonPrefix(path, actualPath) } if commonPrefixCount != smt.depth() { @@ -312,6 +308,9 @@ func (smt *SparseMerkleTree) updateWithSideNodes(path []byte, value []byte, side // // If the leaf is a placeholder, the leaf data is nil. func (smt *SparseMerkleTree) sideNodesForRoot(path []byte, root []byte, getSiblingData bool) ([][]byte, [][]byte, []byte, []byte, error) { + if len(path) != smt.values.GetKeySize() { + return nil, nil, nil, nil, ErrWrongKeySize + } // Side nodes for the path. Nodes are inserted in reverse order, then the // slice is reversed at the end. sideNodes := make([][]byte, 0, smt.depth()) @@ -406,8 +405,7 @@ func (smt *SparseMerkleTree) ProveUpdatableForRoot(key []byte, root []byte) (Spa } func (smt *SparseMerkleTree) doProveForRoot(key []byte, root []byte, isUpdatable bool) (SparseMerkleProof, error) { - path := smt.th.path(key) - sideNodes, pathNodes, leafData, siblingData, err := smt.sideNodesForRoot(path, root, isUpdatable) + sideNodes, pathNodes, leafData, siblingData, err := smt.sideNodesForRoot(key, root, isUpdatable) if err != nil { return SparseMerkleProof{}, err } @@ -423,8 +421,8 @@ func (smt *SparseMerkleTree) doProveForRoot(key []byte, root []byte, isUpdatable // value, we do not need to add anything else to the proof. var nonMembershipLeafData []byte if !bytes.Equal(pathNodes[0], smt.th.placeholder()) { - actualPath, _ := smt.th.parseLeaf(leafData) - if !bytes.Equal(actualPath, path) { + actualKey, _ := smt.th.parseLeaf(leafData, smt.values.GetKeySize()) + if !bytes.Equal(actualKey, key) { // This is a non-membership proof that involves showing a different leaf. // Add the leaf data to the proof. nonMembershipLeafData = leafData @@ -452,6 +450,6 @@ func (smt *SparseMerkleTree) ProveCompactForRoot(key []byte, root []byte) (Spars if err != nil { return SparseCompactMerkleProof{}, err } - compactedProof, err := CompactProof(proof, smt.th.hasher) + compactedProof, err := CompactProof(proof, smt.th.hasher, smt.values.GetKeySize()) return compactedProof, err } diff --git a/smt_test.go b/smt_test.go index 43f4b39..9ccf1b5 100644 --- a/smt_test.go +++ b/smt_test.go @@ -3,28 +3,180 @@ package smt import ( "bytes" "crypto/sha256" - "hash" "math/rand" "testing" ) +func TestSparseMerkleTreeKeySizeChecks(t *testing.T) { + hasher := sha256.New() + keySize := len([]byte("testKey1")) + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(keySize) + smt := NewSparseMerkleTree(smn, smv, hasher) + + _, _ = smt.Update([]byte("testKey1"), []byte("testValue1")) + _, _ = smt.Update([]byte("testKey2"), []byte("testValue2")) + + _, err := smt.Get(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when getting `keySize + 1`. Actual exception: %v", err) + } + _, err = smt.Get(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when getting `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.Update(randomBytes(keySize+1), []byte("testValue1")) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when updating `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.Update(randomBytes(keySize-1), []byte("testValue1")) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when updating `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.Prove(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when proving `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.Prove(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when proving `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.Delete(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when deleting `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.Delete(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when deleting `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.DeleteForRoot(randomBytes(keySize+1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when delete for root `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.DeleteForRoot(randomBytes(keySize-1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when delete for root `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.GetDescend(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when get descend for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.GetDescend(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when get descend for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.Has(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when has `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.Has(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when has `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.HasDescend(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when has descend `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.HasDescend(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when has descend `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.ProveCompact(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove compact for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.ProveCompact(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove compact for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.ProveCompactForRoot(randomBytes(keySize+1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove compact for root for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.ProveCompactForRoot(randomBytes(keySize-1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove compact for root for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.ProveForRoot(randomBytes(keySize+1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove for root for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.ProveForRoot(randomBytes(keySize-1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove for root for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.ProveUpdatable(randomBytes(keySize + 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove updatable for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.ProveUpdatable(randomBytes(keySize - 1)) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove updatable for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.ProveUpdatableForRoot(randomBytes(keySize+1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove updatable for root for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.ProveUpdatableForRoot(randomBytes(keySize-1), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when prove updatable for root for `keySize - 1`. Actual exception: %v", err) + } + + _, err = smt.UpdateForRoot(randomBytes(keySize+1), []byte("testValue1"), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when updating for root for `keySize + 1`. Actual exception: %v", err) + } + + _, err = smt.UpdateForRoot(randomBytes(keySize-1), []byte("testValue1"), smt.Root()) + if err != ErrWrongKeySize { + t.Errorf("should have returned wrong key size exception when updating for root for `keySize - 1`. Actual exception: %v", err) + } +} + // Test base case tree update operations with a few keys. func TestSparseMerkleTreeUpdateBasic(t *testing.T) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + keySize := len([]byte("testKey1")) + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(keySize) + smt := NewSparseMerkleTree(smn, smv, hasher) var value []byte var has bool var err error // Test getting an empty key. - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting empty key: %v", err) } if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting empty key") } - has, err = smt.Has([]byte("testKey")) + has, err = smt.Has([]byte("testKey1")) if err != nil { t.Errorf("returned error when checking presence of empty key: %v", err) } @@ -33,18 +185,18 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { } // Test updating the empty key. - _, err = smt.Update([]byte("testKey"), []byte("testValue")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } - has, err = smt.Has([]byte("testKey")) + has, err = smt.Has([]byte("testKey1")) if err != nil { t.Errorf("returned error when checking presence of non-empty key: %v", err) } @@ -53,11 +205,11 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { } // Test updating the non-empty key. - _, err = smt.Update([]byte("testKey"), []byte("testValue2")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue2")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } @@ -65,22 +217,24 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { t.Error("did not get correct value when getting non-empty key") } - // Test updating a second empty key where the path for both keys share the - // first 2 bits (when using SHA256). - _, err = smt.Update([]byte("foo"), []byte("testValue")) + // Test updating a second empty key where the path differs in the first 8 bits + differentTestKey := make([]byte, keySize) + copy(differentTestKey, "testKey1") + differentTestKey[0] = differentTestKey[0] << 2 + _, err = smt.Update(differentTestKey, []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty second key: %v", err) } - value, err = smt.Get([]byte("foo")) + value, err = smt.Get(differentTestKey) if err != nil { t.Errorf("returned error when getting non-empty second key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty second key") } // Test updating a third empty key. - _, err = smt.Update([]byte("testKey2"), []byte("testValue")) + _, err = smt.Update([]byte("testKey2"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty third key: %v", err) } @@ -88,10 +242,10 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { if err != nil { t.Errorf("returned error when getting non-empty third key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty third key") } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } @@ -101,7 +255,7 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { // Test that a tree can be imported from a MapStore. smt2 := ImportSparseMerkleTree(smn, smv, sha256.New(), smt.Root()) - value, err = smt2.Get([]byte("testKey")) + value, err = smt2.Get([]byte("testKey1")) if err != nil { t.Error("returned error when getting non-empty key") } @@ -112,26 +266,28 @@ func TestSparseMerkleTreeUpdateBasic(t *testing.T) { // Test known tree ops func TestSparseMerkleTreeKnown(t *testing.T) { - h := newDummyHasher(sha256.New()) - smn, smv := NewSimpleMap(), NewSimpleMap() + h := sha256.New() + keySize := 16 + smn, _ := NewSimpleMap(h.Size()) + smv, _ := NewSimpleMap(keySize) smt := NewSparseMerkleTree(smn, smv, h) var value []byte var err error - baseKey := make([]byte, h.Size()+4) - key1 := make([]byte, h.Size()+4) + baseKey := make([]byte, keySize) + key1 := make([]byte, keySize) copy(key1, baseKey) key1[4] = byte(0b00000000) - key2 := make([]byte, h.Size()+4) + key2 := make([]byte, keySize) copy(key2, baseKey) key2[4] = byte(0b01000000) - key3 := make([]byte, h.Size()+4) + key3 := make([]byte, keySize) copy(key3, baseKey) key3[4] = byte(0b10000000) - key4 := make([]byte, h.Size()+4) + key4 := make([]byte, keySize) copy(key4, baseKey) key4[4] = byte(0b11000000) - key5 := make([]byte, h.Size()+4) + key5 := make([]byte, keySize) copy(key5, baseKey) key5[4] = byte(0b11010000) @@ -197,7 +353,10 @@ func TestSparseMerkleTreeKnown(t *testing.T) { proof3, _ := smt.Prove(key3) proof4, _ := smt.Prove(key4) proof5, _ := smt.Prove(key5) - dsmst := NewDeepSparseMerkleSubTree(NewSimpleMap(), NewSimpleMap(), h, smt.Root()) + + smn, _ = NewSimpleMap(h.Size()) + smv, _ = NewSimpleMap(keySize) + dsmst := NewDeepSparseMerkleSubTree(smn, smv, h, smt.Root()) err = dsmst.AddBranch(proof1, key1, []byte("testValue1")) if err != nil { t.Errorf("returned error when adding branch to deep subtree: %v", err) @@ -222,25 +381,23 @@ func TestSparseMerkleTreeKnown(t *testing.T) { // Test tree operations when two leafs are immediate neighbors. func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { - h := newDummyHasher(sha256.New()) - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, h) + hasher := sha256.New() + keySize := hasher.Size() + smn, _ := NewSimpleMap(keySize) + smv, _ := NewSimpleMap(keySize) + smt := NewSparseMerkleTree(smn, smv, hasher) var value []byte var err error // Make two neighboring keys. - // - // The dummy hash function expects keys to prefixed with four bytes of 0, - // which will cause it to return the preimage itself as the digest, without - // the first four bytes. - key1 := make([]byte, h.Size()+4) + key1 := make([]byte, keySize) rand.Read(key1) key1[0], key1[1], key1[2], key1[3] = byte(0), byte(0), byte(0), byte(0) - key1[h.Size()+4-1] = byte(0) - key2 := make([]byte, h.Size()+4) + key1[keySize-1] = byte(0) + key2 := make([]byte, keySize) copy(key2, key1) // We make key2's least significant bit different than key1's - key2[h.Size()+4-1] = byte(1) + key2[keySize-1] = byte(1) _, err = smt.Update(key1, []byte("testValue1")) if err != nil { @@ -277,42 +434,44 @@ func TestSparseMerkleTreeMaxHeightCase(t *testing.T) { // Test base case tree delete operations with a few keys. func TestSparseMerkleTreeDeleteBasic(t *testing.T) { - smn, smv := NewSimpleMap(), NewSimpleMap() - smt := NewSparseMerkleTree(smn, smv, sha256.New()) + hasher := sha256.New() + smn, _ := NewSimpleMap(hasher.Size()) + smv, _ := NewSimpleMap(len([]byte("testKey1"))) + smt := NewSparseMerkleTree(smn, smv, hasher) // Testing inserting, deleting a key, and inserting it again. - _, err := smt.Update([]byte("testKey"), []byte("testValue")) + _, err := smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } root1 := smt.Root() - _, err = smt.Update([]byte("testKey"), defaultValue) + _, err = smt.Update([]byte("testKey1"), defaultValue) if err != nil { t.Errorf("returned error when deleting key: %v", err) } - value, err := smt.Get([]byte("testKey")) + value, err := smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting deleted key: %v", err) } if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting deleted key") } - has, err := smt.Has([]byte("testKey")) + has, err := smt.Has([]byte("testKey1")) if err != nil { t.Errorf("returned error when checking existence of deleted key: %v", err) } if has { - t.Error("returned 'true' when checking existernce of deleted key") + t.Error("returned 'true' when checking existence of deleted key") } - _, err = smt.Update([]byte("testKey"), []byte("testValue")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } if !bytes.Equal(root1, smt.Root()) { @@ -320,7 +479,7 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) { } // Test inserting and deleting a second key. - _, err = smt.Update([]byte("testKey2"), []byte("testValue")) + _, err = smt.Update([]byte("testKey2"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty second key: %v", err) } @@ -335,44 +494,48 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) { if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting deleted key") } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } if !bytes.Equal(root1, smt.Root()) { t.Error("tree root is not as expected after deleting second key") } - // Test inserting and deleting a different second key, when the the first 2 - // bits of the path for the two keys in the tree are the same (when using SHA256). - _, err = smt.Update([]byte("foo"), []byte("testValue")) + // Test inserting and deleting a different second key, when the first 2 + // bits of the path for the two keys in the tree are the same. + differentTestKey := make([]byte, len([]byte("testKey1"))) + copy(differentTestKey, "testKey1") + differentTestKey[0] = byte(0b1000000) + countCommonPrefix([]byte("testKey1"), differentTestKey) + _, err = smt.Update(differentTestKey, []byte("testValue1")) if err != nil { t.Errorf("unable to update key: %v", err) } - value, err = smt.Get([]byte("foo")) + value, err = smt.Get(differentTestKey) if err != nil { t.Errorf("returned error when updating empty second key: %v", err) } - _, err = smt.Update([]byte("foo"), defaultValue) + _, err = smt.Update(differentTestKey, defaultValue) if err != nil { t.Errorf("returned error when deleting key: %v", err) } - value, err = smt.Get([]byte("foo")) + value, err = smt.Get(differentTestKey) if err != nil { t.Errorf("returned error when getting deleted key: %v", err) } if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting deleted key") } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } if !bytes.Equal(root1, smt.Root()) { @@ -380,91 +543,43 @@ func TestSparseMerkleTreeDeleteBasic(t *testing.T) { } // Testing inserting, deleting a key, and inserting it again, using Delete - _, err = smt.Update([]byte("testKey"), []byte("testValue")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } root1 = smt.Root() - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when deleting key: %v", err) } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting deleted key: %v", err) } if !bytes.Equal(defaultValue, value) { t.Error("did not get default value when getting deleted key") } - has, err = smt.Has([]byte("testKey")) + has, err = smt.Has([]byte("testKey1")) if err != nil { t.Errorf("returned error when checking existence of deleted key: %v", err) } if has { - t.Error("returned 'true' when checking existernce of deleted key") + t.Error("returned 'true' when checking existence of deleted key") } - _, err = smt.Update([]byte("testKey"), []byte("testValue")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } - value, err = smt.Get([]byte("testKey")) + value, err = smt.Get([]byte("testKey1")) if err != nil { t.Errorf("returned error when getting non-empty key: %v", err) } - if !bytes.Equal([]byte("testValue"), value) { + if !bytes.Equal([]byte("testValue1"), value) { t.Error("did not get correct value when getting non-empty key") } if !bytes.Equal(root1, smt.Root()) { t.Error("tree root is not as expected after re-inserting key after deletion") } - -} - -// dummyHasher is a dummy hasher for tests, where the digest of keys is equivalent to the preimage. -type dummyHasher struct { - baseHasher hash.Hash - data []byte -} - -func newDummyHasher(baseHasher hash.Hash) hash.Hash { - return &dummyHasher{ - baseHasher: baseHasher, - } -} - -func (h *dummyHasher) Write(data []byte) (int, error) { - h.data = append(h.data, data...) - return len(data), nil -} - -func (h *dummyHasher) Sum(prefix []byte) []byte { - preimage := make([]byte, len(h.data)) - copy(preimage, h.data) - preimage = append(prefix, preimage...) - - var digest []byte - // Keys should be prefixed with four bytes of value 0. - if bytes.Equal(preimage[:4], []byte{0, 0, 0, 0}) && len(preimage) == h.Size()+4 { - digest = preimage[4:] - } else { - h.baseHasher.Write(preimage) - digest = h.baseHasher.Sum(nil) - h.baseHasher.Reset() - } - - return digest -} - -func (h *dummyHasher) Reset() { - h.data = nil -} - -func (h *dummyHasher) Size() int { - return h.baseHasher.Size() -} - -func (h *dummyHasher) BlockSize() int { - return h.Size() } func TestOrphanRemoval(t *testing.T) { @@ -476,9 +591,11 @@ func TestOrphanRemoval(t *testing.T) { } setup := func() { - smn, smv = NewSimpleMap(), NewSimpleMap() - smt = NewSparseMerkleTree(smn, smv, sha256.New()) - _, err = smt.Update([]byte("testKey"), []byte("testValue")) + hasher := sha256.New() + smn, _ = NewSimpleMap(hasher.Size()) + smv, _ = NewSimpleMap(len([]byte("testKey1"))) + smt = NewSparseMerkleTree(smn, smv, hasher) + _, err = smt.Update([]byte("testKey1"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating empty key: %v", err) } @@ -490,7 +607,7 @@ func TestOrphanRemoval(t *testing.T) { t.Run("delete 1", func(t *testing.T) { setup() - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } @@ -501,7 +618,7 @@ func TestOrphanRemoval(t *testing.T) { t.Run("overwrite 1", func(t *testing.T) { setup() - _, err = smt.Update([]byte("testKey"), []byte("testValue2")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue2")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } @@ -515,9 +632,18 @@ func TestOrphanRemoval(t *testing.T) { newKey string count int } + + newTestKey1 := make([]byte, len([]byte("testKey1"))) + newTestKey2 := make([]byte, len([]byte("testKey1"))) + copy(newTestKey1, "testKey1") + copy(newTestKey2, "testKey1") + + newTestKey1[0] = byte(0b10000000) // key having zero common prefix with `testKey1` + newTestKey2[0] = byte(0b01000000) // key having two common prefixes with `testKey1` + cases := []testCase{ - {"testKey2", 3}, // common prefix = 0, root + 2 leaves - {"foo", 5}, // common prefix = 2, root + 2 node branch + 2 leaves + {string(newTestKey1), 3}, // common prefix = 0, root + 2 leaves + {string(newTestKey2), 5}, // common prefix = 2, root + 2 node branch + 2 leaves } t.Run("delete multiple", func(t *testing.T) { @@ -530,7 +656,7 @@ func TestOrphanRemoval(t *testing.T) { if tc.count != nodeCount() { t.Errorf("expected %d nodes after insertion, got: %d", tc.count, nodeCount()) } - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } @@ -549,14 +675,14 @@ func TestOrphanRemoval(t *testing.T) { t.Run("overwrite and delete", func(t *testing.T) { setup() - _, err = smt.Update([]byte("testKey"), []byte("testValue2")) + _, err = smt.Update([]byte("testKey1"), []byte("testValue2")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } if 1 != nodeCount() { t.Errorf("expected 1 nodes after insertion, got: %d", nodeCount()) } - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } @@ -571,7 +697,7 @@ func TestOrphanRemoval(t *testing.T) { t.Errorf("returned error when updating non-empty key: %v", err) } if tc.count != nodeCount() { - t.Errorf("expected 1 nodes after insertion, got: %d", nodeCount()) + t.Errorf("expected %d nodes after insertion, got: %d", tc.count, nodeCount()) } _, err = smt.Update([]byte(tc.newKey), []byte("testValue3")) if err != nil { @@ -580,7 +706,7 @@ func TestOrphanRemoval(t *testing.T) { if tc.count != nodeCount() { t.Errorf("expected %d nodes after insertion, got: %d", tc.count, nodeCount()) } - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } @@ -600,11 +726,11 @@ func TestOrphanRemoval(t *testing.T) { t.Run("delete duplicate value", func(t *testing.T) { setup() - _, err = smt.Update([]byte("testKey2"), []byte("testValue")) + _, err = smt.Update([]byte("testKey2"), []byte("testValue1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } - _, err = smt.Delete([]byte("testKey")) + _, err = smt.Delete([]byte("testKey1")) if err != nil { t.Errorf("returned error when updating non-empty key: %v", err) } diff --git a/treehasher.go b/treehasher.go index eda798e..d510b5a 100644 --- a/treehasher.go +++ b/treehasher.go @@ -14,8 +14,10 @@ type treeHasher struct { } func newTreeHasher(hasher hash.Hash) *treeHasher { - th := treeHasher{hasher: hasher} - th.zeroValue = make([]byte, th.pathSize()) + th := treeHasher{ + hasher: hasher, + } + th.zeroValue = make([]byte, th.hasher.Size()) return &th } @@ -27,25 +29,19 @@ func (th *treeHasher) digest(data []byte) []byte { return sum } -func (th *treeHasher) path(key []byte) []byte { - return th.digest(key) -} - func (th *treeHasher) digestLeaf(path []byte, leafData []byte) ([]byte, []byte) { value := make([]byte, 0, len(leafPrefix)+len(path)+len(leafData)) value = append(value, leafPrefix...) value = append(value, path...) value = append(value, leafData...) - th.hasher.Write(value) - sum := th.hasher.Sum(nil) - th.hasher.Reset() + sum := th.digest(value) return sum, value } -func (th *treeHasher) parseLeaf(data []byte) ([]byte, []byte) { - return data[len(leafPrefix) : th.pathSize()+len(leafPrefix)], data[len(leafPrefix)+th.pathSize():] +func (th *treeHasher) parseLeaf(data []byte, keySize int) ([]byte, []byte) { + return data[len(leafPrefix) : keySize+len(leafPrefix)], data[len(leafPrefix)+keySize:] } func (th *treeHasher) isLeaf(data []byte) bool { @@ -58,19 +54,13 @@ func (th *treeHasher) digestNode(leftData []byte, rightData []byte) ([]byte, []b value = append(value, leftData...) value = append(value, rightData...) - th.hasher.Write(value) - sum := th.hasher.Sum(nil) - th.hasher.Reset() + sum := th.digest(value) return sum, value } func (th *treeHasher) parseNode(data []byte) ([]byte, []byte) { - return data[len(nodePrefix) : th.pathSize()+len(nodePrefix)], data[len(nodePrefix)+th.pathSize():] -} - -func (th *treeHasher) pathSize() int { - return th.hasher.Size() + return data[len(nodePrefix) : th.hasher.Size()+len(nodePrefix)], data[len(nodePrefix)+th.hasher.Size():] } func (th *treeHasher) placeholder() []byte {