diff --git a/server/admin/admin_server_test.go b/server/admin/admin_server_test.go
index e041fe4e77..2e86e46adf 100644
--- a/server/admin/admin_server_test.go
+++ b/server/admin/admin_server_test.go
@@ -596,10 +596,10 @@ func TestServer_CreateTree_AllowedTreeTypes(t *testing.T) {
for _, test := range tests {
setup := setupAdminServer(
ctrl,
- nil, /* keygen */
- false, /* snapshot */
- test.wantCode == codes.OK, /* shouldCommit */
- false /* commitErr */)
+ nil, // keygen
+ false, // snapshot
+ test.wantCode == codes.OK,
+ false)
s := setup.server
tx := setup.tx
s.allowedTreeTypes = test.treeTypes
@@ -771,7 +771,7 @@ func TestServer_DeleteTree(t *testing.T) {
nil, /* keygen */
false, /* snapshot */
true, /* shouldCommit */
- false /* commitErr */)
+ false)
req := &trillian.DeleteTreeRequest{TreeId: test.tree.TreeId}
tx := setup.tx
@@ -810,10 +810,10 @@ func TestServer_DeleteTreeErrors(t *testing.T) {
for _, test := range tests {
setup := setupAdminServer(
ctrl,
- nil, /* keygen */
- false, /* snapshot */
- test.deleteErr == nil, /* shouldCommit */
- test.commitErr /* commitErr */)
+ nil,
+ false,
+ test.deleteErr == nil,
+ test.commitErr)
req := &trillian.DeleteTreeRequest{TreeId: 10}
tx := setup.tx
@@ -858,7 +858,7 @@ func TestServer_UndeleteTree(t *testing.T) {
nil, /* keygen */
false, /* snapshot */
true, /* shouldCommit */
- false /* commitErr */)
+ false)
req := &trillian.UndeleteTreeRequest{TreeId: test.tree.TreeId}
tx := setup.tx
@@ -897,10 +897,10 @@ func TestServer_UndeleteTreeErrors(t *testing.T) {
for _, test := range tests {
setup := setupAdminServer(
ctrl,
- nil, /* keygen */
- false, /* snapshot */
- test.undeleteErr == nil, /* shouldCommit */
- test.commitErr /* commitErr */)
+ nil,
+ false,
+ test.undeleteErr == nil,
+ test.commitErr)
req := &trillian.UndeleteTreeRequest{TreeId: 10}
tx := setup.tx
diff --git a/server/mock_log_operation.go b/server/mock_log_operation.go
index ddc1662314..f8451991f0 100644
--- a/server/mock_log_operation.go
+++ b/server/mock_log_operation.go
@@ -6,8 +6,9 @@ package server
import (
context "context"
- gomock "github.com/golang/mock/gomock"
reflect "reflect"
+
+ gomock "github.com/golang/mock/gomock"
)
// MockLogOperation is a mock of LogOperation interface
diff --git a/server/postgres_storage_provider.go b/server/postgres_storage_provider.go
index 42d6c8f5bb..8c045b11f9 100644
--- a/server/postgres_storage_provider.go
+++ b/server/postgres_storage_provider.go
@@ -66,7 +66,9 @@ func newPGProvider(mf monitoring.MetricFactory) (StorageProvider, error) {
}
func (s *pgProvider) LogStorage() storage.LogStorage {
- panic("Not Implemented")
+
+ glog.Warningf("Support for the PostgreSQL log is experimental. Please use at your own risk!!!")
+ return postgres.NewLogStorage(s.db, s.mf)
}
func (s *pgProvider) MapStorage() storage.MapStorage {
diff --git a/storage/postgres/README.md b/storage/postgres/README.md
new file mode 100644
index 0000000000..09c80328e1
--- /dev/null
+++ b/storage/postgres/README.md
@@ -0,0 +1,17 @@
+# Postgres LogStorage
+
+## Notes and Caveats
+The current LogStorage part of the Postgres implementation was based off what
+was already written for MySQL. Thus, the two user-defined functions included in
+storage.sql. MySQL doesn't kill a transaction when a duplicate is detected, but
+PostgreSQL does. So, to preserve the workflow, I included the two functions
+which trap this error and allow the code to continue executing. The only other
+change I made was to fully translate the MySQL queries to PostgreSQL compatible ones
+and tidy up some of the extant tree storage code.
+
+storage_unsafe.sql really isn't unsafe, but I have pulled some of the safety
+rails from the tables to improve performance. It also works under the notion that
+there will only be a single tree in a given database. An improvement on this theme
+would be to add all layers below the trees table in their own separate schemas.
+This would further eliminate indexs and foreign key requirements, but it should
+be left for those who require enhanced performance. Storage.sql should be fine for most applications
diff --git a/storage/postgres/admin_storage_test.go b/storage/postgres/admin_storage_test.go
index d8557bc1c8..64144f93c8 100644
--- a/storage/postgres/admin_storage_test.go
+++ b/storage/postgres/admin_storage_test.go
@@ -29,6 +29,7 @@ import (
)
var allTables = []string{"unsequenced", "tree_head", "sequenced_leaf_data", "leaf_data", "subtree", "tree_control", "trees"}
+var db *sql.DB
const selectTreeControlByID = "SELECT signing_enabled, sequencing_enabled, sequence_interval_seconds FROM tree_control WHERE tree_id = $1"
diff --git a/storage/postgres/log_storage.go b/storage/postgres/log_storage.go
new file mode 100644
index 0000000000..9edc38fa7c
--- /dev/null
+++ b/storage/postgres/log_storage.go
@@ -0,0 +1,971 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "bytes"
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sort"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/golang/protobuf/ptypes"
+ "github.com/google/trillian"
+ "github.com/google/trillian/merkle/hashers"
+ "github.com/google/trillian/monitoring"
+ "github.com/google/trillian/storage"
+ "github.com/google/trillian/storage/cache"
+ "github.com/google/trillian/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+const (
+ valuesPlaceholder5 = "($1,$2,$3,$4,$5)"
+ insertLeafDataSQL = "select insert_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)"
+ insertSequencedLeafSQL = "select insert_sequenced_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)"
+
+ selectNonDeletedTreeIDByTypeAndStateSQL = `
+ SELECT tree_id FROM trees WHERE tree_type in ($1,$2) AND tree_state in ($3,$4) AND (deleted IS NULL OR deleted = false)`
+
+ selectSequencedLeafCountSQL = "SELECT COUNT(*) FROM sequenced_leaf_data WHERE tree_id=$1"
+ selectUnsequencedLeafCountSQL = "SELECT tree_id, COUNT(1) FROM unsequenced GROUP BY tree_id"
+ //selectLatestSignedLogRootSQL = `SELECT tree_head_timestamp,tree_size,root_hash,tree_revision,root_signature
+ // FROM tree_head WHERE tree_id=$1
+ // ORDER BY tree_head_timestamp DESC LIMIT 1`
+
+ selectLeavesByRangeSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos
+ FROM leaf_data l,sequenced_leaf_data s
+ WHERE l.leaf_identity_hash = s.leaf_identity_hash
+ AND s.sequence_number >= $1 AND s.sequence_number < $2 AND l.tree_id = $3 AND s.tree_id = l.tree_id` + orderBySequenceNumberSQL
+
+ // These statements need to be expanded to provide the correct number of parameter placeholders.
+ selectLeavesByIndexSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos
+ FROM leaf_data l,sequenced_leaf_data s
+ WHERE l.leaf_identity_hash = s.leaf_identity_hash
+ AND s.sequence_number IN (` + placeholderSQL + `) AND l.tree_id = AND s.tree_id = l.tree_id`
+ selectLeavesByMerkleHashSQL = `SELECT s.merkle_leaf_hash,l.leaf_identity_hash,l.leaf_value,s.sequence_number,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos
+ FROM leaf_data l,sequenced_leaf_data s
+ WHERE l.leaf_identity_hash = s.leaf_identity_hash
+ AND s.merkle_leaf_hash IN (` + placeholderSQL + `) AND l.tree_id = AND s.tree_id = l.tree_id`
+ // TODO(drysdale): rework the code so the dummy hash isn't needed (e.g. this assumes hash size is 32)
+ dummymerkleLeafHash = "00000000000000000000000000000000"
+ // This statement returns a dummy Merkle leaf hash value (which must be
+ // of the right size) so that its signature matches that of the other
+ // leaf-selection statements.
+ selectLeavesByLeafIdentityHashSQL = `SELECT '` + dummymerkleLeafHash + `',l.leaf_identity_hash,l.leaf_value,-1,l.extra_data,l.queue_timestamp_nanos,s.integrate_timestamp_nanos
+ FROM leaf_data l LEFT JOIN sequenced_leaf_data s ON (l.leaf_identity_hash = s.leaf_identity_hash AND l.tree_id = s.tree_id)
+ WHERE l.leaf_identity_hash IN (` + placeholderSQL + `) AND l.tree_id = `
+
+ // Same as above except with leaves ordered by sequence so we only incur this cost when necessary
+ orderBySequenceNumberSQL = " ORDER BY s.sequence_number"
+ selectLeavesByMerkleHashOrderedBySequenceSQL = selectLeavesByMerkleHashSQL + orderBySequenceNumberSQL
+
+ // Error code returned by driver when inserting a duplicate row
+
+ logIDLabel = "logid"
+)
+
+var (
+ defaultLogStrata = []int{8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}
+
+ once sync.Once
+ queuedCounter monitoring.Counter
+ queuedDupCounter monitoring.Counter
+ dequeuedCounter monitoring.Counter
+
+ queueLatency monitoring.Histogram
+ queueInsertLatency monitoring.Histogram
+ queueReadLatency monitoring.Histogram
+ queueInsertLeafLatency monitoring.Histogram
+ queueInsertEntryLatency monitoring.Histogram
+ dequeueLatency monitoring.Histogram
+ dequeueSelectLatency monitoring.Histogram
+ dequeueRemoveLatency monitoring.Histogram
+)
+
+func createMetrics(mf monitoring.MetricFactory) {
+ queuedCounter = mf.NewCounter("postgres_queued_leaves", "Number of leaves queued", logIDLabel)
+ queuedDupCounter = mf.NewCounter("postgres_queued_dup_leaves", "Number of duplicate leaves queued", logIDLabel)
+ dequeuedCounter = mf.NewCounter("postgres_dequeued_leaves", "Number of leaves dequeued", logIDLabel)
+
+ queueLatency = mf.NewHistogram("postgres_queue_leaves_latency", "Latency of queue leaves operation in seconds", logIDLabel)
+ queueInsertLatency = mf.NewHistogram("postgres_queue_leaves_latency_insert", "Latency of insertion part of queue leaves operation in seconds", logIDLabel)
+ queueReadLatency = mf.NewHistogram("postgres_queue_leaves_latency_read_dups", "Latency of read-duplicates part of queue leaves operation in seconds", logIDLabel)
+ queueInsertLeafLatency = mf.NewHistogram("postgres_queue_leaf_latency_leaf", "Latency of insert-leaf part of queue (single) leaf operation in seconds", logIDLabel)
+ queueInsertEntryLatency = mf.NewHistogram("postgres_queue_leaf_latency_entry", "Latency of insert-entry part of queue (single) leaf operation in seconds", logIDLabel)
+
+ dequeueLatency = mf.NewHistogram("postgres_dequeue_leaves_latency", "Latency of dequeue leaves operation in seconds", logIDLabel)
+ dequeueSelectLatency = mf.NewHistogram("postgres_dequeue_leaves_latency_select", "Latency of selection part of dequeue leaves operation in seconds", logIDLabel)
+ dequeueRemoveLatency = mf.NewHistogram("postgres_dequeue_leaves_latency_remove", "Latency of removal part of dequeue leaves operation in seconds", logIDLabel)
+}
+
+func labelForTX(t *logTreeTX) string {
+ return strconv.FormatInt(t.treeID, 10)
+}
+
+func observe(hist monitoring.Histogram, duration time.Duration, label string) {
+ hist.Observe(duration.Seconds(), label)
+}
+
+type postgresLogStorage struct {
+ *pgTreeStorage
+ admin storage.AdminStorage
+ metricFactory monitoring.MetricFactory
+}
+
+// NewLogStorage creates a storage.LogStorage instance for the specified PostgreSQL URL.
+// It assumes storage.AdminStorage is backed by the same PostgreSQL database as well.
+func NewLogStorage(db *sql.DB, mf monitoring.MetricFactory) storage.LogStorage {
+ if mf == nil {
+ mf = monitoring.InertMetricFactory{}
+ }
+ return &postgresLogStorage{
+ admin: NewAdminStorage(db),
+ pgTreeStorage: newTreeStorage(db),
+ metricFactory: mf,
+ }
+}
+
+func (m *postgresLogStorage) CheckDatabaseAccessible(ctx context.Context) error {
+ return m.db.PingContext(ctx)
+}
+
+func (m *postgresLogStorage) getLeavesByIndexStmt(ctx context.Context, num int) (*sql.Stmt, error) {
+ stmt := &statementSkeleton{
+ sql: selectLeavesByIndexSQL,
+ firstInsertion: "%s",
+ firstPlaceholders: 1,
+ restInsertion: "%s",
+ restPlaceholders: 1,
+ num: num,
+ }
+ return m.getStmt(ctx, stmt)
+}
+
+func (m *postgresLogStorage) getLeavesByMerkleHashStmt(ctx context.Context, num int, orderBySequence bool) (*sql.Stmt, error) {
+ if orderBySequence {
+
+ orderByStmt := &statementSkeleton{
+ sql: selectLeavesByMerkleHashOrderedBySequenceSQL,
+ firstInsertion: "%s",
+ firstPlaceholders: 1,
+ restInsertion: "%s",
+ restPlaceholders: 1,
+ num: num,
+ }
+
+ return m.getStmt(ctx, orderByStmt)
+ }
+
+ merkleHashStmt := &statementSkeleton{
+ sql: selectLeavesByMerkleHashSQL,
+ firstInsertion: "%s",
+ firstPlaceholders: 1,
+ restInsertion: "%s",
+ restPlaceholders: 1,
+ num: num,
+ }
+
+ return m.getStmt(ctx, merkleHashStmt)
+}
+
+func (m *postgresLogStorage) getLeavesByLeafIdentityHashStmt(ctx context.Context, num int) (*sql.Stmt, error) {
+ identityHashStmt := &statementSkeleton{
+ sql: selectLeavesByLeafIdentityHashSQL,
+ firstInsertion: "%s",
+ firstPlaceholders: 1,
+ restInsertion: "%s",
+ restPlaceholders: 1,
+ num: num,
+ }
+
+ return m.getStmt(ctx, identityHashStmt)
+}
+
+// readOnlyLogTX implements storage.ReadOnlyLogTX
+type readOnlyLogTX struct {
+ ls *postgresLogStorage
+ tx *sql.Tx
+}
+
+func (m *postgresLogStorage) Snapshot(ctx context.Context) (storage.ReadOnlyLogTX, error) {
+ tx, err := m.db.BeginTx(ctx, nil /* opts */)
+ if err != nil {
+ glog.Warningf("Could not start ReadOnlyLogTX: %s", err)
+ return nil, err
+ }
+ return &readOnlyLogTX{m, tx}, nil
+}
+
+func (t *readOnlyLogTX) Commit() error {
+ return t.tx.Commit()
+}
+
+func (t *readOnlyLogTX) Rollback() error {
+ return t.tx.Rollback()
+}
+
+func (t *readOnlyLogTX) Close() error {
+ if err := t.Rollback(); err != nil && err != sql.ErrTxDone {
+ glog.Warningf("Rollback error on Close(): %v", err)
+ return err
+ }
+ return nil
+}
+
+func (t *readOnlyLogTX) GetActiveLogIDs(ctx context.Context) ([]int64, error) {
+ // Include logs that are DRAINING in the active list as we're still
+ // integrating leaves into them.
+ rows, err := t.tx.QueryContext(
+ ctx, selectNonDeletedTreeIDByTypeAndStateSQL,
+ trillian.TreeType_LOG.String(), trillian.TreeType_PREORDERED_LOG.String(),
+ trillian.TreeState_ACTIVE.String(), trillian.TreeState_DRAINING.String())
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+ ids := []int64{}
+ for rows.Next() {
+ var treeID int64
+ if err := rows.Scan(&treeID); err != nil {
+ return nil, err
+ }
+ ids = append(ids, treeID)
+ }
+ return ids, rows.Err()
+}
+
+func (m *postgresLogStorage) beginInternal(ctx context.Context, tree *trillian.Tree) (storage.LogTreeTX, error) {
+ once.Do(func() {
+ createMetrics(m.metricFactory)
+ })
+ hasher, err := hashers.NewLogHasher(tree.HashStrategy)
+ if err != nil {
+ return nil, err
+ }
+
+ stCache := cache.NewLogSubtreeCache(defaultLogStrata, hasher)
+ ttx, err := m.beginTreeTx(ctx, tree, hasher.Size(), stCache)
+ if err != nil && err != storage.ErrTreeNeedsInit {
+ return nil, err
+ }
+
+ ltx := &logTreeTX{
+ treeTX: ttx,
+ ls: m,
+ }
+ ltx.slr, err = ltx.fetchLatestRoot(ctx)
+ if err == storage.ErrTreeNeedsInit {
+ return ltx, err
+ } else if err != nil {
+ ttx.Rollback()
+ return nil, err
+ }
+ if err := ltx.root.UnmarshalBinary(ltx.slr.LogRoot); err != nil {
+ ttx.Rollback()
+ return nil, err
+ }
+
+ ltx.treeTX.writeRevision = int64(ltx.root.Revision) + 1
+ return ltx, nil
+}
+
+func (m *postgresLogStorage) ReadWriteTransaction(ctx context.Context, tree *trillian.Tree, f storage.LogTXFunc) error {
+ tx, err := m.beginInternal(ctx, tree)
+ if err != nil && err != storage.ErrTreeNeedsInit {
+ return err
+ }
+ defer tx.Close()
+ if err := f(ctx, tx); err != nil {
+ return err
+ }
+ return tx.Commit()
+}
+
+func (m *postgresLogStorage) AddSequencedLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
+ tx, err := m.beginInternal(ctx, tree)
+ if err != nil {
+ return nil, err
+ }
+ res, err := tx.AddSequencedLeaves(ctx, leaves, timestamp)
+ if err != nil {
+ return nil, err
+ }
+ if err := tx.Commit(); err != nil {
+ return nil, err
+ }
+ return res, nil
+}
+
+func (m *postgresLogStorage) SnapshotForTree(ctx context.Context, tree *trillian.Tree) (storage.ReadOnlyLogTreeTX, error) {
+ tx, err := m.beginInternal(ctx, tree)
+ if err != nil && err != storage.ErrTreeNeedsInit {
+ return nil, err
+ }
+ return tx, err
+}
+
+func (m *postgresLogStorage) QueueLeaves(ctx context.Context, tree *trillian.Tree, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
+ tx, err := m.beginInternal(ctx, tree)
+ if err != nil {
+ return nil, err
+ }
+ existing, err := tx.QueueLeaves(ctx, leaves, queueTimestamp)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := tx.Commit(); err != nil {
+ return nil, err
+ }
+
+ ret := make([]*trillian.QueuedLogLeaf, len(leaves))
+ for i, e := range existing {
+ if e != nil {
+ ret[i] = &trillian.QueuedLogLeaf{
+ Leaf: e,
+ Status: status.Newf(codes.AlreadyExists, "leaf already exists: %v", e.LeafIdentityHash).Proto(),
+ }
+ continue
+ }
+ ret[i] = &trillian.QueuedLogLeaf{Leaf: leaves[i]}
+ }
+ return ret, nil
+}
+
+type logTreeTX struct {
+ treeTX
+ ls *postgresLogStorage
+ root types.LogRootV1
+ slr trillian.SignedLogRoot
+}
+
+func (t *logTreeTX) ReadRevision(ctx context.Context) (int64, error) {
+ return int64(t.root.Revision), nil
+}
+
+func (t *logTreeTX) WriteRevision(ctx context.Context) (int64, error) {
+ if t.treeTX.writeRevision < 0 {
+ return t.treeTX.writeRevision, errors.New("logTreeTX write revision not populated")
+ }
+ return t.treeTX.writeRevision, nil
+}
+
+func (t *logTreeTX) DequeueLeaves(ctx context.Context, limit int, cutoffTime time.Time) ([]*trillian.LogLeaf, error) {
+ if t.treeType == trillian.TreeType_PREORDERED_LOG {
+ // TODO(pavelkalinnikov): Optimize this by fetching only the required
+ // fields of LogLeaf. We can avoid joining with LeafData table here.
+ return t.GetLeavesByRange(ctx, int64(t.root.TreeSize), int64(limit))
+ }
+
+ start := time.Now()
+ stx, err := t.tx.PrepareContext(ctx, selectQueuedLeavesSQL)
+ if err != nil {
+ glog.Warningf("Failed to prepare dequeue select: %s", err)
+ return nil, err
+ }
+ defer stx.Close()
+
+ leaves := make([]*trillian.LogLeaf, 0, limit)
+ dq := make([]dequeuedLeaf, 0, limit)
+ rows, err := stx.QueryContext(ctx, t.treeID, cutoffTime.UnixNano(), limit)
+ if err != nil {
+ glog.Warningf("Failed to select rows for work: %s", err)
+ return nil, err
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ leaf, dqInfo, err := t.dequeueLeaf(rows)
+ if err != nil {
+ glog.Warningf("Error dequeuing leaf: %v %v", err, selectQueuedLeavesSQL)
+ return nil, err
+ }
+
+ if len(leaf.LeafIdentityHash) != t.hashSizeBytes {
+ return nil, errors.New("dequeued a leaf with incorrect hash size")
+ }
+
+ leaves = append(leaves, leaf)
+ dq = append(dq, dqInfo)
+ }
+
+ if rows.Err() != nil {
+ return nil, rows.Err()
+ }
+ label := labelForTX(t)
+ selectDuration := time.Since(start)
+ observe(dequeueSelectLatency, selectDuration, label)
+
+ // The convention is that if leaf processing succeeds (by committing this tx)
+ // then the unsequenced entries for them are removed
+ if len(leaves) > 0 {
+ err = t.removeSequencedLeaves(ctx, dq)
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ totalDuration := time.Since(start)
+ removeDuration := totalDuration - selectDuration
+ observe(dequeueRemoveLatency, removeDuration, label)
+ observe(dequeueLatency, totalDuration, label)
+ dequeuedCounter.Add(float64(len(leaves)), label)
+
+ return leaves, nil
+}
+
+// sortLeavesForInsert returns a slice containing the passed in leaves sorted
+// by LeafIdentityHash, and paired with their original positions.
+// QueueLeaves and AddSequencedLeaves use this to make the order that LeafData
+// row locks are acquired deterministic and reduce the chance of deadlocks.
+func sortLeavesForInsert(leaves []*trillian.LogLeaf) []leafAndPosition {
+ ordLeaves := make([]leafAndPosition, len(leaves))
+ for i, leaf := range leaves {
+ ordLeaves[i] = leafAndPosition{leaf: leaf, idx: i}
+ }
+ sort.Sort(byLeafIdentityHashWithPosition(ordLeaves))
+ return ordLeaves
+}
+
+func (t *logTreeTX) QueueLeaves(ctx context.Context, leaves []*trillian.LogLeaf, queueTimestamp time.Time) ([]*trillian.LogLeaf, error) {
+ // Don't accept batches if any of the leaves are invalid.
+ for _, leaf := range leaves {
+ if len(leaf.LeafIdentityHash) != t.hashSizeBytes {
+ return nil, fmt.Errorf("queued leaf must have a leaf ID hash of length %d", t.hashSizeBytes)
+ }
+ var err error
+ leaf.QueueTimestamp, err = ptypes.TimestampProto(queueTimestamp)
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ }
+ start := time.Now()
+ label := labelForTX(t)
+
+ ordLeaves := sortLeavesForInsert(leaves)
+ existingCount := 0
+ existingLeaves := make([]*trillian.LogLeaf, len(leaves))
+
+ for _, ol := range ordLeaves {
+ i, leaf := ol.idx, ol.leaf
+
+ leafStart := time.Now()
+ qTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp)
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ dupCheckRow, err := t.tx.QueryContext(ctx, insertLeafDataSQL, t.treeID, leaf.LeafIdentityHash, leaf.LeafValue, leaf.ExtraData, qTimestamp.UnixNano())
+ if err != nil {
+ return nil, fmt.Errorf("dupecheck failed: %v", err)
+ }
+ insertDuration := time.Since(leafStart)
+ observe(queueInsertLeafLatency, insertDuration, label)
+ resultData := false
+ for dupCheckRow.Next() {
+ err := dupCheckRow.Scan(&resultData)
+ if err != nil {
+ return nil, fmt.Errorf("dupecheck failed: %v", err)
+ }
+ if !resultData {
+ break
+ }
+ }
+ dupCheckRow.Close()
+ if !resultData {
+ // Remember the duplicate leaf, using the requested leaf for now.
+ existingLeaves[i] = leaf
+ existingCount++
+ queuedDupCounter.Inc(label)
+ glog.Warningf("Found duplicate %v %v", t.treeID, leaf)
+ continue
+ }
+
+ // Create the work queue entry
+ args := []interface{}{
+ t.treeID,
+ leaf.LeafIdentityHash,
+ leaf.MerkleLeafHash,
+ }
+ queueTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp)
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ args = append(args, queueArgs(t.treeID, leaf.LeafIdentityHash, queueTimestamp)...)
+ _, err = t.tx.ExecContext(
+ ctx,
+ insertUnsequencedEntrySQL,
+ args...,
+ )
+ if err != nil {
+ glog.Warningf("Error inserting into Unsequenced: %s query %v arguements: %v", err, insertUnsequencedEntrySQL, args)
+ return nil, fmt.Errorf("Unsequenced: %v -- %v", err, args)
+ }
+ leafDuration := time.Since(leafStart)
+ observe(queueInsertEntryLatency, (leafDuration - insertDuration), label)
+ }
+ insertDuration := time.Since(start)
+ observe(queueInsertLatency, insertDuration, label)
+ queuedCounter.Add(float64(len(leaves)), label)
+
+ if existingCount == 0 {
+ return existingLeaves, nil
+ }
+
+ // For existing leaves, we need to retrieve the contents. First collate the desired LeafIdentityHash values.
+ var toRetrieve [][]byte
+ for _, existing := range existingLeaves {
+ if existing != nil {
+ toRetrieve = append(toRetrieve, existing.LeafIdentityHash)
+ }
+ }
+ results, err := t.getLeafDataByIdentityHash(ctx, toRetrieve)
+ if err != nil {
+ return nil, fmt.Errorf("failed to retrieve existing leaves: %v %v", err, toRetrieve)
+ }
+ if len(results) != len(toRetrieve) {
+ return nil, fmt.Errorf("failed to retrieve all existing leaves: got %d, want %d", len(results), len(toRetrieve))
+ }
+ // Replace the requested leaves with the actual leaves.
+ for i, requested := range existingLeaves {
+ if requested == nil {
+ continue
+ }
+ found := false
+ for _, result := range results {
+ if bytes.Equal(result.LeafIdentityHash, requested.LeafIdentityHash) {
+ existingLeaves[i] = result
+ found = true
+ break
+ }
+ }
+ if !found {
+ return nil, fmt.Errorf("failed to find existing leaf for hash %x", requested.LeafIdentityHash)
+ }
+ }
+ totalDuration := time.Since(start)
+ readDuration := totalDuration - insertDuration
+ observe(queueReadLatency, readDuration, label)
+ observe(queueLatency, totalDuration, label)
+
+ return existingLeaves, nil
+}
+
+func (t *logTreeTX) AddSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf, timestamp time.Time) ([]*trillian.QueuedLogLeaf, error) {
+ res := make([]*trillian.QueuedLogLeaf, len(leaves))
+ ok := status.New(codes.OK, "OK").Proto()
+
+ // Leaves in this transaction are inserted in two tables. For each leaf, if
+ // one of the two inserts fails, we remove the side effect by rolling back to
+ // a savepoint installed before the first insert of the two.
+ const savepoint = "SAVEPOINT AddSequencedLeaves"
+ if _, err := t.tx.ExecContext(ctx, savepoint); err != nil {
+ glog.Errorf("Error adding savepoint: %s", err)
+ return nil, err
+ }
+ // TODO(pavelkalinnikov): Consider performance implication of executing this
+ // extra SAVEPOINT, especially for 1-entry batches. Optimize if necessary.
+
+ // Note: LeafData inserts are presumably protected from deadlocks due to
+ // sorting, but the order of the corresponding SequencedLeafData inserts
+ // becomes indeterministic. However, in a typical case when leaves are
+ // supplied in contiguous non-intersecting batches, the chance of having
+ // circular dependencies between transactions is significantly lower.
+ ordLeaves := sortLeavesForInsert(leaves)
+ for _, ol := range ordLeaves {
+ i, leaf := ol.idx, ol.leaf
+
+ // This should fail on insert, but catch it early.
+ if got, want := len(leaf.LeafIdentityHash), t.hashSizeBytes; got != want {
+ return nil, status.Errorf(codes.FailedPrecondition, "leaves[%d] has incorrect hash size %d, want %d", i, got, want)
+ }
+
+ if _, err := t.tx.ExecContext(ctx, savepoint); err != nil {
+ glog.Errorf("Error updating savepoint: %s", err)
+ return nil, err
+ }
+
+ res[i] = &trillian.QueuedLogLeaf{Status: ok}
+
+ // TODO(pavelkalinnikov): Measure latencies.
+ _, err := t.tx.ExecContext(ctx, insertLeafDataSQL,
+ t.treeID, leaf.LeafIdentityHash, leaf.LeafValue, leaf.ExtraData, timestamp.UnixNano())
+ // TODO(pavelkalinnikov): Detach PREORDERED_LOG integration latency metric.
+ if err != nil {
+ glog.Errorf("Error inserting leaves[%d] into LeafData: %s", i, err)
+ return nil, err
+ }
+
+ dupCheckRow, err := t.tx.QueryContext(ctx, insertSequencedLeafSQL,
+ t.treeID, leaf.LeafIndex, leaf.LeafIdentityHash, leaf.MerkleLeafHash, 0)
+ // TODO(pavelkalinnikov): Update IntegrateTimestamp on integrating the leaf.
+ resultData := true
+ for dupCheckRow.Next() {
+ dupCheckRow.Scan(&resultData)
+ if !resultData {
+ break
+ }
+ }
+ dupCheckRow.Close()
+ if !resultData {
+ res[i].Status = status.New(codes.FailedPrecondition, "conflicting LeafIndex").Proto()
+ if _, err := t.tx.ExecContext(ctx, "ROLLBACK TO "+savepoint); err != nil {
+ glog.Errorf("Error rolling back to savepoint: %s", err)
+ return nil, err
+ }
+ } else if err != nil {
+ glog.Errorf("Error inserting leaves[%d] into SequencedLeafData: %s %s", i, err, leaf.LeafIdentityHash)
+ return nil, err
+ }
+
+ // TODO(pavelkalinnikov): Load LeafData for conflicting entries.
+ }
+
+ if _, err := t.tx.ExecContext(ctx, "RELEASE "+savepoint); err != nil {
+ glog.Errorf("Error releasing savepoint: %s", err)
+ return nil, err
+ }
+
+ return res, nil
+}
+
+func (t *logTreeTX) GetSequencedLeafCount(ctx context.Context) (int64, error) {
+ var sequencedLeafCount int64
+
+ err := t.tx.QueryRowContext(ctx, selectSequencedLeafCountSQL, t.treeID).Scan(&sequencedLeafCount)
+ if err != nil {
+ glog.Warningf("Error getting sequenced leaf count: %s", err)
+ }
+
+ return sequencedLeafCount, err
+}
+
+func (t *logTreeTX) GetLeavesByIndex(ctx context.Context, leaves []int64) ([]*trillian.LogLeaf, error) {
+ if t.treeType == trillian.TreeType_LOG {
+ treeSize := int64(t.root.TreeSize)
+ for _, leaf := range leaves {
+ if leaf < 0 {
+ return nil, status.Errorf(codes.InvalidArgument, "index %d is < 0", leaf)
+ }
+ if leaf >= treeSize {
+ return nil, status.Errorf(codes.OutOfRange, "invalid leaf index %d, want < TreeSize(%d)", leaf, treeSize)
+ }
+ }
+ }
+ tmpl, err := t.ls.getLeavesByIndexStmt(ctx, len(leaves))
+ if err != nil {
+ return nil, err
+ }
+ stx := t.tx.StmtContext(ctx, tmpl)
+ defer stx.Close()
+
+ var args []interface{}
+ for _, nodeID := range leaves {
+ args = append(args, interface{}(int64(nodeID)))
+ }
+ args = append(args, interface{}(t.treeID))
+ rows, err := stx.QueryContext(ctx, args...)
+ if err != nil {
+ glog.Warningf("Failed to get leaves by idx: %s", err)
+ return nil, err
+ }
+ defer rows.Close()
+
+ ret := make([]*trillian.LogLeaf, 0, len(leaves))
+ for rows.Next() {
+ leaf := &trillian.LogLeaf{}
+ var qTimestamp, iTimestamp int64
+ if err := rows.Scan(
+ &leaf.MerkleLeafHash,
+ &leaf.LeafIdentityHash,
+ &leaf.LeafValue,
+ &leaf.LeafIndex,
+ &leaf.ExtraData,
+ &qTimestamp,
+ &iTimestamp); err != nil {
+ glog.Warningf("Failed to scan merkle leaves: %s", err)
+ return nil, err
+ }
+ var err error
+ leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, qTimestamp))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, iTimestamp))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid integrate timestamp: %v", err)
+ }
+ ret = append(ret, leaf)
+ }
+
+ if got, want := len(ret), len(leaves); got != want {
+ return nil, status.Errorf(codes.Internal, "len(ret): %d, want %d", got, want)
+ }
+ return ret, nil
+}
+
+func (t *logTreeTX) GetLeavesByRange(ctx context.Context, start, count int64) ([]*trillian.LogLeaf, error) {
+ if count <= 0 {
+ return nil, status.Errorf(codes.InvalidArgument, "invalid count %d, want > 0", count)
+ }
+ if start < 0 {
+ return nil, status.Errorf(codes.InvalidArgument, "invalid start %d, want >= 0", start)
+ }
+
+ if t.treeType == trillian.TreeType_LOG {
+ treeSize := int64(t.root.TreeSize)
+ if treeSize <= 0 {
+ return nil, status.Errorf(codes.OutOfRange, "empty tree")
+ } else if start >= treeSize {
+ return nil, status.Errorf(codes.OutOfRange, "invalid start %d, want < TreeSize(%d)", start, treeSize)
+ }
+ // Ensure no entries queried/returned beyond the tree.
+ if maxCount := treeSize - start; count > maxCount {
+ count = maxCount
+ }
+ }
+ // TODO(pavelkalinnikov): Further clip `count` to a safe upper bound like 64k.
+
+ args := []interface{}{start, start + count, t.treeID}
+ rows, err := t.tx.QueryContext(ctx, selectLeavesByRangeSQL, args...)
+ if err != nil {
+ glog.Warningf("Failed to get leaves by range: %s", err)
+ return nil, err
+ }
+ defer rows.Close()
+
+ ret := make([]*trillian.LogLeaf, 0, count)
+ for wantIndex := start; rows.Next(); wantIndex++ {
+ leaf := &trillian.LogLeaf{}
+ var qTimestamp, iTimestamp int64
+ if err := rows.Scan(
+ &leaf.MerkleLeafHash,
+ &leaf.LeafIdentityHash,
+ &leaf.LeafValue,
+ &leaf.LeafIndex,
+ &leaf.ExtraData,
+ &qTimestamp,
+ &iTimestamp); err != nil {
+ glog.Warningf("Failed to scan merkle leaves: %s", err)
+ return nil, err
+ }
+ if leaf.LeafIndex != wantIndex {
+ if wantIndex < int64(t.root.TreeSize) {
+ return nil, fmt.Errorf("got unexpected index %d, want %d", leaf.LeafIndex, wantIndex)
+ }
+ break
+ }
+ var err error
+ leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, qTimestamp))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, iTimestamp))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid integrate timestamp: %v", err)
+ }
+ ret = append(ret, leaf)
+ }
+
+ return ret, nil
+}
+
+func (t *logTreeTX) GetLeavesByHash(ctx context.Context, leafHashes [][]byte, orderBySequence bool) ([]*trillian.LogLeaf, error) {
+ tmpl, err := t.ls.getLeavesByMerkleHashStmt(ctx, len(leafHashes), orderBySequence)
+ if err != nil {
+ return nil, err
+ }
+
+ return t.getLeavesByHashInternal(ctx, leafHashes, tmpl, "merkle")
+}
+
+// getLeafDataByIdentityHash retrieves leaf data by LeafIdentityHash, returned
+// as a slice of LogLeaf objects for convenience. However, note that the
+// returned LogLeaf objects will not have a valid MerkleLeafHash, LeafIndex, or IntegrateTimestamp.
+func (t *logTreeTX) getLeafDataByIdentityHash(ctx context.Context, leafHashes [][]byte) ([]*trillian.LogLeaf, error) {
+ tmpl, err := t.ls.getLeavesByLeafIdentityHashStmt(ctx, len(leafHashes))
+ if err != nil {
+ return nil, err
+ }
+ return t.getLeavesByHashInternal(ctx, leafHashes, tmpl, "leaf-identity")
+}
+
+func (t *logTreeTX) LatestSignedLogRoot(ctx context.Context) (trillian.SignedLogRoot, error) {
+ return t.slr, nil
+}
+
+// fetchLatestRoot reads the latest SignedLogRoot from the DB and returns it.
+func (t *logTreeTX) fetchLatestRoot(ctx context.Context) (trillian.SignedLogRoot, error) {
+ // var timestamp, treeSize, treeRevision int64
+ var rootSignatureBytes []byte
+ var jsonObj []byte
+
+ t.tx.QueryRowContext(
+ ctx,
+ "select current_tree_data,root_signature from trees where tree_id = $1",
+ t.treeID).Scan(&jsonObj, &rootSignatureBytes)
+ if jsonObj == nil { //this fixes the createtree workflow
+ return trillian.SignedLogRoot{}, storage.ErrTreeNeedsInit
+ }
+ var logRoot types.LogRootV1
+ json.Unmarshal(jsonObj, &logRoot)
+ newRoot, _ := logRoot.MarshalBinary()
+ return trillian.SignedLogRoot{
+ KeyHint: types.SerializeKeyHint(t.treeID),
+ LogRoot: newRoot,
+ LogRootSignature: rootSignatureBytes,
+ }, nil
+}
+
+func (t *logTreeTX) StoreSignedLogRoot(ctx context.Context, root trillian.SignedLogRoot) error {
+ var logRoot types.LogRootV1
+ if err := logRoot.UnmarshalBinary(root.LogRoot); err != nil {
+ glog.Warningf("Failed to parse log root: %x %v", root.LogRoot, err)
+ return err
+ }
+ if len(logRoot.Metadata) != 0 {
+ return fmt.Errorf("unimplemented: postgres storage does not support log root metadata")
+
+ }
+ //get a json copy of the tree_head
+ data, _ := json.Marshal(logRoot)
+ t.tx.ExecContext(
+ ctx,
+ "update trees set current_tree_data = $1,root_signature = $2 where tree_id = $3",
+ data,
+ root.LogRootSignature,
+ t.treeID)
+ res, err := t.tx.ExecContext(
+ ctx,
+ insertTreeHeadSQL,
+ t.treeID,
+ logRoot.TimestampNanos,
+ logRoot.TreeSize,
+ logRoot.RootHash,
+ logRoot.Revision,
+ root.LogRootSignature)
+ if err != nil {
+ glog.Warningf("Failed to store signed root: %s", err)
+ }
+
+ return checkResultOkAndRowCountIs(res, err, 1)
+}
+
+func (t *logTreeTX) getLeavesByHashInternal(ctx context.Context, leafHashes [][]byte, tmpl *sql.Stmt, desc string) ([]*trillian.LogLeaf, error) {
+ stx := t.tx.StmtContext(ctx, tmpl)
+ defer stx.Close()
+
+ var args []interface{}
+ for _, hash := range leafHashes {
+ args = append(args, interface{}([]byte(hash)))
+ }
+ args = append(args, interface{}(t.treeID))
+ rows, err := stx.QueryContext(ctx, args...)
+ if err != nil {
+ glog.Warningf("Query() %s hash = %v", desc, err)
+ return nil, err
+ }
+ defer rows.Close()
+
+ // The tree could include duplicates so we don't know how many results will be returned
+ var ret []*trillian.LogLeaf
+ for rows.Next() {
+ leaf := &trillian.LogLeaf{}
+ // We might be using a LEFT JOIN in our statement, so leaves which are
+ // queued but not yet integrated will have a NULL IntegrateTimestamp
+ // when there's no corresponding entry in SequencedLeafData, even though
+ // the table definition forbids that, so we use a nullable type here and
+ // check its validity below.
+ var integrateTS sql.NullInt64
+ var queueTS int64
+
+ if err := rows.Scan(&leaf.MerkleLeafHash, &leaf.LeafIdentityHash, &leaf.LeafValue, &leaf.LeafIndex, &leaf.ExtraData, &queueTS, &integrateTS); err != nil {
+ glog.Warningf("LogID: %d Scan() %s = %s", t.treeID, desc, err)
+ return nil, err
+ }
+ var err error
+ leaf.QueueTimestamp, err = ptypes.TimestampProto(time.Unix(0, queueTS))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ if integrateTS.Valid {
+ leaf.IntegrateTimestamp, err = ptypes.TimestampProto(time.Unix(0, integrateTS.Int64))
+ if err != nil {
+ return nil, fmt.Errorf("got invalid integrate timestamp: %v", err)
+ }
+ }
+
+ if got, want := len(leaf.MerkleLeafHash), t.hashSizeBytes; got != want {
+ return nil, fmt.Errorf("LogID: %d Scanned leaf %s does not have hash length %d, got %d", t.treeID, desc, want, got)
+ }
+
+ ret = append(ret, leaf)
+ }
+
+ return ret, nil
+}
+
+func (t *readOnlyLogTX) GetUnsequencedCounts(ctx context.Context) (storage.CountByLogID, error) {
+ stx, err := t.tx.PrepareContext(ctx, selectUnsequencedLeafCountSQL)
+ if err != nil {
+ glog.Warningf("Failed to prep unsequenced leaf count statement: %v", err)
+ return nil, err
+ }
+ defer stx.Close()
+
+ rows, err := stx.QueryContext(ctx)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ ret := make(map[int64]int64)
+ for rows.Next() {
+ var logID, count int64
+ if err := rows.Scan(&logID, &count); err != nil {
+ return nil, fmt.Errorf("failed to scan row from unsequenced counts: %v", err)
+ }
+ ret[logID] = count
+ }
+ return ret, nil
+}
+
+// leafAndPosition records original position before sort.
+type leafAndPosition struct {
+ leaf *trillian.LogLeaf
+ idx int
+}
+
+// byLeafIdentityHashWithPosition allows sorting (as above), but where we need
+// to remember the original position
+type byLeafIdentityHashWithPosition []leafAndPosition
+
+func (l byLeafIdentityHashWithPosition) Len() int {
+ return len(l)
+}
+func (l byLeafIdentityHashWithPosition) Swap(i, j int) {
+ l[i], l[j] = l[j], l[i]
+}
+func (l byLeafIdentityHashWithPosition) Less(i, j int) bool {
+ return bytes.Compare(l[i].leaf.LeafIdentityHash, l[j].leaf.LeafIdentityHash) == -1
+}
diff --git a/storage/postgres/log_storage_test.go b/storage/postgres/log_storage_test.go
new file mode 100644
index 0000000000..e0413cbefd
--- /dev/null
+++ b/storage/postgres/log_storage_test.go
@@ -0,0 +1,1364 @@
+// Copyright 2016 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "bytes"
+ "context"
+ "crypto"
+ "crypto/sha256"
+ "database/sql"
+ "fmt"
+ "reflect"
+ "sort"
+ "testing"
+ "time"
+
+ "github.com/golang/protobuf/proto"
+ "github.com/golang/protobuf/ptypes"
+ "github.com/google/trillian"
+ "github.com/google/trillian/storage"
+ "github.com/google/trillian/storage/testonly"
+ "github.com/google/trillian/types"
+ "github.com/kylelemons/godebug/pretty"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ tcrypto "github.com/google/trillian/crypto"
+ ttestonly "github.com/google/trillian/testonly"
+
+ _ "github.com/lib/pq"
+)
+
+// Must be 32 bytes to match sha256 length if it was a real hash
+var dummyHash = []byte("hashxxxxhashxxxxhashxxxxhashxxxx")
+var dummyRawHash = []byte("xxxxhashxxxxhashxxxxhashxxxxhash")
+var dummyRawHash2 = []byte("yyyyhashyyyyhashyyyyhashyyyyhash")
+var dummyHash2 = []byte("HASHxxxxhashxxxxhashxxxxhashxxxx")
+var dummyHash3 = []byte("hashxxxxhashxxxxhashxxxxHASHxxxx")
+
+// Time we will queue all leaves at
+var fakeQueueTime = time.Date(2016, 11, 10, 15, 16, 27, 0, time.UTC)
+
+// Time we will integrate all leaves at
+var fakeIntegrateTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC)
+
+// Time we'll request for guard cutoff in tests that don't test this (should include all above)
+var fakeDequeueCutoffTime = time.Date(2016, 11, 10, 15, 16, 30, 0, time.UTC)
+
+// Used for tests involving extra data
+var someExtraData = []byte("Some extra data")
+var someExtraData2 = []byte("Some even more extra data")
+
+const leavesToInsert = 5
+const sequenceNumber int64 = 237
+
+// Tests that access the db should each use a distinct log ID to prevent lock contention when
+// run in parallel or race conditions / unexpected interactions. Tests that pass should hold
+// no locks afterwards.
+
+func createFakeLeaf(ctx context.Context, db *sql.DB, logID int64, rawHash, hash, data, extraData []byte, seq int64, t *testing.T) *trillian.LogLeaf {
+ t.Helper()
+ queuedAtNanos := fakeQueueTime.UnixNano()
+ integratedAtNanos := fakeIntegrateTime.UnixNano()
+ _, err := db.ExecContext(ctx, "select * from insert_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)", logID, rawHash, data, extraData, queuedAtNanos)
+ _, err2 := db.ExecContext(ctx, "select * from insert_sequenced_leaf_data_ignore_duplicates($1,$2,$3,$4,$5)", logID, seq, rawHash, hash, integratedAtNanos)
+
+ if err != nil || err2 != nil {
+ t.Fatalf("Failed to create test leaves: %v %v", err, err2)
+ }
+ if err != nil {
+ panic(err)
+ }
+ integrateTimestamp, err := ptypes.TimestampProto(fakeIntegrateTime)
+ if err != nil {
+ panic(err)
+ }
+ return &trillian.LogLeaf{
+ MerkleLeafHash: hash,
+ LeafValue: data,
+ ExtraData: extraData,
+ LeafIndex: seq,
+ LeafIdentityHash: rawHash,
+ IntegrateTimestamp: integrateTimestamp,
+ }
+}
+
+func checkLeafContents(leaf *trillian.LogLeaf, seq int64, rawHash, hash, data, extraData []byte, t *testing.T) {
+ t.Helper()
+ if got, want := leaf.MerkleLeafHash, hash; !bytes.Equal(got, want) {
+ t.Fatalf("Wrong leaf hash in returned leaf got\n%v\nwant:\n%v", got, want)
+ }
+
+ if got, want := leaf.LeafIdentityHash, rawHash; !bytes.Equal(got, want) {
+ t.Fatalf("Wrong raw leaf hash in returned leaf got\n%v\nwant:\n%v", got, want)
+ }
+
+ if got, want := seq, leaf.LeafIndex; got != want {
+ t.Fatalf("Bad sequence number in returned leaf got: %d, want:%d", got, want)
+ }
+
+ if got, want := leaf.LeafValue, data; !bytes.Equal(got, want) {
+ t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want)
+ }
+
+ if got, want := leaf.ExtraData, extraData; !bytes.Equal(got, want) {
+ t.Fatalf("Unxpected data in returned leaf. got:\n%v\nwant:\n%v", got, want)
+ }
+
+ iTime, err := ptypes.Timestamp(leaf.IntegrateTimestamp)
+ if err != nil {
+ t.Fatalf("Got invalid integrate timestamp: %v", err)
+ }
+ if got, want := iTime.UnixNano(), fakeIntegrateTime.UnixNano(); got != want {
+ t.Errorf("Wrong IntegrateTimestamp: got %v, want %v", got, want)
+ }
+}
+
+func TestMySQLLogStorage_CheckDatabaseAccessible(t *testing.T) {
+ cleanTestDB(db, t)
+ s := NewLogStorage(db, nil)
+ if err := s.CheckDatabaseAccessible(context.Background()); err != nil {
+ t.Errorf("CheckDatabaseAccessible() = %v, want = nil", err)
+ }
+}
+
+func TestSnapshot(t *testing.T) {
+ cleanTestDB(db, t)
+
+ frozenLog := createTreeOrPanic(db, testonly.LogTree)
+ createFakeSignedLogRoot(db, frozenLog, 0)
+ if _, err := updateTree(db, frozenLog.TreeId, func(tree *trillian.Tree) {
+ tree.TreeState = trillian.TreeState_FROZEN
+ }); err != nil {
+ t.Fatalf("Error updating frozen tree: %v", err)
+ }
+
+ activeLog := createTreeOrPanic(db, testonly.LogTree)
+ createFakeSignedLogRoot(db, activeLog, 0)
+ mapTreeID := createTreeOrPanic(db, testonly.MapTree).TreeId
+
+ tests := []struct {
+ desc string
+ tree *trillian.Tree
+ wantErr bool
+ }{
+ {
+ desc: "unknownSnapshot",
+ tree: logTree(-1),
+ wantErr: true,
+ },
+ {
+ desc: "activeLogSnapshot",
+ tree: activeLog,
+ },
+ {
+ desc: "frozenSnapshot",
+ tree: frozenLog,
+ },
+ {
+ desc: "mapSnapshot",
+ tree: logTree(mapTreeID),
+ wantErr: true,
+ },
+ }
+
+ ctx := context.Background()
+ s := NewLogStorage(db, nil)
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ tx, err := s.SnapshotForTree(ctx, test.tree)
+
+ if err == storage.ErrTreeNeedsInit {
+ defer tx.Close()
+ }
+
+ if hasErr := err != nil; hasErr != test.wantErr {
+ t.Fatalf("err = %q, wantErr = %v", err, test.wantErr)
+ } else if hasErr {
+ return
+ }
+ defer tx.Close()
+
+ _, err = tx.LatestSignedLogRoot(ctx)
+ if err != nil {
+ t.Errorf("LatestSignedLogRoot() returned err = %v", err)
+ }
+ if err := tx.Commit(); err != nil {
+ t.Errorf("Commit() returned err = %v", err)
+ }
+ })
+ }
+}
+
+func TestReadWriteTransaction(t *testing.T) {
+ cleanTestDB(db, t)
+ activeLog := createTreeOrPanic(db, testonly.LogTree)
+ createFakeSignedLogRoot(db, activeLog, 0)
+
+ tests := []struct {
+ desc string
+ tree *trillian.Tree
+ wantErr bool
+ wantLogRoot []byte
+ wantTXRev int64
+ }{
+ {
+ // Unknown logs IDs are now handled outside storage.
+ desc: "unknownBegin",
+ tree: logTree(-1),
+ wantLogRoot: nil,
+ wantTXRev: -1,
+ },
+ {
+ desc: "activeLogBegin",
+ tree: activeLog,
+ wantLogRoot: func() []byte {
+ b, err := (&types.LogRootV1{RootHash: []byte{0}}).MarshalBinary()
+ if err != nil {
+ panic(err)
+ }
+ return b
+ }(),
+ wantTXRev: 1,
+ },
+ }
+
+ ctx := context.Background()
+ s := NewLogStorage(db, nil)
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ err := s.ReadWriteTransaction(ctx, test.tree, func(ctx context.Context, tx storage.LogTreeTX) error {
+ root, err := tx.LatestSignedLogRoot(ctx)
+ if err != nil {
+ t.Fatalf("%v: LatestSignedLogRoot() returned err = %v", test.desc, err)
+ }
+ gotRev, _ := tx.WriteRevision(ctx)
+ if gotRev != test.wantTXRev {
+ t.Errorf("%v: WriteRevision() = %v, want = %v", test.desc, gotRev, test.wantTXRev)
+ }
+ if got, want := root.LogRoot, test.wantLogRoot; !bytes.Equal(got, want) {
+ t.Errorf("%v: LogRoot: \n%x, want \n%x", test.desc, got, want)
+ }
+ return nil
+ })
+ if hasErr := err != nil; hasErr != test.wantErr {
+ t.Fatalf("%v: err = %q, wantErr = %v", test.desc, err, test.wantErr)
+ } else if hasErr {
+ return
+ }
+ })
+ }
+}
+
+func TestQueueDuplicateLeaf(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+ count := 15
+ leaves := createTestLeaves(int64(count), 10)
+ leaves2 := createTestLeaves(int64(count), 12)
+ leaves3 := createTestLeaves(3, 100)
+
+ // Note that tests accumulate queued leaves on top of each other.
+ var tests = []struct {
+ desc string
+ leaves []*trillian.LogLeaf
+ want []*trillian.LogLeaf
+ }{
+ {
+ desc: "[10, 11, 12, ...]",
+ leaves: leaves,
+ want: make([]*trillian.LogLeaf, count),
+ },
+ {
+ desc: "[12, 13, 14, ...] so first (count-2) are duplicates",
+ leaves: leaves2,
+ want: append(leaves[2:], nil, nil),
+ },
+ {
+ desc: "[10, 100, 11, 101, 102] so [dup, new, dup, new, dup]",
+ leaves: []*trillian.LogLeaf{leaves[0], leaves3[0], leaves[1], leaves3[1], leaves[2]},
+ want: []*trillian.LogLeaf{leaves[0], nil, leaves[1], nil, leaves[2]},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ existing, err := tx.QueueLeaves(ctx, test.leaves, fakeQueueTime)
+ if err != nil {
+ t.Errorf("Failed to queue leaves: %v", err)
+ return err
+ }
+
+ if len(existing) != len(test.want) {
+ t.Fatalf("|QueueLeaves()|=%d; want %d", len(existing), len(test.want))
+ }
+ for i, want := range test.want {
+ got := existing[i]
+ if want == nil {
+ if got != nil {
+ t.Fatalf("QueueLeaves()[%d]=%v; want nil", i, got)
+ }
+ return nil
+ }
+ if got == nil {
+ t.Fatalf("QueueLeaves()[%d]=nil; want non-nil", i)
+ } else if !bytes.Equal(got.LeafIdentityHash, want.LeafIdentityHash) {
+ t.Fatalf("QueueLeaves()[%d].LeafIdentityHash=%x; want %x", i, got.LeafIdentityHash, want.LeafIdentityHash)
+ }
+ }
+ return nil
+ })
+ })
+ }
+}
+
+func TestQueueLeaves(t *testing.T) {
+ ctx := context.Background()
+
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves := createTestLeaves(leavesToInsert, 20)
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
+ t.Fatalf("Failed to queue leaves: %v", err)
+ }
+ return nil
+ })
+
+ // Should see the leaves in the database. There is no API to read from the unsequenced data.
+ var count int
+ if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM unsequenced WHERE Tree_id=$1", tree.TreeId).Scan(&count); err != nil {
+ t.Fatalf("Could not query row count: %v", err)
+ }
+ if leavesToInsert != count {
+ t.Fatalf("Expected %d unsequenced rows but got: %d", leavesToInsert, count)
+ }
+
+ // Additional check on timestamp being set correctly in the database
+ var queueTimestamp int64
+ if err := db.QueryRowContext(ctx, "SELECT DISTINCT queue_timestamp_nanos FROM unsequenced WHERE tree_id=$1", tree.TreeId).Scan(&queueTimestamp); err != nil {
+ t.Fatalf("Could not query timestamp: %v", err)
+ }
+ if got, want := queueTimestamp, fakeQueueTime.UnixNano(); got != want {
+ t.Fatalf("Incorrect queue timestamp got: %d want: %d", got, want)
+ }
+}
+
+// AddSequencedLeaves tests. ---------------------------------------------------
+
+type addSequencedLeavesTest struct {
+ t *testing.T
+ s storage.LogStorage
+ tree *trillian.Tree
+}
+
+func initAddSequencedLeavesTest(t *testing.T) addSequencedLeavesTest {
+ cleanTestDB(db, t)
+ s := NewLogStorage(db, nil)
+ tree := createTreeOrPanic(db, testonly.PreorderedLogTree)
+ return addSequencedLeavesTest{t, s, tree}
+}
+
+func (t *addSequencedLeavesTest) addSequencedLeaves(leaves []*trillian.LogLeaf) {
+ runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ if _, err := tx.AddSequencedLeaves(ctx, leaves, fakeQueueTime); err != nil {
+ t.t.Fatalf("Failed to add sequenced leaves: %v", err)
+ }
+ // TODO(pavelkalinnikov): Verify returned status for each leaf.
+ return nil
+ })
+}
+
+func (t *addSequencedLeavesTest) verifySequencedLeaves(start, count int64, exp []*trillian.LogLeaf) {
+ var stored []*trillian.LogLeaf
+ runLogTX(t.s, t.tree, t.t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ var err error
+ stored, err = tx.GetLeavesByRange(ctx, start, count)
+ if err != nil {
+ t.t.Fatalf("Failed to read sequenced leaves: %v", err)
+ }
+ return nil
+ })
+ if got, want := len(stored), len(exp); got != want {
+ t.t.Fatalf("Unexpected number of leaves: got %d, want %d %d %d %v", got, want, start, count, exp)
+ }
+
+ for i, leaf := range stored {
+ if got, want := leaf.LeafIndex, exp[i].LeafIndex; got != want {
+ t.t.Fatalf("Leaf #%d: LeafIndex=%v, want %v", i, got, want)
+ }
+ if got, want := leaf.LeafIdentityHash, exp[i].LeafIdentityHash; !bytes.Equal(got, want) {
+ t.t.Fatalf("Leaf #%d: LeafIdentityHash=%v, want %v %d %d %v", i, got, want, start, count, t.tree)
+ }
+ }
+}
+
+func TestAddSequencedLeavesUnordered(t *testing.T) {
+ const chunk = leavesToInsert
+ const count = chunk * 5
+ const extraCount = 16
+ leaves := createTestLeaves(count, 0)
+
+ aslt := initAddSequencedLeavesTest(t)
+ for _, idx := range []int{1, 0, 4, 2} {
+ aslt.addSequencedLeaves(leaves[chunk*idx : chunk*(idx+1)])
+ }
+ aslt.verifySequencedLeaves(0, count+extraCount, leaves[:chunk*3])
+ aslt.verifySequencedLeaves(chunk*4, chunk+extraCount, leaves[chunk*4:count])
+ aslt.addSequencedLeaves(leaves[chunk*3 : chunk*4])
+ aslt.verifySequencedLeaves(0, count+extraCount, leaves)
+}
+
+func TestAddSequencedLeavesWithDuplicates(t *testing.T) {
+ leaves := createTestLeaves(6, 0)
+
+ aslt := initAddSequencedLeavesTest(t)
+ aslt.addSequencedLeaves(leaves[:3])
+ aslt.verifySequencedLeaves(0, 3, leaves[:3])
+ aslt.addSequencedLeaves(leaves[2:]) // Full dup.
+ aslt.verifySequencedLeaves(0, 6, leaves)
+
+ dupLeaves := createTestLeaves(4, 6)
+ dupLeaves[0].LeafIdentityHash = leaves[0].LeafIdentityHash // Hash dup.
+ dupLeaves[2].LeafIndex = 2 // Index dup.
+ aslt.addSequencedLeaves(dupLeaves)
+ aslt.verifySequencedLeaves(6, 4, dupLeaves[0:2])
+ aslt.verifySequencedLeaves(7, 4, dupLeaves[1:2])
+ aslt.verifySequencedLeaves(8, 4, nil)
+ aslt.verifySequencedLeaves(9, 4, dupLeaves[3:4])
+ dupLeaves = createTestLeaves(4, 6)
+ aslt.addSequencedLeaves(dupLeaves)
+}
+
+// -----------------------------------------------------------------------------
+
+func TestDequeueLeavesNoneQueued(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves, err := tx.DequeueLeaves(ctx, 999, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Didn't expect an error on dequeue with no work to be done: %v", err)
+ }
+ if len(leaves) > 0 {
+ t.Fatalf("Expected nothing to be dequeued but we got %d leaves", len(leaves))
+ }
+ return nil
+ })
+}
+
+func TestDequeueLeaves(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves := createTestLeaves(leavesToInsert, 20)
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
+ t.Fatalf("Failed to queue leaves: %v", err)
+ }
+ return nil
+ })
+ }
+
+ {
+ // Now try to dequeue them
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves2) != leavesToInsert {
+ t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
+ }
+ ensureAllLeavesDistinct(leaves2, t)
+ return nil
+ })
+ }
+
+ {
+ // If we dequeue again then we should now get nothing
+ runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
+ leaves3, err := tx3.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves (second time): %v", err)
+ }
+ if len(leaves3) != 0 {
+ t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves3))
+ }
+ return nil
+ })
+ }
+}
+
+func TestDequeueLeavesHaveQueueTimestamp(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves := createTestLeaves(leavesToInsert, 20)
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
+ t.Fatalf("Failed to queue leaves: %v", err)
+ }
+ return nil
+ })
+ }
+
+ {
+ // Now try to dequeue them
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves2) != leavesToInsert {
+ t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
+ }
+ ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime)
+ return nil
+ })
+ }
+}
+
+func TestDequeueLeavesTwoBatches(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ leavesToDequeue1 := 3
+ leavesToDequeue2 := 2
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves := createTestLeaves(leavesToInsert, 20)
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeDequeueCutoffTime); err != nil {
+ t.Fatalf("Failed to queue leaves: %v", err)
+ }
+ return nil
+ })
+ }
+
+ var err error
+ var leaves2, leaves3, leaves4 []*trillian.LogLeaf
+ {
+ // Now try to dequeue some of them
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ leaves2, err = tx2.DequeueLeaves(ctx, leavesToDequeue1, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves2) != leavesToDequeue1 {
+ t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
+ }
+ ensureAllLeavesDistinct(leaves2, t)
+ ensureLeavesHaveQueueTimestamp(t, leaves2, fakeDequeueCutoffTime)
+ return nil
+ })
+
+ // Now try to dequeue the rest of them
+ runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
+ leaves3, err = tx3.DequeueLeaves(ctx, leavesToDequeue2, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves3) != leavesToDequeue2 {
+ t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves3), leavesToDequeue2)
+ }
+ ensureAllLeavesDistinct(leaves3, t)
+ ensureLeavesHaveQueueTimestamp(t, leaves3, fakeDequeueCutoffTime)
+
+ // Plus the union of the leaf batches should all have distinct hashes
+ leaves4 = append(leaves2, leaves3...)
+ ensureAllLeavesDistinct(leaves4, t)
+ return nil
+ })
+ }
+
+ {
+ // If we dequeue again then we should now get nothing
+ runLogTX(s, tree, t, func(ctx context.Context, tx4 storage.LogTreeTX) error {
+ leaves5, err := tx4.DequeueLeaves(ctx, 99, fakeDequeueCutoffTime)
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves (second time): %v", err)
+ }
+ if len(leaves5) != 0 {
+ t.Fatalf("Dequeued %d leaves but expected to get none", len(leaves5))
+ }
+ return nil
+ })
+ }
+}
+
+// Queues leaves and attempts to dequeue before the guard cutoff allows it. This should
+// return nothing. Then retry with an inclusive guard cutoff and ensure the leaves
+// are returned.
+func TestDequeueLeavesGuardInterval(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves := createTestLeaves(leavesToInsert, 20)
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
+ t.Fatalf("Failed to queue leaves: %v", err)
+ }
+ return nil
+ })
+ }
+
+ {
+ // Now try to dequeue them using a cutoff that means we should get none
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ leaves2, err := tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(-time.Second))
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves2) != 0 {
+ t.Fatalf("Dequeued %d leaves when they all should be in guard interval", len(leaves2))
+ }
+
+ // Try to dequeue again using a cutoff that should include them
+ leaves2, err = tx2.DequeueLeaves(ctx, 99, fakeQueueTime.Add(time.Second))
+ if err != nil {
+ t.Fatalf("Failed to dequeue leaves: %v", err)
+ }
+ if len(leaves2) != leavesToInsert {
+ t.Fatalf("Dequeued %d leaves but expected to get %d", len(leaves2), leavesToInsert)
+ }
+ ensureAllLeavesDistinct(leaves2, t)
+ return nil
+ })
+ }
+}
+
+func TestDequeueLeavesTimeOrdering(t *testing.T) {
+ // Queue two small batches of leaves at different timestamps. Do two separate dequeue
+ // transactions and make sure the returned leaves are respecting the time ordering of the
+ // queue.
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ batchSize := 2
+ leaves := createTestLeaves(int64(batchSize), 0)
+ leaves2 := createTestLeaves(int64(batchSize), int64(batchSize))
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ if _, err := tx.QueueLeaves(ctx, leaves, fakeQueueTime); err != nil {
+ t.Fatalf("QueueLeaves(1st batch) = %v", err)
+ }
+ // These are one second earlier so should be dequeued first
+ if _, err := tx.QueueLeaves(ctx, leaves2, fakeQueueTime.Add(-time.Second)); err != nil {
+ t.Fatalf("QueueLeaves(2nd batch) = %v", err)
+ }
+ return nil
+ })
+ }
+
+ {
+ // Now try to dequeue two leaves and we should get the second batch
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ dequeue1, err := tx2.DequeueLeaves(ctx, batchSize, fakeQueueTime)
+ if err != nil {
+ t.Fatalf("DequeueLeaves(1st) = %v", err)
+ }
+ if got, want := len(dequeue1), batchSize; got != want {
+ t.Fatalf("Dequeue count mismatch (1st) got: %d, want: %d", got, want)
+ }
+ ensureAllLeavesDistinct(dequeue1, t)
+
+ // Ensure this is the second batch queued by comparing leaf hashes (must be distinct as
+ // the leaf data was).
+ if !leafInBatch(dequeue1[0], leaves2) || !leafInBatch(dequeue1[1], leaves2) {
+ t.Fatalf("Got leaf from wrong batch (1st dequeue): %v", dequeue1)
+ }
+ return nil
+ })
+
+ // Try to dequeue again and we should get the batch that was queued first, though at a later time
+ runLogTX(s, tree, t, func(ctx context.Context, tx3 storage.LogTreeTX) error {
+ dequeue2, err := tx3.DequeueLeaves(ctx, batchSize, fakeQueueTime)
+ if err != nil {
+ t.Fatalf("DequeueLeaves(2nd) = %v", err)
+ }
+ if got, want := len(dequeue2), batchSize; got != want {
+ t.Fatalf("Dequeue count mismatch (2nd) got: %d, want: %d", got, want)
+ }
+ ensureAllLeavesDistinct(dequeue2, t)
+
+ // Ensure this is the first batch by comparing leaf hashes.
+ if !leafInBatch(dequeue2[0], leaves) || !leafInBatch(dequeue2[1], leaves) {
+ t.Fatalf("Got leaf from wrong batch (2nd dequeue): %v", dequeue2)
+ }
+ return nil
+ })
+ }
+}
+
+func TestGetLeavesByHashNotPresent(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ hashes := [][]byte{[]byte("thisdoesn'texist")}
+ leaves, err := tx.GetLeavesByHash(ctx, hashes, false)
+ if err != nil {
+ t.Fatalf("Error getting leaves by hash: %v", err)
+ }
+ if len(leaves) != 0 {
+ t.Fatalf("Expected no leaves returned but got %d", len(leaves))
+ }
+ return nil
+ })
+}
+
+func TestGetLeavesByHash(t *testing.T) {
+ ctx := context.Background()
+
+ // Create fake leaf as if it had been sequenced
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ data := []byte("some data")
+ createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t)
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ hashes := [][]byte{dummyHash}
+ leaves, err := tx.GetLeavesByHash(ctx, hashes, false)
+ if err != nil {
+ t.Fatalf("Unexpected error getting leaf by hash: %v", err)
+ }
+ if len(leaves) != 1 {
+ t.Fatalf("Got %d leaves but expected one", len(leaves))
+ }
+ checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
+ return nil
+ })
+}
+
+func TestGetLeavesByIndex(t *testing.T) {
+ ctx := context.Background()
+
+ // Create fake leaf as if it had been sequenced, read it back and check contents
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ // The leaf indices are checked against the tree size so we need a root.
+ createFakeSignedLogRoot(db, tree, uint64(sequenceNumber+1))
+
+ data := []byte("some data")
+ data2 := []byte("some other data")
+ createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t)
+ createFakeLeaf(ctx, db, tree.TreeId, dummyRawHash2, dummyHash2, data2, someExtraData2, sequenceNumber-1, t)
+
+ var tests = []struct {
+ desc string
+ indices []int64
+ wantErr bool
+ wantCode codes.Code
+ checkFn func([]*trillian.LogLeaf, *testing.T)
+ }{
+ {
+ desc: "InTree",
+ indices: []int64{sequenceNumber},
+ checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
+ checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
+ },
+ },
+ {
+ desc: "InTree2",
+ indices: []int64{sequenceNumber - 1},
+ wantErr: false,
+ checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
+ checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
+ },
+ },
+ {
+ desc: "InTreeMultiple",
+ indices: []int64{sequenceNumber - 1, sequenceNumber},
+ checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
+ checkLeafContents(leaves[1], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
+ checkLeafContents(leaves[0], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
+ },
+ },
+ {
+ desc: "InTreeMultipleReverse",
+ indices: []int64{sequenceNumber, sequenceNumber - 1},
+ checkFn: func(leaves []*trillian.LogLeaf, t *testing.T) {
+ checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t)
+ checkLeafContents(leaves[1], sequenceNumber, dummyRawHash2, dummyHash2, data2, someExtraData2, t)
+ },
+ }, {
+ desc: "OutsideTree",
+ indices: []int64{sequenceNumber + 1},
+ wantErr: true,
+ wantCode: codes.OutOfRange,
+ },
+ {
+ desc: "LongWayOutsideTree",
+ indices: []int64{9999},
+ wantErr: true,
+ wantCode: codes.OutOfRange,
+ },
+ {
+ desc: "MixedInOutTree",
+ indices: []int64{sequenceNumber, sequenceNumber + 1},
+ wantErr: true,
+ wantCode: codes.OutOfRange,
+ },
+ {
+ desc: "MixedInOutTree2",
+ indices: []int64{sequenceNumber - 1, sequenceNumber + 1},
+ wantErr: true,
+ wantCode: codes.OutOfRange,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ got, err := tx.GetLeavesByIndex(ctx, test.indices)
+ if test.wantErr {
+ if err == nil || status.Code(err) != test.wantCode {
+ t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: nil, err with code %v", test.indices, got, err, test.wantCode)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("GetLeavesByIndex(%v)=%v,%v; want: got, nil", test.indices, got, err)
+ }
+ }
+ return nil
+ })
+ })
+ }
+}
+
+// GetLeavesByRange tests. -----------------------------------------------------
+
+type getLeavesByRangeTest struct {
+ start, count int64
+ want []int64
+ wantErr bool
+}
+
+func testGetLeavesByRangeImpl(t *testing.T, create *trillian.Tree, tests []getLeavesByRangeTest) {
+ cleanTestDB(db, t)
+
+ ctx := context.Background()
+ tree, err := createTree(db, create)
+ if err != nil {
+ t.Fatalf("Error creating log: %v", err)
+ }
+ // Note: GetLeavesByRange loads the root internally to get the tree size.
+ createFakeSignedLogRoot(db, tree, 14)
+ s := NewLogStorage(db, nil)
+
+ // Create leaves [0]..[19] but drop leaf [5] and set the tree size to 14.
+ for i := int64(0); i < 20; i++ {
+ if i == 5 {
+ continue
+ }
+ data := []byte{byte(i)}
+ identityHash := sha256.Sum256(data)
+ createFakeLeaf(ctx, db, tree.TreeId, identityHash[:], identityHash[:], data, someExtraData, i, t)
+ }
+
+ for _, test := range tests {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ leaves, err := tx.GetLeavesByRange(ctx, test.start, test.count)
+ if err != nil {
+ if !test.wantErr {
+ t.Errorf("GetLeavesByRange(%d, +%d)=_,%v; want _,nil", test.start, test.count, err)
+ }
+ return nil
+ }
+ if test.wantErr {
+ t.Errorf("GetLeavesByRange(%d, +%d)=_,nil; want _,non-nil", test.start, test.count)
+ }
+ got := make([]int64, len(leaves))
+ for i, leaf := range leaves {
+ got[i] = leaf.LeafIndex
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("GetLeavesByRange(%d, +%d)=%+v; want %+v", test.start, test.count, got, test.want)
+ }
+ return nil
+ })
+ }
+}
+
+func TestGetLeavesByRangeFromLog(t *testing.T) {
+ var tests = []getLeavesByRangeTest{
+ {start: 0, count: 1, want: []int64{0}},
+ {start: 0, count: 2, want: []int64{0, 1}},
+ {start: 1, count: 3, want: []int64{1, 2, 3}},
+ {start: 10, count: 7, want: []int64{10, 11, 12, 13}},
+ {start: 13, count: 1, want: []int64{13}},
+ {start: 14, count: 4, wantErr: true}, // Starts right after tree size.
+ {start: 19, count: 2, wantErr: true}, // Starts further away.
+ {start: 3, count: 5, wantErr: true}, // Hits non-contiguous leaves.
+ {start: 5, count: 5, wantErr: true}, // Starts from a missing leaf.
+ {start: 1, count: 0, wantErr: true}, // Empty range.
+ {start: -1, count: 1, wantErr: true}, // Negative start.
+ {start: 1, count: -1, wantErr: true}, // Negative count.
+ {start: 100, count: 30, wantErr: true}, // Starts after all stored leaves.
+ }
+ testGetLeavesByRangeImpl(t, testonly.LogTree, tests)
+}
+
+func TestGetLeavesByRangeFromPreorderedLog(t *testing.T) {
+ var tests = []getLeavesByRangeTest{
+ {start: 0, count: 1, want: []int64{0}},
+ {start: 0, count: 2, want: []int64{0, 1}},
+ {start: 1, count: 3, want: []int64{1, 2, 3}},
+ {start: 10, count: 7, want: []int64{10, 11, 12, 13, 14, 15, 16}},
+ {start: 13, count: 1, want: []int64{13}},
+ // Starts right after tree size.
+ {start: 14, count: 4, want: []int64{14, 15, 16, 17}},
+ {start: 19, count: 2, want: []int64{19}}, // Starts further away.
+ {start: 3, count: 5, wantErr: true}, // Hits non-contiguous leaves.
+ {start: 5, count: 5, wantErr: true}, // Starts from a missing leaf.
+ {start: 1, count: 0, wantErr: true}, // Empty range.
+ {start: -1, count: 1, wantErr: true}, // Negative start.
+ {start: 1, count: -1, wantErr: true}, // Negative count.
+ {start: 100, count: 30, want: []int64{}}, // Starts after all stored leaves.
+ }
+ testGetLeavesByRangeImpl(t, testonly.PreorderedLogTree, tests)
+}
+
+// -----------------------------------------------------------------------------
+
+func TestLatestSignedRootNoneWritten(t *testing.T) {
+ ctx := context.Background()
+
+ cleanTestDB(db, t)
+ tree, err := createTree(db, testonly.LogTree)
+ if err != nil {
+ t.Fatalf("createTree: %v", err)
+ }
+ s := NewLogStorage(db, nil)
+
+ tx, err := s.SnapshotForTree(ctx, tree)
+ if err != storage.ErrTreeNeedsInit {
+ t.Fatalf("SnapshotForTree gave %v, want %v", err, storage.ErrTreeNeedsInit)
+ }
+ commit(tx, t)
+}
+
+func TestLatestSignedLogRoot(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
+ root, err := signer.SignLogRoot(&types.LogRootV1{
+ TimestampNanos: 98765,
+ TreeSize: 16,
+ Revision: 5,
+ RootHash: []byte(dummyHash),
+ })
+ if err != nil {
+ t.Fatalf("SignLogRoot(): %v", err)
+ }
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
+ t.Fatalf("Failed to store signed root: %v", err)
+ }
+ return nil
+ })
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ root2, err := tx2.LatestSignedLogRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to read back new log root: %v", err)
+ }
+ if !proto.Equal(root, &root2) {
+ t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2)
+ }
+ return nil
+ })
+ }
+}
+
+func TestDuplicateSignedLogRoot(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
+ root, err := signer.SignLogRoot(&types.LogRootV1{
+ TimestampNanos: 98765,
+ TreeSize: 16,
+ Revision: 5,
+ RootHash: []byte(dummyHash),
+ })
+ if err != nil {
+ t.Fatalf("SignLogRoot(): %v", err)
+ }
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
+ t.Fatalf("Failed to store signed root: %v", err)
+ }
+ // Shouldn't be able to do it again
+ // if err := tx.StoreSignedLogRoot(ctx, *root); err == nil {
+ // t.Fatal("Allowed duplicate signed root")
+ // }
+ return nil
+ })
+}
+
+func TestLogRootUpdate(t *testing.T) {
+ // Write two roots for a log and make sure the one with the newest timestamp supersedes
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ signer := tcrypto.NewSigner(tree.TreeId, ttestonly.NewSignerWithFixedSig(nil, []byte("notempty")), crypto.SHA256)
+ root, err := signer.SignLogRoot(&types.LogRootV1{
+ TimestampNanos: 98765,
+ TreeSize: 16,
+ Revision: 5,
+ RootHash: []byte(dummyHash),
+ })
+ if err != nil {
+ t.Fatalf("SignLogRoot(): %v", err)
+ }
+ root2, err := signer.SignLogRoot(&types.LogRootV1{
+ TimestampNanos: 98766,
+ TreeSize: 16,
+ Revision: 6,
+ RootHash: []byte(dummyHash),
+ })
+ if err != nil {
+ t.Fatalf("SignLogRoot(): %v", err)
+ }
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
+ t.Fatalf("Failed to store signed root: %v", err)
+ }
+ if err := tx.StoreSignedLogRoot(ctx, *root2); err != nil {
+ t.Fatalf("Failed to store signed root: %v", err)
+ }
+ return nil
+ })
+
+ runLogTX(s, tree, t, func(ctx context.Context, tx2 storage.LogTreeTX) error {
+ root3, err := tx2.LatestSignedLogRoot(ctx)
+ if err != nil {
+ t.Fatalf("Failed to read back new log root: %v", err)
+ }
+ if !proto.Equal(root2, &root3) {
+ t.Fatalf("Root round trip failed: <%v> and: <%v>", root, root2)
+ }
+ return nil
+ })
+}
+
+func TestGetActiveLogIDs(t *testing.T) {
+ ctx := context.Background()
+
+ cleanTestDB(db, t)
+ admin := NewAdminStorage(db)
+
+ // Create a few test trees
+ log1 := proto.Clone(testonly.LogTree).(*trillian.Tree)
+ log2 := proto.Clone(testonly.LogTree).(*trillian.Tree)
+ log3 := proto.Clone(testonly.PreorderedLogTree).(*trillian.Tree)
+ drainingLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
+ frozenLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
+ deletedLog := proto.Clone(testonly.LogTree).(*trillian.Tree)
+ map1 := proto.Clone(testonly.MapTree).(*trillian.Tree)
+ map2 := proto.Clone(testonly.MapTree).(*trillian.Tree)
+ deletedMap := proto.Clone(testonly.MapTree).(*trillian.Tree)
+ for _, tree := range []*trillian.Tree{log1, log2, log3, drainingLog, frozenLog, deletedLog, map1, map2, deletedMap} {
+ newTree, err := storage.CreateTree(ctx, admin, tree)
+ if err != nil {
+ t.Fatalf("CreateTree(%+v) returned err = %v", tree, err)
+ }
+ *tree = *newTree
+ }
+
+ // FROZEN is not a valid initial state, so we have to update it separately.
+ if _, err := storage.UpdateTree(ctx, admin, frozenLog.TreeId, func(t *trillian.Tree) {
+ t.TreeState = trillian.TreeState_FROZEN
+ }); err != nil {
+ t.Fatalf("UpdateTree() returned err = %v", err)
+ }
+ // DRAINING is not a valid initial state, so we have to update it separately.
+ if _, err := storage.UpdateTree(ctx, admin, drainingLog.TreeId, func(t *trillian.Tree) {
+ t.TreeState = trillian.TreeState_DRAINING
+ }); err != nil {
+ t.Fatalf("UpdateTree() returned err = %v", err)
+ }
+
+ // Update deleted trees accordingly
+ updateDeletedStmt, err := db.PrepareContext(ctx, "UPDATE Trees SET Deleted = $1 WHERE Tree_Id = $2")
+ if err != nil {
+ t.Fatalf("PrepareContext() returned err = %v", err)
+ }
+ defer updateDeletedStmt.Close()
+ for _, treeID := range []int64{deletedLog.TreeId, deletedMap.TreeId} {
+ if _, err := updateDeletedStmt.ExecContext(ctx, true, treeID); err != nil {
+ t.Fatalf("ExecContext(%v) returned err = %v", treeID, err)
+ }
+ }
+
+ s := NewLogStorage(db, nil)
+ tx, err := s.Snapshot(ctx)
+ if err != nil {
+ t.Fatalf("Snapshot() returns err = %v", err)
+ }
+ defer tx.Close()
+ got, err := tx.GetActiveLogIDs(ctx)
+ if err != nil {
+ t.Fatalf("GetActiveLogIDs() returns err = %v", err)
+ }
+ if err := tx.Commit(); err != nil {
+ t.Errorf("Commit() returned err = %v", err)
+ }
+
+ want := []int64{log1.TreeId, log2.TreeId, log3.TreeId, drainingLog.TreeId}
+ sort.Slice(got, func(i, j int) bool { return got[i] < got[j] })
+ sort.Slice(want, func(i, j int) bool { return want[i] < want[j] })
+ if diff := pretty.Compare(got, want); diff != "" {
+ t.Errorf("post-GetActiveLogIDs diff (-got +want):\n%v", diff)
+ }
+}
+
+func TestGetActiveLogIDsEmpty(t *testing.T) {
+ ctx := context.Background()
+
+ cleanTestDB(db, t)
+ s := NewLogStorage(db, nil)
+
+ tx, err := s.Snapshot(context.Background())
+ if err != nil {
+ t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err)
+ }
+ defer tx.Close()
+ ids, err := tx.GetActiveLogIDs(ctx)
+ if err != nil {
+ t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err)
+ }
+ if err := tx.Commit(); err != nil {
+ t.Errorf("Commit() = %v, want = nil", err)
+ }
+
+ if got, want := len(ids), 0; got != want {
+ t.Errorf("GetActiveLogIDs(): got %v IDs, want = %v", got, want)
+ }
+}
+
+func TestReadOnlyLogTX_Rollback(t *testing.T) {
+ ctx := context.Background()
+ cleanTestDB(db, t)
+ s := NewLogStorage(db, nil)
+ tx, err := s.Snapshot(ctx)
+ if err != nil {
+ t.Fatalf("Snapshot() = (_, %v), want = (_, nil)", err)
+ }
+ defer tx.Close()
+ if _, err := tx.GetActiveLogIDs(ctx); err != nil {
+ t.Fatalf("GetActiveLogIDs() = (_, %v), want = (_, nil)", err)
+ }
+ // It's a bit hard to have a more meaningful test. This should suffice.
+ if err := tx.Rollback(); err != nil {
+ t.Errorf("Rollback() = (_, %v), want = (_, nil)", err)
+ }
+}
+
+func TestGetSequencedLeafCount(t *testing.T) {
+ ctx := context.Background()
+
+ // We'll create leaves for two different trees
+ cleanTestDB(db, t)
+ log1 := createTreeOrPanic(db, testonly.LogTree)
+ log2 := createTreeOrPanic(db, testonly.LogTree)
+ s := NewLogStorage(db, nil)
+
+ {
+ // Create fake leaf as if it had been sequenced
+ data := []byte("some data")
+ createFakeLeaf(ctx, db, log1.TreeId, dummyHash, dummyRawHash, data, someExtraData, sequenceNumber, t)
+
+ // Create fake leaves for second tree as if they had been sequenced
+ data2 := []byte("some data 2")
+ data3 := []byte("some data 3")
+ createFakeLeaf(ctx, db, log2.TreeId, dummyHash2, dummyRawHash, data2, someExtraData, sequenceNumber, t)
+ createFakeLeaf(ctx, db, log2.TreeId, dummyHash3, dummyRawHash, data3, someExtraData, sequenceNumber+1, t)
+ }
+
+ // Read back the leaf counts from both trees
+ runLogTX(s, log1, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ count1, err := tx.GetSequencedLeafCount(ctx)
+ if err != nil {
+ t.Fatalf("unexpected error getting leaf count: %v", err)
+ }
+ if want, got := int64(1), count1; want != got {
+ t.Fatalf("expected %d sequenced for logId but got %d", want, got)
+ }
+ return nil
+ })
+
+ runLogTX(s, log2, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ count2, err := tx.GetSequencedLeafCount(ctx)
+ if err != nil {
+ t.Fatalf("unexpected error getting leaf count2: %v", err)
+ }
+ if want, got := int64(2), count2; want != got {
+ t.Fatalf("expected %d sequenced for logId2 but got %d", want, got)
+ }
+ return nil
+ })
+}
+
+func TestSortByLeafIdentityHash(t *testing.T) {
+ l := make([]*trillian.LogLeaf, 30)
+ for i := range l {
+ hash := sha256.Sum256([]byte{byte(i)})
+ leaf := trillian.LogLeaf{
+ LeafIdentityHash: hash[:],
+ LeafValue: []byte(fmt.Sprintf("Value %d", i)),
+ ExtraData: []byte(fmt.Sprintf("Extra %d", i)),
+ LeafIndex: int64(i),
+ }
+ l[i] = &leaf
+ }
+ sort.Sort(byLeafIdentityHash(l))
+ for i := range l {
+ if i == 0 {
+ continue
+ }
+ if bytes.Compare(l[i-1].LeafIdentityHash, l[i].LeafIdentityHash) != -1 {
+ t.Errorf("sorted leaves not in order, [%d] = %x, [%d] = %x", i-1, l[i-1].LeafIdentityHash, i, l[i].LeafIdentityHash)
+ }
+ }
+
+}
+
+func ensureAllLeavesDistinct(leaves []*trillian.LogLeaf, t *testing.T) {
+ t.Helper()
+ // All the leaf value hashes should be distinct because the leaves were created with distinct
+ // leaf data. If only we had maps with slices as keys or sets or pretty much any kind of usable
+ // data structures we could do this properly.
+ for i := range leaves {
+ for j := range leaves {
+ if i != j && bytes.Equal(leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash) {
+ t.Fatalf("Unexpectedly got a duplicate leaf hash: %v %v",
+ leaves[i].LeafIdentityHash, leaves[j].LeafIdentityHash)
+ }
+ }
+ }
+}
+
+func ensureLeavesHaveQueueTimestamp(t *testing.T, leaves []*trillian.LogLeaf, want time.Time) {
+ t.Helper()
+ for _, leaf := range leaves {
+ gotQTimestamp, err := ptypes.Timestamp(leaf.QueueTimestamp)
+ if err != nil {
+ t.Fatalf("Got invalid queue timestamp: %v", err)
+ }
+ if got, want := gotQTimestamp.UnixNano(), want.UnixNano(); got != want {
+ t.Errorf("Got leaf with QueueTimestampNanos = %v, want %v: %v", got, want, leaf)
+ }
+ }
+}
+
+// Creates some test leaves with predictable data
+func createTestLeaves(n, startSeq int64) []*trillian.LogLeaf {
+ var leaves []*trillian.LogLeaf
+ for l := int64(0); l < n; l++ {
+ lv := fmt.Sprintf("Leaf %d", l+startSeq)
+ h := sha256.New()
+ h.Write([]byte(lv))
+ leafHash := h.Sum(nil)
+ leaf := &trillian.LogLeaf{
+ LeafIdentityHash: leafHash,
+ MerkleLeafHash: leafHash,
+ LeafValue: []byte(lv),
+ ExtraData: []byte(fmt.Sprintf("Extra %d", l)),
+ LeafIndex: int64(startSeq + l),
+ }
+ leaves = append(leaves, leaf)
+ }
+
+ return leaves
+}
+
+// Convenience methods to avoid copying out "if err != nil { blah }" all over the place
+func runLogTX(s storage.LogStorage, tree *trillian.Tree, t *testing.T, f storage.LogTXFunc) {
+ t.Helper()
+ if err := s.ReadWriteTransaction(context.Background(), tree, f); err != nil {
+ t.Fatalf("Failed to run log tx: %v", err)
+ }
+}
+
+type committableTX interface {
+ Commit() error
+}
+
+func commit(tx committableTX, t *testing.T) {
+ t.Helper()
+ if err := tx.Commit(); err != nil {
+ t.Errorf("Failed to commit tx: %v", err)
+ }
+}
+
+func leafInBatch(leaf *trillian.LogLeaf, batch []*trillian.LogLeaf) bool {
+ for _, bl := range batch {
+ if bytes.Equal(bl.LeafIdentityHash, leaf.LeafIdentityHash) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// byLeafIdentityHash allows sorting of leaves by their identity hash, so DB
+// operations always happen in a consistent order.
+type byLeafIdentityHash []*trillian.LogLeaf
+
+func (l byLeafIdentityHash) Len() int { return len(l) }
+func (l byLeafIdentityHash) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
+func (l byLeafIdentityHash) Less(i, j int) bool {
+ return bytes.Compare(l[i].LeafIdentityHash, l[j].LeafIdentityHash) == -1
+}
+
+func logTree(logID int64) *trillian.Tree {
+ return &trillian.Tree{
+ TreeId: logID,
+ TreeType: trillian.TreeType_LOG,
+ HashStrategy: trillian.HashStrategy_RFC6962_SHA256,
+ }
+}
diff --git a/storage/postgres/queue.go b/storage/postgres/queue.go
new file mode 100644
index 0000000000..d6b8356967
--- /dev/null
+++ b/storage/postgres/queue.go
@@ -0,0 +1,130 @@
+// +build !batched_queue
+
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/golang/protobuf/ptypes"
+ "github.com/google/trillian"
+)
+
+const (
+ // If this statement ORDER BY clause is changed refer to the comment in removeSequencedLeaves
+ selectQueuedLeavesSQL = `SELECT leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos
+ FROM unsequenced
+ WHERE tree_id=$1
+ AND bucket=0
+ AND queue_timestamp_nanos<=$2
+ ORDER BY queue_timestamp_nanos,leaf_identity_hash ASC LIMIT $3`
+ insertUnsequencedEntrySQL = "select insert_leaf_data_ignore_duplicates($1,$2,$3,$4)"
+ deleteUnsequencedSQL = "DELETE FROM unsequenced WHERE tree_id = $1 and bucket=0 and queue_timestamp_nanos = $2 and leaf_identity_hash=$3"
+)
+
+type dequeuedLeaf struct {
+ queueTimestampNanos int64
+ leafIdentityHash []byte
+}
+
+func dequeueInfo(leafIDHash []byte, queueTimestamp int64) dequeuedLeaf {
+ return dequeuedLeaf{queueTimestampNanos: queueTimestamp, leafIdentityHash: leafIDHash}
+}
+
+func (t *logTreeTX) dequeueLeaf(rows *sql.Rows) (*trillian.LogLeaf, dequeuedLeaf, error) {
+ var leafIDHash []byte
+ var merkleHash []byte
+ var queueTimestamp int64
+
+ err := rows.Scan(&leafIDHash, &merkleHash, &queueTimestamp)
+ if err != nil {
+ glog.Warningf("Error scanning work rows: %s", err)
+ return nil, dequeuedLeaf{}, err
+ }
+
+ // Note: the LeafData and ExtraData being nil here is OK as this is only used by the
+ // sequencer. The sequencer only writes to the SequencedLeafData table and the client
+ // supplied data was already written to LeafData as part of queueing the leaf.
+ queueTimestampProto, err := ptypes.TimestampProto(time.Unix(0, queueTimestamp))
+ if err != nil {
+ return nil, dequeuedLeaf{}, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ leaf := &trillian.LogLeaf{
+ LeafIdentityHash: leafIDHash,
+ MerkleLeafHash: merkleHash,
+ QueueTimestamp: queueTimestampProto,
+ }
+ return leaf, dequeueInfo(leafIDHash, queueTimestamp), nil
+}
+
+func queueArgs(treeID int64, identityHash []byte, queueTimestamp time.Time) []interface{} {
+ return []interface{}{queueTimestamp.UnixNano()}
+}
+
+func (t *logTreeTX) UpdateSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf) error {
+ for _, leaf := range leaves {
+ // This should fail on insert but catch it early
+ if len(leaf.LeafIdentityHash) != t.hashSizeBytes {
+ return errors.New("sequenced leaf has incorrect hash size")
+ }
+
+ iTimestamp, err := ptypes.Timestamp(leaf.IntegrateTimestamp)
+ if err != nil {
+ return fmt.Errorf("got invalid integrate timestamp: %v", err)
+ }
+ _, err = t.tx.ExecContext(
+ ctx,
+ insertSequencedLeafSQL+valuesPlaceholder5,
+ t.treeID,
+ leaf.LeafIdentityHash,
+ leaf.MerkleLeafHash,
+ leaf.LeafIndex,
+ iTimestamp.UnixNano())
+ if err != nil {
+ glog.Warningf("Failed to update sequenced leaves: %s", err)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// removeSequencedLeaves removes the passed in leaves slice (which may be
+// modified as part of the operation).
+func (t *logTreeTX) removeSequencedLeaves(ctx context.Context, leaves []dequeuedLeaf) error {
+ // Don't need to re-sort because the query ordered by leaf hash. If that changes because
+ // the query is expensive then the sort will need to be done here. See comment in
+ // QueueLeaves.
+ stx, err := t.tx.PrepareContext(ctx, deleteUnsequencedSQL)
+ if err != nil {
+ glog.Warningf("Failed to prep delete statement for sequenced work: %v", err)
+ return err
+ }
+ for _, dql := range leaves {
+ result, err := stx.ExecContext(ctx, t.treeID, dql.queueTimestampNanos, dql.leafIdentityHash)
+ err = checkResultOkAndRowCountIs(result, err, int64(1))
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
diff --git a/storage/postgres/queue_batching.go b/storage/postgres/queue_batching.go
new file mode 100644
index 0000000000..e319a9c701
--- /dev/null
+++ b/storage/postgres/queue_batching.go
@@ -0,0 +1,147 @@
+// +build batched_queue
+
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "crypto/sha256"
+ "database/sql"
+ "encoding/binary"
+ "fmt"
+ "strings"
+ "time"
+
+ "github.com/golang/glog"
+ "github.com/golang/protobuf/ptypes"
+ "github.com/google/trillian"
+)
+
+const (
+ // If this statement ORDER BY clause is changed refer to the comment in removeSequencedLeaves
+ selectQueuedLeavesSQL = `SELECT leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos,queue_id
+ FROM unsequenced
+ WHERE tree_id=$1
+ AND Bucket=0
+ AND queue_timestamp_nanos<=$2
+ ORDER BY queue_timestamp_nanos,leaf_identity_hash ASC LIMIT $3`
+ insertUnsequencedEntrySQL = `INSERT INTO unsequenced(tree_id,Bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos,queue_id) VALUES($1,0,$2,$3,$4,$5)`
+ deleteUnsequencedSQL = "DELETE FROM unsequenced WHERE queue_id IN ()"
+)
+
+type dequeuedLeaf []byte
+
+func dequeueInfo(_ []byte, queueID []byte) dequeuedLeaf {
+ return dequeuedLeaf(queueID)
+}
+
+func (t *logTreeTX) dequeueLeaf(rows *sql.Rows) (*trillian.LogLeaf, dequeuedLeaf, error) {
+ var leafIDHash []byte
+ var merkleHash []byte
+ var queueTimestamp int64
+ var queueID []byte
+
+ err := rows.Scan(&leafIDHash, &merkleHash, &queueTimestamp, &queueID)
+ if err != nil {
+ glog.Warningf("Error scanning work rows: %s", err)
+ return nil, nil, err
+ }
+
+ queueTimestampProto, err := ptypes.TimestampProto(time.Unix(0, queueTimestamp))
+ if err != nil {
+ return nil, dequeuedLeaf{}, fmt.Errorf("got invalid queue timestamp: %v", err)
+ }
+ // Note: the LeafData and ExtraData being nil here is OK as this is only used by the
+ // sequencer. The sequencer only writes to the SequencedLeafData table and the client
+ // supplied data was already written to LeafData as part of queueing the leaf.
+ leaf := &trillian.LogLeaf{
+ LeafIdentityHash: leafIDHash,
+ MerkleLeafHash: merkleHash,
+ QueueTimestamp: queueTimestampProto,
+ }
+ return leaf, dequeueInfo(leafIDHash, queueID), nil
+}
+
+func generateQueueID(treeID int64, leafIdentityHash []byte, timestamp int64) []byte {
+ h := sha256.New()
+ b := make([]byte, 10)
+ binary.PutVarint(b, treeID)
+ h.Write(b)
+ b = make([]byte, 10)
+ binary.PutVarint(b, timestamp)
+ h.Write(b)
+ h.Write(leafIdentityHash)
+ return h.Sum(nil)
+}
+
+func queueArgs(treeID int64, identityHash []byte, queueTimestamp time.Time) []interface{} {
+ timestamp := queueTimestamp.UnixNano()
+ return []interface{}{timestamp, generateQueueID(treeID, identityHash, timestamp)}
+}
+
+func (t *logTreeTX) UpdateSequencedLeaves(ctx context.Context, leaves []*trillian.LogLeaf) error {
+ querySuffix := []string{}
+ args := []interface{}{}
+ for _, leaf := range leaves {
+ iTimestamp, err := ptypes.Timestamp(leaf.IntegrateTimestamp)
+ if err != nil {
+ return fmt.Errorf("got invalid integrate timestamp: %v", err)
+ }
+ querySuffix = append(querySuffix, valuesPlaceholder5)
+ args = append(args, t.treeID, leaf.LeafIdentityHash, leaf.MerkleLeafHash, leaf.LeafIndex, iTimestamp.UnixNano())
+ }
+ result, err := t.tx.ExecContext(ctx, insertSequencedLeafSQL+strings.Join(querySuffix, ","), args...)
+ if err != nil {
+ glog.Warningf("Failed to update sequenced leaves: %s", err)
+ }
+ return checkResultOkAndRowCountIs(result, err, int64(len(leaves)))
+}
+
+func (m *postgresLogStorage) getDeleteUnsequencedStmt(ctx context.Context, num int) (*sql.Stmt, error) {
+ stmt := &statementSkeleton{
+ sql: deleteUnsequencedSQL,
+ firstInsertion: "%s",
+ firstPlaceholders: 1,
+ restInsertion: "%s",
+ restPlaceholders: 1,
+ num: num,
+ }
+ return m.getStmt(ctx, stmt)
+}
+
+// removeSequencedLeaves removes the passed in leaves slice (which may be
+// modified as part of the operation).
+func (t *logTreeTX) removeSequencedLeaves(ctx context.Context, queueIDs []dequeuedLeaf) error {
+ // Don't need to re-sort because the query ordered by leaf hash. If that changes because
+ // the query is expensive then the sort will need to be done here. See comment in
+ // QueueLeaves.
+ tmpl, err := t.ls.getDeleteUnsequencedStmt(ctx, len(queueIDs))
+ if err != nil {
+ glog.Warningf("Failed to get delete statement for sequenced work: %s", err)
+ return err
+ }
+ stx := t.tx.StmtContext(ctx, tmpl)
+ args := make([]interface{}, len(queueIDs))
+ for i, q := range queueIDs {
+ args[i] = []byte(q)
+ }
+ result, err := stx.ExecContext(ctx, args...)
+ if err != nil {
+ // Error is handled by checkResultOkAndRowCountIs() below
+ glog.Warningf("Failed to delete sequenced work: %s", err)
+ }
+ return checkResultOkAndRowCountIs(result, err, int64(len(queueIDs)))
+}
diff --git a/storage/postgres/schema/storage.sql b/storage/postgres/schema/storage.sql
index ea42d301aa..0f33976d88 100644
--- a/storage/postgres/schema/storage.sql
+++ b/storage/postgres/schema/storage.sql
@@ -4,11 +4,11 @@
-- ---------------------------------------------
-- Tree Enums
-CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING');
-CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG');
-CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256');
-CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256');
-CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA');
+CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING');--end
+CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG');--end
+CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256');--end
+CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256');--end
+CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA');--end
-- Tree parameters should not be changed after creation. Doing so can
-- render the data in the tree unusable or inconsistent.
@@ -28,8 +28,10 @@ CREATE TABLE IF NOT EXISTS trees (
public_key BYTEA NOT NULL,
deleted BOOLEAN NOT NULL DEFAULT FALSE,
delete_time_millis BIGINT,
+ current_tree_data json,
+ root_signature BYTEA,
PRIMARY KEY(tree_id)
-);
+);--end
-- This table contains tree parameters that can be changed at runtime such as for
-- administrative purposes.
@@ -40,7 +42,7 @@ CREATE TABLE IF NOT EXISTS tree_control(
sequence_interval_seconds INTEGER NOT NULL,
PRIMARY KEY(tree_id),
FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE
-);
+);--end
CREATE TABLE IF NOT EXISTS subtree(
tree_id BIGINT NOT NULL,
@@ -49,7 +51,7 @@ CREATE TABLE IF NOT EXISTS subtree(
subtree_revision INTEGER NOT NULL,
PRIMARY KEY(tree_id, subtree_id, subtree_revision),
FOREIGN KEY(tree_id) REFERENCES Trees(tree_id) ON DELETE CASCADE
-);
+);--end
-- The TreeRevisionIdx is used to enforce that there is only one STH at any
-- tree revision
@@ -62,11 +64,11 @@ CREATE TABLE IF NOT EXISTS tree_head(
tree_revision BIGINT,
PRIMARY KEY(tree_id, tree_revision),
FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE
-);
+);--end
-- TODO(vishal) benchmark this to see if it's a suitable replacement for not
-- having a DESC scan on the primary key
-CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC);
+CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC);--end
-- ---------------------------------------------
-- Log specific stuff here
@@ -94,7 +96,7 @@ CREATE TABLE IF NOT EXISTS leaf_data(
queue_timestamp_nanos BIGINT NOT NULL,
PRIMARY KEY(tree_id, leaf_identity_hash),
FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE
-);
+);--end
-- When a leaf is sequenced a row is added to this table. If logs allow duplicates then
-- multiple rows will exist with different sequence numbers. The signed timestamp
@@ -116,9 +118,9 @@ CREATE TABLE IF NOT EXISTS sequenced_leaf_data(
PRIMARY KEY(tree_id, sequence_number),
FOREIGN KEY(tree_id) REFERENCES trees(tree_id) ON DELETE CASCADE,
FOREIGN KEY(tree_id, leaf_identity_hash) REFERENCES leaf_data(tree_id, leaf_identity_hash) ON DELETE CASCADE
-);
+);--end
-CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash);
+CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash);--end
CREATE TABLE IF NOT EXISTS unsequenced(
tree_id BIGINT NOT NULL,
@@ -138,4 +140,51 @@ CREATE TABLE IF NOT EXISTS unsequenced(
-- built with the batched_queue tag.
queue_id BYTEA DEFAULT NULL UNIQUE,
PRIMARY KEY (tree_id, bucket, queue_timestamp_nanos, leaf_identity_hash)
-);
+);--end
+
+CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, leaf_value bytea, extra_data bytea, queue_timestamp_nanos bigint)
+ RETURNS boolean
+ LANGUAGE plpgsql
+AS $function$
+ begin
+ INSERT INTO leaf_data(tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos) VALUES (tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos);
+ return true;
+ exception
+ when unique_violation then
+ return false;
+ when others then
+ raise notice '% %', SQLERRM, SQLSTATE;
+ end;
+$function$;--end
+
+CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, queue_timestamp_nanos bigint)
+ RETURNS boolean
+ LANGUAGE plpgsql
+AS $function$
+ begin
+ INSERT INTO unsequenced(tree_id,bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos) VALUES(tree_id,0,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos);
+ return true;
+ exception
+ when unique_violation then
+ return false;
+ when others then
+ raise notice '% %', SQLERRM, SQLSTATE;
+ end;
+$function$;--end
+
+
+
+CREATE OR REPLACE FUNCTION public.insert_sequenced_leaf_data_ignore_duplicates(tree_id bigint, sequence_number bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, integrate_timestamp_nanos bigint)
+ RETURNS boolean
+ LANGUAGE plpgsql
+AS $function$
+ begin
+ INSERT INTO sequenced_leaf_data(tree_id, sequence_number, leaf_identity_hash, merkle_leaf_hash, integrate_timestamp_nanos) VALUES(tree_id, sequence_number, leaf_identity_hash, merkle_leaf_hash, integrate_timestamp_nanos);
+ return true;
+ exception
+ when unique_violation then
+ return false;
+ when others then
+ raise notice '% %', SQLERRM, SQLSTATE;
+ end;
+$function$;--end
diff --git a/storage/postgres/storage_test.go b/storage/postgres/storage_test.go
index 036a71f749..57cae6916f 100644
--- a/storage/postgres/storage_test.go
+++ b/storage/postgres/storage_test.go
@@ -1,4 +1,4 @@
-// Copyright 2018 Google Inc. All Rights Reserved.
+// Copyright 2019 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -11,22 +11,157 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
+
package postgres
import (
+ "bytes"
"context"
+ "crypto"
+ "crypto/sha256"
"database/sql"
"flag"
+ "fmt"
"os"
"testing"
- "time"
"github.com/golang/glog"
+ "github.com/google/trillian"
+ tcrypto "github.com/google/trillian/crypto"
+ "github.com/google/trillian/storage"
"github.com/google/trillian/storage/postgres/testdb"
+ storageto "github.com/google/trillian/storage/testonly"
+ "github.com/google/trillian/testonly"
+ "github.com/google/trillian/types"
)
-// db is shared throughout all postgres tests
-var db *sql.DB
+func TestNodeRoundTrip(t *testing.T) {
+ cleanTestDB(db, t)
+ tree := createTreeOrPanic(db, storageto.LogTree)
+ s := NewLogStorage(db, nil)
+
+ const writeRevision = int64(100)
+ nodesToStore := createSomeNodes()
+ nodeIDsToRead := make([]storage.NodeID, len(nodesToStore))
+ for i := range nodesToStore {
+ nodeIDsToRead[i] = nodesToStore[i].NodeID
+ }
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ forceWriteRevision(writeRevision, tx)
+
+ // Need to read nodes before attempting to write
+ if _, err := tx.GetMerkleNodes(ctx, 99, nodeIDsToRead); err != nil {
+ t.Fatalf("Failed to read nodes: %s", err)
+ }
+ if err := tx.SetMerkleNodes(ctx, nodesToStore); err != nil {
+ t.Fatalf("Failed to store nodes: %s", err)
+ }
+ return nil
+ })
+ }
+
+ {
+ runLogTX(s, tree, t, func(ctx context.Context, tx storage.LogTreeTX) error {
+ readNodes, err := tx.GetMerkleNodes(ctx, 100, nodeIDsToRead)
+ if err != nil {
+ t.Fatalf("Failed to retrieve nodes: %s", err)
+ }
+ if err := nodesAreEqual(readNodes, nodesToStore); err != nil {
+ t.Fatalf("Read back different nodes from the ones stored: %s", err)
+ }
+ return nil
+ })
+ }
+}
+func forceWriteRevision(rev int64, tx storage.TreeTX) {
+ mtx, ok := tx.(*logTreeTX)
+ if !ok {
+ panic(fmt.Sprintf("tx is %T, want *logTreeTX", tx))
+ }
+ mtx.treeTX.writeRevision = rev
+}
+
+func createSomeNodes() []storage.Node {
+ r := make([]storage.Node, 4)
+ for i := range r {
+ r[i].NodeID = storage.NewNodeIDWithPrefix(uint64(i), 8, 8, 8)
+ h := sha256.Sum256([]byte{byte(i)})
+ r[i].Hash = h[:]
+ glog.Infof("Node to store: %v", r[i].NodeID)
+ }
+ return r
+}
+
+func nodesAreEqual(lhs []storage.Node, rhs []storage.Node) error {
+ if ls, rs := len(lhs), len(rhs); ls != rs {
+ return fmt.Errorf("different number of nodes, %d vs %d", ls, rs)
+ }
+ for i := range lhs {
+ if l, r := lhs[i].NodeID.String(), rhs[i].NodeID.String(); l != r {
+ return fmt.Errorf("NodeIDs are not the same,\nlhs = %v,\nrhs = %v", l, r)
+ }
+ if l, r := lhs[i].Hash, rhs[i].Hash; !bytes.Equal(l, r) {
+ return fmt.Errorf("hashes are not the same for %s,\nlhs = %v,\nrhs = %v", lhs[i].NodeID.CoordString(), l, r)
+ }
+ }
+ return nil
+}
+
+func openTestDBOrDie() *sql.DB {
+ db, err := testdb.NewTrillianDB(context.TODO())
+ if err != nil {
+ panic(err)
+ }
+ return db
+}
+
+func createFakeSignedLogRoot(db *sql.DB, tree *trillian.Tree, treeSize uint64) {
+ signer := tcrypto.NewSigner(0, testonly.NewSignerWithFixedSig(nil, []byte("notnil")), crypto.SHA256)
+
+ ctx := context.Background()
+ l := NewLogStorage(db, nil)
+ err := l.ReadWriteTransaction(ctx, tree, func(ctx context.Context, tx storage.LogTreeTX) error {
+ root, err := signer.SignLogRoot(&types.LogRootV1{TreeSize: treeSize, RootHash: []byte{0}})
+ if err != nil {
+ return fmt.Errorf("error creating new SignedLogRoot: %v", err)
+ }
+ if err := tx.StoreSignedLogRoot(ctx, *root); err != nil {
+ return fmt.Errorf("error storing new SignedLogRoot: %v", err)
+ }
+ return nil
+ })
+ if err != nil {
+ panic(fmt.Sprintf("ReadWriteTransaction() = %v", err))
+ }
+}
+
+// createTree creates the specified tree using AdminStorage.
+func createTree(db *sql.DB, tree *trillian.Tree) (*trillian.Tree, error) {
+ ctx := context.Background()
+ s := NewAdminStorage(db)
+ tree, err := storage.CreateTree(ctx, s, tree)
+ if err != nil {
+ return nil, err
+ }
+ return tree, nil
+}
+
+func createTreeOrPanic(db *sql.DB, create *trillian.Tree) *trillian.Tree {
+ tree, err := createTree(db, create)
+ if err != nil {
+ panic(fmt.Sprintf("Error creating tree: %v", err))
+ }
+ return tree
+}
+
+// updateTree updates the specified tree using AdminStorage.
+func updateTree(db *sql.DB, treeID int64, updateFn func(*trillian.Tree)) (*trillian.Tree, error) {
+ ctx := context.Background()
+ s := NewAdminStorage(db)
+ return storage.UpdateTree(ctx, s, treeID, updateFn)
+}
func TestMain(m *testing.M) {
flag.Parse()
@@ -34,9 +169,7 @@ func TestMain(m *testing.M) {
glog.Errorf("PG not available, skipping all PG storage tests")
return
}
- ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*30))
- defer cancel()
- db = testdb.NewTrillianDBOrDie(ctx)
+ db = openTestDBOrDie()
defer db.Close()
os.Exit(m.Run())
}
diff --git a/storage/postgres/storage_unsafe.sql b/storage/postgres/storage_unsafe.sql
new file mode 100644
index 0000000000..5f6008d2f7
--- /dev/null
+++ b/storage/postgres/storage_unsafe.sql
@@ -0,0 +1,167 @@
+-- Postgres impl of storage
+-- ---------------------------------------------
+-- Tree stuff here
+-- ---------------------------------------------
+
+-- Tree Enums
+CREATE TYPE E_TREE_STATE AS ENUM('ACTIVE', 'FROZEN', 'DRAINING');
+CREATE TYPE E_TREE_TYPE AS ENUM('LOG', 'MAP', 'PREORDERED_LOG');
+CREATE TYPE E_HASH_STRATEGY AS ENUM('RFC6962_SHA256', 'TEST_MAP_HASHER', 'OBJECT_RFC6962_SHA256', 'CONIKS_SHA512_256', 'CONIKS_SHA256');
+CREATE TYPE E_HASH_ALGORITHM AS ENUM('SHA256');
+CREATE TYPE E_SIGNATURE_ALGORITHM AS ENUM('ECDSA', 'RSA');
+
+-- Tree parameters should not be changed after creation. Doing so can
+-- render the data in the tree unusable or inconsistent.
+CREATE TABLE IF NOT EXISTS trees (
+ tree_id BIGINT NOT NULL,
+ tree_state E_TREE_STATE NOT NULL,
+ tree_type E_TREE_TYPE NOT NULL,
+ hash_strategy E_HASH_STRATEGY NOT NULL,
+ hash_algorithm E_HASH_ALGORITHM NOT NULL,
+ signature_algorithm E_SIGNATURE_ALGORITHM NOT NULL,
+ display_name VARCHAR(20),
+ description VARCHAR(200),
+ create_time_millis BIGINT NOT NULL,
+ update_time_millis BIGINT NOT NULL,
+ max_root_duration_millis BIGINT NOT NULL,
+ private_key BYTEA NOT NULL,
+ public_key BYTEA NOT NULL,
+ deleted BOOLEAN NOT NULL DEFAULT FALSE,
+ delete_time_millis BIGINT,
+ current_tree_data json,
+ root_signature BYTEA,
+ PRIMARY KEY(tree_id)
+);
+
+-- This table contains tree parameters that can be changed at runtime such as for
+-- administrative purposes.
+CREATE TABLE IF NOT EXISTS tree_control(
+ tree_id BIGINT NOT NULL,
+ signing_enabled BOOLEAN NOT NULL,
+ sequencing_enabled BOOLEAN NOT NULL,
+ sequence_interval_seconds INTEGER NOT NULL,
+ PRIMARY KEY(tree_id)
+);
+
+CREATE TABLE IF NOT EXISTS subtree(
+ tree_id BIGINT NOT NULL,
+ subtree_id BYTEA NOT NULL,
+ nodes BYTEA NOT NULL,
+ subtree_revision INTEGER NOT NULL,
+ PRIMARY KEY(subtree_id, subtree_revision)
+);
+
+-- The TreeRevisionIdx is used to enforce that there is only one STH at any
+-- tree revision
+CREATE TABLE IF NOT EXISTS tree_head(
+ tree_id BIGINT NOT NULL,
+ tree_head_timestamp BIGINT,
+ tree_size BIGINT,
+ root_hash BYTEA NOT NULL,
+ root_signature BYTEA NOT NULL,
+ tree_revision BIGINT,
+ PRIMARY KEY(tree_id, tree_revision)
+);
+
+-- TODO(vishal) benchmark this to see if it's a suitable replacement for not
+-- having a DESC scan on the primary key
+CREATE UNIQUE INDEX TreeHeadRevisionIdx ON tree_head(tree_id, tree_revision DESC);
+
+-- ---------------------------------------------
+-- Log specific stuff here
+-- ---------------------------------------------
+
+-- Creating index at same time as table allows some storage engines to better
+-- optimize physical storage layout. Most engines allow multiple nulls in a
+-- unique index but some may not.
+
+-- A leaf that has not been sequenced has a row in this table. If duplicate leaves
+-- are allowed they will all reference this row.
+CREATE TABLE IF NOT EXISTS leaf_data(
+ tree_id BIGINT NOT NULL,
+ -- This is a personality specific hash of some subset of the leaf data.
+ -- It's only purpose is to allow Trillian to identify duplicate entries in
+ -- the context of the personality.
+ leaf_identity_hash BYTEA NOT NULL,
+ -- This is the data stored in the leaf for example in CT it contains a DER encoded
+ -- X.509 certificate but is application dependent
+ leaf_value BYTEA NOT NULL,
+ -- This is extra data that the application can associate with the leaf should it wish to.
+ -- This data is not included in signing and hashing.
+ extra_data BYTEA,
+ -- The timestamp from when this leaf data was first queued for inclusion.
+ queue_timestamp_nanos BIGINT NOT NULL,
+ PRIMARY KEY(leaf_identity_hash)
+);
+
+-- When a leaf is sequenced a row is added to this table. If logs allow duplicates then
+-- multiple rows will exist with different sequence numbers. The signed timestamp
+-- will be communicated via the unsequenced table as this might need to be unique, depending
+-- on the log parameters and we can't insert into this table until we have the sequence number
+-- which is not available at the time we queue the entry. We need both hashes because the
+-- LeafData table is keyed by the raw data hash.
+CREATE TABLE IF NOT EXISTS sequenced_leaf_data(
+ tree_id BIGINT NOT NULL,
+ sequence_number BIGINT NOT NULL,
+ -- This is a personality specific has of some subset of the leaf data.
+ -- It's only purpose is to allow Trillian to identify duplicate entries in
+ -- the context of the personality.
+ leaf_identity_hash BYTEA NOT NULL,
+ -- This is a MerkleLeafHash as defined by the treehasher that the log uses. For example for
+ -- CT this hash will include the leaf prefix byte as well as the leaf data.
+ merkle_leaf_hash BYTEA NOT NULL,
+ integrate_timestamp_nanos BIGINT NOT NULL,
+ PRIMARY KEY(sequence_number)
+);
+
+CREATE INDEX SequencedLeafMerkleIdx ON sequenced_leaf_data(tree_id, merkle_leaf_hash);
+
+CREATE TABLE IF NOT EXISTS unsequenced(
+ tree_id BIGINT NOT NULL,
+ -- The bucket field is to allow the use of time based ring bucketed schemes if desired. If
+ -- unused this should be set to zero for all entries.
+ bucket INTEGER NOT NULL,
+ -- This is a personality specific hash of some subset of the leaf data.
+ -- It's only purpose is to allow Trillian to identify duplicate entries in
+ -- the context of the personality.
+ leaf_identity_hash BYTEA NOT NULL,
+ -- This is a MerkleLeafHash as defined by the treehasher that the log uses. For example for
+ -- CT this hash will include the leaf prefix byte as well as the leaf data.
+ merkle_leaf_hash BYTEA NOT NULL,
+ queue_timestamp_nanos BIGINT NOT NULL,
+ -- This is a SHA256 hash of the TreeID, LeafIdentityHash and QueueTimestampNanos. It is used
+ -- for batched deletes from the table when trillian_log_server and trillian_log_signer are
+ -- built with the batched_queue tag.
+ queue_id BYTEA DEFAULT NULL UNIQUE,
+ PRIMARY KEY (queue_timestamp_nanos, leaf_identity_hash)
+);
+
+CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, leaf_value bytea, extra_data bytea, queue_timestamp_nanos bigint)
+ RETURNS boolean
+ LANGUAGE plpgsql
+AS $function$
+ begin
+ INSERT INTO leaf_data(tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos) VALUES (tree_id,leaf_identity_hash,leaf_value,extra_data,queue_timestamp_nanos);
+ return true;
+ exception
+ when unique_violation then
+ return false;
+ when others then
+ raise notice '% %', SQLERRM, SQLSTATE;
+ end;
+$function$;
+
+CREATE OR REPLACE FUNCTION public.insert_leaf_data_ignore_duplicates(tree_id bigint, leaf_identity_hash bytea, merkle_leaf_hash bytea, queue_timestamp_nanos bigint)
+ RETURNS boolean
+ LANGUAGE plpgsql
+AS $function$
+ begin
+ INSERT INTO unsequenced(tree_id,bucket,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos) VALUES(tree_id,0,leaf_identity_hash,merkle_leaf_hash,queue_timestamp_nanos);
+ return true;
+ exception
+ when unique_violation then
+ return false;
+ when others then
+ raise notice '% %', SQLERRM, SQLSTATE;
+ end;
+$function$;
diff --git a/storage/postgres/testdb/testdb.go b/storage/postgres/testdb/testdb.go
index 7ed9a81d7c..31c43f5fc8 100644
--- a/storage/postgres/testdb/testdb.go
+++ b/storage/postgres/testdb/testdb.go
@@ -27,8 +27,6 @@ import (
"time"
"github.com/google/trillian/testonly"
-
- _ "github.com/lib/pq" // pg driver
)
var (
@@ -52,6 +50,28 @@ func PGAvailable() bool {
return true
}
+//This just executes a simple query in the configured database. Only used as a placeholder
+// for testing queries and how go returns results
+func TestSQL(ctx context.Context) string {
+ db, err := sql.Open("postgres", getConnStr(*dbName))
+ if err != nil {
+ fmt.Println("Error: ", err)
+ return "error"
+ }
+ defer db.Close()
+ result, err := db.QueryContext(ctx, "select 1=1")
+ if err != nil {
+ fmt.Println("Error: ", err)
+ return "error"
+ }
+ var resultData bool
+ result.Scan(&resultData)
+ if resultData {
+ fmt.Println("Result: ", resultData)
+ }
+ return "done"
+}
+
// newEmptyDB creates a new, empty database.
func newEmptyDB(ctx context.Context) (*sql.DB, error) {
db, err := sql.Open("postgres", getConnStr(*dbName))
@@ -61,12 +81,10 @@ func newEmptyDB(ctx context.Context) (*sql.DB, error) {
// Create a randomly-named database and then connect using the new name.
name := fmt.Sprintf("trl_%v", time.Now().UnixNano())
-
stmt := fmt.Sprintf("CREATE DATABASE %v", name)
if _, err := db.ExecContext(ctx, stmt); err != nil {
return nil, fmt.Errorf("error running statement %q: %v", stmt, err)
}
-
db.Close()
db, err = sql.Open("postgres", getConnStr(name))
if err != nil {
@@ -90,7 +108,7 @@ func NewTrillianDB(ctx context.Context) (*sql.DB, error) {
return nil, err
}
- for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";") {
+ for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";--end") {
stmt = strings.TrimSpace(stmt)
if stmt == "" {
continue
diff --git a/storage/postgres/tree_storage.go b/storage/postgres/tree_storage.go
index 60aba30aa7..fdf3b79a72 100644
--- a/storage/postgres/tree_storage.go
+++ b/storage/postgres/tree_storage.go
@@ -20,6 +20,7 @@ import (
"encoding/base64"
"fmt"
"runtime/debug"
+ "strconv"
"strings"
"sync"
@@ -43,13 +44,15 @@ const (
SELECT n.subtree_id, max(n.subtree_revision) AS max_revision
FROM subtree n
WHERE n.subtree_id IN (` + placeholderSQL + `) AND
- n.tree_id = ? AND n.subtree_revision <= ?
+ n.tree_id = AND n.subtree_revision <=
GROUP BY n.subtree_id
) AS x
INNER JOIN subtree
ON subtree.subtree_id = x.subtree_id
AND subtree.subtree_revision = x.max_revision
- AND subtree.tree_id = ?`
+ AND subtree.tree_id = `
+ insertTreeHeadSQL = `INSERT INTO tree_head(tree_id,tree_head_timestamp,tree_size,root_hash,tree_revision,root_signature)
+ VALUES($1,$2,$3,$4,$5,$6)`
)
// pgTreeStorage contains the pgLogStorage implementation.
@@ -141,6 +144,13 @@ func (p *pgTreeStorage) getStmt(ctx context.Context, skeleton *statementSkeleton
}
statement, err := expandPlaceholderSQL(skeleton)
+
+ counter := skeleton.restPlaceholders*skeleton.num + 1
+ for strings.Contains(statement, "") {
+ statement = strings.Replace(statement, "", "$"+strconv.Itoa(counter), 1)
+ counter++
+ }
+
if err != nil {
glog.Warningf("Failed to expand placeholder sql: %v", skeleton)
return nil, err
@@ -254,10 +264,9 @@ func (t *treeTX) getSubtrees(ctx context.Context, treeRevision int64, nodeIDs []
args = append(args, interface{}(t.treeID))
args = append(args, interface{}(treeRevision))
args = append(args, interface{}(t.treeID))
-
rows, err := stx.QueryContext(ctx, args...)
if err != nil {
- glog.Warningf("Failed to get merkle subtrees: %s", err)
+ glog.Warningf("Failed to get merkle subtrees: QueryContext(%v) = (_, %q)", args, err)
return nil, err
}
defer rows.Close()
@@ -393,3 +402,52 @@ func (t *treeTX) Close() error {
}
return nil
}
+
+func (t *treeTX) GetMerkleNodes(ctx context.Context, treeRevision int64, nodeIDs []storage.NodeID) ([]storage.Node, error) {
+ return t.subtreeCache.GetNodes(nodeIDs, t.getSubtreesAtRev(ctx, treeRevision))
+}
+
+func (t *treeTX) SetMerkleNodes(ctx context.Context, nodes []storage.Node) error {
+ for _, n := range nodes {
+ err := t.subtreeCache.SetNodeHash(n.NodeID, n.Hash,
+ func(nID storage.NodeID) (*storagepb.SubtreeProto, error) {
+ return t.getSubtree(ctx, t.writeRevision, nID)
+ })
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (t *treeTX) IsOpen() bool {
+ return !t.closed
+}
+
+// getSubtreesAtRev returns a GetSubtreesFunc which reads at the passed in rev.
+func (t *treeTX) getSubtreesAtRev(ctx context.Context, rev int64) cache.GetSubtreesFunc {
+ return func(ids []storage.NodeID) ([]*storagepb.SubtreeProto, error) {
+ return t.getSubtrees(ctx, rev, ids)
+ }
+}
+
+func checkResultOkAndRowCountIs(res sql.Result, err error, count int64) error {
+ // The Exec() might have just failed
+ if err != nil {
+ return err
+ }
+
+ // Otherwise we have to look at the result of the operation
+ rowsAffected, rowsError := res.RowsAffected()
+
+ if rowsError != nil {
+ return rowsError
+ }
+
+ if rowsAffected != count {
+ return fmt.Errorf("Expected %d row(s) to be affected but saw: %d", count,
+ rowsAffected)
+ }
+
+ return nil
+}