Skip to content

Commit

Permalink
compact: Simplify getting hashes in NewTreeWithState (#1618)
Browse files Browse the repository at this point in the history
This change factors out getting the list of node IDs needed to initialize compact.Tree
into a separate exported function.

This eliminates GetNodesFunc callback which makes client code simpler, and also
allows a more direct initialization of compact.Range.
  • Loading branch information
pav-kv authored May 20, 2019
1 parent c3b56b5 commit 4c202cf
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 100 deletions.
59 changes: 29 additions & 30 deletions log/sequencer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,40 +133,35 @@ func NewSequencer(
}

func (s Sequencer) buildMerkleTreeFromStorageAtRoot(ctx context.Context, root *types.LogRootV1, tx storage.TreeTX) (*compact.Tree, error) {
mt, err := compact.NewTreeWithState(s.hasher, int64(root.TreeSize), func(ids []compact.NodeID) ([][]byte, error) {
storIDs := make([]storage.NodeID, len(ids))
for i, id := range ids {
nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth)
if err != nil {
return nil, fmt.Errorf("failed to create nodeID: %v", err)
}
storIDs[i] = nodeID
}

nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs)
ids := compact.TreeNodes(root.TreeSize)
storIDs := make([]storage.NodeID, len(ids))
for i, id := range ids {
nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth)
if err != nil {
return nil, fmt.Errorf("failed to get Merkle nodes: %v", err)
}
if got, want := len(nodes), len(storIDs); got != want {
return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got)
}
for i, id := range storIDs {
if !nodes[i].NodeID.Equivalent(id) {
return nil, fmt.Errorf("node ID mismatch at %d", i)
}
return nil, fmt.Errorf("failed to create nodeID: %v", err)
}
storIDs[i] = nodeID
}

hashes := make([][]byte, len(nodes))
for i, node := range nodes {
hashes[i] = node.Hash
nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs)
if err != nil {
return nil, fmt.Errorf("failed to get Merkle nodes: %v", err)
}
if got, want := len(nodes), len(storIDs); got != want {
return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got)
}
for i, id := range storIDs {
if !nodes[i].NodeID.Equivalent(id) {
return nil, fmt.Errorf("node ID mismatch at %d", i)
}
return hashes, nil
}, root.RootHash)
}

if err != nil {
return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err)
hashes := make([][]byte, len(nodes))
for i, node := range nodes {
hashes[i] = node.Hash
}
return mt, nil

return compact.NewTreeWithState(s.hasher, int64(root.TreeSize), hashes, root.RootHash)
}

func (s Sequencer) buildNodesFromNodeMap(nodeMap map[compact.NodeID][]byte, newVersion int64) ([]storage.Node, error) {
Expand Down Expand Up @@ -241,8 +236,12 @@ func (s Sequencer) initMerkleTreeFromStorage(ctx context.Context, currentRoot *t
return compact.NewTree(s.hasher), nil
}

// Initialize the compact tree state to match the latest root in the database
return s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx)
// Initialize the compact tree state to match the latest root in the database.
mt, err := s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx)
if err != nil {
return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err)
}
return mt, err
}

// sequencingTask provides sequenced LogLeaf entries, and updates storage
Expand Down
8 changes: 4 additions & 4 deletions log/sequencer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ var (
// Nodes that will be loaded when updating the tree of size 21.
compactTree21 = []storage.Node{
{
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64},
Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="),
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60},
Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="),
NodeRevision: 5,
},
{
Expand All @@ -146,8 +146,8 @@ var (
NodeRevision: 5,
},
{
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60},
Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="),
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64},
Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="),
NodeRevision: 5,
},
}
Expand Down
61 changes: 28 additions & 33 deletions merkle/compact/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,18 @@ func isPerfectTree(size int64) bool {
return size != 0 && (size&(size-1) == 0)
}

// GetNodesFunc is a function prototype which can look up particular nodes
// within a non-compact Merkle tree. Used by the compact Tree to populate
// itself with correct state when starting up with a non-empty tree.
type GetNodesFunc func(ids []NodeID) ([][]byte, error)

// NewTreeWithState creates a new compact Tree for the passed in size.
//
// This can fail if the nodes required to recreate the tree state cannot be
// fetched or the calculated root hash after population does not match the
// expected value.
// This can fail if the number of hashes does not correspond to the tree size,
// or the calculated root hash does not match the passed in expected value.
//
// getNodesFn will be called with the coordinates of internal Merkle tree nodes
// whose hash values are required to initialize the internal state of the
// compact Tree. The expectedRoot is the known-good tree root of the tree at
// the specified size, and is used to verify the initial state.
func NewTreeWithState(hasher hashers.LogHasher, size int64, getNodesFn GetNodesFunc, expectedRoot []byte) (*Tree, error) {
ids := make([]NodeID, 0, bits.OnesCount64(uint64(size)))
// Iterate over perfect subtrees along the right border of the tree. Those
// correspond to the bits of the tree size that are set to one.
for sz := uint64(size); sz != 0; sz &= sz - 1 {
level := uint(bits.TrailingZeros64(sz))
index := (sz - 1) >> level
ids = append(ids, NewNodeID(level, index))
}
hashes, err := getNodesFn(ids)
if err != nil {
return nil, fmt.Errorf("failed to fetch nodes: %v", err)
}
if got, want := len(hashes), len(ids); got != want {
return nil, fmt.Errorf("got %d hashes, needed %d", got, want)
}
// Note: Right border nodes of compact.Range are ordered from root to leaves.
for i, j := 0, len(hashes)-1; i < j; i, j = i+1, j-1 {
hashes[i], hashes[j] = hashes[j], hashes[i]
}

// hashes is the list of node hashes that comprise the compact tree. The list
// of the corresponding node IDs that the caller can use to retrieve these
// hashes can be obtained using the TreeNodes function.
//
// The expectedRoot is the known-good tree root of the tree at the specified
// size, and is used to verify the initial state.
func NewTreeWithState(hasher hashers.LogHasher, size int64, hashes [][]byte, expectedRoot []byte) (*Tree, error) {
fact := RangeFactory{Hash: hasher.HashChildren}
rng, err := fact.NewRange(0, uint64(size), hashes)
if err != nil {
Expand Down Expand Up @@ -210,3 +186,22 @@ func (t *Tree) getNodes() [][]byte {
}
return n
}

// TreeNodes returns the list of node IDs that comprise a compact tree, in the
// same order they are used in compact.Tree and compact.Range, i.e. ordered
// from upper to lower levels.
func TreeNodes(size uint64) []NodeID {
ids := make([]NodeID, 0, bits.OnesCount64(size))
// Iterate over perfect subtrees along the right border of the tree. Those
// correspond to the bits of the tree size that are set to one.
for sz := size; sz != 0; sz &= sz - 1 {
level := uint(bits.TrailingZeros64(sz))
index := (sz - 1) >> level
ids = append(ids, NewNodeID(level, index))
}
// Note: Right border nodes of compact.Range are ordered from root to leaves.
for i, j := 0, len(ids)-1; i < j; i, j = i+1, j-1 {
ids[i], ids[j] = ids[j], ids[i]
}
return ids
}
72 changes: 39 additions & 33 deletions merkle/compact/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package compact
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"math/bits"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -131,32 +131,22 @@ func TestAppendLeaf(t *testing.T) {
}
}

func failingGetNodesFunc(_ []NodeID) ([][]byte, error) {
return nil, errors.New("bang")
}

// This returns something that won't result in a valid root hash match, doesn't really
// matter what it is but it must be correct length for an SHA256 hash as if it was real
func fixedHashGetNodesFunc(ids []NodeID) ([][]byte, error) {
func fixedHashGetNodesFunc(ids []NodeID) [][]byte {
hashes := make([][]byte, len(ids))
for i := range ids {
hashes[i] = []byte("12345678901234567890123456789012")
}
return hashes, nil
}

func TestLoadingTreeFailsNodeFetch(t *testing.T) {
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, failingGetNodesFunc, []byte("notimportant"))

if err == nil || !strings.Contains(err.Error(), "bang") {
t.Errorf("Did not return correctly on failed node fetch: %v", err)
}
return hashes
}

func TestLoadingTreeFailsBadRootHash(t *testing.T) {
hashes := fixedHashGetNodesFunc(TreeNodes(237))

// Supply a root hash that can't possibly match the result of the SHA 256 hashing on our dummy
// data
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, fixedHashGetNodesFunc, []byte("nomatch!nomatch!nomatch!nomatch!"))
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, hashes, []byte("nomatch!nomatch!nomatch!nomatch!"))
if err == nil || !strings.HasPrefix(err.Error(), "root hash mismatch") {
t.Errorf("Did not return correct error on root mismatch: %v", err)
}
Expand All @@ -166,26 +156,17 @@ func TestCompactVsFullTree(t *testing.T) {
imt := merkle.NewInMemoryMerkleTree(rfc6962.DefaultHasher)
nodes := make(map[NodeID][]byte)

getNode := func(id NodeID) ([]byte, error) {
return nodes[id], nil
getHashes := func(ids []NodeID) [][]byte {
hashes := make([][]byte, len(ids))
for i, id := range ids {
hashes[i] = nodes[id]
}
return hashes
}

for i := int64(0); i < 1024; i++ {
cmt, err := NewTreeWithState(
rfc6962.DefaultHasher,
imt.LeafCount(),
func(ids []NodeID) ([][]byte, error) {
hashes := make([][]byte, len(ids))
for i, id := range ids {
var err error
hashes[i], err = getNode(id)
if err != nil {
return nil, err
}
}
return hashes, nil
}, imt.CurrentRoot().Hash())

hashes := getHashes(TreeNodes(uint64(imt.LeafCount())))
cmt, err := NewTreeWithState(rfc6962.DefaultHasher, imt.LeafCount(), hashes, imt.CurrentRoot().Hash())
if err != nil {
t.Errorf("interation %d: failed to create CMT with state: %v", i, err)
}
Expand Down Expand Up @@ -284,6 +265,31 @@ func TestRootHashForVariousTreeSizes(t *testing.T) {
}
}

func TestTreeNodes(t *testing.T) {
for _, tc := range []struct {
size uint64
want []NodeID
}{
{size: 0, want: []NodeID{}},
{size: 1, want: []NodeID{{Level: 0, Index: 0}}},
{size: 2, want: []NodeID{{Level: 1, Index: 0}}},
{size: 3, want: []NodeID{{Level: 1, Index: 0}, {Level: 0, Index: 2}}},
{size: 4, want: []NodeID{{Level: 2, Index: 0}}},
{size: 5, want: []NodeID{{Level: 2, Index: 0}, {Level: 0, Index: 4}}},
{size: 15, want: []NodeID{{Level: 3, Index: 0}, {Level: 2, Index: 2}, {Level: 1, Index: 6}, {Level: 0, Index: 14}}},
{size: 100, want: []NodeID{{Level: 6, Index: 0}, {Level: 5, Index: 2}, {Level: 2, Index: 24}}},
{size: 513, want: []NodeID{{Level: 9, Index: 0}, {Level: 0, Index: 512}}},
{size: uint64(1) << 63, want: []NodeID{{Level: 63, Index: 0}}},
{size: (uint64(1) << 63) + (uint64(1) << 57), want: []NodeID{{Level: 63, Index: 0}, {Level: 57, Index: 64}}},
} {
t.Run(fmt.Sprintf("size:%d", tc.size), func(t *testing.T) {
if got, want := TreeNodes(tc.size), tc.want; !reflect.DeepEqual(got, tc.want) {
t.Fatalf("TreeNodes: got %v, want %v", got, want)
}
})
}
}

func benchmarkAppendLeaf(b *testing.B, visit VisitFn) {
b.Helper()
const size = 1024
Expand Down

0 comments on commit 4c202cf

Please sign in to comment.