From 4c202cf89104e6fff99441894096bc9af178daae Mon Sep 17 00:00:00 2001 From: Pavel Kalinnikov Date: Mon, 20 May 2019 15:52:54 +0100 Subject: [PATCH] compact: Simplify getting hashes in NewTreeWithState (#1618) This change factors out getting the list of node IDs needed to initialize compact.Tree into a separate exported function. This eliminates GetNodesFunc callback which makes client code simpler, and also allows a more direct initialization of compact.Range. --- log/sequencer.go | 59 +++++++++++++++--------------- log/sequencer_test.go | 8 ++--- merkle/compact/tree.go | 61 +++++++++++++++---------------- merkle/compact/tree_test.go | 72 ++++++++++++++++++++----------------- 4 files changed, 100 insertions(+), 100 deletions(-) diff --git a/log/sequencer.go b/log/sequencer.go index 3c802d5c99..e323bb1764 100644 --- a/log/sequencer.go +++ b/log/sequencer.go @@ -133,40 +133,35 @@ func NewSequencer( } func (s Sequencer) buildMerkleTreeFromStorageAtRoot(ctx context.Context, root *types.LogRootV1, tx storage.TreeTX) (*compact.Tree, error) { - mt, err := compact.NewTreeWithState(s.hasher, int64(root.TreeSize), func(ids []compact.NodeID) ([][]byte, error) { - storIDs := make([]storage.NodeID, len(ids)) - for i, id := range ids { - nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth) - if err != nil { - return nil, fmt.Errorf("failed to create nodeID: %v", err) - } - storIDs[i] = nodeID - } - - nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs) + ids := compact.TreeNodes(root.TreeSize) + storIDs := make([]storage.NodeID, len(ids)) + for i, id := range ids { + nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth) if err != nil { - return nil, fmt.Errorf("failed to get Merkle nodes: %v", err) - } - if got, want := len(nodes), len(storIDs); got != want { - return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got) - } - for i, id := range storIDs { - if !nodes[i].NodeID.Equivalent(id) { - return nil, fmt.Errorf("node ID mismatch at %d", i) - } + return nil, fmt.Errorf("failed to create nodeID: %v", err) } + storIDs[i] = nodeID + } - hashes := make([][]byte, len(nodes)) - for i, node := range nodes { - hashes[i] = node.Hash + nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs) + if err != nil { + return nil, fmt.Errorf("failed to get Merkle nodes: %v", err) + } + if got, want := len(nodes), len(storIDs); got != want { + return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got) + } + for i, id := range storIDs { + if !nodes[i].NodeID.Equivalent(id) { + return nil, fmt.Errorf("node ID mismatch at %d", i) } - return hashes, nil - }, root.RootHash) + } - if err != nil { - return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err) + hashes := make([][]byte, len(nodes)) + for i, node := range nodes { + hashes[i] = node.Hash } - return mt, nil + + return compact.NewTreeWithState(s.hasher, int64(root.TreeSize), hashes, root.RootHash) } func (s Sequencer) buildNodesFromNodeMap(nodeMap map[compact.NodeID][]byte, newVersion int64) ([]storage.Node, error) { @@ -241,8 +236,12 @@ func (s Sequencer) initMerkleTreeFromStorage(ctx context.Context, currentRoot *t return compact.NewTree(s.hasher), nil } - // Initialize the compact tree state to match the latest root in the database - return s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx) + // Initialize the compact tree state to match the latest root in the database. + mt, err := s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx) + if err != nil { + return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err) + } + return mt, err } // sequencingTask provides sequenced LogLeaf entries, and updates storage diff --git a/log/sequencer_test.go b/log/sequencer_test.go index 5206aa7582..2a2616c64b 100644 --- a/log/sequencer_test.go +++ b/log/sequencer_test.go @@ -136,8 +136,8 @@ var ( // Nodes that will be loaded when updating the tree of size 21. compactTree21 = []storage.Node{ { - NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64}, - Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="), + NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60}, + Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="), NodeRevision: 5, }, { @@ -146,8 +146,8 @@ var ( NodeRevision: 5, }, { - NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60}, - Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="), + NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64}, + Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="), NodeRevision: 5, }, } diff --git a/merkle/compact/tree.go b/merkle/compact/tree.go index 21067006ae..74973270d2 100644 --- a/merkle/compact/tree.go +++ b/merkle/compact/tree.go @@ -39,42 +39,18 @@ func isPerfectTree(size int64) bool { return size != 0 && (size&(size-1) == 0) } -// GetNodesFunc is a function prototype which can look up particular nodes -// within a non-compact Merkle tree. Used by the compact Tree to populate -// itself with correct state when starting up with a non-empty tree. -type GetNodesFunc func(ids []NodeID) ([][]byte, error) - // NewTreeWithState creates a new compact Tree for the passed in size. // -// This can fail if the nodes required to recreate the tree state cannot be -// fetched or the calculated root hash after population does not match the -// expected value. +// This can fail if the number of hashes does not correspond to the tree size, +// or the calculated root hash does not match the passed in expected value. // -// getNodesFn will be called with the coordinates of internal Merkle tree nodes -// whose hash values are required to initialize the internal state of the -// compact Tree. The expectedRoot is the known-good tree root of the tree at -// the specified size, and is used to verify the initial state. -func NewTreeWithState(hasher hashers.LogHasher, size int64, getNodesFn GetNodesFunc, expectedRoot []byte) (*Tree, error) { - ids := make([]NodeID, 0, bits.OnesCount64(uint64(size))) - // Iterate over perfect subtrees along the right border of the tree. Those - // correspond to the bits of the tree size that are set to one. - for sz := uint64(size); sz != 0; sz &= sz - 1 { - level := uint(bits.TrailingZeros64(sz)) - index := (sz - 1) >> level - ids = append(ids, NewNodeID(level, index)) - } - hashes, err := getNodesFn(ids) - if err != nil { - return nil, fmt.Errorf("failed to fetch nodes: %v", err) - } - if got, want := len(hashes), len(ids); got != want { - return nil, fmt.Errorf("got %d hashes, needed %d", got, want) - } - // Note: Right border nodes of compact.Range are ordered from root to leaves. - for i, j := 0, len(hashes)-1; i < j; i, j = i+1, j-1 { - hashes[i], hashes[j] = hashes[j], hashes[i] - } - +// hashes is the list of node hashes that comprise the compact tree. The list +// of the corresponding node IDs that the caller can use to retrieve these +// hashes can be obtained using the TreeNodes function. +// +// The expectedRoot is the known-good tree root of the tree at the specified +// size, and is used to verify the initial state. +func NewTreeWithState(hasher hashers.LogHasher, size int64, hashes [][]byte, expectedRoot []byte) (*Tree, error) { fact := RangeFactory{Hash: hasher.HashChildren} rng, err := fact.NewRange(0, uint64(size), hashes) if err != nil { @@ -210,3 +186,22 @@ func (t *Tree) getNodes() [][]byte { } return n } + +// TreeNodes returns the list of node IDs that comprise a compact tree, in the +// same order they are used in compact.Tree and compact.Range, i.e. ordered +// from upper to lower levels. +func TreeNodes(size uint64) []NodeID { + ids := make([]NodeID, 0, bits.OnesCount64(size)) + // Iterate over perfect subtrees along the right border of the tree. Those + // correspond to the bits of the tree size that are set to one. + for sz := size; sz != 0; sz &= sz - 1 { + level := uint(bits.TrailingZeros64(sz)) + index := (sz - 1) >> level + ids = append(ids, NewNodeID(level, index)) + } + // Note: Right border nodes of compact.Range are ordered from root to leaves. + for i, j := 0, len(ids)-1; i < j; i, j = i+1, j-1 { + ids[i], ids[j] = ids[j], ids[i] + } + return ids +} diff --git a/merkle/compact/tree_test.go b/merkle/compact/tree_test.go index 707a0a5b16..e7b6059e80 100644 --- a/merkle/compact/tree_test.go +++ b/merkle/compact/tree_test.go @@ -17,9 +17,9 @@ package compact import ( "bytes" "encoding/base64" - "errors" "fmt" "math/bits" + "reflect" "strings" "testing" @@ -131,32 +131,22 @@ func TestAppendLeaf(t *testing.T) { } } -func failingGetNodesFunc(_ []NodeID) ([][]byte, error) { - return nil, errors.New("bang") -} - // This returns something that won't result in a valid root hash match, doesn't really // matter what it is but it must be correct length for an SHA256 hash as if it was real -func fixedHashGetNodesFunc(ids []NodeID) ([][]byte, error) { +func fixedHashGetNodesFunc(ids []NodeID) [][]byte { hashes := make([][]byte, len(ids)) for i := range ids { hashes[i] = []byte("12345678901234567890123456789012") } - return hashes, nil -} - -func TestLoadingTreeFailsNodeFetch(t *testing.T) { - _, err := NewTreeWithState(rfc6962.DefaultHasher, 237, failingGetNodesFunc, []byte("notimportant")) - - if err == nil || !strings.Contains(err.Error(), "bang") { - t.Errorf("Did not return correctly on failed node fetch: %v", err) - } + return hashes } func TestLoadingTreeFailsBadRootHash(t *testing.T) { + hashes := fixedHashGetNodesFunc(TreeNodes(237)) + // Supply a root hash that can't possibly match the result of the SHA 256 hashing on our dummy // data - _, err := NewTreeWithState(rfc6962.DefaultHasher, 237, fixedHashGetNodesFunc, []byte("nomatch!nomatch!nomatch!nomatch!")) + _, err := NewTreeWithState(rfc6962.DefaultHasher, 237, hashes, []byte("nomatch!nomatch!nomatch!nomatch!")) if err == nil || !strings.HasPrefix(err.Error(), "root hash mismatch") { t.Errorf("Did not return correct error on root mismatch: %v", err) } @@ -166,26 +156,17 @@ func TestCompactVsFullTree(t *testing.T) { imt := merkle.NewInMemoryMerkleTree(rfc6962.DefaultHasher) nodes := make(map[NodeID][]byte) - getNode := func(id NodeID) ([]byte, error) { - return nodes[id], nil + getHashes := func(ids []NodeID) [][]byte { + hashes := make([][]byte, len(ids)) + for i, id := range ids { + hashes[i] = nodes[id] + } + return hashes } for i := int64(0); i < 1024; i++ { - cmt, err := NewTreeWithState( - rfc6962.DefaultHasher, - imt.LeafCount(), - func(ids []NodeID) ([][]byte, error) { - hashes := make([][]byte, len(ids)) - for i, id := range ids { - var err error - hashes[i], err = getNode(id) - if err != nil { - return nil, err - } - } - return hashes, nil - }, imt.CurrentRoot().Hash()) - + hashes := getHashes(TreeNodes(uint64(imt.LeafCount()))) + cmt, err := NewTreeWithState(rfc6962.DefaultHasher, imt.LeafCount(), hashes, imt.CurrentRoot().Hash()) if err != nil { t.Errorf("interation %d: failed to create CMT with state: %v", i, err) } @@ -284,6 +265,31 @@ func TestRootHashForVariousTreeSizes(t *testing.T) { } } +func TestTreeNodes(t *testing.T) { + for _, tc := range []struct { + size uint64 + want []NodeID + }{ + {size: 0, want: []NodeID{}}, + {size: 1, want: []NodeID{{Level: 0, Index: 0}}}, + {size: 2, want: []NodeID{{Level: 1, Index: 0}}}, + {size: 3, want: []NodeID{{Level: 1, Index: 0}, {Level: 0, Index: 2}}}, + {size: 4, want: []NodeID{{Level: 2, Index: 0}}}, + {size: 5, want: []NodeID{{Level: 2, Index: 0}, {Level: 0, Index: 4}}}, + {size: 15, want: []NodeID{{Level: 3, Index: 0}, {Level: 2, Index: 2}, {Level: 1, Index: 6}, {Level: 0, Index: 14}}}, + {size: 100, want: []NodeID{{Level: 6, Index: 0}, {Level: 5, Index: 2}, {Level: 2, Index: 24}}}, + {size: 513, want: []NodeID{{Level: 9, Index: 0}, {Level: 0, Index: 512}}}, + {size: uint64(1) << 63, want: []NodeID{{Level: 63, Index: 0}}}, + {size: (uint64(1) << 63) + (uint64(1) << 57), want: []NodeID{{Level: 63, Index: 0}, {Level: 57, Index: 64}}}, + } { + t.Run(fmt.Sprintf("size:%d", tc.size), func(t *testing.T) { + if got, want := TreeNodes(tc.size), tc.want; !reflect.DeepEqual(got, tc.want) { + t.Fatalf("TreeNodes: got %v, want %v", got, want) + } + }) + } +} + func benchmarkAppendLeaf(b *testing.B, visit VisitFn) { b.Helper() const size = 1024