diff --git a/server/admin/admin_server_test.go b/server/admin/admin_server_test.go index e041fe4e77..2e86e46adf 100644 --- a/server/admin/admin_server_test.go +++ b/server/admin/admin_server_test.go @@ -596,10 +596,10 @@ func TestServer_CreateTree_AllowedTreeTypes(t *testing.T) { for _, test := range tests { setup := setupAdminServer( ctrl, - nil, /* keygen */ - false, /* snapshot */ - test.wantCode == codes.OK, /* shouldCommit */ - false /* commitErr */) + nil, // keygen + false, // snapshot + test.wantCode == codes.OK, + false) s := setup.server tx := setup.tx s.allowedTreeTypes = test.treeTypes @@ -771,7 +771,7 @@ func TestServer_DeleteTree(t *testing.T) { nil, /* keygen */ false, /* snapshot */ true, /* shouldCommit */ - false /* commitErr */) + false) req := &trillian.DeleteTreeRequest{TreeId: test.tree.TreeId} tx := setup.tx @@ -810,10 +810,10 @@ func TestServer_DeleteTreeErrors(t *testing.T) { for _, test := range tests { setup := setupAdminServer( ctrl, - nil, /* keygen */ - false, /* snapshot */ - test.deleteErr == nil, /* shouldCommit */ - test.commitErr /* commitErr */) + nil, + false, + test.deleteErr == nil, + test.commitErr) req := &trillian.DeleteTreeRequest{TreeId: 10} tx := setup.tx @@ -858,7 +858,7 @@ func TestServer_UndeleteTree(t *testing.T) { nil, /* keygen */ false, /* snapshot */ true, /* shouldCommit */ - false /* commitErr */) + false) req := &trillian.UndeleteTreeRequest{TreeId: test.tree.TreeId} tx := setup.tx @@ -897,10 +897,10 @@ func TestServer_UndeleteTreeErrors(t *testing.T) { for _, test := range tests { setup := setupAdminServer( ctrl, - nil, /* keygen */ - false, /* snapshot */ - test.undeleteErr == nil, /* shouldCommit */ - test.commitErr /* commitErr */) + nil, + false, + test.undeleteErr == nil, + test.commitErr) req := &trillian.UndeleteTreeRequest{TreeId: 10} tx := setup.tx diff --git a/server/mock_log_operation.go b/server/mock_log_operation.go index ddc1662314..f8451991f0 100644 --- a/server/mock_log_operation.go +++ b/server/mock_log_operation.go @@ -6,8 +6,9 @@ package server import ( context "context" - gomock "github.com/golang/mock/gomock" reflect "reflect" + + gomock "github.com/golang/mock/gomock" ) // MockLogOperation is a mock of LogOperation interface diff --git a/server/postgres_storage_provider.go b/server/postgres_storage_provider.go index 42d6c8f5bb..8c045b11f9 100644 --- a/server/postgres_storage_provider.go +++ b/server/postgres_storage_provider.go @@ -66,7 +66,9 @@ func newPGProvider(mf monitoring.MetricFactory) (StorageProvider, error) { } func (s *pgProvider) LogStorage() storage.LogStorage { - panic("Not Implemented") + + glog.Warningf("Support for the PostgreSQL log is experimental. Please use at your own risk!!!") + return postgres.NewLogStorage(s.db, s.mf) } func (s *pgProvider) MapStorage() storage.MapStorage { diff --git a/storage/postgres/README.md b/storage/postgres/README.md new file mode 100644 index 0000000000..09c80328e1 --- /dev/null +++ b/storage/postgres/README.md @@ -0,0 +1,17 @@ +# Postgres LogStorage + +## Notes and Caveats +The current LogStorage part of the Postgres implementation was based off what +was already written for MySQL. Thus, the two user-defined functions included in +storage.sql. MySQL doesn't kill a transaction when a duplicate is detected, but +PostgreSQL does. So, to preserve the workflow, I included the two functions +which trap this error and allow the code to continue executing. The only other +change I made was to fully translate the MySQL queries to PostgreSQL compatible ones +and tidy up some of the extant tree storage code. + +storage_unsafe.sql really isn't unsafe, but I have pulled some of the safety +rails from the tables to improve performance. It also works under the notion that +there will only be a single tree in a given database. An improvement on this theme +would be to add all layers below the trees table in their own separate schemas. +This would further eliminate indexs and foreign key requirements, but it should +be left for those who require enhanced performance. Storage.sql should be fine for most applications diff --git a/storage/postgres/admin_storage_test.go b/storage/postgres/admin_storage_test.go index d8557bc1c8..64144f93c8 100644 --- a/storage/postgres/admin_storage_test.go +++ b/storage/postgres/admin_storage_test.go @@ -29,6 +29,7 @@ import ( ) var allTables = []string{"unsequenced", "tree_head", "sequenced_leaf_data", "leaf_data", "subtree", "tree_control", "trees"} +var db *sql.DB const selectTreeControlByID = "SELECT signing_enabled, sequencing_enabled, sequence_interval_seconds FROM tree_control WHERE tree_id = $1" diff --git a/storage/postgres/log_storage.go b/storage/postgres/log_storage.go new file mode 100644 index 0000000000..9edc38fa7c --- /dev/null +++ b/storage/postgres/log_storage.go @@ -0,0 +1,971 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "sort" + "strconv" + "sync" + "time" + + "github.com/golang/glog" + "github.com/golang/protobuf/ptypes" + "github.com/google/trillian" + "github.com/google/trillian/merkle/hashers" + "github.com/google/trillian/monitoring" + "github.com/google/trillian/storage" + "github.com/google/trillian/storage/cache" + "github.com/google/trillian/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + valuesPlaceholder5 = "($1,$2,$3,$4,$5)" + insertLeafDataSQL = "select insert_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)" + insertSequencedLeafSQL = "select insert_sequenced_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)" + + selectNonDeletedTreeIDByTypeAndStateSQL = ` + SELECT tree_id FROM trees WHERE tree_type in ($1,$2) AND tree_state in ($3,$4) AND (deleted IS NULL OR deleted = false)` + + selectSequencedLeafCountSQL = "SELECT COUNT(*) FROM sequenced_leaf_data WHERE tree_id=$1" + selectUnsequencedLeafCountSQL = "SELECT tree_id, COUNT(1) FROM unsequenced GROUP BY tree_id" + //selectLatestSignedLogRootSQL = `SELECT tree_head_timestamp,tree_size,root_hash,tree_revision,root_signature + // FROM tree_head WHERE tree_id=$1 + // ORDER BY tree_head_timestamp DESC LIMIT 1` + + selectLeavesByRangeSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos + FROM leaf_data l,sequenced_leaf_data s + WHERE l.leaf_identity_hash = s.leaf_identity_hash + AND s.sequence_number >= $1 AND s.sequence_number < $2 AND l.tree_id = $3 AND s.tree_id = l.tree_id` + orderBySequenceNumberSQL + + // These statements need to be expanded to provide the correct number of parameter placeholders. + selectLeavesByIndexSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos + FROM leaf_data l,sequenced_leaf_data s + WHERE l.leaf_identity_hash = s.leaf_identity_hash + AND s.sequence_number IN (` + placeholderSQL + `) AND l.tree_id = AND s.tree_id = l.tree_id` + selectLeavesByMerkleHashSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos + FROM leaf_data l,sequenced_leaf_data s + WHERE l.leaf_identity_hash = s.leaf_identity_hash + AND s.merkle_leaf_hash IN (` + placeholderSQL + `) AND l.tree_id = AND s.tree_id = l.tree_id` + // TODO(drysdale): rework the code so the dummy hash isn't needed (e.g. this assumes hash size is 32) + dummymerkleLeafHash = "00000000000000000000000000000000" + // This statement returns a dummy Merkle leaf hash value (which must be + // of the right size) so that its signature matches that of the other + // leaf-selection statements. + selectLeavesByLeafIdentityHashSQL = `SELECT '` + dummymerkleLeafHash + `',l.leaf_identity_hash,l.leaf_value,-1,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos + FROM leaf_data l LEFT JOIN sequenced_leaf_data s ON (l.leaf_identity_hash = s.leaf_identity_hash AND l.tree_id = s.tree_id) + WHERE l.leaf_identity_hash IN (` + placeholderSQL + `) AND l.tree_id = ` + + // Same as above except with leaves ordered by sequence so we only incur this cost when necessary + orderBySequenceNumberSQL = " ORDER BY s.sequence_number" + selectLeavesByMerkleHashOrderedBySequenceSQL = selectLeavesByMerkleHashSQL + orderBySequenceNumberSQL + + // Error code returned by driver when inserting a duplicate row + + logIDLabel = "logid" +) + +var ( + defaultLogStrata = []int{8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8} + + once sync.Once + queuedCounter monitoring.Counter + queuedDupCounter monitoring.Counter + dequeuedCounter monitoring.Counter + + queueLatency monitoring.Histogram + queueInsertLatency monitoring.Histogram + queueReadLatency monitoring.Histogram + queueInsertLeafLatency monitoring.Histogram + queueInsertEntryLatency monitoring.Histogram + dequeueLatency monitoring.Histogram + dequeueSelectLatency monitoring.Histogram + dequeueRemoveLatency monitoring.Histogram +) + +func createMetrics(mf monitoring.MetricFactory) { + queuedCounter = mf.NewCounter("postgres_queued_leaves", "Number of leaves queued", logIDLabel) + queuedDupCounter = mf.NewCounter("postgres_queued_dup_leaves", "Number of duplicate leaves queued", logIDLabel) + dequeuedCounter = mf.NewCounter("postgres_dequeued_leaves", "Number of leaves dequeued", logIDLabel) + + queueLatency = mf.NewHistogram("postgres_queue_leaves_latency", "Latency of queue leaves operation in seconds", logIDLabel) + queueInsertLatency = mf.NewHistogram("postgres_queue_leaves_latency_insert", "Latency of insertion part of queue leaves operation in seconds", logIDLabel) + queueReadLatency = mf.NewHistogram("postgres_queue_leaves_latency_read_dups", "Latency of read-duplicates part of queue leaves operation in seconds", logIDLabel) + queueInsertLeafLatency = mf.NewHistogram("postgres_queue_leaf_latency_leaf", "Latency of insert-leaf part of queue (single) leaf operation in seconds", logIDLabel) + queueInsertEntryLatency = mf.NewHistogram("postgres_queue_leaf_latency_entry", "Latency of insert-entry part of queue (single) leaf operation in seconds", logIDLabel) + + dequeueLatency = mf.NewHistogram("postgres_dequeue_leaves_latency", "Latency of dequeue leaves operation in seconds", logIDLabel) + dequeueSelectLatency = mf.NewHistogram("postgres_dequeue_leaves_latency_select", "Latency of selection part of dequeue leaves operation in seconds", logIDLabel) + dequeueRemoveLatency = mf.NewHistogram("postgres_dequeue_leaves_latency_remove", "Latency of removal part of dequeue leaves operation in seconds", logIDLabel) +} + +func labelForTX(t *logTreeTX) string { + return strconv.FormatInt(t.treeID, 10) +} + +func observe(hist monitoring.Histogram, duration time.Duration, label string) { + hist.Observe(duration.Seconds(), label) +} + +type postgresLogStorage struct { + *pgTreeStorage + admin storage.AdminStorage + metricFactory monitoring.MetricFactory +} + +// NewLogStorage creates a storage.LogStorage instance for the specified PostgreSQL URL. +// It assumes storage.AdminStorage is backed by the same PostgreSQL database as well. +func NewLogStorage(db *sql.DB, mf monitoring.MetricFactory) storage.LogStorage { + if mf == nil { + mf = monitoring.InertMetricFactory{} + } + return &postgresLogStorage{ + admin: NewAdminStorage(db), + pgTreeStorage: newTreeStorage(db), + metricFactory: mf, + } +} + +func (m *postgresLogStorage) CheckDatabaseAccessible(ctx context.Context) error { + return m.db.PingContext(ctx) +} + +func (m *postgresLogStorage) getLeavesByIndexStmt(ctx context.Context, num int) (*sql.Stmt, error) { + stmt := &statementSkeleton{ + sql: selectLeavesByIndexSQL, + firstInsertion: "%s", + firstPlaceholders: 1, + restInsertion: "%s", + restPlaceholders: 1, + num: num, + } + return m.getStmt(ctx, stmt) +} + +func (m *postgresLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sql.Stmt, error) { + if orderBySequence { + + orderByStmt := &statementSkeleton{ + sql: selectLeavesByMerkleHashOrderedBySequenceSQL, + firstInsertion: "%s", + firstPlaceholders: 1, + restInsertion: "%s", + restPlaceholders: 1, + num: num, + } + + return m.getStmt(ctx, orderByStmt) + } + + merkleHashStmt := &statementSkeleton{ + sql: selectLeavesByMerkleHashSQL, + firstInsertion: "%s", + firstPlaceholders: 1, + restInsertion: "%s", + restPlaceholders: 1, + num: num, + } + + return m.getStmt(ctx, merkleHashStmt) +} + +func (m *postgresLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sql.Stmt, error) { + identityHashStmt := &statementSkeleton{ + sql: selectLeavesByLeafIdentityHashSQL, + firstInsertion: "%s", + firstPlaceholders: 1, + restInsertion: "%s", + restPlaceholders: 1, + num: num, + } + + return m.getStmt(ctx, identityHashStmt) +} + +// readOnlyLogTX implements storage.ReadOnlyLogTX +type readOnlyLogTX struct { + ls *postgresLogStorage + tx *sql.Tx +} + +func (m *postgresLogStorage) Snapshot(ctx context.Context) (storage.ReadOnlyLogTX, error) { + tx, err := m.db.BeginTx(ctx, nil /* opts */) + if err != nil { + glog.Warningf("Could not start ReadOnlyLogTX: %s", err) + return nil, err + } + return &readOnlyLogTX{m, tx}, nil +} + +func (t *readOnlyLogTX) Commit() error { + return t.tx.Commit() +} + +func (t *readOnlyLogTX) Rollback() error { + return t.tx.Rollback() +} + +func (t *readOnlyLogTX) Close() error { + if err := t.Rollback(); err != nil && err != sql.ErrTxDone { + glog.Warningf("Rollback error on Close(): %v", err) + return err + } + return nil +} + +func (t *readOnlyLogTX) GetActiveLogIDs(ctx context.Context) ([]int64, error) { + // Include logs that are DRAINING in the active list as we're still + // integrating leaves into them. + rows, err := t.tx.QueryContext( + ctx, selectNonDeletedTreeIDByTypeAndStateSQL, + trillian.TreeType_LOG.String(), trillian.TreeType_PREORDERED_LOG.String(), + trillian.TreeState_ACTIVE.String(), trillian.TreeState_DRAINING.String()) + if err != nil { + return nil, err + } + defer rows.Close() + ids := []int64{} + for rows.Next() { + var treeID int64 + if err := rows.Scan(&treeID); err != nil { + return nil, err + } + ids = append(ids, treeID) + } + return ids, rows.Err() +} + +func (m *postgresLogStorage) beginInternal(ctx context.Context, tree *trillian.Tree) (storage.LogTreeTX, error) { + once.Do(func() { + createMetrics(m.metricFactory) + }) + hasher, err := hashers.NewLogHasher(tree.HashStrategy) + if err != nil { + return nil, err + } + + stCache := cache.NewLogSubtreeCache(defaultLogStrata, hasher) + ttx, err := m.beginTreeTx(ctx, tree, hasher.Size(), stCache) + if err != nil && err != storage.ErrTreeNeedsInit { + return nil, err + } + + ltx := &logTreeTX{ + treeTX: ttx, + ls: m, + } + ltx.slr, err = ltx.fetchLatestRoot(ctx) + if err == storage.ErrTreeNeedsInit { + return ltx, err + } else if err != nil { + ttx.Rollback() + return nil, err + } + if err := ltx.root.UnmarshalBinary(ltx.slr.LogRoot); err != nil { + ttx.Rollback() + return nil, err + } + + ltx.treeTX.writeRevision = int64(ltx.root.Revision) + 1 + return ltx, nil +} + +func (m *postgresLogStorage) ReadWriteTransaction(ctx context.Context, tree *trillian.Tree, f storage.LogTXFunc) error { + tx, err := m.beginInternal(ctx, tree) + if err != nil && err != storage.ErrTreeNeedsInit { + return err + } + defer tx.Close() + if err := f(ctx, tx); err != nil { + return err + } + return tx.Commit() +} + +func (m *postgresLogStorage) AddSequencedLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) { + tx, err := m.beginInternal(ctx, tree) + if err != nil { + return nil, err + } + res, err := tx.AddSequencedLeaves(ctx, leaves, timestamp) + if err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + return res, nil +} + +func (m *postgresLogStorage) SnapshotForTree(ctx context.Context, tree *trillian.Tree) (storage.ReadOnlyLogTreeTX, error) { + tx, err := m.beginInternal(ctx, tree) + if err != nil && err != storage.ErrTreeNeedsInit { + return nil, err + } + return tx, err +} + +func (m *postgresLogStorage) QueueLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.QueuedLogLeaf, error) { + tx, err := m.beginInternal(ctx, tree) + if err != nil { + return nil, err + } + existing, err := tx.QueueLeaves(ctx, leaves, queueTimestamp) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + ret := make([]*trillian.QueuedLogLeaf, len(leaves)) + for i, e := range existing { + if e != nil { + ret[i] = &trillian.QueuedLogLeaf{ + Leaf: e, + Status: status.Newf(codes.AlreadyExists, "leaf already exists: %v", e.LeafIdentityHash).Proto(), + } + continue + } + ret[i] = &trillian.QueuedLogLeaf{Leaf: leaves[i]} + } + return ret, nil +} + +type logTreeTX struct { + treeTX + ls *postgresLogStorage + root types.LogRootV1 + slr trillian.SignedLogRoot +} + +func (t *logTreeTX) ReadRevision(ctx context.Context) (int64, error) { + return int64(t.root.Revision), nil +} + +func (t *logTreeTX) WriteRevision(ctx context.Context) (int64, error) { + if t.treeTX.writeRevision < 0 { + return t.treeTX.writeRevision, errors.New("logTreeTX write revision not populated") + } + return t.treeTX.writeRevision, nil +} + +func (t *logTreeTX) DequeueLeaves(ctx context.Context, limit int, cutoffTime time.Time) ([]*trillian.LogLeaf, error) { + if t.treeType == trillian.TreeType_PREORDERED_LOG { + // TODO(pavelkalinnikov): Optimize this by fetching only the required + // fields of LogLeaf. We can avoid joining with LeafData table here. + return t.GetLeavesByRange(ctx, int64(t.root.TreeSize), int64(limit)) + } + + start := time.Now() + stx, err := t.tx.PrepareContext(ctx, selectQueuedLeavesSQL) + if err != nil { + glog.Warningf("Failed to prepare dequeue select: %s", err) + return nil, err + } + defer stx.Close() + + leaves := make([]*trillian.LogLeaf, 0, limit) + dq := make([]dequeuedLeaf, 0, limit) + rows, err := stx.QueryContext(ctx, t.treeID, cutoffTime.UnixNano(), limit) + if err != nil { + glog.Warningf("Failed to select rows for work: %s", err) + return nil, err + } + defer rows.Close() + + for rows.Next() { + leaf, dqInfo, err := t.dequeueLeaf(rows) + if err != nil { + glog.Warningf("Error dequeuing leaf: %v %v", err, selectQueuedLeavesSQL) + return nil, err + } + + if len(leaf.LeafIdentityHash) != t.hashSizeBytes { + return nil, errors.New("dequeued a leaf with incorrect hash size") + } + + leaves = append(leaves, leaf) + dq = append(dq, dqInfo) + } + + if rows.Err() != nil { + return nil, rows.Err() + } + label := labelForTX(t) + selectDuration := time.Since(start) + observe(dequeueSelectLatency, selectDuration, label) + + // The convention is that if leaf processing succeeds (by committing this tx) + // then the unsequenced entries for them are removed + if len(leaves) > 0 { + err = t.removeSequencedLeaves(ctx, dq) + } + + if err != nil { + return nil, err + } + + totalDuration := time.Since(start) + removeDuration := totalDuration - selectDuration + observe(dequeueRemoveLatency, removeDuration, label) + observe(dequeueLatency, totalDuration, label) + dequeuedCounter.Add(float64(len(leaves)), label) + + return leaves, nil +} + +// sortLeavesForInsert returns a slice containing the passed in leaves sorted +// by LeafIdentityHash, and paired with their original positions. +// QueueLeaves and AddSequencedLeaves use this to make the order that LeafData +// row locks are acquired deterministic and reduce the chance of deadlocks. +func sortLeavesForInsert(leaves []*trillian.LogLeaf) []leafAndPosition { + ordLeaves := make([]leafAndPosition, len(leaves)) + for i, leaf := range leaves { + ordLeaves[i] = leafAndPosition{leaf: leaf, idx: i} + } + sort.Sort(byLeafIdentityHashWithPosition(ordLeaves)) + return ordLeaves +} + +func (t *logTreeTX) QueueLeaves(ctx context.Context, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.LogLeaf, error) { + // Don't accept batches if any of the leaves are invalid. + for _, leaf := range leaves { + if len(leaf.LeafIdentityHash) != t.hashSizeBytes { + return nil, fmt.Errorf("queued leaf must have a leaf ID hash of length %d", t.hashSizeBytes) + } + var err error + leaf.QueueTimestamp, err = ptypes.TimestampProto(queueTimestamp) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + } + start := time.Now() + label := labelForTX(t) + + ordLeaves := sortLeavesForInsert(leaves) + existingCount := 0 + existingLeaves := make([]*trillian.LogLeaf, len(leaves)) + + for _, ol := range ordLeaves { + i, leaf := ol.idx, ol.leaf + + leafStart := time.Now() + qTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + dupCheckRow, err := t.tx.QueryContext(ctx, insertLeafDataSQL, t.treeID, leaf.LeafIdentityHash, leaf.LeafValue, leaf.ExtraData, qTimestamp.UnixNano()) + if err != nil { + return nil, fmt.Errorf("dupecheck failed: %v", err) + } + insertDuration := time.Since(leafStart) + observe(queueInsertLeafLatency, insertDuration, label) + resultData := false + for dupCheckRow.Next() { + err := dupCheckRow.Scan(&resultData) + if err != nil { + return nil, fmt.Errorf("dupecheck failed: %v", err) + } + if !resultData { + break + } + } + dupCheckRow.Close() + if !resultData { + // Remember the duplicate leaf, using the requested leaf for now. + existingLeaves[i] = leaf + existingCount++ + queuedDupCounter.Inc(label) + glog.Warningf("Found duplicate %v %v", t.treeID, leaf) + continue + } + + // Create the work queue entry + args := []interface{}{ + t.treeID, + leaf.LeafIdentityHash, + leaf.MerkleLeafHash, + } + queueTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + args = append(args, queueArgs(t.treeID, leaf.LeafIdentityHash, queueTimestamp)...) + _, err = t.tx.ExecContext( + ctx, + insertUnsequencedEntrySQL, + args..., + ) + if err != nil { + glog.Warningf("Error inserting into Unsequenced: %s query %v arguements: %v", err, insertUnsequencedEntrySQL, args) + return nil, fmt.Errorf("Unsequenced: %v -- %v", err, args) + } + leafDuration := time.Since(leafStart) + observe(queueInsertEntryLatency, (leafDuration - insertDuration), label) + } + insertDuration := time.Since(start) + observe(queueInsertLatency, insertDuration, label) + queuedCounter.Add(float64(len(leaves)), label) + + if existingCount == 0 { + return existingLeaves, nil + } + + // For existing leaves, we need to retrieve the contents. First collate the desired LeafIdentityHash values. + var toRetrieve [][]byte + for _, existing := range existingLeaves { + if existing != nil { + toRetrieve = append(toRetrieve, existing.LeafIdentityHash) + } + } + results, err := t.getLeafDataByIdentityHash(ctx, toRetrieve) + if err != nil { + return nil, fmt.Errorf("failed to retrieve existing leaves: %v %v", err, toRetrieve) + } + if len(results) != len(toRetrieve) { + return nil, fmt.Errorf("failed to retrieve all existing leaves: got %d, want %d", len(results), len(toRetrieve)) + } + // Replace the requested leaves with the actual leaves. + for i, requested := range existingLeaves { + if requested == nil { + continue + } + found := false + for _, result := range results { + if bytes.Equal(result.LeafIdentityHash, requested.LeafIdentityHash) { + existingLeaves[i] = result + found = true + break + } + } + if !found { + return nil, fmt.Errorf("failed to find existing leaf for hash %x", requested.LeafIdentityHash) + } + } + totalDuration := time.Since(start) + readDuration := totalDuration - insertDuration + observe(queueReadLatency, readDuration, label) + observe(queueLatency, totalDuration, label) + + return existingLeaves, nil +} + +func (t *logTreeTX) AddSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) { + res := make([]*trillian.QueuedLogLeaf, len(leaves)) + ok := status.New(codes.OK, "OK").Proto() + + // Leaves in this transaction are inserted in two tables. For each leaf, if + // one of the two inserts fails, we remove the side effect by rolling back to + // a savepoint installed before the first insert of the two. + const savepoint = "SAVEPOINT AddSequencedLeaves" + if _, err := t.tx.ExecContext(ctx, savepoint); err != nil { + glog.Errorf("Error adding savepoint: %s", err) + return nil, err + } + // TODO(pavelkalinnikov): Consider performance implication of executing this + // extra SAVEPOINT, especially for 1-entry batches. Optimize if necessary. + + // Note: LeafData inserts are presumably protected from deadlocks due to + // sorting, but the order of the corresponding SequencedLeafData inserts + // becomes indeterministic. However, in a typical case when leaves are + // supplied in contiguous non-intersecting batches, the chance of having + // circular dependencies between transactions is significantly lower. + ordLeaves := sortLeavesForInsert(leaves) + for _, ol := range ordLeaves { + i, leaf := ol.idx, ol.leaf + + // This should fail on insert, but catch it early. + if got, want := len(leaf.LeafIdentityHash), t.hashSizeBytes; got != want { + return nil, status.Errorf(codes.FailedPrecondition, "leaves[%d] has incorrect hash size %d, want %d", i, got, want) + } + + if _, err := t.tx.ExecContext(ctx, savepoint); err != nil { + glog.Errorf("Error updating savepoint: %s", err) + return nil, err + } + + res[i] = &trillian.QueuedLogLeaf{Status: ok} + + // TODO(pavelkalinnikov): Measure latencies. + _, err := t.tx.ExecContext(ctx, insertLeafDataSQL, + t.treeID, leaf.LeafIdentityHash, leaf.LeafValue, leaf.ExtraData, timestamp.UnixNano()) + // TODO(pavelkalinnikov): Detach PREORDERED_LOG integration latency metric. + if err != nil { + glog.Errorf("Error inserting leaves[%d] into LeafData: %s", i, err) + return nil, err + } + + dupCheckRow, err := t.tx.QueryContext(ctx, insertSequencedLeafSQL, + t.treeID, leaf.LeafIndex, leaf.LeafIdentityHash, leaf.MerkleLeafHash, 0) + // TODO(pavelkalinnikov): Update IntegrateTimestamp on integrating the leaf. + resultData := true + for dupCheckRow.Next() { + dupCheckRow.Scan(&resultData) + if !resultData { + break + } + } + dupCheckRow.Close() + if !resultData { + res[i].Status = status.New(codes.FailedPrecondition, "conflicting LeafIndex").Proto() + if _, err := t.tx.ExecContext(ctx, "ROLLBACK TO "+savepoint); err != nil { + glog.Errorf("Error rolling back to savepoint: %s", err) + return nil, err + } + } else if err != nil { + glog.Errorf("Error inserting leaves[%d] into SequencedLeafData: %s %s", i, err, leaf.LeafIdentityHash) + return nil, err + } + + // TODO(pavelkalinnikov): Load LeafData for conflicting entries. + } + + if _, err := t.tx.ExecContext(ctx, "RELEASE "+savepoint); err != nil { + glog.Errorf("Error releasing savepoint: %s", err) + return nil, err + } + + return res, nil +} + +func (t *logTreeTX) GetSequencedLeafCount(ctx context.Context) (int64, error) { + var sequencedLeafCount int64 + + err := t.tx.QueryRowContext(ctx, selectSequencedLeafCountSQL, t.treeID).Scan(&sequencedLeafCount) + if err != nil { + glog.Warningf("Error getting sequenced leaf count: %s", err) + } + + return sequencedLeafCount, err +} + +func (t *logTreeTX) GetLeavesByIndex(ctx context.Context, leaves []int64) ([]*trillian.LogLeaf, error) { + if t.treeType == trillian.TreeType_LOG { + treeSize := int64(t.root.TreeSize) + for _, leaf := range leaves { + if leaf < 0 { + return nil, status.Errorf(codes.InvalidArgument, "index %d is < 0", leaf) + } + if leaf >= treeSize { + return nil, status.Errorf(codes.OutOfRange, "invalid leaf index %d, want < TreeSize(%d)", leaf, treeSize) + } + } + } + tmpl, err := t.ls.getLeavesByIndexStmt(ctx, len(leaves)) + if err != nil { + return nil, err + } + stx := t.tx.StmtContext(ctx, tmpl) + defer stx.Close() + + var args []interface{} + for _, nodeID := range leaves { + args = append(args, interface{}(int64(nodeID))) + } + args = append(args, interface{}(t.treeID)) + rows, err := stx.QueryContext(ctx, args...) + if err != nil { + glog.Warningf("Failed to get leaves by idx: %s", err) + return nil, err + } + defer rows.Close() + + ret := make([]*trillian.LogLeaf, 0, len(leaves)) + for rows.Next() { + leaf := &trillian.LogLeaf{} + var qTimestamp, iTimestamp int64 + if err := rows.Scan( + &leaf.MerkleLeafHash, + &leaf.LeafIdentityHash, + &leaf.LeafValue, + &leaf.LeafIndex, + &leaf.ExtraData, + &qTimestamp, + &iTimestamp); err != nil { + glog.Warningf("Failed to scan merkle leaves: %s", err) + return nil, err + } + var err error + leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, qTimestamp)) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, iTimestamp)) + if err != nil { + return nil, fmt.Errorf("got invalid integrate timestamp: %v", err) + } + ret = append(ret, leaf) + } + + if got, want := len(ret), len(leaves); got != want { + return nil, status.Errorf(codes.Internal, "len(ret): %d, want %d", got, want) + } + return ret, nil +} + +func (t *logTreeTX) GetLeavesByRange(ctx context.Context, start, count int64) ([]*trillian.LogLeaf, error) { + if count <= 0 { + return nil, status.Errorf(codes.InvalidArgument, "invalid count %d, want > 0", count) + } + if start < 0 { + return nil, status.Errorf(codes.InvalidArgument, "invalid start %d, want >= 0", start) + } + + if t.treeType == trillian.TreeType_LOG { + treeSize := int64(t.root.TreeSize) + if treeSize <= 0 { + return nil, status.Errorf(codes.OutOfRange, "empty tree") + } else if start >= treeSize { + return nil, status.Errorf(codes.OutOfRange, "invalid start %d, want < TreeSize(%d)", start, treeSize) + } + // Ensure no entries queried/returned beyond the tree. + if maxCount := treeSize - start; count > maxCount { + count = maxCount + } + } + // TODO(pavelkalinnikov): Further clip `count` to a safe upper bound like 64k. + + args := []interface{}{start, start + count, t.treeID} + rows, err := t.tx.QueryContext(ctx, selectLeavesByRangeSQL, args...) + if err != nil { + glog.Warningf("Failed to get leaves by range: %s", err) + return nil, err + } + defer rows.Close() + + ret := make([]*trillian.LogLeaf, 0, count) + for wantIndex := start; rows.Next(); wantIndex++ { + leaf := &trillian.LogLeaf{} + var qTimestamp, iTimestamp int64 + if err := rows.Scan( + &leaf.MerkleLeafHash, + &leaf.LeafIdentityHash, + &leaf.LeafValue, + &leaf.LeafIndex, + &leaf.ExtraData, + &qTimestamp, + &iTimestamp); err != nil { + glog.Warningf("Failed to scan merkle leaves: %s", err) + return nil, err + } + if leaf.LeafIndex != wantIndex { + if wantIndex < int64(t.root.TreeSize) { + return nil, fmt.Errorf("got unexpected index %d, want %d", leaf.LeafIndex, wantIndex) + } + break + } + var err error + leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, qTimestamp)) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, iTimestamp)) + if err != nil { + return nil, fmt.Errorf("got invalid integrate timestamp: %v", err) + } + ret = append(ret, leaf) + } + + return ret, nil +} + +func (t *logTreeTX) GetLeavesByHash(ctx context.Context, leafHashes [][]byte, orderBySequence bool) ([]*trillian.LogLeaf, error) { + tmpl, err := t.ls.getLeavesByMerkleHashStmt(ctx, len(leafHashes), orderBySequence) + if err != nil { + return nil, err + } + + return t.getLeavesByHashInternal(ctx, leafHashes, tmpl, "merkle") +} + +// getLeafDataByIdentityHash retrieves leaf data by LeafIdentityHash, returned +// as a slice of LogLeaf objects for convenience. However, note that the +// returned LogLeaf objects will not have a valid MerkleLeafHash, LeafIndex, or IntegrateTimestamp. +func (t *logTreeTX) getLeafDataByIdentityHash(ctx context.Context, leafHashes [][]byte) ([]*trillian.LogLeaf, error) { + tmpl, err := t.ls.getLeavesByLeafIdentityHashStmt(ctx, len(leafHashes)) + if err != nil { + return nil, err + } + return t.getLeavesByHashInternal(ctx, leafHashes, tmpl, "leaf-identity") +} + +func (t *logTreeTX) LatestSignedLogRoot(ctx context.Context) (trillian.SignedLogRoot, error) { + return t.slr, nil +} + +// fetchLatestRoot reads the latest SignedLogRoot from the DB and returns it. +func (t *logTreeTX) fetchLatestRoot(ctx context.Context) (trillian.SignedLogRoot, error) { + // var timestamp, treeSize, treeRevision int64 + var rootSignatureBytes []byte + var jsonObj []byte + + t.tx.QueryRowContext( + ctx, + "select current_tree_data,root_signature from trees where tree_id = $1", + t.treeID).Scan(&jsonObj, &rootSignatureBytes) + if jsonObj == nil { //this fixes the createtree workflow + return trillian.SignedLogRoot{}, storage.ErrTreeNeedsInit + } + var logRoot types.LogRootV1 + json.Unmarshal(jsonObj, &logRoot) + newRoot, _ := logRoot.MarshalBinary() + return trillian.SignedLogRoot{ + KeyHint: types.SerializeKeyHint(t.treeID), + LogRoot: newRoot, + LogRootSignature: rootSignatureBytes, + }, nil +} + +func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, root trillian.SignedLogRoot) error { + var logRoot types.LogRootV1 + if err := logRoot.UnmarshalBinary(root.LogRoot); err != nil { + glog.Warningf("Failed to parse log root: %x %v", root.LogRoot, err) + return err + } + if len(logRoot.Metadata) != 0 { + return fmt.Errorf("unimplemented: postgres storage does not support log root metadata") + + } + //get a json copy of the tree_head + data, _ := json.Marshal(logRoot) + t.tx.ExecContext( + ctx, + "update trees set current_tree_data = $1,root_signature = $2 where tree_id = $3", + data, + root.LogRootSignature, + t.treeID) + res, err := t.tx.ExecContext( + ctx, + insertTreeHeadSQL, + t.treeID, + logRoot.TimestampNanos, + logRoot.TreeSize, + logRoot.RootHash, + logRoot.Revision, + root.LogRootSignature) + if err != nil { + glog.Warningf("Failed to store signed root: %s", err) + } + + return checkResultOkAndRowCountIs(res, err, 1) +} + +func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sql.Stmt, desc string) ([]*trillian.LogLeaf, error) { + stx := t.tx.StmtContext(ctx, tmpl) + defer stx.Close() + + var args []interface{} + for _, hash := range leafHashes { + args = append(args, interface{}([]byte(hash))) + } + args = append(args, interface{}(t.treeID)) + rows, err := stx.QueryContext(ctx, args...) + if err != nil { + glog.Warningf("Query() %s hash = %v", desc, err) + return nil, err + } + defer rows.Close() + + // The tree could include duplicates so we don't know how many results will be returned + var ret []*trillian.LogLeaf + for rows.Next() { + leaf := &trillian.LogLeaf{} + // We might be using a LEFT JOIN in our statement, so leaves which are + // queued but not yet integrated will have a NULL IntegrateTimestamp + // when there's no corresponding entry in SequencedLeafData, even though + // the table definition forbids that, so we use a nullable type here and + // check its validity below. + var integrateTS sql.NullInt64 + var queueTS int64 + + if err := rows.Scan(&leaf.MerkleLeafHash, &leaf.LeafIdentityHash, &leaf.LeafValue, &leaf.LeafIndex, &leaf.ExtraData, &queueTS, &integrateTS); err != nil { + glog.Warningf("LogID: %d Scan() %s = %s", t.treeID, desc, err) + return nil, err + } + var err error + leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, queueTS)) + if err != nil { + return nil, fmt.Errorf("got invalid queue timestamp: %v", err) + } + if integrateTS.Valid { + leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, integrateTS.Int64)) + if err != nil { + return nil, fmt.Errorf("got invalid integrate timestamp: %v", err) + } + } + + if got, want := len(leaf.MerkleLeafHash), t.hashSizeBytes; got != want { + return nil, fmt.Errorf("LogID: %d Scanned leaf %s does not have hash length %d, got %d", t.treeID, desc, want, got) + } + + ret = append(ret, leaf) + } + + return ret, nil +} + +func (t *readOnlyLogTX) GetUnsequencedCounts(ctx context.Context) (storage.CountByLogID, error) { + stx, err := t.tx.PrepareContext(ctx, selectUnsequencedLeafCountSQL) + if err != nil { + glog.Warningf("Failed to prep unsequenced leaf count statement: %v", err) + return nil, err + } + defer stx.Close() + + rows, err := stx.QueryContext(ctx) + if err != nil { + return nil, err + } + defer rows.Close() + + ret := make(map[int64]int64) + for rows.Next() { + var logID, count int64 + if err := rows.Scan(&logID, &count); err != nil { + return nil, fmt.Errorf("failed to scan row from unsequenced counts: %v", err) + } + ret[logID] = count + } + return ret, nil +} + +// leafAndPosition records original position before sort. +type leafAndPosition struct { + leaf *trillian.LogLeaf + idx int +} + +// byLeafIdentityHashWithPosition allows sorting (as above), but where we need +// to remember the original position +type byLeafIdentityHashWithPosition []leafAndPosition + +func (l byLeafIdentityHashWithPosition) Len() int { + return len(l) +} +func (l byLeafIdentityHashWithPosition) Swap(i, j int) { + l[i], l[j] = l[j], l[i] +} +func (l byLeafIdentityHashWithPosition) Less(i, j int) bool { + return bytes.Compare(l[i].leaf.LeafIdentityHash, l[j].leaf.LeafIdentityHash) == -1 +} diff --git a/storage/postgres/log_storage_test.go b/storage/postgres/log_storage_test.go new file mode 100644 index 0000000000..e0413cbefd --- /dev/null +++ b/storage/postgres/log_storage_test.go @@ -0,0 +1,1364 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "bytes" + "context" + "crypto" + "crypto/sha256" + "database/sql" + "fmt" + "reflect" + "sort" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/google/trillian" + "github.com/google/trillian/storage" + "github.com/google/trillian/storage/testonly" + "github.com/google/trillian/types" + "github.com/kylelemons/godebug/pretty" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + tcrypto "github.com/google/trillian/crypto" + ttestonly "github.com/google/trillian/testonly" + + _ "github.com/lib/pq" +) + +// Must be 32 bytes to match sha256 length if it was a real hash +var dummyHash = []byte("hashxxxxhashxxxxhashxxxxhashxxxx") +var dummyRawHash = []byte("xxxxhashxxxxhashxxxxhashxxxxhash") +var dummyRawHash2 = []byte("yyyyhashyyyyhashyyyyhashyyyyhash") +var dummyHash2 = []byte("HASHxxxxhashxxxxhashxxxxhashxxxx") +var dummyHash3 = []byte("hashxxxxhashxxxxhashxxxxHASHxxxx") + +// Time we will queue all leaves at +var fakeQueueTime = time.Date(2016, 11, 10, 15, 16, 27, 0, time.UTC) + +// Time we will integrate all leaves at +var fakeIntegrateTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC) + +// Time we'll request for guard cutoff in tests that don't test this (should include all above) +var fakeDequeueCutoffTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC) + +// Used for tests involving extra data +var someExtraData = []byte("Some extra data") +var someExtraData2 = []byte("Some even more extra data") + +const leavesToInsert = 5 +const sequenceNumber int64 = 237 + +// Tests that access the db should each use a distinct log ID to prevent lock contention when +// run in parallel or race conditions / unexpected interactions. Tests that pass should hold +// no locks afterwards. + +func createFakeLeaf(ctx context.Context, db *sql.DB, logID int64, rawHash, hash, data, extraData []byte, seq int64, t *testing.T) *trillian.LogLeaf { + t.Helper() + queuedAtNanos := fakeQueueTime.UnixNano() + integratedAtNanos := fakeIntegrateTime.UnixNano() + _, err := db.ExecContext(ctx, "select * from insert_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)", logID, rawHash, data, extraData, queuedAtNanos) + _, err2 := db.ExecContext(ctx, "select * from insert_sequenced_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)", logID, seq, rawHash, hash, integratedAtNanos) + + if err != nil || err2 != nil { + t.Fatalf("Failed to create test leaves: %v %v", err, err2) + } + if err != nil { + panic(err) + } + integrateTimestamp, err := ptypes.TimestampProto(fakeIntegrateTime) + if err != nil { + panic(err) + } + return &trillian.LogLeaf{ + MerkleLeafHash: hash, + LeafValue: data, + ExtraData: extraData, + LeafIndex: seq, + LeafIdentityHash: rawHash, + IntegrateTimestamp: integrateTimestamp, + } +} + +func checkLeafContents(leaf *trillian.LogLeaf, seq int64, rawHash, hash, data, extraData []byte, t *testing.T) { + t.Helper() + if got, want := leaf.MerkleLeafHash, hash; !bytes.Equal(got, want) { + t.Fatalf("Wrong leaf hash in returned leaf got\n%v\nwant:\n%v", got, want) + } + + if got, want := leaf.LeafIdentityHash, rawHash; !bytes.Equal(got, want) { + t.Fatalf("Wrong raw leaf hash in returned leaf got\n%v\nwant:\n%v", got, want) + } + + if got, want := seq, leaf.LeafIndex; got != want { + t.Fatalf("Bad sequence number in returned leaf got: %d, want:%d", got, want) + } + + if got, want := leaf.LeafValue, data; !bytes.Equal(got, want) { + t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want) + } + + if got, want := leaf.ExtraData, extraData; !bytes.Equal(got, want) { + t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want) + } + + iTime, err := ptypes.Timestamp(leaf.IntegrateTimestamp) + if err != nil { + t.Fatalf("Got invalid integrate timestamp: %v", err) + } + if got, want := iTime.UnixNano(), fakeIntegrateTime.UnixNano(); got != want { + t.Errorf("Wrong IntegrateTimestamp: got %v, want %v", got, want) + } +} + +func TestMySQLLogStorage_CheckDatabaseAccessible(t *testing.T) { + cleanTestDB(db, t) + s := NewLogStorage(db, nil) + if err := s.CheckDatabaseAccessible(context.Background()); err != nil { + t.Errorf("CheckDatabaseAccessible() = %v, want = nil", err) + } +} + +func TestSnapshot(t *testing.T) { + cleanTestDB(db, t) + + frozenLog := createTreeOrPanic(db, testonly.LogTree) + createFakeSignedLogRoot(db, frozenLog, 0) + if _, err := updateTree(db, frozenLog.TreeId, func(tree *trillian.Tree) { + tree.TreeState = trillian.TreeState_FROZEN + }); err != nil { + t.Fatalf("Error updating frozen tree: %v", err) + } + + activeLog := createTreeOrPanic(db, testonly.LogTree) + createFakeSignedLogRoot(db, activeLog, 0) + mapTreeID := createTreeOrPanic(db, testonly.MapTree).TreeId + + tests := []struct { + desc string + tree *trillian.Tree + wantErr bool + }{ + { + desc: "unknownSnapshot", + tree: logTree(-1), + wantErr: true, + }, + { + desc: "activeLogSnapshot", + tree: activeLog, + }, + { + desc: "frozenSnapshot", + tree: frozenLog, + }, + { + desc: "mapSnapshot", + tree: logTree(mapTreeID), + wantErr: true, + }, + } + + ctx := context.Background() + s := NewLogStorage(db, nil) + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + tx, err := s.SnapshotForTree(ctx, test.tree) + + if err == storage.ErrTreeNeedsInit { + defer tx.Close() + } + + if hasErr := err != nil; hasErr != test.wantErr { + t.Fatalf("err = %q, wantErr = %v", err, test.wantErr) + } else if hasErr { + return + } + defer tx.Close() + + _, err = tx.LatestSignedLogRoot(ctx) + if err != nil { + t.Errorf("LatestSignedLogRoot() returned err = %v", err) + } + if err := tx.Commit(); err != nil { + t.Errorf("Commit() returned err = %v", err) + } + }) + } +} + +func TestReadWriteTransaction(t *testing.T) { + cleanTestDB(db, t) + activeLog := createTreeOrPanic(db, testonly.LogTree) + createFakeSignedLogRoot(db, activeLog, 0) + + tests := []struct { + desc string + tree *trillian.Tree + wantErr bool + wantLogRoot []byte + wantTXRev int64 + }{ + { + // Unknown logs IDs are now handled outside storage. + desc: "unknownBegin", + tree: logTree(-1), + wantLogRoot: nil, + wantTXRev: -1, + }, + { + desc: "activeLogBegin", + tree: activeLog, + wantLogRoot: func() []byte { + b, err := (&types.LogRootV1{RootHash: []byte{0}}).MarshalBinary() + if err != nil { + panic(err) + } + return b + }(), + wantTXRev: 1, + }, + } + + ctx := context.Background() + s := NewLogStorage(db, nil) + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + err := s.ReadWriteTransaction(ctx, test.tree, func(ctx context.Context, tx storage.LogTreeTX) error { + root, err := tx.LatestSignedLogRoot(ctx) + if err != nil { + t.Fatalf("%v: LatestSignedLogRoot() returned err = %v", test.desc, err) + } + gotRev, _ := tx.WriteRevision(ctx) + if gotRev != test.wantTXRev { + t.Errorf("%v: WriteRevision() = %v, want = %v", test.desc, gotRev, test.wantTXRev) + } + if got, want := root.LogRoot, test.wantLogRoot; !bytes.Equal(got, want) { + t.Errorf("%v: LogRoot: \n%x, want \n%x", test.desc, got, want) + } + return nil + }) + if hasErr := err != nil; hasErr != test.wantErr { + t.Fatalf("%v: err = %q, wantErr = %v", test.desc, err, test.wantErr) + } else if hasErr { + return + } + }) + } +} + +func TestQueueDuplicateLeaf(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + count := 15 + leaves := createTestLeaves(int64(count), 10) + leaves2 := createTestLeaves(int64(count), 12) + leaves3 := createTestLeaves(3, 100) + + // Note that tests accumulate queued leaves on top of each other. + var tests = []struct { + desc string + leaves []*trillian.LogLeaf + want []*trillian.LogLeaf + }{ + { + desc: "[10, 11, 12, ...]", + leaves: leaves, + want: make([]*trillian.LogLeaf, count), + }, + { + desc: "[12, 13, 14, ...] so first (count-2) are duplicates", + leaves: leaves2, + want: append(leaves[2:], nil, nil), + }, + { + desc: "[10, 100, 11, 101, 102] so [dup, new, dup, new, dup]", + leaves: []*trillian.LogLeaf{leaves[0], leaves3[0], leaves[1], leaves3[1], leaves[2]}, + want: []*trillian.LogLeaf{leaves[0], nil, leaves[1], nil, leaves[2]}, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + existing, err := tx.QueueLeaves(ctx, test.leaves, fakeQueueTime) + if err != nil { + t.Errorf("Failed to queue leaves: %v", err) + return err + } + + if len(existing) != len(test.want) { + t.Fatalf("|QueueLeaves()|=%d; want %d", len(existing), len(test.want)) + } + for i, want := range test.want { + got := existing[i] + if want == nil { + if got != nil { + t.Fatalf("QueueLeaves()[%d]=%v; want nil", i, got) + } + return nil + } + if got == nil { + t.Fatalf("QueueLeaves()[%d]=nil; want non-nil", i) + } else if !bytes.Equal(got.LeafIdentityHash, want.LeafIdentityHash) { + t.Fatalf("QueueLeaves()[%d].LeafIdentityHash=%x; want %x", i, got.LeafIdentityHash, want.LeafIdentityHash) + } + } + return nil + }) + }) + } +} + +func TestQueueLeaves(t *testing.T) { + ctx := context.Background() + + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves := createTestLeaves(leavesToInsert, 20) + if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil { + t.Fatalf("Failed to queue leaves: %v", err) + } + return nil + }) + + // Should see the leaves in the database. There is no API to read from the unsequenced data. + var count int + if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM unsequenced WHERE Tree_id=$1", tree.TreeId).Scan(&count); err != nil { + t.Fatalf("Could not query row count: %v", err) + } + if leavesToInsert != count { + t.Fatalf("Expected %d unsequenced rows but got: %d", leavesToInsert, count) + } + + // Additional check on timestamp being set correctly in the database + var queueTimestamp int64 + if err := db.QueryRowContext(ctx, "SELECT DISTINCT queue_timestamp_nanos FROM unsequenced WHERE tree_id=$1", tree.TreeId).Scan(&queueTimestamp); err != nil { + t.Fatalf("Could not query timestamp: %v", err) + } + if got, want := queueTimestamp, fakeQueueTime.UnixNano(); got != want { + t.Fatalf("Incorrect queue timestamp got: %d want: %d", got, want) + } +} + +// AddSequencedLeaves tests. --------------------------------------------------- + +type addSequencedLeavesTest struct { + t *testing.T + s storage.LogStorage + tree *trillian.Tree +} + +func initAddSequencedLeavesTest(t *testing.T) addSequencedLeavesTest { + cleanTestDB(db, t) + s := NewLogStorage(db, nil) + tree := createTreeOrPanic(db, testonly.PreorderedLogTree) + return addSequencedLeavesTest{t, s, tree} +} + +func (t *addSequencedLeavesTest) addSequencedLeaves(leaves []*trillian.LogLeaf) { + runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error { + if _, err := tx.AddSequencedLeaves(ctx, leaves, fakeQueueTime); err != nil { + t.t.Fatalf("Failed to add sequenced leaves: %v", err) + } + // TODO(pavelkalinnikov): Verify returned status for each leaf. + return nil + }) +} + +func (t *addSequencedLeavesTest) verifySequencedLeaves(start, count int64, exp []*trillian.LogLeaf) { + var stored []*trillian.LogLeaf + runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error { + var err error + stored, err = tx.GetLeavesByRange(ctx, start, count) + if err != nil { + t.t.Fatalf("Failed to read sequenced leaves: %v", err) + } + return nil + }) + if got, want := len(stored), len(exp); got != want { + t.t.Fatalf("Unexpected number of leaves: got %d, want %d %d %d %v", got, want, start, count, exp) + } + + for i, leaf := range stored { + if got, want := leaf.LeafIndex, exp[i].LeafIndex; got != want { + t.t.Fatalf("Leaf #%d: LeafIndex=%v, want %v", i, got, want) + } + if got, want := leaf.LeafIdentityHash, exp[i].LeafIdentityHash; !bytes.Equal(got, want) { + t.t.Fatalf("Leaf #%d: LeafIdentityHash=%v, want %v %d %d %v", i, got, want, start, count, t.tree) + } + } +} + +func TestAddSequencedLeavesUnordered(t *testing.T) { + const chunk = leavesToInsert + const count = chunk * 5 + const extraCount = 16 + leaves := createTestLeaves(count, 0) + + aslt := initAddSequencedLeavesTest(t) + for _, idx := range []int{1, 0, 4, 2} { + aslt.addSequencedLeaves(leaves[chunk*idx : chunk*(idx+1)]) + } + aslt.verifySequencedLeaves(0, count+extraCount, leaves[:chunk*3]) + aslt.verifySequencedLeaves(chunk*4, chunk+extraCount, leaves[chunk*4:count]) + aslt.addSequencedLeaves(leaves[chunk*3 : chunk*4]) + aslt.verifySequencedLeaves(0, count+extraCount, leaves) +} + +func TestAddSequencedLeavesWithDuplicates(t *testing.T) { + leaves := createTestLeaves(6, 0) + + aslt := initAddSequencedLeavesTest(t) + aslt.addSequencedLeaves(leaves[:3]) + aslt.verifySequencedLeaves(0, 3, leaves[:3]) + aslt.addSequencedLeaves(leaves[2:]) // Full dup. + aslt.verifySequencedLeaves(0, 6, leaves) + + dupLeaves := createTestLeaves(4, 6) + dupLeaves[0].LeafIdentityHash = leaves[0].LeafIdentityHash // Hash dup. + dupLeaves[2].LeafIndex = 2 // Index dup. + aslt.addSequencedLeaves(dupLeaves) + aslt.verifySequencedLeaves(6, 4, dupLeaves[0:2]) + aslt.verifySequencedLeaves(7, 4, dupLeaves[1:2]) + aslt.verifySequencedLeaves(8, 4, nil) + aslt.verifySequencedLeaves(9, 4, dupLeaves[3:4]) + dupLeaves = createTestLeaves(4, 6) + aslt.addSequencedLeaves(dupLeaves) +} + +// ----------------------------------------------------------------------------- + +func TestDequeueLeavesNoneQueued(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves, err := tx.DequeueLeaves(ctx, 999, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Didn't expect an error on dequeue with no work to be done: %v", err) + } + if len(leaves) > 0 { + t.Fatalf("Expected nothing to be dequeued but we got %d leaves", len(leaves)) + } + return nil + }) +} + +func TestDequeueLeaves(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves := createTestLeaves(leavesToInsert, 20) + if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil { + t.Fatalf("Failed to queue leaves: %v", err) + } + return nil + }) + } + + { + // Now try to dequeue them + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves2) != leavesToInsert { + t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert) + } + ensureAllLeavesDistinct(leaves2, t) + return nil + }) + } + + { + // If we dequeue again then we should now get nothing + runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error { + leaves3, err := tx3.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves (second time): %v", err) + } + if len(leaves3) != 0 { + t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves3)) + } + return nil + }) + } +} + +func TestDequeueLeavesHaveQueueTimestamp(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves := createTestLeaves(leavesToInsert, 20) + if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil { + t.Fatalf("Failed to queue leaves: %v", err) + } + return nil + }) + } + + { + // Now try to dequeue them + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves2) != leavesToInsert { + t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert) + } + ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime) + return nil + }) + } +} + +func TestDequeueLeavesTwoBatches(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + leavesToDequeue1 := 3 + leavesToDequeue2 := 2 + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves := createTestLeaves(leavesToInsert, 20) + if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil { + t.Fatalf("Failed to queue leaves: %v", err) + } + return nil + }) + } + + var err error + var leaves2, leaves3, leaves4 []*trillian.LogLeaf + { + // Now try to dequeue some of them + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + leaves2, err = tx2.DequeueLeaves(ctx, leavesToDequeue1, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves2) != leavesToDequeue1 { + t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert) + } + ensureAllLeavesDistinct(leaves2, t) + ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime) + return nil + }) + + // Now try to dequeue the rest of them + runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error { + leaves3, err = tx3.DequeueLeaves(ctx, leavesToDequeue2, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves3) != leavesToDequeue2 { + t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves3), leavesToDequeue2) + } + ensureAllLeavesDistinct(leaves3, t) + ensureLeavesHaveQueueTimestamp(t, leaves3, fakeDequeueCutoffTime) + + // Plus the union of the leaf batches should all have distinct hashes + leaves4 = append(leaves2, leaves3...) + ensureAllLeavesDistinct(leaves4, t) + return nil + }) + } + + { + // If we dequeue again then we should now get nothing + runLogTX(s, tree, t, func(ctx context.Context, tx4 storage.LogTreeTX) error { + leaves5, err := tx4.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime) + if err != nil { + t.Fatalf("Failed to dequeue leaves (second time): %v", err) + } + if len(leaves5) != 0 { + t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves5)) + } + return nil + }) + } +} + +// Queues leaves and attempts to dequeue before the guard cutoff allows it. This should +// return nothing. Then retry with an inclusive guard cutoff and ensure the leaves +// are returned. +func TestDequeueLeavesGuardInterval(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves := createTestLeaves(leavesToInsert, 20) + if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil { + t.Fatalf("Failed to queue leaves: %v", err) + } + return nil + }) + } + + { + // Now try to dequeue them using a cutoff that means we should get none + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(-time.Second)) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves2) != 0 { + t.Fatalf("Dequeued %d leaves when they all should be in guard interval", len(leaves2)) + } + + // Try to dequeue again using a cutoff that should include them + leaves2, err = tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(time.Second)) + if err != nil { + t.Fatalf("Failed to dequeue leaves: %v", err) + } + if len(leaves2) != leavesToInsert { + t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert) + } + ensureAllLeavesDistinct(leaves2, t) + return nil + }) + } +} + +func TestDequeueLeavesTimeOrdering(t *testing.T) { + // Queue two small batches of leaves at different timestamps. Do two separate dequeue + // transactions and make sure the returned leaves are respecting the time ordering of the + // queue. + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + batchSize := 2 + leaves := createTestLeaves(int64(batchSize), 0) + leaves2 := createTestLeaves(int64(batchSize), int64(batchSize)) + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil { + t.Fatalf("QueueLeaves(1st batch) = %v", err) + } + // These are one second earlier so should be dequeued first + if _, err := tx.QueueLeaves(ctx, leaves2, fakeQueueTime.Add(-time.Second)); err != nil { + t.Fatalf("QueueLeaves(2nd batch) = %v", err) + } + return nil + }) + } + + { + // Now try to dequeue two leaves and we should get the second batch + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + dequeue1, err := tx2.DequeueLeaves(ctx, batchSize, fakeQueueTime) + if err != nil { + t.Fatalf("DequeueLeaves(1st) = %v", err) + } + if got, want := len(dequeue1), batchSize; got != want { + t.Fatalf("Dequeue count mismatch (1st) got: %d, want: %d", got, want) + } + ensureAllLeavesDistinct(dequeue1, t) + + // Ensure this is the second batch queued by comparing leaf hashes (must be distinct as + // the leaf data was). + if !leafInBatch(dequeue1[0], leaves2) || !leafInBatch(dequeue1[1], leaves2) { + t.Fatalf("Got leaf from wrong batch (1st dequeue): %v", dequeue1) + } + return nil + }) + + // Try to dequeue again and we should get the batch that was queued first, though at a later time + runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error { + dequeue2, err := tx3.DequeueLeaves(ctx, batchSize, fakeQueueTime) + if err != nil { + t.Fatalf("DequeueLeaves(2nd) = %v", err) + } + if got, want := len(dequeue2), batchSize; got != want { + t.Fatalf("Dequeue count mismatch (2nd) got: %d, want: %d", got, want) + } + ensureAllLeavesDistinct(dequeue2, t) + + // Ensure this is the first batch by comparing leaf hashes. + if !leafInBatch(dequeue2[0], leaves) || !leafInBatch(dequeue2[1], leaves) { + t.Fatalf("Got leaf from wrong batch (2nd dequeue): %v", dequeue2) + } + return nil + }) + } +} + +func TestGetLeavesByHashNotPresent(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + hashes := [][]byte{[]byte("thisdoesn'texist")} + leaves, err := tx.GetLeavesByHash(ctx, hashes, false) + if err != nil { + t.Fatalf("Error getting leaves by hash: %v", err) + } + if len(leaves) != 0 { + t.Fatalf("Expected no leaves returned but got %d", len(leaves)) + } + return nil + }) +} + +func TestGetLeavesByHash(t *testing.T) { + ctx := context.Background() + + // Create fake leaf as if it had been sequenced + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + data := []byte("some data") + createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + hashes := [][]byte{dummyHash} + leaves, err := tx.GetLeavesByHash(ctx, hashes, false) + if err != nil { + t.Fatalf("Unexpected error getting leaf by hash: %v", err) + } + if len(leaves) != 1 { + t.Fatalf("Got %d leaves but expected one", len(leaves)) + } + checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t) + return nil + }) +} + +func TestGetLeavesByIndex(t *testing.T) { + ctx := context.Background() + + // Create fake leaf as if it had been sequenced, read it back and check contents + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + // The leaf indices are checked against the tree size so we need a root. + createFakeSignedLogRoot(db, tree, uint64(sequenceNumber+1)) + + data := []byte("some data") + data2 := []byte("some other data") + createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) + createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash2, dummyHash2, data2, someExtraData2, sequenceNumber-1, t) + + var tests = []struct { + desc string + indices []int64 + wantErr bool + wantCode codes.Code + checkFn func([]*trillian.LogLeaf, *testing.T) + }{ + { + desc: "InTree", + indices: []int64{sequenceNumber}, + checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) { + checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t) + }, + }, + { + desc: "InTree2", + indices: []int64{sequenceNumber - 1}, + wantErr: false, + checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) { + checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t) + }, + }, + { + desc: "InTreeMultiple", + indices: []int64{sequenceNumber - 1, sequenceNumber}, + checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) { + checkLeafContents(leaves[1], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t) + checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t) + }, + }, + { + desc: "InTreeMultipleReverse", + indices: []int64{sequenceNumber, sequenceNumber - 1}, + checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) { + checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t) + checkLeafContents(leaves[1], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t) + }, + }, { + desc: "OutsideTree", + indices: []int64{sequenceNumber + 1}, + wantErr: true, + wantCode: codes.OutOfRange, + }, + { + desc: "LongWayOutsideTree", + indices: []int64{9999}, + wantErr: true, + wantCode: codes.OutOfRange, + }, + { + desc: "MixedInOutTree", + indices: []int64{sequenceNumber, sequenceNumber + 1}, + wantErr: true, + wantCode: codes.OutOfRange, + }, + { + desc: "MixedInOutTree2", + indices: []int64{sequenceNumber - 1, sequenceNumber + 1}, + wantErr: true, + wantCode: codes.OutOfRange, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + got, err := tx.GetLeavesByIndex(ctx, test.indices) + if test.wantErr { + if err == nil || status.Code(err) != test.wantCode { + t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: nil, err with code %v", test.indices, got, err, test.wantCode) + } + } else { + if err != nil { + t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: got, nil", test.indices, got, err) + } + } + return nil + }) + }) + } +} + +// GetLeavesByRange tests. ----------------------------------------------------- + +type getLeavesByRangeTest struct { + start, count int64 + want []int64 + wantErr bool +} + +func testGetLeavesByRangeImpl(t *testing.T, create *trillian.Tree, tests []getLeavesByRangeTest) { + cleanTestDB(db, t) + + ctx := context.Background() + tree, err := createTree(db, create) + if err != nil { + t.Fatalf("Error creating log: %v", err) + } + // Note: GetLeavesByRange loads the root internally to get the tree size. + createFakeSignedLogRoot(db, tree, 14) + s := NewLogStorage(db, nil) + + // Create leaves [0]..[19] but drop leaf [5] and set the tree size to 14. + for i := int64(0); i < 20; i++ { + if i == 5 { + continue + } + data := []byte{byte(i)} + identityHash := sha256.Sum256(data) + createFakeLeaf(ctx, db, tree.TreeId, identityHash[:], identityHash[:], data, someExtraData, i, t) + } + + for _, test := range tests { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + leaves, err := tx.GetLeavesByRange(ctx, test.start, test.count) + if err != nil { + if !test.wantErr { + t.Errorf("GetLeavesByRange(%d, +%d)=_,%v; want _,nil", test.start, test.count, err) + } + return nil + } + if test.wantErr { + t.Errorf("GetLeavesByRange(%d, +%d)=_,nil; want _,non-nil", test.start, test.count) + } + got := make([]int64, len(leaves)) + for i, leaf := range leaves { + got[i] = leaf.LeafIndex + } + if !reflect.DeepEqual(got, test.want) { + t.Errorf("GetLeavesByRange(%d, +%d)=%+v; want %+v", test.start, test.count, got, test.want) + } + return nil + }) + } +} + +func TestGetLeavesByRangeFromLog(t *testing.T) { + var tests = []getLeavesByRangeTest{ + {start: 0, count: 1, want: []int64{0}}, + {start: 0, count: 2, want: []int64{0, 1}}, + {start: 1, count: 3, want: []int64{1, 2, 3}}, + {start: 10, count: 7, want: []int64{10, 11, 12, 13}}, + {start: 13, count: 1, want: []int64{13}}, + {start: 14, count: 4, wantErr: true}, // Starts right after tree size. + {start: 19, count: 2, wantErr: true}, // Starts further away. + {start: 3, count: 5, wantErr: true}, // Hits non-contiguous leaves. + {start: 5, count: 5, wantErr: true}, // Starts from a missing leaf. + {start: 1, count: 0, wantErr: true}, // Empty range. + {start: -1, count: 1, wantErr: true}, // Negative start. + {start: 1, count: -1, wantErr: true}, // Negative count. + {start: 100, count: 30, wantErr: true}, // Starts after all stored leaves. + } + testGetLeavesByRangeImpl(t, testonly.LogTree, tests) +} + +func TestGetLeavesByRangeFromPreorderedLog(t *testing.T) { + var tests = []getLeavesByRangeTest{ + {start: 0, count: 1, want: []int64{0}}, + {start: 0, count: 2, want: []int64{0, 1}}, + {start: 1, count: 3, want: []int64{1, 2, 3}}, + {start: 10, count: 7, want: []int64{10, 11, 12, 13, 14, 15, 16}}, + {start: 13, count: 1, want: []int64{13}}, + // Starts right after tree size. + {start: 14, count: 4, want: []int64{14, 15, 16, 17}}, + {start: 19, count: 2, want: []int64{19}}, // Starts further away. + {start: 3, count: 5, wantErr: true}, // Hits non-contiguous leaves. + {start: 5, count: 5, wantErr: true}, // Starts from a missing leaf. + {start: 1, count: 0, wantErr: true}, // Empty range. + {start: -1, count: 1, wantErr: true}, // Negative start. + {start: 1, count: -1, wantErr: true}, // Negative count. + {start: 100, count: 30, want: []int64{}}, // Starts after all stored leaves. + } + testGetLeavesByRangeImpl(t, testonly.PreorderedLogTree, tests) +} + +// ----------------------------------------------------------------------------- + +func TestLatestSignedRootNoneWritten(t *testing.T) { + ctx := context.Background() + + cleanTestDB(db, t) + tree, err := createTree(db, testonly.LogTree) + if err != nil { + t.Fatalf("createTree: %v", err) + } + s := NewLogStorage(db, nil) + + tx, err := s.SnapshotForTree(ctx, tree) + if err != storage.ErrTreeNeedsInit { + t.Fatalf("SnapshotForTree gave %v, want %v", err, storage.ErrTreeNeedsInit) + } + commit(tx, t) +} + +func TestLatestSignedLogRoot(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256) + root, err := signer.SignLogRoot(&types.LogRootV1{ + TimestampNanos: 98765, + TreeSize: 16, + Revision: 5, + RootHash: []byte(dummyHash), + }) + if err != nil { + t.Fatalf("SignLogRoot(): %v", err) + } + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + if err := tx.StoreSignedLogRoot(ctx, *root); err != nil { + t.Fatalf("Failed to store signed root: %v", err) + } + return nil + }) + + { + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + root2, err := tx2.LatestSignedLogRoot(ctx) + if err != nil { + t.Fatalf("Failed to read back new log root: %v", err) + } + if !proto.Equal(root, &root2) { + t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2) + } + return nil + }) + } +} + +func TestDuplicateSignedLogRoot(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256) + root, err := signer.SignLogRoot(&types.LogRootV1{ + TimestampNanos: 98765, + TreeSize: 16, + Revision: 5, + RootHash: []byte(dummyHash), + }) + if err != nil { + t.Fatalf("SignLogRoot(): %v", err) + } + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + if err := tx.StoreSignedLogRoot(ctx, *root); err != nil { + t.Fatalf("Failed to store signed root: %v", err) + } + // Shouldn't be able to do it again + // if err := tx.StoreSignedLogRoot(ctx, *root); err == nil { + // t.Fatal("Allowed duplicate signed root") + // } + return nil + }) +} + +func TestLogRootUpdate(t *testing.T) { + // Write two roots for a log and make sure the one with the newest timestamp supersedes + cleanTestDB(db, t) + tree := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256) + root, err := signer.SignLogRoot(&types.LogRootV1{ + TimestampNanos: 98765, + TreeSize: 16, + Revision: 5, + RootHash: []byte(dummyHash), + }) + if err != nil { + t.Fatalf("SignLogRoot(): %v", err) + } + root2, err := signer.SignLogRoot(&types.LogRootV1{ + TimestampNanos: 98766, + TreeSize: 16, + Revision: 6, + RootHash: []byte(dummyHash), + }) + if err != nil { + t.Fatalf("SignLogRoot(): %v", err) + } + + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + if err := tx.StoreSignedLogRoot(ctx, *root); err != nil { + t.Fatalf("Failed to store signed root: %v", err) + } + if err := tx.StoreSignedLogRoot(ctx, *root2); err != nil { + t.Fatalf("Failed to store signed root: %v", err) + } + return nil + }) + + runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error { + root3, err := tx2.LatestSignedLogRoot(ctx) + if err != nil { + t.Fatalf("Failed to read back new log root: %v", err) + } + if !proto.Equal(root2, &root3) { + t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2) + } + return nil + }) +} + +func TestGetActiveLogIDs(t *testing.T) { + ctx := context.Background() + + cleanTestDB(db, t) + admin := NewAdminStorage(db) + + // Create a few test trees + log1 := proto.Clone(testonly.LogTree).(*trillian.Tree) + log2 := proto.Clone(testonly.LogTree).(*trillian.Tree) + log3 := proto.Clone(testonly.PreorderedLogTree).(*trillian.Tree) + drainingLog := proto.Clone(testonly.LogTree).(*trillian.Tree) + frozenLog := proto.Clone(testonly.LogTree).(*trillian.Tree) + deletedLog := proto.Clone(testonly.LogTree).(*trillian.Tree) + map1 := proto.Clone(testonly.MapTree).(*trillian.Tree) + map2 := proto.Clone(testonly.MapTree).(*trillian.Tree) + deletedMap := proto.Clone(testonly.MapTree).(*trillian.Tree) + for _, tree := range []*trillian.Tree{log1, log2, log3, drainingLog, frozenLog, deletedLog, map1, map2, deletedMap} { + newTree, err := storage.CreateTree(ctx, admin, tree) + if err != nil { + t.Fatalf("CreateTree(%+v) returned err = %v", tree, err) + } + *tree = *newTree + } + + // FROZEN is not a valid initial state, so we have to update it separately. + if _, err := storage.UpdateTree(ctx, admin, frozenLog.TreeId, func(t *trillian.Tree) { + t.TreeState = trillian.TreeState_FROZEN + }); err != nil { + t.Fatalf("UpdateTree() returned err = %v", err) + } + // DRAINING is not a valid initial state, so we have to update it separately. + if _, err := storage.UpdateTree(ctx, admin, drainingLog.TreeId, func(t *trillian.Tree) { + t.TreeState = trillian.TreeState_DRAINING + }); err != nil { + t.Fatalf("UpdateTree() returned err = %v", err) + } + + // Update deleted trees accordingly + updateDeletedStmt, err := db.PrepareContext(ctx, "UPDATE Trees SET Deleted = $1 WHERE Tree_Id = $2") + if err != nil { + t.Fatalf("PrepareContext() returned err = %v", err) + } + defer updateDeletedStmt.Close() + for _, treeID := range []int64{deletedLog.TreeId, deletedMap.TreeId} { + if _, err := updateDeletedStmt.ExecContext(ctx, true, treeID); err != nil { + t.Fatalf("ExecContext(%v) returned err = %v", treeID, err) + } + } + + s := NewLogStorage(db, nil) + tx, err := s.Snapshot(ctx) + if err != nil { + t.Fatalf("Snapshot() returns err = %v", err) + } + defer tx.Close() + got, err := tx.GetActiveLogIDs(ctx) + if err != nil { + t.Fatalf("GetActiveLogIDs() returns err = %v", err) + } + if err := tx.Commit(); err != nil { + t.Errorf("Commit() returned err = %v", err) + } + + want := []int64{log1.TreeId, log2.TreeId, log3.TreeId, drainingLog.TreeId} + sort.Slice(got, func(i, j int) bool { return got[i] < got[j] }) + sort.Slice(want, func(i, j int) bool { return want[i] < want[j] }) + if diff := pretty.Compare(got, want); diff != "" { + t.Errorf("post-GetActiveLogIDs diff (-got +want):\n%v", diff) + } +} + +func TestGetActiveLogIDsEmpty(t *testing.T) { + ctx := context.Background() + + cleanTestDB(db, t) + s := NewLogStorage(db, nil) + + tx, err := s.Snapshot(context.Background()) + if err != nil { + t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err) + } + defer tx.Close() + ids, err := tx.GetActiveLogIDs(ctx) + if err != nil { + t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err) + } + if err := tx.Commit(); err != nil { + t.Errorf("Commit() = %v, want = nil", err) + } + + if got, want := len(ids), 0; got != want { + t.Errorf("GetActiveLogIDs(): got %v IDs, want = %v", got, want) + } +} + +func TestReadOnlyLogTX_Rollback(t *testing.T) { + ctx := context.Background() + cleanTestDB(db, t) + s := NewLogStorage(db, nil) + tx, err := s.Snapshot(ctx) + if err != nil { + t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err) + } + defer tx.Close() + if _, err := tx.GetActiveLogIDs(ctx); err != nil { + t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err) + } + // It's a bit hard to have a more meaningful test. This should suffice. + if err := tx.Rollback(); err != nil { + t.Errorf("Rollback() = (_, %v), want = (_, nil)", err) + } +} + +func TestGetSequencedLeafCount(t *testing.T) { + ctx := context.Background() + + // We'll create leaves for two different trees + cleanTestDB(db, t) + log1 := createTreeOrPanic(db, testonly.LogTree) + log2 := createTreeOrPanic(db, testonly.LogTree) + s := NewLogStorage(db, nil) + + { + // Create fake leaf as if it had been sequenced + data := []byte("some data") + createFakeLeaf(ctx, db, log1.TreeId, dummyHash, dummyRawHash, data, someExtraData, sequenceNumber, t) + + // Create fake leaves for second tree as if they had been sequenced + data2 := []byte("some data 2") + data3 := []byte("some data 3") + createFakeLeaf(ctx, db, log2.TreeId, dummyHash2, dummyRawHash, data2, someExtraData, sequenceNumber, t) + createFakeLeaf(ctx, db, log2.TreeId, dummyHash3, dummyRawHash, data3, someExtraData, sequenceNumber+1, t) + } + + // Read back the leaf counts from both trees + runLogTX(s, log1, t, func(ctx context.Context, tx storage.LogTreeTX) error { + count1, err := tx.GetSequencedLeafCount(ctx) + if err != nil { + t.Fatalf("unexpected error getting leaf count: %v", err) + } + if want, got := int64(1), count1; want != got { + t.Fatalf("expected %d sequenced for logId but got %d", want, got) + } + return nil + }) + + runLogTX(s, log2, t, func(ctx context.Context, tx storage.LogTreeTX) error { + count2, err := tx.GetSequencedLeafCount(ctx) + if err != nil { + t.Fatalf("unexpected error getting leaf count2: %v", err) + } + if want, got := int64(2), count2; want != got { + t.Fatalf("expected %d sequenced for logId2 but got %d", want, got) + } + return nil + }) +} + +func TestSortByLeafIdentityHash(t *testing.T) { + l := make([]*trillian.LogLeaf, 30) + for i := range l { + hash := sha256.Sum256([]byte{byte(i)}) + leaf := trillian.LogLeaf{ + LeafIdentityHash: hash[:], + LeafValue: []byte(fmt.Sprintf("Value %d", i)), + ExtraData: []byte(fmt.Sprintf("Extra %d", i)), + LeafIndex: int64(i), + } + l[i] = &leaf + } + sort.Sort(byLeafIdentityHash(l)) + for i := range l { + if i == 0 { + continue + } + if bytes.Compare(l[i-1].LeafIdentityHash, l[i].LeafIdentityHash) != -1 { + t.Errorf("sorted leaves not in order, [%d] = %x, [%d] = %x", i-1, l[i-1].LeafIdentityHash, i, l[i].LeafIdentityHash) + } + } + +} + +func ensureAllLeavesDistinct(leaves []*trillian.LogLeaf, t *testing.T) { + t.Helper() + // All the leaf value hashes should be distinct because the leaves were created with distinct + // leaf data. If only we had maps with slices as keys or sets or pretty much any kind of usable + // data structures we could do this properly. + for i := range leaves { + for j := range leaves { + if i != j && bytes.Equal(leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash) { + t.Fatalf("Unexpectedly got a duplicate leaf hash: %v %v", + leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash) + } + } + } +} + +func ensureLeavesHaveQueueTimestamp(t *testing.T, leaves []*trillian.LogLeaf, want time.Time) { + t.Helper() + for _, leaf := range leaves { + gotQTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp) + if err != nil { + t.Fatalf("Got invalid queue timestamp: %v", err) + } + if got, want := gotQTimestamp.UnixNano(), want.UnixNano(); got != want { + t.Errorf("Got leaf with QueueTimestampNanos = %v, want %v: %v", got, want, leaf) + } + } +} + +// Creates some test leaves with predictable data +func createTestLeaves(n, startSeq int64) []*trillian.LogLeaf { + var leaves []*trillian.LogLeaf + for l := int64(0); l < n; l++ { + lv := fmt.Sprintf("Leaf %d", l+startSeq) + h := sha256.New() + h.Write([]byte(lv)) + leafHash := h.Sum(nil) + leaf := &trillian.LogLeaf{ + LeafIdentityHash: leafHash, + MerkleLeafHash: leafHash, + LeafValue: []byte(lv), + ExtraData: []byte(fmt.Sprintf("Extra %d", l)), + LeafIndex: int64(startSeq + l), + } + leaves = append(leaves, leaf) + } + + return leaves +} + +// Convenience methods to avoid copying out "if err != nil { blah }" all over the place +func runLogTX(s storage.LogStorage, tree *trillian.Tree, t *testing.T, f storage.LogTXFunc) { + t.Helper() + if err := s.ReadWriteTransaction(context.Background(), tree, f); err != nil { + t.Fatalf("Failed to run log tx: %v", err) + } +} + +type committableTX interface { + Commit() error +} + +func commit(tx committableTX, t *testing.T) { + t.Helper() + if err := tx.Commit(); err != nil { + t.Errorf("Failed to commit tx: %v", err) + } +} + +func leafInBatch(leaf *trillian.LogLeaf, batch []*trillian.LogLeaf) bool { + for _, bl := range batch { + if bytes.Equal(bl.LeafIdentityHash, leaf.LeafIdentityHash) { + return true + } + } + + return false +} + +// byLeafIdentityHash allows sorting of leaves by their identity hash, so DB +// operations always happen in a consistent order. +type byLeafIdentityHash []*trillian.LogLeaf + +func (l byLeafIdentityHash) Len() int { return len(l) } +func (l byLeafIdentityHash) Swap(i, j int) { l[i], l[j] = l[j], l[i] } +func (l byLeafIdentityHash) Less(i, j int) bool { + return bytes.Compare(l[i].LeafIdentityHash, l[j].LeafIdentityHash) == -1 +} + +func logTree(logID int64) *trillian.Tree { + return &trillian.Tree{ + TreeId: logID, + TreeType: trillian.TreeType_LOG, + HashStrategy: trillian.HashStrategy_RFC6962_SHA256, + } +} diff --git a/storage/postgres/queue.go b/storage/postgres/queue.go new file mode 100644 index 0000000000..d6b8356967 --- /dev/null +++ b/storage/postgres/queue.go @@ -0,0 +1,130 @@ +// +build !batched_queue + +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/golang/glog" + "github.com/golang/protobuf/ptypes" + "github.com/google/trillian" +) + +const ( + // If this statement ORDER BY clause is changed refer to the comment in removeSequencedLeaves + selectQueuedLeavesSQL = `SELECT leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos + FROM unsequenced + WHERE tree_id=$1 + AND bucket=0 + AND queue_timestamp_nanos<=$2 + ORDER BY queue_timestamp_nanos,leaf_identity_hash ASC LIMIT $3` + insertUnsequencedEntrySQL = "select insert_leaf_data_ignore_duplicates($1,$2,$3,$4)" + deleteUnsequencedSQL = "DELETE FROM unsequenced WHERE tree_id = $1 and bucket=0 and queue_timestamp_nanos = $2 and leaf_identity_hash=$3" +) + +type dequeuedLeaf struct { + queueTimestampNanos int64 + leafIdentityHash []byte +} + +func dequeueInfo(leafIDHash []byte, queueTimestamp int64) dequeuedLeaf { + return dequeuedLeaf{queueTimestampNanos: queueTimestamp, leafIdentityHash: leafIDHash} +} + +func (t *logTreeTX) dequeueLeaf(rows *sql.Rows) (*trillian.LogLeaf, dequeuedLeaf, error) { + var leafIDHash []byte + var merkleHash []byte + var queueTimestamp int64 + + err := rows.Scan(&leafIDHash, &merkleHash, &queueTimestamp) + if err != nil { + glog.Warningf("Error scanning work rows: %s", err) + return nil, dequeuedLeaf{}, err + } + + // Note: the LeafData and ExtraData being nil here is OK as this is only used by the + // sequencer. The sequencer only writes to the SequencedLeafData table and the client + // supplied data was already written to LeafData as part of queueing the leaf. + queueTimestampProto, err := ptypes.TimestampProto(time.Unix(0, queueTimestamp)) + if err != nil { + return nil, dequeuedLeaf{}, fmt.Errorf("got invalid queue timestamp: %v", err) + } + leaf := &trillian.LogLeaf{ + LeafIdentityHash: leafIDHash, + MerkleLeafHash: merkleHash, + QueueTimestamp: queueTimestampProto, + } + return leaf, dequeueInfo(leafIDHash, queueTimestamp), nil +} + +func queueArgs(treeID int64, identityHash []byte, queueTimestamp time.Time) []interface{} { + return []interface{}{queueTimestamp.UnixNano()} +} + +func (t *logTreeTX) UpdateSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf) error { + for _, leaf := range leaves { + // This should fail on insert but catch it early + if len(leaf.LeafIdentityHash) != t.hashSizeBytes { + return errors.New("sequenced leaf has incorrect hash size") + } + + iTimestamp, err := ptypes.Timestamp(leaf.IntegrateTimestamp) + if err != nil { + return fmt.Errorf("got invalid integrate timestamp: %v", err) + } + _, err = t.tx.ExecContext( + ctx, + insertSequencedLeafSQL+valuesPlaceholder5, + t.treeID, + leaf.LeafIdentityHash, + leaf.MerkleLeafHash, + leaf.LeafIndex, + iTimestamp.UnixNano()) + if err != nil { + glog.Warningf("Failed to update sequenced leaves: %s", err) + return err + } + } + + return nil +} + +// removeSequencedLeaves removes the passed in leaves slice (which may be +// modified as part of the operation). +func (t *logTreeTX) removeSequencedLeaves(ctx context.Context, leaves []dequeuedLeaf) error { + // Don't need to re-sort because the query ordered by leaf hash. If that changes because + // the query is expensive then the sort will need to be done here. See comment in + // QueueLeaves. + stx, err := t.tx.PrepareContext(ctx, deleteUnsequencedSQL) + if err != nil { + glog.Warningf("Failed to prep delete statement for sequenced work: %v", err) + return err + } + for _, dql := range leaves { + result, err := stx.ExecContext(ctx, t.treeID, dql.queueTimestampNanos, dql.leafIdentityHash) + err = checkResultOkAndRowCountIs(result, err, int64(1)) + if err != nil { + return err + } + } + + return nil +} diff --git a/storage/postgres/queue_batching.go b/storage/postgres/queue_batching.go new file mode 100644 index 0000000000..e319a9c701 --- /dev/null +++ b/storage/postgres/queue_batching.go @@ -0,0 +1,147 @@ +// +build batched_queue + +// Copyright 2017 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/binary" + "fmt" + "strings" + "time" + + "github.com/golang/glog" + "github.com/golang/protobuf/ptypes" + "github.com/google/trillian" +) + +const ( + // If this statement ORDER BY clause is changed refer to the comment in removeSequencedLeaves + selectQueuedLeavesSQL = `SELECT leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos,queue_id + FROM unsequenced + WHERE tree_id=$1 + AND Bucket=0 + AND queue_timestamp_nanos<=$2 + ORDER BY queue_timestamp_nanos,leaf_identity_hash ASC LIMIT $3` + insertUnsequencedEntrySQL = `INSERT INTO unsequenced(tree_id,Bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos,queue_id) VALUES($1,0,$2,$3,$4,$5)` + deleteUnsequencedSQL = "DELETE FROM unsequenced WHERE queue_id IN ()" +) + +type dequeuedLeaf []byte + +func dequeueInfo(_ []byte, queueID []byte) dequeuedLeaf { + return dequeuedLeaf(queueID) +} + +func (t *logTreeTX) dequeueLeaf(rows *sql.Rows) (*trillian.LogLeaf, dequeuedLeaf, error) { + var leafIDHash []byte + var merkleHash []byte + var queueTimestamp int64 + var queueID []byte + + err := rows.Scan(&leafIDHash, &merkleHash, &queueTimestamp, &queueID) + if err != nil { + glog.Warningf("Error scanning work rows: %s", err) + return nil, nil, err + } + + queueTimestampProto, err := ptypes.TimestampProto(time.Unix(0, queueTimestamp)) + if err != nil { + return nil, dequeuedLeaf{}, fmt.Errorf("got invalid queue timestamp: %v", err) + } + // Note: the LeafData and ExtraData being nil here is OK as this is only used by the + // sequencer. The sequencer only writes to the SequencedLeafData table and the client + // supplied data was already written to LeafData as part of queueing the leaf. + leaf := &trillian.LogLeaf{ + LeafIdentityHash: leafIDHash, + MerkleLeafHash: merkleHash, + QueueTimestamp: queueTimestampProto, + } + return leaf, dequeueInfo(leafIDHash, queueID), nil +} + +func generateQueueID(treeID int64, leafIdentityHash []byte, timestamp int64) []byte { + h := sha256.New() + b := make([]byte, 10) + binary.PutVarint(b, treeID) + h.Write(b) + b = make([]byte, 10) + binary.PutVarint(b, timestamp) + h.Write(b) + h.Write(leafIdentityHash) + return h.Sum(nil) +} + +func queueArgs(treeID int64, identityHash []byte, queueTimestamp time.Time) []interface{} { + timestamp := queueTimestamp.UnixNano() + return []interface{}{timestamp, generateQueueID(treeID, identityHash, timestamp)} +} + +func (t *logTreeTX) UpdateSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf) error { + querySuffix := []string{} + args := []interface{}{} + for _, leaf := range leaves { + iTimestamp, err := ptypes.Timestamp(leaf.IntegrateTimestamp) + if err != nil { + return fmt.Errorf("got invalid integrate timestamp: %v", err) + } + querySuffix = append(querySuffix, valuesPlaceholder5) + args = append(args, t.treeID, leaf.LeafIdentityHash, leaf.MerkleLeafHash, leaf.LeafIndex, iTimestamp.UnixNano()) + } + result, err := t.tx.ExecContext(ctx, insertSequencedLeafSQL+strings.Join(querySuffix, ","), args...) + if err != nil { + glog.Warningf("Failed to update sequenced leaves: %s", err) + } + return checkResultOkAndRowCountIs(result, err, int64(len(leaves))) +} + +func (m *postgresLogStorage) getDeleteUnsequencedStmt(ctx context.Context, num int) (*sql.Stmt, error) { + stmt := &statementSkeleton{ + sql: deleteUnsequencedSQL, + firstInsertion: "%s", + firstPlaceholders: 1, + restInsertion: "%s", + restPlaceholders: 1, + num: num, + } + return m.getStmt(ctx, stmt) +} + +// removeSequencedLeaves removes the passed in leaves slice (which may be +// modified as part of the operation). +func (t *logTreeTX) removeSequencedLeaves(ctx context.Context, queueIDs []dequeuedLeaf) error { + // Don't need to re-sort because the query ordered by leaf hash. If that changes because + // the query is expensive then the sort will need to be done here. See comment in + // QueueLeaves. + tmpl, err := t.ls.getDeleteUnsequencedStmt(ctx, len(queueIDs)) + if err != nil { + glog.Warningf("Failed to get delete statement for sequenced work: %s", err) + return err + } + stx := t.tx.StmtContext(ctx, tmpl) + args := make([]interface{}, len(queueIDs)) + for i, q := range queueIDs { + args[i] = []byte(q) + } + result, err := stx.ExecContext(ctx, args...) + if err != nil { + // Error is handled by checkResultOkAndRowCountIs() below + glog.Warningf("Failed to delete sequenced work: %s", err) + } + return checkResultOkAndRowCountIs(result, err, int64(len(queueIDs))) +} diff --git a/storage/postgres/schema/storage.sql b/storage/postgres/schema/storage.sql index ea42d301aa..0f33976d88 100644 --- a/storage/postgres/schema/storage.sql +++ b/storage/postgres/schema/storage.sql @@ -4,11 +4,11 @@ -- --------------------------------------------- -- Tree Enums -CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING'); -CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG'); -CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256'); -CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256'); -CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA'); +CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING');--end +CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG');--end +CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256');--end +CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256');--end +CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA');--end -- Tree parameters should not be changed after creation. Doing so can -- render the data in the tree unusable or inconsistent. @@ -28,8 +28,10 @@ CREATE TABLE IF NOT EXISTS trees ( public_key BYTEA NOT NULL, deleted BOOLEAN NOT NULL DEFAULT FALSE, delete_time_millis BIGINT, + current_tree_data json, + root_signature BYTEA, PRIMARY KEY(tree_id) -); +);--end -- This table contains tree parameters that can be changed at runtime such as for -- administrative purposes. @@ -40,7 +42,7 @@ CREATE TABLE IF NOT EXISTS tree_control( sequence_interval_seconds INTEGER NOT NULL, PRIMARY KEY(tree_id), FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE -); +);--end CREATE TABLE IF NOT EXISTS subtree( tree_id BIGINT NOT NULL, @@ -49,7 +51,7 @@ CREATE TABLE IF NOT EXISTS subtree( subtree_revision INTEGER NOT NULL, PRIMARY KEY(tree_id, subtree_id, subtree_revision), FOREIGN KEY(tree_id) REFERENCES Trees(tree_id) ON DELETE CASCADE -); +);--end -- The TreeRevisionIdx is used to enforce that there is only one STH at any -- tree revision @@ -62,11 +64,11 @@ CREATE TABLE IF NOT EXISTS tree_head( tree_revision BIGINT, PRIMARY KEY(tree_id, tree_revision), FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE -); +);--end -- TODO(vishal) benchmark this to see if it's a suitable replacement for not -- having a DESC scan on the primary key -CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC); +CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC);--end -- --------------------------------------------- -- Log specific stuff here @@ -94,7 +96,7 @@ CREATE TABLE IF NOT EXISTS leaf_data( queue_timestamp_nanos BIGINT NOT NULL, PRIMARY KEY(tree_id, leaf_identity_hash), FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE -); +);--end -- When a leaf is sequenced a row is added to this table. If logs allow duplicates then -- multiple rows will exist with different sequence numbers. The signed timestamp @@ -116,9 +118,9 @@ CREATE TABLE IF NOT EXISTS sequenced_leaf_data( PRIMARY KEY(tree_id, sequence_number), FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE, FOREIGN KEY(tree_id, leaf_identity_hash) REFERENCES leaf_data(tree_id, leaf_identity_hash) ON DELETE CASCADE -); +);--end -CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash); +CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash);--end CREATE TABLE IF NOT EXISTS unsequenced( tree_id BIGINT NOT NULL, @@ -138,4 +140,51 @@ CREATE TABLE IF NOT EXISTS unsequenced( -- built with the batched_queue tag. queue_id BYTEA DEFAULT NULL UNIQUE, PRIMARY KEY (tree_id, bucket, queue_timestamp_nanos, leaf_identity_hash) -); +);--end + +CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, leaf_value bytea, extra_data bytea, queue_timestamp_nanos bigint) + RETURNS boolean + LANGUAGE plpgsql +AS $function$ + begin + INSERT INTO leaf_data(tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos) VALUES (tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos); + return true; + exception + when unique_violation then + return false; + when others then + raise notice '% %', SQLERRM, SQLSTATE; + end; +$function$;--end + +CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, queue_timestamp_nanos bigint) + RETURNS boolean + LANGUAGE plpgsql +AS $function$ + begin + INSERT INTO unsequenced(tree_id,bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos) VALUES(tree_id,0,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos); + return true; + exception + when unique_violation then + return false; + when others then + raise notice '% %', SQLERRM, SQLSTATE; + end; +$function$;--end + + + +CREATE OR REPLACE FUNCTION public.insert_sequenced_leaf_data_ignore_duplicates(tree_id bigint, sequence_number bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, integrate_timestamp_nanos bigint) + RETURNS boolean + LANGUAGE plpgsql +AS $function$ + begin + INSERT INTO sequenced_leaf_data(tree_id, sequence_number, leaf_identity_hash, merkle_leaf_hash, integrate_timestamp_nanos) VALUES(tree_id, sequence_number, leaf_identity_hash, merkle_leaf_hash, integrate_timestamp_nanos); + return true; + exception + when unique_violation then + return false; + when others then + raise notice '% %', SQLERRM, SQLSTATE; + end; +$function$;--end diff --git a/storage/postgres/storage_test.go b/storage/postgres/storage_test.go index 036a71f749..57cae6916f 100644 --- a/storage/postgres/storage_test.go +++ b/storage/postgres/storage_test.go @@ -1,4 +1,4 @@ -// Copyright 2018 Google Inc. All Rights Reserved. +// Copyright 2019 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,22 +11,157 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package postgres import ( + "bytes" "context" + "crypto" + "crypto/sha256" "database/sql" "flag" + "fmt" "os" "testing" - "time" "github.com/golang/glog" + "github.com/google/trillian" + tcrypto "github.com/google/trillian/crypto" + "github.com/google/trillian/storage" "github.com/google/trillian/storage/postgres/testdb" + storageto "github.com/google/trillian/storage/testonly" + "github.com/google/trillian/testonly" + "github.com/google/trillian/types" ) -// db is shared throughout all postgres tests -var db *sql.DB +func TestNodeRoundTrip(t *testing.T) { + cleanTestDB(db, t) + tree := createTreeOrPanic(db, storageto.LogTree) + s := NewLogStorage(db, nil) + + const writeRevision = int64(100) + nodesToStore := createSomeNodes() + nodeIDsToRead := make([]storage.NodeID, len(nodesToStore)) + for i := range nodesToStore { + nodeIDsToRead[i] = nodesToStore[i].NodeID + } + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + forceWriteRevision(writeRevision, tx) + + // Need to read nodes before attempting to write + if _, err := tx.GetMerkleNodes(ctx, 99, nodeIDsToRead); err != nil { + t.Fatalf("Failed to read nodes: %s", err) + } + if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil { + t.Fatalf("Failed to store nodes: %s", err) + } + return nil + }) + } + + { + runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error { + readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead) + if err != nil { + t.Fatalf("Failed to retrieve nodes: %s", err) + } + if err := nodesAreEqual(readNodes, nodesToStore); err != nil { + t.Fatalf("Read back different nodes from the ones stored: %s", err) + } + return nil + }) + } +} +func forceWriteRevision(rev int64, tx storage.TreeTX) { + mtx, ok := tx.(*logTreeTX) + if !ok { + panic(fmt.Sprintf("tx is %T, want *logTreeTX", tx)) + } + mtx.treeTX.writeRevision = rev +} + +func createSomeNodes() []storage.Node { + r := make([]storage.Node, 4) + for i := range r { + r[i].NodeID = storage.NewNodeIDWithPrefix(uint64(i), 8, 8, 8) + h := sha256.Sum256([]byte{byte(i)}) + r[i].Hash = h[:] + glog.Infof("Node to store: %v", r[i].NodeID) + } + return r +} + +func nodesAreEqual(lhs []storage.Node, rhs []storage.Node) error { + if ls, rs := len(lhs), len(rhs); ls != rs { + return fmt.Errorf("different number of nodes, %d vs %d", ls, rs) + } + for i := range lhs { + if l, r := lhs[i].NodeID.String(), rhs[i].NodeID.String(); l != r { + return fmt.Errorf("NodeIDs are not the same,\nlhs = %v,\nrhs = %v", l, r) + } + if l, r := lhs[i].Hash, rhs[i].Hash; !bytes.Equal(l, r) { + return fmt.Errorf("hashes are not the same for %s,\nlhs = %v,\nrhs = %v", lhs[i].NodeID.CoordString(), l, r) + } + } + return nil +} + +func openTestDBOrDie() *sql.DB { + db, err := testdb.NewTrillianDB(context.TODO()) + if err != nil { + panic(err) + } + return db +} + +func createFakeSignedLogRoot(db *sql.DB, tree *trillian.Tree, treeSize uint64) { + signer := tcrypto.NewSigner(0, testonly.NewSignerWithFixedSig(nil, []byte("notnil")), crypto.SHA256) + + ctx := context.Background() + l := NewLogStorage(db, nil) + err := l.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error { + root, err := signer.SignLogRoot(&types.LogRootV1{TreeSize: treeSize, RootHash: []byte{0}}) + if err != nil { + return fmt.Errorf("error creating new SignedLogRoot: %v", err) + } + if err := tx.StoreSignedLogRoot(ctx, *root); err != nil { + return fmt.Errorf("error storing new SignedLogRoot: %v", err) + } + return nil + }) + if err != nil { + panic(fmt.Sprintf("ReadWriteTransaction() = %v", err)) + } +} + +// createTree creates the specified tree using AdminStorage. +func createTree(db *sql.DB, tree *trillian.Tree) (*trillian.Tree, error) { + ctx := context.Background() + s := NewAdminStorage(db) + tree, err := storage.CreateTree(ctx, s, tree) + if err != nil { + return nil, err + } + return tree, nil +} + +func createTreeOrPanic(db *sql.DB, create *trillian.Tree) *trillian.Tree { + tree, err := createTree(db, create) + if err != nil { + panic(fmt.Sprintf("Error creating tree: %v", err)) + } + return tree +} + +// updateTree updates the specified tree using AdminStorage. +func updateTree(db *sql.DB, treeID int64, updateFn func(*trillian.Tree)) (*trillian.Tree, error) { + ctx := context.Background() + s := NewAdminStorage(db) + return storage.UpdateTree(ctx, s, treeID, updateFn) +} func TestMain(m *testing.M) { flag.Parse() @@ -34,9 +169,7 @@ func TestMain(m *testing.M) { glog.Errorf("PG not available, skipping all PG storage tests") return } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*30)) - defer cancel() - db = testdb.NewTrillianDBOrDie(ctx) + db = openTestDBOrDie() defer db.Close() os.Exit(m.Run()) } diff --git a/storage/postgres/storage_unsafe.sql b/storage/postgres/storage_unsafe.sql new file mode 100644 index 0000000000..5f6008d2f7 --- /dev/null +++ b/storage/postgres/storage_unsafe.sql @@ -0,0 +1,167 @@ +-- Postgres impl of storage +-- --------------------------------------------- +-- Tree stuff here +-- --------------------------------------------- + +-- Tree Enums +CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING'); +CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG'); +CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256'); +CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256'); +CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA'); + +-- Tree parameters should not be changed after creation. Doing so can +-- render the data in the tree unusable or inconsistent. +CREATE TABLE IF NOT EXISTS trees ( + tree_id BIGINT NOT NULL, + tree_state E_TREE_STATE NOT NULL, + tree_type E_TREE_TYPE NOT NULL, + hash_strategy E_HASH_STRATEGY NOT NULL, + hash_algorithm E_HASH_ALGORITHM NOT NULL, + signature_algorithm E_SIGNATURE_ALGORITHM NOT NULL, + display_name VARCHAR(20), + description VARCHAR(200), + create_time_millis BIGINT NOT NULL, + update_time_millis BIGINT NOT NULL, + max_root_duration_millis BIGINT NOT NULL, + private_key BYTEA NOT NULL, + public_key BYTEA NOT NULL, + deleted BOOLEAN NOT NULL DEFAULT FALSE, + delete_time_millis BIGINT, + current_tree_data json, + root_signature BYTEA, + PRIMARY KEY(tree_id) +); + +-- This table contains tree parameters that can be changed at runtime such as for +-- administrative purposes. +CREATE TABLE IF NOT EXISTS tree_control( + tree_id BIGINT NOT NULL, + signing_enabled BOOLEAN NOT NULL, + sequencing_enabled BOOLEAN NOT NULL, + sequence_interval_seconds INTEGER NOT NULL, + PRIMARY KEY(tree_id) +); + +CREATE TABLE IF NOT EXISTS subtree( + tree_id BIGINT NOT NULL, + subtree_id BYTEA NOT NULL, + nodes BYTEA NOT NULL, + subtree_revision INTEGER NOT NULL, + PRIMARY KEY(subtree_id, subtree_revision) +); + +-- The TreeRevisionIdx is used to enforce that there is only one STH at any +-- tree revision +CREATE TABLE IF NOT EXISTS tree_head( + tree_id BIGINT NOT NULL, + tree_head_timestamp BIGINT, + tree_size BIGINT, + root_hash BYTEA NOT NULL, + root_signature BYTEA NOT NULL, + tree_revision BIGINT, + PRIMARY KEY(tree_id, tree_revision) +); + +-- TODO(vishal) benchmark this to see if it's a suitable replacement for not +-- having a DESC scan on the primary key +CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC); + +-- --------------------------------------------- +-- Log specific stuff here +-- --------------------------------------------- + +-- Creating index at same time as table allows some storage engines to better +-- optimize physical storage layout. Most engines allow multiple nulls in a +-- unique index but some may not. + +-- A leaf that has not been sequenced has a row in this table. If duplicate leaves +-- are allowed they will all reference this row. +CREATE TABLE IF NOT EXISTS leaf_data( + tree_id BIGINT NOT NULL, + -- This is a personality specific hash of some subset of the leaf data. + -- It's only purpose is to allow Trillian to identify duplicate entries in + -- the context of the personality. + leaf_identity_hash BYTEA NOT NULL, + -- This is the data stored in the leaf for example in CT it contains a DER encoded + -- X.509 certificate but is application dependent + leaf_value BYTEA NOT NULL, + -- This is extra data that the application can associate with the leaf should it wish to. + -- This data is not included in signing and hashing. + extra_data BYTEA, + -- The timestamp from when this leaf data was first queued for inclusion. + queue_timestamp_nanos BIGINT NOT NULL, + PRIMARY KEY(leaf_identity_hash) +); + +-- When a leaf is sequenced a row is added to this table. If logs allow duplicates then +-- multiple rows will exist with different sequence numbers. The signed timestamp +-- will be communicated via the unsequenced table as this might need to be unique, depending +-- on the log parameters and we can't insert into this table until we have the sequence number +-- which is not available at the time we queue the entry. We need both hashes because the +-- LeafData table is keyed by the raw data hash. +CREATE TABLE IF NOT EXISTS sequenced_leaf_data( + tree_id BIGINT NOT NULL, + sequence_number BIGINT NOT NULL, + -- This is a personality specific has of some subset of the leaf data. + -- It's only purpose is to allow Trillian to identify duplicate entries in + -- the context of the personality. + leaf_identity_hash BYTEA NOT NULL, + -- This is a MerkleLeafHash as defined by the treehasher that the log uses. For example for + -- CT this hash will include the leaf prefix byte as well as the leaf data. + merkle_leaf_hash BYTEA NOT NULL, + integrate_timestamp_nanos BIGINT NOT NULL, + PRIMARY KEY(sequence_number) +); + +CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash); + +CREATE TABLE IF NOT EXISTS unsequenced( + tree_id BIGINT NOT NULL, + -- The bucket field is to allow the use of time based ring bucketed schemes if desired. If + -- unused this should be set to zero for all entries. + bucket INTEGER NOT NULL, + -- This is a personality specific hash of some subset of the leaf data. + -- It's only purpose is to allow Trillian to identify duplicate entries in + -- the context of the personality. + leaf_identity_hash BYTEA NOT NULL, + -- This is a MerkleLeafHash as defined by the treehasher that the log uses. For example for + -- CT this hash will include the leaf prefix byte as well as the leaf data. + merkle_leaf_hash BYTEA NOT NULL, + queue_timestamp_nanos BIGINT NOT NULL, + -- This is a SHA256 hash of the TreeID, LeafIdentityHash and QueueTimestampNanos. It is used + -- for batched deletes from the table when trillian_log_server and trillian_log_signer are + -- built with the batched_queue tag. + queue_id BYTEA DEFAULT NULL UNIQUE, + PRIMARY KEY (queue_timestamp_nanos, leaf_identity_hash) +); + +CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, leaf_value bytea, extra_data bytea, queue_timestamp_nanos bigint) + RETURNS boolean + LANGUAGE plpgsql +AS $function$ + begin + INSERT INTO leaf_data(tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos) VALUES (tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos); + return true; + exception + when unique_violation then + return false; + when others then + raise notice '% %', SQLERRM, SQLSTATE; + end; +$function$; + +CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, queue_timestamp_nanos bigint) + RETURNS boolean + LANGUAGE plpgsql +AS $function$ + begin + INSERT INTO unsequenced(tree_id,bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos) VALUES(tree_id,0,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos); + return true; + exception + when unique_violation then + return false; + when others then + raise notice '% %', SQLERRM, SQLSTATE; + end; +$function$; diff --git a/storage/postgres/testdb/testdb.go b/storage/postgres/testdb/testdb.go index 7ed9a81d7c..31c43f5fc8 100644 --- a/storage/postgres/testdb/testdb.go +++ b/storage/postgres/testdb/testdb.go @@ -27,8 +27,6 @@ import ( "time" "github.com/google/trillian/testonly" - - _ "github.com/lib/pq" // pg driver ) var ( @@ -52,6 +50,28 @@ func PGAvailable() bool { return true } +//This just executes a simple query in the configured database. Only used as a placeholder +// for testing queries and how go returns results +func TestSQL(ctx context.Context) string { + db, err := sql.Open("postgres", getConnStr(*dbName)) + if err != nil { + fmt.Println("Error: ", err) + return "error" + } + defer db.Close() + result, err := db.QueryContext(ctx, "select 1=1") + if err != nil { + fmt.Println("Error: ", err) + return "error" + } + var resultData bool + result.Scan(&resultData) + if resultData { + fmt.Println("Result: ", resultData) + } + return "done" +} + // newEmptyDB creates a new, empty database. func newEmptyDB(ctx context.Context) (*sql.DB, error) { db, err := sql.Open("postgres", getConnStr(*dbName)) @@ -61,12 +81,10 @@ func newEmptyDB(ctx context.Context) (*sql.DB, error) { // Create a randomly-named database and then connect using the new name. name := fmt.Sprintf("trl_%v", time.Now().UnixNano()) - stmt := fmt.Sprintf("CREATE DATABASE %v", name) if _, err := db.ExecContext(ctx, stmt); err != nil { return nil, fmt.Errorf("error running statement %q: %v", stmt, err) } - db.Close() db, err = sql.Open("postgres", getConnStr(name)) if err != nil { @@ -90,7 +108,7 @@ func NewTrillianDB(ctx context.Context) (*sql.DB, error) { return nil, err } - for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";") { + for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";--end") { stmt = strings.TrimSpace(stmt) if stmt == "" { continue diff --git a/storage/postgres/tree_storage.go b/storage/postgres/tree_storage.go index 60aba30aa7..fdf3b79a72 100644 --- a/storage/postgres/tree_storage.go +++ b/storage/postgres/tree_storage.go @@ -20,6 +20,7 @@ import ( "encoding/base64" "fmt" "runtime/debug" + "strconv" "strings" "sync" @@ -43,13 +44,15 @@ const ( SELECT n.subtree_id, max(n.subtree_revision) AS max_revision FROM subtree n WHERE n.subtree_id IN (` + placeholderSQL + `) AND - n.tree_id = ? AND n.subtree_revision <= ? + n.tree_id = AND n.subtree_revision <= GROUP BY n.subtree_id ) AS x INNER JOIN subtree ON subtree.subtree_id = x.subtree_id AND subtree.subtree_revision = x.max_revision - AND subtree.tree_id = ?` + AND subtree.tree_id = ` + insertTreeHeadSQL = `INSERT INTO tree_head(tree_id,tree_head_timestamp,tree_size,root_hash,tree_revision,root_signature) + VALUES($1,$2,$3,$4,$5,$6)` ) // pgTreeStorage contains the pgLogStorage implementation. @@ -141,6 +144,13 @@ func (p *pgTreeStorage) getStmt(ctx context.Context, skeleton *statementSkeleton } statement, err := expandPlaceholderSQL(skeleton) + + counter := skeleton.restPlaceholders*skeleton.num + 1 + for strings.Contains(statement, "") { + statement = strings.Replace(statement, "", "$"+strconv.Itoa(counter), 1) + counter++ + } + if err != nil { glog.Warningf("Failed to expand placeholder sql: %v", skeleton) return nil, err @@ -254,10 +264,9 @@ func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, nodeIDs [] args = append(args, interface{}(t.treeID)) args = append(args, interface{}(treeRevision)) args = append(args, interface{}(t.treeID)) - rows, err := stx.QueryContext(ctx, args...) if err != nil { - glog.Warningf("Failed to get merkle subtrees: %s", err) + glog.Warningf("Failed to get merkle subtrees: QueryContext(%v) = (_, %q)", args, err) return nil, err } defer rows.Close() @@ -393,3 +402,52 @@ func (t *treeTX) Close() error { } return nil } + +func (t *treeTX) GetMerkleNodes(ctx context.Context, treeRevision int64, nodeIDs []storage.NodeID) ([]storage.Node, error) { + return t.subtreeCache.GetNodes(nodeIDs, t.getSubtreesAtRev(ctx, treeRevision)) +} + +func (t *treeTX) SetMerkleNodes(ctx context.Context, nodes []storage.Node) error { + for _, n := range nodes { + err := t.subtreeCache.SetNodeHash(n.NodeID, n.Hash, + func(nID storage.NodeID) (*storagepb.SubtreeProto, error) { + return t.getSubtree(ctx, t.writeRevision, nID) + }) + if err != nil { + return err + } + } + return nil +} + +func (t *treeTX) IsOpen() bool { + return !t.closed +} + +// getSubtreesAtRev returns a GetSubtreesFunc which reads at the passed in rev. +func (t *treeTX) getSubtreesAtRev(ctx context.Context, rev int64) cache.GetSubtreesFunc { + return func(ids []storage.NodeID) ([]*storagepb.SubtreeProto, error) { + return t.getSubtrees(ctx, rev, ids) + } +} + +func checkResultOkAndRowCountIs(res sql.Result, err error, count int64) error { + // The Exec() might have just failed + if err != nil { + return err + } + + // Otherwise we have to look at the result of the operation + rowsAffected, rowsError := res.RowsAffected() + + if rowsError != nil { + return rowsError + } + + if rowsAffected != count { + return fmt.Errorf("Expected %d row(s) to be affected but saw: %d", count, + rowsAffected) + } + + return nil +}