Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compact: Simplify getting hashes in NewTreeWithState #1618

Merged
merged 2 commits into from
May 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions log/sequencer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,40 +133,35 @@ func NewSequencer(
}

func (s Sequencer) buildMerkleTreeFromStorageAtRoot(ctx context.Context, root *types.LogRootV1, tx storage.TreeTX) (*compact.Tree, error) {
mt, err := compact.NewTreeWithState(s.hasher, int64(root.TreeSize), func(ids []compact.NodeID) ([][]byte, error) {
storIDs := make([]storage.NodeID, len(ids))
for i, id := range ids {
nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth)
if err != nil {
return nil, fmt.Errorf("failed to create nodeID: %v", err)
}
storIDs[i] = nodeID
}

nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs)
ids := compact.TreeNodes(root.TreeSize)
storIDs := make([]storage.NodeID, len(ids))
for i, id := range ids {
nodeID, err := storage.NewNodeIDForTreeCoords(int64(id.Level), int64(id.Index), maxTreeDepth)
if err != nil {
return nil, fmt.Errorf("failed to get Merkle nodes: %v", err)
}
if got, want := len(nodes), len(storIDs); got != want {
return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got)
}
for i, id := range storIDs {
if !nodes[i].NodeID.Equivalent(id) {
return nil, fmt.Errorf("node ID mismatch at %d", i)
}
return nil, fmt.Errorf("failed to create nodeID: %v", err)
}
storIDs[i] = nodeID
}

hashes := make([][]byte, len(nodes))
for i, node := range nodes {
hashes[i] = node.Hash
nodes, err := tx.GetMerkleNodes(ctx, int64(root.Revision), storIDs)
if err != nil {
return nil, fmt.Errorf("failed to get Merkle nodes: %v", err)
}
if got, want := len(nodes), len(storIDs); got != want {
return nil, fmt.Errorf("failed to get %d nodes at rev %d, got %d", want, root.Revision, got)
}
for i, id := range storIDs {
if !nodes[i].NodeID.Equivalent(id) {
return nil, fmt.Errorf("node ID mismatch at %d", i)
}
return hashes, nil
}, root.RootHash)
}

if err != nil {
return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err)
hashes := make([][]byte, len(nodes))
for i, node := range nodes {
hashes[i] = node.Hash
}
return mt, nil

return compact.NewTreeWithState(s.hasher, int64(root.TreeSize), hashes, root.RootHash)
}

func (s Sequencer) buildNodesFromNodeMap(nodeMap map[compact.NodeID][]byte, newVersion int64) ([]storage.Node, error) {
Expand Down Expand Up @@ -241,8 +236,12 @@ func (s Sequencer) initMerkleTreeFromStorage(ctx context.Context, currentRoot *t
return compact.NewTree(s.hasher), nil
}

// Initialize the compact tree state to match the latest root in the database
return s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx)
// Initialize the compact tree state to match the latest root in the database.
mt, err := s.buildMerkleTreeFromStorageAtRoot(ctx, currentRoot, tx)
if err != nil {
pav-kv marked this conversation as resolved.
Show resolved Hide resolved
return nil, fmt.Errorf("%x: %v", s.signer.KeyHint, err)
}
return mt, err
}

// sequencingTask provides sequenced LogLeaf entries, and updates storage
Expand Down
8 changes: 4 additions & 4 deletions log/sequencer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ var (
// Nodes that will be loaded when updating the tree of size 21.
compactTree21 = []storage.Node{
{
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64},
Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="),
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60},
Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="),
NodeRevision: 5,
},
{
Expand All @@ -146,8 +146,8 @@ var (
NodeRevision: 5,
},
{
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, PrefixLenBits: 60},
Hash: testonly.MustDecodeBase64("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC="),
NodeID: storage.NodeID{Path: []uint8{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14}, PrefixLenBits: 64},
Hash: testonly.MustDecodeBase64("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="),
NodeRevision: 5,
},
}
Expand Down
61 changes: 28 additions & 33 deletions merkle/compact/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,18 @@ func isPerfectTree(size int64) bool {
return size != 0 && (size&(size-1) == 0)
}

// GetNodesFunc is a function prototype which can look up particular nodes
// within a non-compact Merkle tree. Used by the compact Tree to populate
// itself with correct state when starting up with a non-empty tree.
type GetNodesFunc func(ids []NodeID) ([][]byte, error)

// NewTreeWithState creates a new compact Tree for the passed in size.
//
// This can fail if the nodes required to recreate the tree state cannot be
// fetched or the calculated root hash after population does not match the
// expected value.
// This can fail if the number of hashes does not correspond to the tree size,
// or the calculated root hash does not match the passed in expected value.
//
// getNodesFn will be called with the coordinates of internal Merkle tree nodes
// whose hash values are required to initialize the internal state of the
// compact Tree. The expectedRoot is the known-good tree root of the tree at
// the specified size, and is used to verify the initial state.
func NewTreeWithState(hasher hashers.LogHasher, size int64, getNodesFn GetNodesFunc, expectedRoot []byte) (*Tree, error) {
ids := make([]NodeID, 0, bits.OnesCount64(uint64(size)))
// Iterate over perfect subtrees along the right border of the tree. Those
// correspond to the bits of the tree size that are set to one.
for sz := uint64(size); sz != 0; sz &= sz - 1 {
level := uint(bits.TrailingZeros64(sz))
index := (sz - 1) >> level
ids = append(ids, NewNodeID(level, index))
}
hashes, err := getNodesFn(ids)
if err != nil {
return nil, fmt.Errorf("failed to fetch nodes: %v", err)
}
if got, want := len(hashes), len(ids); got != want {
return nil, fmt.Errorf("got %d hashes, needed %d", got, want)
}
// Note: Right border nodes of compact.Range are ordered from root to leaves.
for i, j := 0, len(hashes)-1; i < j; i, j = i+1, j-1 {
hashes[i], hashes[j] = hashes[j], hashes[i]
}

// hashes is the list of node hashes that comprise the compact tree. The list
// of the corresponding node IDs that the caller can use to retrieve these
// hashes can be obtained using the TreeNodes function.
//
// The expectedRoot is the known-good tree root of the tree at the specified
// size, and is used to verify the initial state.
func NewTreeWithState(hasher hashers.LogHasher, size int64, hashes [][]byte, expectedRoot []byte) (*Tree, error) {
fact := RangeFactory{Hash: hasher.HashChildren}
rng, err := fact.NewRange(0, uint64(size), hashes)
if err != nil {
Expand Down Expand Up @@ -210,3 +186,22 @@ func (t *Tree) getNodes() [][]byte {
}
return n
}

// TreeNodes returns the list of node IDs that comprise a compact tree, in the
// same order they are used in compact.Tree and compact.Range, i.e. ordered
// from upper to lower levels.
func TreeNodes(size uint64) []NodeID {
ids := make([]NodeID, 0, bits.OnesCount64(size))
// Iterate over perfect subtrees along the right border of the tree. Those
// correspond to the bits of the tree size that are set to one.
for sz := size; sz != 0; sz &= sz - 1 {
level := uint(bits.TrailingZeros64(sz))
index := (sz - 1) >> level
ids = append(ids, NewNodeID(level, index))
}
// Note: Right border nodes of compact.Range are ordered from root to leaves.
for i, j := 0, len(ids)-1; i < j; i, j = i+1, j-1 {
ids[i], ids[j] = ids[j], ids[i]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought but we could add the other ordering to Range if reversal is going to crop up a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thought, we'll see if/when another use-case appears. That wouldn't be a trivial change though. Range implementation has a very specific ordering: hashes slice contains 2 stacks of hashes - the front one is ordered bottom-to-top, the back one is ordered top-to-bottom. Stack ordering comes handy when 2 ranges are merged, so that the slice is mutated with minimal copies/reallocations.

}
return ids
}
72 changes: 39 additions & 33 deletions merkle/compact/tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package compact
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"math/bits"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -131,32 +131,22 @@ func TestAppendLeaf(t *testing.T) {
}
}

func failingGetNodesFunc(_ []NodeID) ([][]byte, error) {
return nil, errors.New("bang")
}

// This returns something that won't result in a valid root hash match, doesn't really
// matter what it is but it must be correct length for an SHA256 hash as if it was real
func fixedHashGetNodesFunc(ids []NodeID) ([][]byte, error) {
func fixedHashGetNodesFunc(ids []NodeID) [][]byte {
hashes := make([][]byte, len(ids))
for i := range ids {
hashes[i] = []byte("12345678901234567890123456789012")
}
return hashes, nil
}

func TestLoadingTreeFailsNodeFetch(t *testing.T) {
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, failingGetNodesFunc, []byte("notimportant"))

if err == nil || !strings.Contains(err.Error(), "bang") {
t.Errorf("Did not return correctly on failed node fetch: %v", err)
}
return hashes
}

func TestLoadingTreeFailsBadRootHash(t *testing.T) {
hashes := fixedHashGetNodesFunc(TreeNodes(237))

// Supply a root hash that can't possibly match the result of the SHA 256 hashing on our dummy
// data
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, fixedHashGetNodesFunc, []byte("nomatch!nomatch!nomatch!nomatch!"))
_, err := NewTreeWithState(rfc6962.DefaultHasher, 237, hashes, []byte("nomatch!nomatch!nomatch!nomatch!"))
if err == nil || !strings.HasPrefix(err.Error(), "root hash mismatch") {
t.Errorf("Did not return correct error on root mismatch: %v", err)
}
Expand All @@ -166,26 +156,17 @@ func TestCompactVsFullTree(t *testing.T) {
imt := merkle.NewInMemoryMerkleTree(rfc6962.DefaultHasher)
nodes := make(map[NodeID][]byte)

getNode := func(id NodeID) ([]byte, error) {
return nodes[id], nil
getHashes := func(ids []NodeID) [][]byte {
hashes := make([][]byte, len(ids))
for i, id := range ids {
hashes[i] = nodes[id]
}
return hashes
}

for i := int64(0); i < 1024; i++ {
cmt, err := NewTreeWithState(
rfc6962.DefaultHasher,
imt.LeafCount(),
func(ids []NodeID) ([][]byte, error) {
hashes := make([][]byte, len(ids))
for i, id := range ids {
var err error
hashes[i], err = getNode(id)
if err != nil {
return nil, err
}
}
return hashes, nil
}, imt.CurrentRoot().Hash())

hashes := getHashes(TreeNodes(uint64(imt.LeafCount())))
cmt, err := NewTreeWithState(rfc6962.DefaultHasher, imt.LeafCount(), hashes, imt.CurrentRoot().Hash())
if err != nil {
t.Errorf("interation %d: failed to create CMT with state: %v", i, err)
}
Expand Down Expand Up @@ -284,6 +265,31 @@ func TestRootHashForVariousTreeSizes(t *testing.T) {
}
}

func TestTreeNodes(t *testing.T) {
for _, tc := range []struct {
size uint64
want []NodeID
}{
{size: 0, want: []NodeID{}},
{size: 1, want: []NodeID{{Level: 0, Index: 0}}},
{size: 2, want: []NodeID{{Level: 1, Index: 0}}},
{size: 3, want: []NodeID{{Level: 1, Index: 0}, {Level: 0, Index: 2}}},
{size: 4, want: []NodeID{{Level: 2, Index: 0}}},
{size: 5, want: []NodeID{{Level: 2, Index: 0}, {Level: 0, Index: 4}}},
{size: 15, want: []NodeID{{Level: 3, Index: 0}, {Level: 2, Index: 2}, {Level: 1, Index: 6}, {Level: 0, Index: 14}}},
{size: 100, want: []NodeID{{Level: 6, Index: 0}, {Level: 5, Index: 2}, {Level: 2, Index: 24}}},
{size: 513, want: []NodeID{{Level: 9, Index: 0}, {Level: 0, Index: 512}}},
{size: uint64(1) << 63, want: []NodeID{{Level: 63, Index: 0}}},
{size: (uint64(1) << 63) + (uint64(1) << 57), want: []NodeID{{Level: 63, Index: 0}, {Level: 57, Index: 64}}},
} {
t.Run(fmt.Sprintf("size:%d", tc.size), func(t *testing.T) {
if got, want := TreeNodes(tc.size), tc.want; !reflect.DeepEqual(got, tc.want) {
t.Fatalf("TreeNodes: got %v, want %v", got, want)
}
})
}
}

func benchmarkAppendLeaf(b *testing.B, visit VisitFn) {
b.Helper()
const size = 1024
Expand Down