diff --git a/integration/log.go b/integration/log.go index d0c87586c5..5245a6b643 100644 --- a/integration/log.go +++ b/integration/log.go @@ -65,9 +65,6 @@ func DefaultTestParameters(treeID int64) TestParameters { } } -// both sizes will be multiplied by the sequencer batch size before being passed to the log -// in cases where we're testing valid proofs -// TODO(Martin2112): This can be updated when the log supports proofs at arbitrary tree sizes type consistencyProofParams struct { size1 int64 size2 int64 @@ -157,19 +154,27 @@ func RunLogIntegration(client trillian.TrillianLogClient, params TestParameters) // Step 6 - Test some consistency proofs glog.Info("Testing consistency proofs") - // Make some proof requests that we know should not succeed + // Make some consistency proof requests that we know should not succeed for _, consistParams := range consistencyProofBadTestParams { - if err := checkConsistencyProof(consistParams, params.treeID, tree, client, params); err == nil { + if err := checkConsistencyProof(consistParams, params.treeID, tree, client, params, int64(params.queueBatchSize)); err == nil { return fmt.Errorf("log consistency for %v: unexpected proof returned", consistParams) } } - // Probe the log between some tree sizes we know are STHs and check the results against - // the in memory tree. + // Probe the log between some tree sizes we know are included and check the results against + // the in memory tree. Request proofs at both STH and non STH sizes unless batch size is one, + // when these would be equivalent requests. for _, consistParams := range consistencyProofTestParams { - if err := checkConsistencyProof(consistParams, params.treeID, tree, client, params); err != nil { + if err := checkConsistencyProof(consistParams, params.treeID, tree, client, params, int64(params.queueBatchSize)); err != nil { return fmt.Errorf("log consistency for %v: proof checks failed: %v", consistParams, err) } + + // Only do this if the batch size changes when halved + if params.queueBatchSize > 1 { + if err := checkConsistencyProof(consistParams, params.treeID, tree, client, params, int64(params.queueBatchSize/2)); err != nil { + return fmt.Errorf("log consistency for %v: proof checks failed (Non STH size): %v", consistParams, err) + } + } } return nil } @@ -385,8 +390,8 @@ func checkInclusionProofsAtIndex(index int64, logID int64, tree *merkle.InMemory resp, err := client.GetInclusionProof(ctx, &trillian.GetInclusionProofRequest{LogId: logID, LeafIndex: index, TreeSize: int64(treeSize)}) cancel() + // If the index is larger than the tree size we cannot have a valid proof if index >= treeSize { - // If the index is larger than the tree size we cannot have a valid proof if err == nil { return fmt.Errorf("log returned proof for index: %d, tree is only size %d", index, treeSize) } @@ -394,38 +399,31 @@ func checkInclusionProofsAtIndex(index int64, logID int64, tree *merkle.InMemory continue } - // If we're not at a valid STH tree size then we can't have a proof - if treeSize == 0 || (treeSize%int64(params.sequencerBatchSize)) != 0 { - if err == nil { - return fmt.Errorf("log returned proof at non STH size: %d", treeSize) - } - } else { - // Otherwise we should have a proof, to be compared against our memory tree - if err != nil || resp.Status.StatusCode != trillian.TrillianApiStatusCode_OK { - return fmt.Errorf("log returned no proof for index %d at size %d, which should have succeeded: %v", index, treeSize, err) - } + // Otherwise we should have a proof, to be compared against our memory tree + if err != nil || resp.Status.StatusCode != trillian.TrillianApiStatusCode_OK { + return fmt.Errorf("log returned no proof for index %d at size %d, which should have succeeded: %v", index, treeSize, err) + } - // Remember that the in memory tree uses 1 based leaf indices - path := tree.PathToRootAtSnapshot(index+1, treeSize) + // Remember that the in memory tree uses 1 based leaf indices + path := tree.PathToRootAtSnapshot(index+1, treeSize) - if err = compareLogAndTreeProof(resp.Proof, path); err != nil { - // The log and tree proof don't match, details in the error - return err - } + if err = compareLogAndTreeProof(resp.Proof, path); err != nil { + // The log and tree proof don't match, details in the error + return err } } return nil } -func checkConsistencyProof(consistParams consistencyProofParams, treeID int64, tree *merkle.InMemoryMerkleTree, client trillian.TrillianLogClient, params TestParameters) error { +func checkConsistencyProof(consistParams consistencyProofParams, treeID int64, tree *merkle.InMemoryMerkleTree, client trillian.TrillianLogClient, params TestParameters, batchSize int64) error { // We expect the proof request to succeed ctx, cancel := getRPCDeadlineContext(params) resp, err := client.GetConsistencyProof(ctx, &trillian.GetConsistencyProofRequest{ LogId: treeID, - FirstTreeSize: consistParams.size1 * int64(params.sequencerBatchSize), - SecondTreeSize: (consistParams.size2 * int64(params.sequencerBatchSize)), + FirstTreeSize: consistParams.size1 * int64(batchSize), + SecondTreeSize: (consistParams.size2 * int64(batchSize)), }) cancel() @@ -435,8 +433,8 @@ func checkConsistencyProof(consistParams consistencyProofParams, treeID int64, t // Get the proof from the memory tree proof := tree.SnapshotConsistency( - (consistParams.size1 * int64(params.sequencerBatchSize)), - (consistParams.size2 * int64(params.sequencerBatchSize))) + (consistParams.size1 * int64(batchSize)), + (consistParams.size2 * int64(batchSize))) // Compare the proofs, they should be identical return compareLogAndTreeProof(resp.Proof, proof) diff --git a/merkle/merkle_path.go b/merkle/merkle_path.go index dc0dc128e0..5ea3fc400c 100644 --- a/merkle/merkle_path.go +++ b/merkle/merkle_path.go @@ -15,14 +15,16 @@ package merkle import ( + "errors" "fmt" "github.com/golang/glog" "github.com/google/trillian/storage" ) -// Verbosity level for logging of debug related items +// Verbosity levels for logging of debug related items const vLevel = 2 +const vvLevel = 4 // NodeFetch bundles a nodeID with additional information on how to use the node to construct the // correct proof. @@ -37,35 +39,37 @@ func (n NodeFetch) Equivalent(other NodeFetch) bool { } // CalcInclusionProofNodeAddresses returns the tree node IDs needed to -// build an inclusion proof for a specified leaf and tree size. The maxBitLen parameter +// build an inclusion proof for a specified leaf and tree size. The snapshot parameter +// is the tree size being queried for, treeSize is the actual size of the tree at the revision +// we are using to fetch nodes (this can be > snapshot). The maxBitLen parameter // is copied into all the returned nodeIDs. -func CalcInclusionProofNodeAddresses(treeSize, index int64, maxBitLen int) ([]NodeFetch, error) { - if index >= treeSize || index < 0 || treeSize < 1 || maxBitLen < 0 { - return []NodeFetch{}, fmt.Errorf("invalid params ts: %d index: %d, bitlen:%d", treeSize, index, maxBitLen) +func CalcInclusionProofNodeAddresses(snapshot, index, treeSize int64, maxBitLen int) ([]NodeFetch, error) { + if snapshot > treeSize || index >= snapshot || index < 0 || snapshot < 1 || maxBitLen <= 0 { + return nil, fmt.Errorf("invalid params s: %d index: %d ts: %d, bitlen:%d", snapshot, index, treeSize, maxBitLen) } - return pathFromNodeToRootAtSnapshot(index, 0, treeSize, maxBitLen) + return pathFromNodeToRootAtSnapshot(index, 0, snapshot, treeSize, maxBitLen) } // CalcConsistencyProofNodeAddresses returns the tree node IDs needed to -// build a consistency proof between two specified tree sizes. The maxBitLen parameter -// is copied into all the returned nodeIDs. The caller is responsible for checking that +// build a consistency proof between two specified tree sizes. snapshot1 and snapshot2 represent +// the two tree sizes for which consistency should be proved, treeSize is the actual size of the +// tree at the revision we are using to fetch nodes (this can be > snapshot2). The maxBitLen +// parameter is copied into all the returned nodeIDs. The caller is responsible for checking that // the input tree sizes correspond to valid tree heads. All returned NodeIDs are tree // coordinates within the new tree. It is assumed that they will be fetched from storage // at a revision corresponding to the STH associated with the treeSize parameter. -func CalcConsistencyProofNodeAddresses(previousTreeSize, treeSize int64, maxBitLen int) ([]NodeFetch, error) { - if previousTreeSize > treeSize || previousTreeSize < 1 || treeSize < 1 || maxBitLen <= 0 { - return []NodeFetch{}, fmt.Errorf("invalid params prior: %d treesize: %d, bitlen:%d", previousTreeSize, treeSize, maxBitLen) +func CalcConsistencyProofNodeAddresses(snapshot1, snapshot2, treeSize int64, maxBitLen int) ([]NodeFetch, error) { + if snapshot1 > snapshot2 || snapshot1 > treeSize || snapshot2 > treeSize || snapshot1 < 1 || snapshot2 < 1 || maxBitLen <= 0 { + return nil, fmt.Errorf("invalid params s1: %d s2: %d ts: %d, bitlen:%d", snapshot1, snapshot2, treeSize, maxBitLen) } - return snapshotConsistency(previousTreeSize, treeSize, maxBitLen) + return snapshotConsistency(snapshot1, snapshot2, treeSize, maxBitLen) } // snapshotConsistency does the calculation of consistency proof node addresses between // two snapshots. Based on the C++ code used by CT but adjusted to fit our situation. -// In particular the code does not need to handle the case where overwritten node hashes -// must be recursively computed because we have versioned nodes. -func snapshotConsistency(snapshot1, snapshot2 int64, maxBitLen int) ([]NodeFetch, error) { +func snapshotConsistency(snapshot1, snapshot2, treeSize int64, maxBitLen int) ([]NodeFetch, error) { proof := make([]NodeFetch, 0, bitLen(snapshot2)+1) glog.V(vLevel).Infof("snapshotConsistency: %d -> %d", snapshot1, snapshot2) @@ -76,13 +80,13 @@ func snapshotConsistency(snapshot1, snapshot2 int64, maxBitLen int) ([]NodeFetch // Compute the (compressed) path to the root of snapshot2. // Everything left of 'node' is equal in both trees; no need to record. for (node & 1) != 0 { - glog.V(vLevel).Infof("Move up: l:%d n:%d", level, node) + glog.V(vvLevel).Infof("Move up: l:%d n:%d", level, node) node >>= 1 level++ } if node != 0 { - glog.V(vLevel).Infof("Not root snapshot1: %d", node) + glog.V(vvLevel).Infof("Not root snapshot1: %d", node) // Not at the root of snapshot 1, record the node n, err := storage.NewNodeIDForTreeCoords(int64(level), node, maxBitLen) if err != nil { @@ -92,15 +96,15 @@ func snapshotConsistency(snapshot1, snapshot2 int64, maxBitLen int) ([]NodeFetch } // Now append the path from this node to the root of snapshot2. - p, err := pathFromNodeToRootAtSnapshot(node, level, snapshot2, maxBitLen) + p, err := pathFromNodeToRootAtSnapshot(node, level, snapshot2, treeSize, maxBitLen) if err != nil { return nil, err } return append(proof, p...), nil } -func pathFromNodeToRootAtSnapshot(node int64, level int, snapshot int64, maxBitLen int) ([]NodeFetch, error) { - glog.V(vLevel).Infof("pathFromNodeToRootAtSnapshot: N:%d, L:%d, S:%d", node, level, snapshot) +func pathFromNodeToRootAtSnapshot(node int64, level int, snapshot, treeSize int64, maxBitLen int) ([]NodeFetch, error) { + glog.V(vLevel).Infof("pathFromNodeToRootAtSnapshot: N:%d, L:%d, S:%d TS:%d", node, level, snapshot, treeSize) proof := make([]NodeFetch, 0, bitLen(snapshot)+1) if snapshot == 0 { @@ -115,7 +119,7 @@ func pathFromNodeToRootAtSnapshot(node int64, level int, snapshot int64, maxBitL sibling := node ^ 1 if sibling < lastNode { // The sibling is not the last node of the level in the snapshot tree - glog.V(vLevel).Infof("Not last: S:%d L:%d", sibling, level) + glog.V(vvLevel).Infof("Not last: S:%d L:%d", sibling, level) n, err := storage.NewNodeIDForTreeCoords(int64(level), sibling, maxBitLen) if err != nil { return nil, err @@ -123,24 +127,36 @@ func pathFromNodeToRootAtSnapshot(node int64, level int, snapshot int64, maxBitL proof = append(proof, NodeFetch{NodeID: n}) } else if sibling == lastNode { // The sibling is the last node of the level in the snapshot tree. - // In the C++ code we'd potentially recompute the node value here because we could be - // referencing a snapshot at a point before additional leaves were added to the tree causing - // some nodes to be overwritten. We have versioned tree nodes so this isn't necessary, - // we won't see any hashes written since the snapshot point. However we do have to account - // for missing levels in the tree. This can only occur on the rightmost tree nodes because - // this is the only area of the tree that is not fully populated. - glog.V(vLevel).Infof("Last: S:%d L:%d", sibling, level) - - // Account for non existent nodes - these can only be the rightmost node at an - // intermediate (non leaf) level in the tree so will always be a right sibling. - l, sibling := skipMissingLevels(snapshot, lastNode, level, node) - n, err := storage.NewNodeIDForTreeCoords(int64(l), sibling, maxBitLen) - if err != nil { - return nil, err + // We might need to recompute a previous hash value here. This can only occur on the + // rightmost tree nodes because this is the only area of the tree that is not fully populated. + glog.V(vvLevel).Infof("Last: S:%d L:%d", sibling, level) + + if snapshot == treeSize { + // No recomputation required as we're using the tree in its current state + // Account for non existent nodes - these can only be the rightmost node at an + // intermediate (non leaf) level in the tree so will always be a right sibling. + n, err := siblingIDSkipLevels(snapshot, lastNode, level, node, maxBitLen) + if err != nil { + return nil, err + } + proof = append(proof, NodeFetch{NodeID: n}) + } else { + // We need to recompute this node, as it was at the prior snapshot point. We record + // the additional fetches needed to do this later + rehashFetches, err := recomputePastSnapshot(snapshot, treeSize, level, maxBitLen) + if err != nil { + return nil, err + } + + // Extra check that the recomputation produced one node + if err = checkRecomputation(rehashFetches); err != nil { + return nil, err + } + + proof = append(proof, rehashFetches...) } - proof = append(proof, NodeFetch{NodeID: n}) } else { - glog.V(vLevel).Infof("Nonexistent: S:%d L:%d", sibling, level) + glog.V(vvLevel).Infof("Nonexistent: S:%d L:%d", sibling, level) } // Sibling > lastNode so does not exist, move up @@ -152,6 +168,117 @@ func pathFromNodeToRootAtSnapshot(node int64, level int, snapshot int64, maxBitL return proof, nil } +// recomputePastSnapshot does the work to recalculate nodes that need to be rehashed because the +// tree state at the snapshot size differs from the size we've stored it at. The calculations +// also need to take into account missing levels, see the tree diagrams in this file. +// If called with snapshot equal to the tree size returns empty. Otherwise, assuming no errors, +// the output of this should always be exactly one node after resolving any rehashing. +// Either a copy of one of the nodes in the tree or a rehashing of multiple nodes to a single +// result node with the value it would have had if the prior snapshot had been stored. +func recomputePastSnapshot(snapshot, treeSize int64, nodeLevel int, maxBitlen int) ([]NodeFetch, error) { + glog.V(vLevel).Infof("recompute s:%d ts:%d level:%d", snapshot, treeSize, nodeLevel) + + fetches := []NodeFetch{} + + if snapshot == treeSize { + // Nothing to do + return nil, nil + } else if snapshot > treeSize { + return nil, fmt.Errorf("recomputePastSnapshot: %d does not exist for tree of size %d", snapshot, treeSize) + } + + // We're recomputing the right hand path, the one to the last leaf + level := 0 + // This is the index of the last node in the snapshot + lastNode := snapshot - 1 + // This is the index of the last node that actually exists in the underlying tree + lastNodeAtLevel := treeSize - 1 + + // Work up towards the root. We may find the node we need without needing to rehash if + // it turns out that the tree is complete up to the level we're recalculating at this + // snapshot. + for (lastNode & 1) != 0 { + if nodeLevel == level { + // Then we want a copy of the node at this level + glog.V(vvLevel).Infof("copying l:%d ln:%d", level, lastNode) + nodeID, err := siblingIDSkipLevels(snapshot, lastNodeAtLevel, level, lastNode^1, maxBitlen) + if err != nil { + return nil, err + } + + glog.V(vvLevel).Infof("copy node at %s", nodeID.CoordString()) + return append(fetches, NodeFetch{Rehash: false, NodeID: nodeID}), nil + } + + // Left sibling and parent exist at this snapshot and don't need to be rehashed + glog.V(vvLevel).Infof("move up ln:%d level:%d", lastNode, level) + lastNode >>= 1 + lastNodeAtLevel >>= 1 + level++ + } + + glog.V(vvLevel).Infof("done ln:%d level:%d", lastNode, level) + + // lastNode is now the index of a left sibling with no right sibling. This is where the + // rehashing starts + savedNodeID, err := siblingIDSkipLevels(snapshot, lastNodeAtLevel, level, lastNode^1, maxBitlen) + glog.V(vvLevel).Infof("root for recompute is: %s", savedNodeID.CoordString()) + if err != nil { + return nil, err + } + + if nodeLevel == level { + glog.V(vvLevel).Info("emit root (1)") + return append(fetches, NodeFetch{Rehash: true, NodeID: savedNodeID}), nil + } + + rehash := false + subRootEmitted := false // whether we've added the recomputed subtree root to the path yet + + // Move towards the tree root (increasing level). Exit when we reach the root or the + // level that is being recomputed. Defer emitting the subtree root to the path until + // the appropriate point because we don't immediately know whether it's part of the + // rehashing. + for lastNode != 0 { + glog.V(vvLevel).Infof("in loop level:%d ln:%d lnal:%d", level, lastNode, lastNodeAtLevel) + + if (lastNode & 1) != 0 { + nodeID, err := siblingIDSkipLevels(snapshot, lastNodeAtLevel, level, (lastNode-1)^1, maxBitlen) + if err != nil { + return nil, err + } + + if !rehash && !subRootEmitted { + glog.V(vvLevel).Info("emit root (2)") + fetches = append(fetches, NodeFetch{Rehash: true, NodeID: savedNodeID}) + subRootEmitted = true + } + + glog.V(vvLevel).Infof("rehash with %s", nodeID.CoordString()) + fetches = append(fetches, NodeFetch{Rehash: true, NodeID: nodeID}) + rehash = true + } + + lastNode >>= 1 + lastNodeAtLevel >>= 1 + level++ + + if nodeLevel == level && !subRootEmitted { + glog.V(vvLevel).Info("emit root (3)") + return append(fetches, NodeFetch{Rehash: rehash, NodeID: savedNodeID}), nil + } + + // Exit early if we've gone far enough up the tree to hit the level we're recomputing + if level == nodeLevel { + glog.V(vvLevel).Infof("returning fetches early: %v", fetches) + return fetches, nil + } + } + + glog.V(vvLevel).Infof("returning fetches: %v", fetches) + return fetches, nil +} + // lastNodeWritten determines if the last node is present in storage for a given Merkle tree size // and level in the tree (0 = leaves, increasing towards the root). This is determined by // examining the bits of the last valid leaf index in a tree of the specified size. Zero bits @@ -245,8 +372,39 @@ func skipMissingLevels(snapshot, lastNode int64, level int, node int64) (int, in level-- sibling *= 2 lastNode = (snapshot - 1) >> uint(level) - glog.V(vLevel).Infof("Move down: S:%d L:%d LN:%d", sibling, level, lastNode) + glog.V(vvLevel).Infof("Move down: S:%d L:%d LN:%d", sibling, level, lastNode) } return level, sibling } + +// checkRecomputation carries out an additional check that the results of recomputePastSnapshot +// are valid. There must be at least one fetch. All fetches must have the same rehash state and if +// there is only one fetch then it must not be a rehash. If all checks pass then the fetches +// represent one node after rehashing is completed. +func checkRecomputation(fetches []NodeFetch) error { + switch len(fetches) { + case 0: + return errors.New("recomputePastSnapshot returned nothing") + case 1: + if fetches[0].Rehash { + return fmt.Errorf("recomputePastSnapshot returned invalid rehash: %v", fetches) + } + default: + for i := range fetches { + if i > 0 && fetches[i].Rehash != fetches[0].Rehash { + return fmt.Errorf("recomputePastSnapshot returned mismatched rehash nodes: %v", fetches) + } + } + } + + return nil +} + +// siblingIDSkipLevels creates a new NodeID for the supplied node, accounting for levels skipped +// in storage. Note that it returns an ID for the node sibling so care should be taken to pass the +// correct value for the node parameter. +func siblingIDSkipLevels(snapshot, lastNode int64, level int, node int64, maxBitLen int) (storage.NodeID, error) { + l, sibling := skipMissingLevels(snapshot, lastNode, level, node) + return storage.NewNodeIDForTreeCoords(int64(l), sibling, maxBitLen) +} diff --git a/merkle/merkle_path_test.go b/merkle/merkle_path_test.go index 27287a3745..dbf5a0894e 100644 --- a/merkle/merkle_path_test.go +++ b/merkle/merkle_path_test.go @@ -250,7 +250,7 @@ func TestBitLen(t *testing.T) { func TestCalcInclusionProofNodeAddresses(t *testing.T) { for _, testCase := range pathTests { - path, err := CalcInclusionProofNodeAddresses(testCase.treeSize, testCase.leafIndex, 64) + path, err := CalcInclusionProofNodeAddresses(testCase.treeSize, testCase.leafIndex, testCase.treeSize, 64) if err != nil { t.Fatalf("unexpected error calculating path %v: %v", testCase, err) @@ -262,7 +262,7 @@ func TestCalcInclusionProofNodeAddresses(t *testing.T) { func TestCalcInclusionProofNodeAddressesBadRanges(t *testing.T) { for _, testCase := range pathTestBad { - _, err := CalcInclusionProofNodeAddresses(testCase.treeSize, testCase.leafIndex, 64) + _, err := CalcInclusionProofNodeAddresses(testCase.treeSize, testCase.leafIndex, testCase.treeSize, 64) if err == nil { t.Fatalf("incorrectly accepted bad params: %v", testCase) @@ -271,7 +271,7 @@ func TestCalcInclusionProofNodeAddressesBadRanges(t *testing.T) { } func TestCalcInclusionProofNodeAddressesRejectsBadBitLen(t *testing.T) { - _, err := CalcInclusionProofNodeAddresses(7, 3, -64) + _, err := CalcInclusionProofNodeAddresses(7, 3, 7, -64) if err == nil { t.Fatal("incorrectly accepted -ve maxBitLen") @@ -280,7 +280,7 @@ func TestCalcInclusionProofNodeAddressesRejectsBadBitLen(t *testing.T) { func TestCalcConsistencyProofNodeAddresses(t *testing.T) { for _, testCase := range consistencyTests { - proof, err := CalcConsistencyProofNodeAddresses(testCase.priorTreeSize, testCase.treeSize, 64) + proof, err := CalcConsistencyProofNodeAddresses(testCase.priorTreeSize, testCase.treeSize, testCase.treeSize, 64) if err != nil { t.Fatalf("failed to calculate consistency proof from %d to %d: %v", testCase.priorTreeSize, testCase.treeSize, err) @@ -292,7 +292,7 @@ func TestCalcConsistencyProofNodeAddresses(t *testing.T) { func TestCalcConsistencyProofNodeAddressesBadInputs(t *testing.T) { for _, testCase := range consistencyTestsBad { - _, err := CalcConsistencyProofNodeAddresses(testCase.priorTreeSize, testCase.treeSize, 64) + _, err := CalcConsistencyProofNodeAddresses(testCase.priorTreeSize, testCase.treeSize, testCase.treeSize, 64) if err == nil { t.Fatalf("consistency path calculation accepted bad input: %v", testCase) @@ -301,8 +301,8 @@ func TestCalcConsistencyProofNodeAddressesBadInputs(t *testing.T) { } func TestCalcConsistencyProofNodeAddressesRejectsBadBitLen(t *testing.T) { - _, err := CalcConsistencyProofNodeAddresses(6, 7, -1) - _, err2 := CalcConsistencyProofNodeAddresses(6, 7, 0) + _, err := CalcConsistencyProofNodeAddresses(6, 7, 7, -1) + _, err2 := CalcConsistencyProofNodeAddresses(6, 7, 7, 0) if err == nil || err2 == nil { t.Fatalf("consistency path calculation accepted bad bitlen: %v %v", err, err2) @@ -341,7 +341,7 @@ func TestLastNodeWritten(t *testing.T) { func TestInclusionSucceedsUpToTreeSize(t *testing.T) { for ts := 1; ts < testUpToTreeSize; ts++ { for i := ts; i < ts; i++ { - if _, err := CalcInclusionProofNodeAddresses(int64(ts), int64(i), 64); err != nil { + if _, err := CalcInclusionProofNodeAddresses(int64(ts), int64(i), int64(ts), 64); err != nil { t.Errorf("CalcInclusionProofNodeAddresses(ts:%d, i:%d) = %v", ts, i, err) } } @@ -351,7 +351,7 @@ func TestInclusionSucceedsUpToTreeSize(t *testing.T) { func TestConsistencySucceedsUpToTreeSize(t *testing.T) { for s1 := 1; s1 < testUpToTreeSize; s1++ { for s2 := s1 + 1; s2 < testUpToTreeSize; s2++ { - if _, err := CalcConsistencyProofNodeAddresses(int64(s1), int64(s2), 64); err != nil { + if _, err := CalcConsistencyProofNodeAddresses(int64(s1), int64(s2), int64(s2), 64); err != nil { t.Errorf("CalcConsistencyProofNodeAddresses(%d, %d) = %v", s1, s2, err) } } diff --git a/server/log_rpc_server.go b/server/log_rpc_server.go index 2698318bed..7d9554dbea 100644 --- a/server/log_rpc_server.go +++ b/server/log_rpc_server.go @@ -15,7 +15,6 @@ package server import ( - "errors" "fmt" "github.com/google/trillian" @@ -35,9 +34,6 @@ import ( // TODO: There is no access control in the server yet and clients could easily modify // any tree. -// TODO(Martin2112): Remove this when the feature is fully implemented -var errRehashNotSupported = errors.New("proof request requires rehash but it's not implemented yet") - // Pass this as a fixed value to proof calculations. It's used as the max depth of the tree const proofMaxBitLen = 64 @@ -137,14 +133,7 @@ func (t *TrillianLogRPCServer) GetInclusionProof(ctx context.Context, req *trill return nil, err } - // TODO(Martin2112): Pass tree size as snapshot size to proof recomputation when implemented - // and remove this check. - if treeSize != req.TreeSize { - tx.Rollback() - return nil, errRehashNotSupported - } - - proof, err := getInclusionProofForLeafIndexAtRevision(tx, treeRevision, req.TreeSize, req.LeafIndex) + proof, err := getInclusionProofForLeafIndexAtRevision(tx, req.TreeSize, treeRevision, treeSize, req.LeafIndex) if err != nil { tx.Rollback() return nil, err @@ -186,12 +175,6 @@ func (t *TrillianLogRPCServer) GetInclusionProofByHash(ctx context.Context, req return nil, err } - // TODO(Martin2112): Pass tree size as snapshot size to proof recomputation when implemented - // and remove this check. - if treeSize != req.TreeSize { - return nil, errRehashNotSupported - } - // Find the leaf index of the supplied hash leafHashes := [][]byte{req.LeafHash} leaves, err := tx.GetLeavesByHash(leafHashes, req.OrderBySequence) @@ -203,7 +186,7 @@ func (t *TrillianLogRPCServer) GetInclusionProofByHash(ctx context.Context, req // TODO(Martin2112): Need to define a limit on number of results or some form of paging etc. proofs := make([]*trillian.Proof, 0, len(leaves)) for _, leaf := range leaves { - proof, err := getInclusionProofForLeafIndexAtRevision(tx, treeRevision, req.TreeSize, leaf.LeafIndex) + proof, err := getInclusionProofForLeafIndexAtRevision(tx, req.TreeSize, treeRevision, treeSize, leaf.LeafIndex) if err != nil { tx.Rollback() return nil, err @@ -239,46 +222,25 @@ func (t *TrillianLogRPCServer) GetConsistencyProof(ctx context.Context, req *tri return nil, fmt.Errorf("%s: second tree size (%d) must be > first tree size (%d)", util.LogIDPrefix(ctx), req.SecondTreeSize, req.FirstTreeSize) } - nodeIDs, err := merkle.CalcConsistencyProofNodeAddresses(req.FirstTreeSize, req.SecondTreeSize, proofMaxBitLen) - if err != nil { - return nil, err - } - tx, err := t.prepareReadOnlyStorageTx(ctx, req.LogId) if err != nil { return nil, err } - // We need to make sure that both the given sizes are actually STHs, though we don't use the - // first tree revision in fetches - // TODO(Martin2112): This fetch can be removed when rehashing is implemented - _, firstTreeSize, err := tx.GetTreeRevisionIncludingSize(req.FirstTreeSize) + secondTreeRevision, secondTreeSize, err := tx.GetTreeRevisionIncludingSize(req.SecondTreeSize) if err != nil { tx.Rollback() return nil, err } - // TODO(Martin2112): Pass tree size as snapshot size to proof recomputation when implemented - // and remove this check. - if firstTreeSize != req.FirstTreeSize { - return nil, errRehashNotSupported - } - - secondTreeRevision, secondTreeSize, err := tx.GetTreeRevisionIncludingSize(req.SecondTreeSize) + nodeFetches, err := merkle.CalcConsistencyProofNodeAddresses(req.FirstTreeSize, req.SecondTreeSize, secondTreeSize, proofMaxBitLen) if err != nil { - tx.Rollback() return nil, err } - // TODO(Martin2112): Pass tree size as snapshot size to proof recomputation when implemented - // and remove this check. - if secondTreeSize != req.SecondTreeSize { - return nil, errRehashNotSupported - } - // Do all the node fetches at the second tree revision, which is what the node ids were calculated // against. - proof, err := fetchNodesAndBuildProof(tx, secondTreeRevision, 0, nodeIDs) + proof, err := fetchNodesAndBuildProof(tx, secondTreeRevision, 0, nodeFetches) if err != nil { tx.Rollback() return nil, err @@ -404,13 +366,7 @@ func (t *TrillianLogRPCServer) GetEntryAndProof(ctx context.Context, req *trilli return nil, err } - // TODO(Martin2112): Pass tree size as snapshot size to proof recomputation when implemented - // and remove this check. - if treeSize != req.TreeSize { - return nil, errRehashNotSupported - } - - proof, err := getInclusionProofForLeafIndexAtRevision(tx, treeRevision, req.TreeSize, req.LeafIndex) + proof, err := getInclusionProofForLeafIndexAtRevision(tx, req.TreeSize, treeRevision, treeSize, req.LeafIndex) if err != nil { tx.Rollback() return nil, err @@ -529,10 +485,9 @@ func validateLeafHashes(leafHashes [][]byte) bool { // getInclusionProofForLeafIndexAtRevision is used by multiple handlers. It does the storage fetching // and makes additional checks on the returned proof. Returns a Proof suitable for inclusion in // an RPC response -func getInclusionProofForLeafIndexAtRevision(tx storage.ReadOnlyLogTX, treeRevision, treeSize, leafIndex int64) (trillian.Proof, error) { +func getInclusionProofForLeafIndexAtRevision(tx storage.ReadOnlyLogTX, snapshot, treeRevision, treeSize, leafIndex int64) (trillian.Proof, error) { // We have the tree size and leaf index so we know the nodes that we need to serve the proof - // TODO(Martin2112): Not sure about hardcoding maxBitLen here - proofNodeIDs, err := merkle.CalcInclusionProofNodeAddresses(treeSize, leafIndex, proofMaxBitLen) + proofNodeIDs, err := merkle.CalcInclusionProofNodeAddresses(snapshot, leafIndex, treeSize, proofMaxBitLen) if err != nil { return trillian.Proof{}, err } diff --git a/server/log_rpc_server_test.go b/server/log_rpc_server_test.go index be2c3c3506..76e14b8d57 100644 --- a/server/log_rpc_server_test.go +++ b/server/log_rpc_server_test.go @@ -714,8 +714,8 @@ func TestGetProofByHashWrongNodeReturned(t *testing.T) { _, err := server.GetInclusionProofByHash(context.Background(), &getInclusionProofByHashRequest7) - if err == nil || !strings.Contains(err.Error(), "expected node") || !strings.Contains(err.Error(), "at proof pos 1") { - t.Fatalf("get inclusion proof by hash returned no or wrong error when get nodes returns wrong count: %v", err) + if err == nil || !strings.Contains(err.Error(), "expected node ") { + t.Fatalf("get inclusion proof by hash returned no or wrong error when get nodes returns wrong node: %v", err) } } @@ -891,8 +891,8 @@ func TestGetProofByIndexWrongNodeReturned(t *testing.T) { _, err := server.GetInclusionProof(context.Background(), &getInclusionProofByIndexRequest7) - if err == nil || !strings.Contains(err.Error(), "expected node") || !strings.Contains(err.Error(), "at proof pos 1") { - t.Fatalf("get inclusion proof by index returned no or wrong error when get nodes returns wrong count: %v", err) + if err == nil || !strings.Contains(err.Error(), "expected node ") { + t.Fatalf("get inclusion proof by index returned no or wrong error when get nodes returns wrong node: %v", err) } } @@ -1274,29 +1274,12 @@ func TestGetConsistencyProofBeginTXFails(t *testing.T) { test.executeBeginFailsTest(t) } -func TestGetConsistencyProofGetTreeRevision1Fails(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - test := newParameterizedTest(ctrl, "GetConsistencyProof", readOnly, - func(t *storage.MockLogTX) { - t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest25.FirstTreeSize).Return(int64(0), int64(0), errors.New("STORAGE")) - }, - func(s *TrillianLogRPCServer) error { - _, err := s.GetConsistencyProof(context.Background(), &getConsistencyProofRequest25) - return err - }) - - test.executeStorageFailureTest(t) -} - -func TestGetConsistencyProofGetTreeRevision2Fails(t *testing.T) { +func TestGetConsistencyProofGetTreeRevisionForSecondTreeSizeFails(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() test := newParameterizedTest(ctrl, "GetConsistencyProof", readOnly, func(t *storage.MockLogTX) { - t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest25.FirstTreeSize).Return(int64(11), getConsistencyProofRequest25.FirstTreeSize, nil) t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest25.SecondTreeSize).Return(int64(0), int64(0), errors.New("STORAGE")) }, func(s *TrillianLogRPCServer) error { @@ -1313,7 +1296,6 @@ func TestGetConsistencyProofGetNodesFails(t *testing.T) { test := newParameterizedTest(ctrl, "GetConsistencyProof", readOnly, func(t *storage.MockLogTX) { - t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.FirstTreeSize).Return(int64(3), getConsistencyProofRequest7.FirstTreeSize, nil) t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.SecondTreeSize).Return(int64(5), getConsistencyProofRequest7.SecondTreeSize, nil) t.EXPECT().GetMerkleNodes(int64(5), nodeIdsConsistencySize4ToSize7).Return([]storage.Node{}, errors.New("STORAGE")) }, @@ -1333,7 +1315,6 @@ func TestGetConsistencyProofGetNodesReturnsWrongCount(t *testing.T) { mockTx := storage.NewMockLogTX(ctrl) mockStorage.EXPECT().Snapshot(gomock.Any()).Return(mockTx, nil) - mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.FirstTreeSize).Return(int64(3), getConsistencyProofRequest7.FirstTreeSize, nil) mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.SecondTreeSize).Return(int64(5), getConsistencyProofRequest7.SecondTreeSize, nil) // The server expects one node from storage but we return two mockTx.EXPECT().GetMerkleNodes(int64(5), nodeIdsConsistencySize4ToSize7).Return([]storage.Node{{NodeRevision: 3}, {NodeRevision: 2}}, nil) @@ -1357,7 +1338,6 @@ func TestGetConsistencyProofGetNodesReturnsWrongNode(t *testing.T) { mockTx := storage.NewMockLogTX(ctrl) mockStorage.EXPECT().Snapshot(gomock.Any()).Return(mockTx, nil) - mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.FirstTreeSize).Return(int64(3), getConsistencyProofRequest7.FirstTreeSize, nil) mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.SecondTreeSize).Return(int64(5), getConsistencyProofRequest7.SecondTreeSize, nil) // Return an unexpected node that wasn't requested mockTx.EXPECT().GetMerkleNodes(int64(5), nodeIdsConsistencySize4ToSize7).Return([]storage.Node{{NodeID: testonly.MustCreateNodeIDForTreeCoords(1, 2, 64), NodeRevision: 3}}, nil) @@ -1368,7 +1348,7 @@ func TestGetConsistencyProofGetNodesReturnsWrongNode(t *testing.T) { _, err := server.GetConsistencyProof(context.Background(), &getConsistencyProofRequest7) - if err == nil || !strings.Contains(err.Error(), "at proof pos 0") { + if err == nil || !strings.Contains(err.Error(), "expected node ") { t.Fatalf("get consistency proof returned no or wrong error when get nodes returns wrong node: %v", err) } } @@ -1379,7 +1359,6 @@ func TestGetConsistencyProofCommitFails(t *testing.T) { test := newParameterizedTest(ctrl, "GetConsistencyProof", readOnly, func(t *storage.MockLogTX) { - t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.FirstTreeSize).Return(int64(3), getConsistencyProofRequest7.FirstTreeSize, nil) t.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.SecondTreeSize).Return(int64(5), getConsistencyProofRequest7.SecondTreeSize, nil) t.EXPECT().GetMerkleNodes(int64(5), nodeIdsConsistencySize4ToSize7).Return([]storage.Node{{NodeID: testonly.MustCreateNodeIDForTreeCoords(2, 1, 64), NodeRevision: 3}}, nil) }, @@ -1399,7 +1378,6 @@ func TestGetConsistencyProof(t *testing.T) { mockTx := storage.NewMockLogTX(ctrl) mockStorage.EXPECT().Snapshot(gomock.Any()).Return(mockTx, nil) - mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.FirstTreeSize).Return(int64(3), getConsistencyProofRequest7.FirstTreeSize, nil) mockTx.EXPECT().GetTreeRevisionIncludingSize(getConsistencyProofRequest7.SecondTreeSize).Return(int64(5), getConsistencyProofRequest7.SecondTreeSize, nil) mockTx.EXPECT().GetMerkleNodes(int64(5), nodeIdsConsistencySize4ToSize7).Return([]storage.Node{{NodeID: testonly.MustCreateNodeIDForTreeCoords(2, 1, 64), NodeRevision: 3, Hash: []byte("nodehash")}}, nil) mockTx.EXPECT().Commit().Return(nil) diff --git a/server/proof_fetcher.go b/server/proof_fetcher.go index 7fcddd16d3..5bcbff7279 100644 --- a/server/proof_fetcher.go +++ b/server/proof_fetcher.go @@ -15,53 +15,123 @@ package server import ( - "errors" "fmt" "github.com/golang/protobuf/proto" "github.com/google/trillian" + "github.com/google/trillian/crypto" "github.com/google/trillian/merkle" "github.com/google/trillian/storage" ) // fetchNodesAndBuildProof is used by both inclusion and consistency proofs. It fetches the nodes // from storage and converts them into the proof proto that will be returned to the client. +// This includes rehashing where necessary to serve proofs for tree sizes between stored tree +// revisions. This code only relies on the NodeReader interface so can be tested without +// a complete storage implementation. func fetchNodesAndBuildProof(tx storage.NodeReader, treeRevision, leafIndex int64, proofNodeFetches []merkle.NodeFetch) (trillian.Proof, error) { - // TODO(Martin2112): Implement the rehashing. Currently just fetches the nodes and ignores this - proofNodeIDs := make([]storage.NodeID, 0, len(proofNodeFetches)) + proofNodes, err := fetchNodes(tx, treeRevision, proofNodeFetches) + if err != nil { + return trillian.Proof{}, err + } - for _, fetch := range proofNodeFetches { - proofNodeIDs = append(proofNodeIDs, fetch.NodeID) + r := newRehasher() + for i, node := range proofNodes { + r.process(node, proofNodeFetches[i]) + } - // TODO(Martin2112): Remove this when rehashing is implemented - if fetch.Rehash { - return trillian.Proof{}, errors.New("proof requires rehashing but it's not implemented yet") - } + return r.rehashedProof(leafIndex) +} + +// rehasher bundles the rehashing logic into a simple state machine +type rehasher struct { + th merkle.TreeHasher + rehashing bool + rehashNode storage.Node + proof []*trillian.Node + proofError error +} + +// init must be called before the rehasher is used or reused +func newRehasher() *rehasher { + return &rehasher{ + // TODO(Martin2112): TreeHasher must be selected based on log config. + th: merkle.NewRFC6962TreeHasher(crypto.NewSHA256()), + } +} + +func (r *rehasher) process(node storage.Node, fetch merkle.NodeFetch) { + switch { + case !r.rehashing && fetch.Rehash: + // Start of a rehashing chain + r.startRehashing(node) + + case r.rehashing && !fetch.Rehash: + // End of a rehash chain, resulting in a rehashed proof node + r.endRehashing() + // And the current node needs to be added to the proof + r.emitNode(node) + + case r.rehashing && fetch.Rehash: + // Continue with rehashing, update the node we're recomputing + r.rehashNode.Hash = r.th.HashChildren(node.Hash, r.rehashNode.Hash) + + default: + // Not rehashing, just pass the node through + r.emitNode(node) + } +} + +func (r *rehasher) emitNode(node storage.Node) { + idBytes, err := proto.Marshal(node.NodeID.AsProto()) + if err != nil { + r.proofError = err + } + r.proof = append(r.proof, &trillian.Node{NodeId: idBytes, NodeHash: node.Hash, NodeRevision: node.NodeRevision}) +} + +func (r *rehasher) startRehashing(node storage.Node) { + r.rehashNode = storage.Node{Hash: node.Hash} + r.rehashing = true +} + +func (r *rehasher) endRehashing() { + if r.rehashing { + r.proof = append(r.proof, &trillian.Node{NodeHash: r.rehashNode.Hash}) + r.rehashing = false + } +} + +func (r *rehasher) rehashedProof(leafIndex int64) (trillian.Proof, error) { + r.endRehashing() + return trillian.Proof{LeafIndex: leafIndex, ProofNode: r.proof}, r.proofError +} + +// fetchNodes removes duplicates from the set of fetches and then passes the result to +// storage. +func fetchNodes(tx storage.NodeReader, treeRevision int64, fetches []merkle.NodeFetch) ([]storage.Node, error) { + // To start with we remove any duplicate fetches + proofNodeIDs := make([]storage.NodeID, 0, len(fetches)) + + for _, fetch := range fetches { + proofNodeIDs = append(proofNodeIDs, fetch.NodeID) } proofNodes, err := tx.GetMerkleNodes(treeRevision, proofNodeIDs) if err != nil { - return trillian.Proof{}, err + return nil, err } if len(proofNodes) != len(proofNodeIDs) { - return trillian.Proof{}, fmt.Errorf("expected %d nodes in proof but got %d", len(proofNodeIDs), len(proofNodes)) + return nil, fmt.Errorf("expected %d nodes from storage but got %d", len(proofNodeIDs), len(proofNodes)) } - proof := make([]*trillian.Node, 0, len(proofNodeIDs)) for i, node := range proofNodes { // additional check that the correct node was returned - if !node.NodeID.Equivalent(proofNodeIDs[i]) { - return trillian.Proof{}, fmt.Errorf("expected node %v at proof pos %d but got %v", proofNodeIDs[i], i, node.NodeID) - } - - idBytes, err := proto.Marshal(node.NodeID.AsProto()) - if err != nil { - return trillian.Proof{}, err + if !node.NodeID.Equivalent(fetches[i].NodeID) { + return []storage.Node{}, fmt.Errorf("expected node %v at proof pos %d but got %v", fetches[i], i, node.NodeID) } - - proof = append(proof, &trillian.Node{NodeId: idBytes, NodeHash: node.Hash, NodeRevision: node.NodeRevision}) } - return trillian.Proof{LeafIndex: leafIndex, ProofNode: proof}, nil + return proofNodes, nil } diff --git a/server/proof_fetcher_test.go b/server/proof_fetcher_test.go new file mode 100644 index 0000000000..3626cbd6ea --- /dev/null +++ b/server/proof_fetcher_test.go @@ -0,0 +1,268 @@ +package server + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/google/trillian" + "github.com/google/trillian/crypto" + "github.com/google/trillian/merkle" + "github.com/google/trillian/storage" + "github.com/google/trillian/storage/testonly" +) + +// rehashTest encapsulates one test case for the rehasher in isolation. Input data like the storage +// hashes and revisions can be arbitrary but the nodes should have distinct values +type rehashTest struct { + desc string + index int64 + nodes []storage.Node + fetches []merkle.NodeFetch + output trillian.Proof +} + +// An arbitrary tree revision to be used in tests +const testTreeRevision int64 = 3 + +// Raw hashes for dummy storage nodes +var h1 = th.HashLeaf([]byte("Hash 1")) +var h2 = th.HashLeaf([]byte("Hash 2")) +var h3 = th.HashLeaf([]byte("Hash 3")) +var h4 = th.HashLeaf([]byte("Hash 4")) +var h5 = th.HashLeaf([]byte("Hash 5")) + +// And the dummy nodes themselves +var sn1 = storage.Node{NodeID: storage.NewNodeIDFromHash(h1), Hash: h1, NodeRevision: 11} +var sn2 = storage.Node{NodeID: storage.NewNodeIDFromHash(h2), Hash: h2, NodeRevision: 22} +var sn3 = storage.Node{NodeID: storage.NewNodeIDFromHash(h3), Hash: h3, NodeRevision: 33} +var sn4 = storage.Node{NodeID: storage.NewNodeIDFromHash(h4), Hash: h4, NodeRevision: 44} +var sn5 = storage.Node{NodeID: storage.NewNodeIDFromHash(h5), Hash: h5, NodeRevision: 55} + +// And the output proof nodes expected for them if they are passed through without rehashing +var n1 = &trillian.Node{NodeHash: h1, NodeId: mustMarshalNodeID(sn1.NodeID), NodeRevision: sn1.NodeRevision} +var n2 = &trillian.Node{NodeHash: h2, NodeId: mustMarshalNodeID(sn2.NodeID), NodeRevision: sn2.NodeRevision} +var n3 = &trillian.Node{NodeHash: h3, NodeId: mustMarshalNodeID(sn3.NodeID), NodeRevision: sn3.NodeRevision} +var n4 = &trillian.Node{NodeHash: h4, NodeId: mustMarshalNodeID(sn4.NodeID), NodeRevision: sn4.NodeRevision} +var n5 = &trillian.Node{NodeHash: h5, NodeId: mustMarshalNodeID(sn5.NodeID), NodeRevision: sn5.NodeRevision} + +// Nodes containing composite hashes. They don't have node ids or revisions as they're recomputed +var n1n2 = &trillian.Node{NodeHash: th.HashChildren(h2, h1)} +var n2n3 = &trillian.Node{NodeHash: th.HashChildren(h3, h2)} +var n2n3n4 = &trillian.Node{NodeHash: th.HashChildren(h4, th.HashChildren(h3, h2))} +var n4n5 = &trillian.Node{NodeHash: th.HashChildren(h5, h4)} + +func TestRehasher(t *testing.T) { + var rehashTests = []rehashTest{ + { + desc: "no rehash", + index: 126, + nodes: []storage.Node{sn1, sn2, sn3}, + fetches: []merkle.NodeFetch{{Rehash: false}, {Rehash: false}, {Rehash: false}}, + output: trillian.Proof{ + LeafIndex: 126, + ProofNode: []*trillian.Node{n1, n2, n3}, + }, + }, + { + desc: "single rehash", + index: 999, + nodes: []storage.Node{sn1, sn2, sn3, sn4, sn5}, + fetches: []merkle.NodeFetch{{Rehash: false}, {Rehash: true}, {Rehash: true}, {Rehash: false}, {Rehash: false}}, + output: trillian.Proof{ + LeafIndex: 999, + ProofNode: []*trillian.Node{n1, n2n3, n4, n5}, + }, + }, + { + desc: "single rehash at end", + index: 11, + nodes: []storage.Node{sn1, sn2, sn3}, + fetches: []merkle.NodeFetch{{Rehash: false}, {Rehash: true}, {Rehash: true}}, + output: trillian.Proof{ + LeafIndex: 11, + ProofNode: []*trillian.Node{n1, n2n3}, + }, + }, + { + desc: "single rehash multiple nodes", + index: 23, + nodes: []storage.Node{sn1, sn2, sn3, sn4, sn5}, + fetches: []merkle.NodeFetch{{Rehash: false}, {Rehash: true}, {Rehash: true}, {Rehash: true}, {Rehash: false}}, + output: trillian.Proof{ + LeafIndex: 23, + ProofNode: []*trillian.Node{n1, n2n3n4, n5}, + }, + }, + { + desc: "multiple rehash", + index: 45, + nodes: []storage.Node{sn1, sn2, sn3, sn4, sn5}, + fetches: []merkle.NodeFetch{{Rehash: true}, {Rehash: true}, {Rehash: false}, {Rehash: true}, {Rehash: true}}, + output: trillian.Proof{ + LeafIndex: 45, + ProofNode: []*trillian.Node{n1n2, n3, n4n5}, + }, + }, + } + + for _, rehashTest := range rehashTests { + r := newRehasher() + for i, node := range rehashTest.nodes { + r.process(node, rehashTest.fetches[i]) + } + + want := rehashTest.output + got, err := r.rehashedProof(rehashTest.index) + + if err != nil { + t.Fatalf("rehash test %s unexpected error: %v", rehashTest.desc, err) + } + + if !proto.Equal(&got, &want) { + t.Errorf("rehash test %s:\ngot: %v\nwant: %v", rehashTest.desc, got, want) + } + } +} + +func TestTree32InclusionProofFetchAll(t *testing.T) { + for ts := 2; ts <= 32; ts++ { + mt := treeAtSize(ts) + r := testonly.NewMultiFakeNodeReaderFromLeaves([]testonly.LeafBatch{ + {TreeRevision: testTreeRevision, Leaves: expandLeaves(0, ts-1), ExpectedRoot: expectedRootAtSize(mt)}, + }) + + for s := int64(2); s <= int64(ts); s++ { + for l := int64(0); l < s; l++ { + fetches, err := merkle.CalcInclusionProofNodeAddresses(s, l, int64(ts), 64) + if err != nil { + t.Fatal(err) + } + + proof, err := fetchNodesAndBuildProof(r, testTreeRevision, int64(l), fetches) + if err != nil { + t.Fatal(err) + } + + // We use +1 here because of the 1 based leaf indexing of this implementation + refProof := mt.PathToRootAtSnapshot(l+1, s) + + if got, want := len(proof.ProofNode), len(refProof); got != want { + t.Fatalf("(%d, %d, %d): got proof len: %d, want: %d: %v\n%v", ts, s, l, got, want, fetches, refProof) + } + + for i := 0; i < len(proof.ProofNode); i++ { + if got, want := hex.EncodeToString(proof.ProofNode[i].NodeHash), hex.EncodeToString(refProof[i].Value.Hash()); got != want { + t.Fatalf("(%d, %d, %d): %d got proof node: %s, want: %s l:%d fetches: %v", ts, s, l, i, got, want, len(proof.ProofNode), fetches) + } + } + } + } + } +} + +func TestTree32InclusionProofFetchMultiBatch(t *testing.T) { + mt := treeAtSize(32) + // The reader is built up with multiple batches, 4 batches x 8 leaves each + r := testonly.NewMultiFakeNodeReaderFromLeaves([]testonly.LeafBatch{ + {TreeRevision: testTreeRevision, Leaves: expandLeaves(0, 7), ExpectedRoot: expectedRootAtSize(treeAtSize(8))}, + {TreeRevision: testTreeRevision + 1, Leaves: expandLeaves(8, 15), ExpectedRoot: expectedRootAtSize(treeAtSize(16))}, + {TreeRevision: testTreeRevision + 2, Leaves: expandLeaves(16, 23), ExpectedRoot: expectedRootAtSize(treeAtSize(24))}, + {TreeRevision: testTreeRevision + 3, Leaves: expandLeaves(24, 31), ExpectedRoot: expectedRootAtSize(mt)}, + }) + + for s := int64(2); s <= 32; s++ { + for l := int64(0); l < s; l++ { + fetches, err := merkle.CalcInclusionProofNodeAddresses(s, l, 32, 64) + if err != nil { + t.Fatal(err) + } + + // Use the highest tree revision that should be available from the node reader + proof, err := fetchNodesAndBuildProof(r, testTreeRevision+3, l, fetches) + if err != nil { + t.Fatal(err) + } + + // We use +1 here because of the 1 based leaf indexing of this implementation + refProof := mt.PathToRootAtSnapshot(l+1, s) + + if got, want := len(proof.ProofNode), len(refProof); got != want { + t.Fatalf("(%d, %d, %d): got proof len: %d, want: %d: %v\n%v", 32, s, l, got, want, fetches, refProof) + } + + for i := 0; i < len(proof.ProofNode); i++ { + if got, want := hex.EncodeToString(proof.ProofNode[i].NodeHash), hex.EncodeToString(refProof[i].Value.Hash()); got != want { + t.Fatalf("(%d, %d, %d): %d got proof node: %s, want: %s l:%d fetches: %v", 32, s, l, i, got, want, len(proof.ProofNode), fetches) + } + } + } + } +} + +func TestTree32ConsistencyProofFetchAll(t *testing.T) { + for ts := 2; ts <= 32; ts++ { + mt := treeAtSize(ts) + r := testonly.NewMultiFakeNodeReaderFromLeaves([]testonly.LeafBatch{ + {TreeRevision: testTreeRevision, Leaves: expandLeaves(0, ts-1), ExpectedRoot: expectedRootAtSize(mt)}, + }) + + for s1 := int64(2); s1 < int64(ts); s1++ { + for s2 := int64(s1 + 1); s2 < int64(ts); s2++ { + fetches, err := merkle.CalcConsistencyProofNodeAddresses(s1, s2, int64(ts), 64) + if err != nil { + t.Fatal(err) + } + + proof, err := fetchNodesAndBuildProof(r, testTreeRevision, int64(s1), fetches) + if err != nil { + t.Fatal(err) + } + + refProof := mt.SnapshotConsistency(s1, s2) + + if got, want := len(proof.ProofNode), len(refProof); got != want { + t.Fatalf("(%d, %d, %d): got proof len: %d, want: %d: %v\n%v", ts, s1, s2, got, want, fetches, refProof) + } + + for i := 0; i < len(proof.ProofNode); i++ { + if got, want := hex.EncodeToString(proof.ProofNode[i].NodeHash), hex.EncodeToString(refProof[i].Value.Hash()); got != want { + t.Fatalf("(%d, %d, %d): %d got proof node: %s, want: %s l:%d fetches: %v", ts, s1, s2, i, got, want, len(proof.ProofNode), fetches) + } + } + } + } + } +} + +func mustMarshalNodeID(nodeID storage.NodeID) []byte { + idBytes, err := proto.Marshal(nodeID.AsProto()) + if err != nil { + panic(err) + } + return idBytes +} + +func expandLeaves(n, m int) []string { + leaves := make([]string, 0, m-n+1) + for l := n; l <= m; l++ { + leaves = append(leaves, fmt.Sprintf("Leaf %d", l)) + } + return leaves +} + +// expectedRootAtSize uses the in memory tree, the tree built with Compact Merkle Tree should +// have the same root. +func expectedRootAtSize(mt *merkle.InMemoryMerkleTree) string { + return hex.EncodeToString(mt.CurrentRoot().Hash()) +} + +func treeAtSize(n int) *merkle.InMemoryMerkleTree { + leaves := expandLeaves(0, n-1) + mt := merkle.NewInMemoryMerkleTree(merkle.NewRFC6962TreeHasher(crypto.NewSHA256())) + for _, leaf := range leaves { + mt.AddLeaf([]byte(leaf)) + } + return mt +} diff --git a/storage/testonly/fake_node_reader.go b/storage/testonly/fake_node_reader.go index 75a0df6c3e..40bff469c4 100644 --- a/storage/testonly/fake_node_reader.go +++ b/storage/testonly/fake_node_reader.go @@ -51,10 +51,10 @@ func NewFakeNodeReader(mappings []NodeMapping, treeSize, treeRevision int64) *Fa return &FakeNodeReader{nodeMap: nodeMap, treeSize: treeSize, treeRevision: treeRevision} } -// GetTreeRevisionAtSize implements the corresponding NodeReader API. -func (f FakeNodeReader) GetTreeRevisionAtSize(treeSize int64) (int64, error) { - if f.treeSize != treeSize { - return int64(0), fmt.Errorf("GetTreeRevisionAtSize() got treeSize:%d, want: %d", treeSize, f.treeSize) +// GetTreeRevisionIncludingSize implements the corresponding NodeReader API. +func (f FakeNodeReader) GetTreeRevisionIncludingSize(treeSize int64) (int64, error) { + if f.treeSize < treeSize { + return int64(0), fmt.Errorf("GetTreeRevisionIncludingSize() got treeSize:%d, want: >= %d", treeSize, f.treeSize) } return f.treeRevision, nil @@ -169,15 +169,15 @@ func (m MultiFakeNodeReader) readerForNodeID(nodeID storage.NodeID, revision int return nil } -// GetTreeRevisionAtSize implements the corresponding NodeReader API. -func (m MultiFakeNodeReader) GetTreeRevisionAtSize(treeSize int64) (int64, error) { +// GetTreeRevisionIncludingSize implements the corresponding NodeReader API. +func (m MultiFakeNodeReader) GetTreeRevisionIncludingSize(treeSize int64) (int64, int64, error) { for i := len(m.readers) - 1; i >= 0; i-- { - if m.readers[i].treeSize == treeSize { - return m.readers[i].treeRevision, nil + if m.readers[i].treeSize >= treeSize { + return m.readers[i].treeRevision, m.readers[i].treeSize, nil } } - return int64(0), fmt.Errorf("want revision for tree size: %d but it doesn't exist", treeSize) + return int64(0), int64(0), fmt.Errorf("want revision for tree size: %d but it doesn't exist", treeSize) } // GetMerkleNodes implements the corresponding NodeReader API.