diff --git a/extension/builtin/default_registry.go b/extension/builtin/default_registry.go index 934ded9b13..044a920cc6 100644 --- a/extension/builtin/default_registry.go +++ b/extension/builtin/default_registry.go @@ -1,6 +1,7 @@ package builtin import ( + "database/sql" "flag" _ "github.com/go-sql-driver/mysql" // Load MySQL driver @@ -15,19 +16,29 @@ var MySQLURIFlag = flag.String("mysql_uri", "test:zaphod@tcp(127.0.0.1:3306)/tes "uri to use with mysql storage") // Default implementation of extension.Registry. -type defaultRegistry struct{} +type defaultRegistry struct { + db *sql.DB +} +// TODO(codingllama): Get rid of the error return func (r defaultRegistry) GetLogStorage(treeID int64) (storage.LogStorage, error) { - return mysql.NewLogStorage(treeID, *MySQLURIFlag) + return mysql.NewLogStorage(treeID, r.db) } +// TODO(codingllama): Get rid of the error return func (r defaultRegistry) GetMapStorage(treeID int64) (storage.MapStorage, error) { - return mysql.NewMapStorage(treeID, *MySQLURIFlag) + return mysql.NewMapStorage(treeID, r.db) } // NewDefaultExtensionRegistry returns the default extension.Registry implementation, which is // backed by a MySQL database and configured via flags. // The returned registry is wraped in a cached registry. func NewDefaultExtensionRegistry() (extension.Registry, error) { - return extension.NewCachedRegistry(defaultRegistry{}), nil + db, err := mysql.OpenDB(*MySQLURIFlag) + if err != nil { + return nil, err + } + return &defaultRegistry{ + db: db, + }, nil } diff --git a/extension/cached_registry.go b/extension/cached_registry.go deleted file mode 100644 index f45e92c74a..0000000000 --- a/extension/cached_registry.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package extension - -import ( - "sync" - - "github.com/google/trillian/storage" -) - -// cachedRegistry delegates method calls to registry, but caches the results for future invocations. -type cachedRegistry struct { - registry Registry - - mu sync.Mutex - logs map[int64]storage.LogStorage - maps map[int64]storage.MapStorage -} - -func (r *cachedRegistry) GetLogStorage(treeID int64) (storage.LogStorage, error) { - r.mu.Lock() - defer r.mu.Unlock() - - storage, ok := r.logs[treeID] - if !ok { - var err error - storage, err = r.registry.GetLogStorage(treeID) - if err != nil { - return nil, err - } - r.logs[treeID] = storage - } - return storage, nil -} - -func (r *cachedRegistry) GetMapStorage(treeID int64) (storage.MapStorage, error) { - r.mu.Lock() - defer r.mu.Unlock() - - storage, ok := r.maps[treeID] - if !ok { - var err error - storage, err = r.registry.GetMapStorage(treeID) - if err != nil { - return nil, err - } - r.maps[treeID] = storage - } - return storage, nil -} - -// NewCachedRegistry wraps a registry into a cached implementation, which caches storages per tree -// ID. -func NewCachedRegistry(registry Registry) Registry { - return &cachedRegistry{ - registry: registry, - logs: make(map[int64]storage.LogStorage), - maps: make(map[int64]storage.MapStorage), - } -} diff --git a/extension/cached_registry_test.go b/extension/cached_registry_test.go deleted file mode 100644 index d61578c330..0000000000 --- a/extension/cached_registry_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2016 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package extension - -import ( - "fmt" - "testing" - - "github.com/golang/mock/gomock" - "github.com/google/trillian/storage" -) - -func TestGetLogStorage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - registry := NewMockRegistry(ctrl) - cachedRegistry := NewCachedRegistry(registry) - - ls1 := storage.NewMockLogStorage(ctrl) - ls2 := storage.NewMockLogStorage(ctrl) - registry.EXPECT().GetLogStorage(int64(1)).Times(1).Return(ls1, nil) - registry.EXPECT().GetLogStorage(int64(2)).Times(1).Return(ls2, nil) - - var tests = []struct { - treeID int64 - want storage.LogStorage - }{ - {1, ls1}, - {1, ls1}, // Same key twice to test caching - {2, ls2}, - } - for _, test := range tests { - got, err := cachedRegistry.GetLogStorage(test.treeID) - switch { - case err != nil: - t.Errorf("GetLogStorage(%v) = (_, %v)", test.treeID, err) - case got != test.want: - t.Errorf("GetLogStorage(%v) = (%q, nil), want %q", test.treeID, got, test.want) - } - } -} - -func TestGetLogStorageError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - registry := NewMockRegistry(ctrl) - cachedRegistry := NewCachedRegistry(registry) - - want := fmt.Errorf("Error getting log storage") - registry.EXPECT().GetLogStorage(int64(1)).Times(2).Return(nil, want) - - // Run twice to make sure caching isn't doing anything funky - for i := 0; i < 2; i++ { - ls, err := cachedRegistry.GetLogStorage(1) - switch { - case err != want: - t.Errorf("GetLogStorage(1) = (_, %q), want %q", err, want) - case ls != nil: - t.Errorf("GetLogStorage(1) = (%q, _), want nil", ls) - } - } -} - -func TestGetMapStorage(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - registry := NewMockRegistry(ctrl) - cachedRegistry := NewCachedRegistry(registry) - - ms1 := storage.NewMockMapStorage(ctrl) - ms2 := storage.NewMockMapStorage(ctrl) - registry.EXPECT().GetMapStorage(int64(1)).Times(1).Return(ms1, nil) - registry.EXPECT().GetMapStorage(int64(2)).Times(1).Return(ms2, nil) - - var tests = []struct { - treeID int64 - want storage.MapStorage - }{ - {1, ms1}, - {1, ms1}, // Same key twice to test caching - {2, ms2}, - } - for _, test := range tests { - got, err := cachedRegistry.GetMapStorage(test.treeID) - switch { - case err != nil: - t.Errorf("GetMapStorage(%v) = (_, %v)", test.treeID, err) - case got != test.want: - t.Errorf("GetMapStorage(%v) = (%q, nil), want %q", test.treeID, got, test.want) - } - } -} - -func TestGetMapStorageError(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - registry := NewMockRegistry(ctrl) - cachedRegistry := NewCachedRegistry(registry) - - want := fmt.Errorf("Error getting map storage") - registry.EXPECT().GetMapStorage(int64(1)).Times(2).Return(nil, want) - - // Run twice to make sure caching isn't doing anything funky - for i := 0; i < 2; i++ { - ls, err := cachedRegistry.GetMapStorage(1) - switch { - case err != want: - t.Errorf("GetMapStorage(1) = (_, %q), want %q", err, want) - case ls != nil: - t.Errorf("GetMapStorage(1) = (%q, _), want nil", ls) - } - } -} diff --git a/storage/mysql/log_admin.go b/storage/mysql/log_admin.go index 863305a8b1..8c724f9d84 100644 --- a/storage/mysql/log_admin.go +++ b/storage/mysql/log_admin.go @@ -1,6 +1,7 @@ package mysql import ( + "database/sql" "fmt" "github.com/google/trillian/crypto" @@ -29,10 +30,10 @@ const ( // CreateTree instantiates a new log with default parameters. // TODO(codinglama): Move to admin API when the admin API is created. -func CreateTree(treeID int64, dbURL string) error { +func CreateTree(treeID int64, db *sql.DB) error { // TODO(codinglama) replace with a GetDatabase from the new extension API when LogID is removed. th := merkle.NewRFC6962TreeHasher(crypto.NewSHA256()) - m, err := newTreeStorage(treeID, dbURL, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) + m, err := newTreeStorage(treeID, db, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) if err != nil { return fmt.Errorf("couldn't create a new treeStorage: %s", err) } @@ -60,10 +61,10 @@ func CreateTree(treeID int64, dbURL string) error { } // DeleteTree deletes a tree by the treeID. -func DeleteTree(treeID int64, dbURL string) error { +func DeleteTree(treeID int64, db *sql.DB) error { // TODO(codinglama) replace with a GetDatabase from the new extension API when LogID is removed. th := merkle.NewRFC6962TreeHasher(crypto.NewSHA256()) - m, err := newTreeStorage(treeID, dbURL, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) + m, err := newTreeStorage(treeID, db, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) if err != nil { return fmt.Errorf("couldn't create a new treeStorage: %s", err) } diff --git a/storage/mysql/log_storage.go b/storage/mysql/log_storage.go index 36ce4320d0..ff65ddb7e8 100644 --- a/storage/mysql/log_storage.go +++ b/storage/mysql/log_storage.go @@ -75,10 +75,10 @@ type mySQLLogStorage struct { } // NewLogStorage creates a mySQLLogStorage instance for the specified MySQL URL. -func NewLogStorage(id int64, dbURL string) (storage.LogStorage, error) { +func NewLogStorage(id int64, db *sql.DB) (storage.LogStorage, error) { // TODO(al): pass this through/configure from DB th := merkle.NewRFC6962TreeHasher(crypto.NewSHA256()) - ts, err := newTreeStorage(id, dbURL, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) + ts, err := newTreeStorage(id, db, th.Size(), defaultLogStrata, cache.PopulateLogSubtreeNodes(th)) if err != nil { return nil, fmt.Errorf("couldn't create a new treeStorage: %s", err) } diff --git a/storage/mysql/log_storage_test.go b/storage/mysql/log_storage_test.go index 712458045d..8de8a91c4f 100644 --- a/storage/mysql/log_storage_test.go +++ b/storage/mysql/log_storage_test.go @@ -80,11 +80,11 @@ func checkLeafContents(leaf trillian.LogLeaf, seq int64, rawHash, hash, data, ex func TestOpenStateCommit(t *testing.T) { logID := createLogID("TestOpenStateCommit") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) - tx, err := s.Begin() + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) + tx, err := s.Begin() if err != nil { t.Fatalf("Failed to set up db transaction: %v", err) } @@ -102,11 +102,11 @@ func TestOpenStateCommit(t *testing.T) { func TestOpenStateRollback(t *testing.T) { logID := createLogID("TestOpenStateRollback") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) - tx, err := s.Begin() + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) + tx, err := s.Begin() if err != nil { t.Fatalf("Failed to set up db transaction: %v", err) } @@ -124,9 +124,9 @@ func TestOpenStateRollback(t *testing.T) { func TestQueueDuplicateLeafFails(t *testing.T) { logID := createLogID("TestQueueDuplicateLeafFails") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -149,9 +149,9 @@ func TestQueueDuplicateLeafFails(t *testing.T) { func TestQueueLeaves(t *testing.T) { logID := createLogID("TestQueueLeaves") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer failIfTXStillOpen(t, "TestQueueLeaves", tx) @@ -167,7 +167,7 @@ func TestQueueLeaves(t *testing.T) { // unsequenced data. var count int - if err := db.QueryRow("SELECT COUNT(*) FROM Unsequenced WHERE TreeID=?", logID.logID).Scan(&count); err != nil { + if err := DB.QueryRow("SELECT COUNT(*) FROM Unsequenced WHERE TreeID=?", logID.logID).Scan(&count); err != nil { t.Fatalf("Could not query row count: %v", err) } @@ -177,7 +177,7 @@ func TestQueueLeaves(t *testing.T) { // Additional check on timestamp being set correctly in the database var queueTimestamp int64 - if err := db.QueryRow("SELECT DISTINCT QueueTimestampNanos FROM Unsequenced WHERE TreeID=?", logID.logID).Scan(&queueTimestamp); err != nil { + if err := DB.QueryRow("SELECT DISTINCT QueueTimestampNanos FROM Unsequenced WHERE TreeID=?", logID.logID).Scan(&queueTimestamp); err != nil { t.Fatalf("Could not query timestamp: %v", err) } @@ -188,9 +188,9 @@ func TestQueueLeaves(t *testing.T) { func TestQueueLeavesBadHash(t *testing.T) { logID := createLogID("TestQueueLeavesBadHash") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer failIfTXStillOpen(t, "TestQueueLeavesBadHash", tx) @@ -211,9 +211,9 @@ func TestQueueLeavesBadHash(t *testing.T) { func TestDequeueLeavesNoneQueued(t *testing.T) { logID := createLogID("TestDequeueLeavesNoneQueued") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -230,9 +230,9 @@ func TestDequeueLeavesNoneQueued(t *testing.T) { func TestDequeueLeaves(t *testing.T) { logID := createLogID("TestDequeueLeaves") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) { tx := beginLogTx(s, t) @@ -284,9 +284,9 @@ func TestDequeueLeaves(t *testing.T) { func TestDequeueLeavesTwoBatches(t *testing.T) { logID := createLogID("TestDequeueLeavesTwoBatches") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) leavesToDequeue1 := 3 leavesToDequeue2 := 2 @@ -367,9 +367,9 @@ func TestDequeueLeavesTwoBatches(t *testing.T) { // are returned. func TestDequeueLeavesGuardInterval(t *testing.T) { logID := createLogID("TestDequeueLeavesGuardInterval") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) { tx := beginLogTx(s, t) @@ -419,9 +419,9 @@ func TestDequeueLeavesTimeOrdering(t *testing.T) { // transactions and make sure the returned leaves are respecting the time ordering of the // queue. logID := createLogID("TestDequeueLeavesTimeOrdering") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) batchSize := 2 { @@ -492,7 +492,8 @@ func TestDequeueLeavesTimeOrdering(t *testing.T) { func TestGetLeavesByHashNotPresent(t *testing.T) { logID := createLogID("TestGetLeavesByHashNotPresent") - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -510,7 +511,8 @@ func TestGetLeavesByHashNotPresent(t *testing.T) { func TestGetLeavesByLeafValueHashNotPresent(t *testing.T) { logID := createLogID("TestGetLeavesByLeafValueHashNotPresent") - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -524,7 +526,8 @@ func TestGetLeavesByLeafValueHashNotPresent(t *testing.T) { func TestGetLeavesByIndexNotPresent(t *testing.T) { logID := createLogID("TestGetLeavesByIndexNotPresent") - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -538,13 +541,13 @@ func TestGetLeavesByIndexNotPresent(t *testing.T) { func TestGetLeavesByHash(t *testing.T) { // Create fake leaf as if it had been sequenced logID := createLogID("TestGetLeavesByHash") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) data := []byte("some data") - createFakeLeaf(db, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) + createFakeLeaf(DB, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -566,13 +569,12 @@ func TestGetLeavesByHash(t *testing.T) { func TestGetLeavesByLeafValueHash(t *testing.T) { // Create fake leaf as if it had been sequenced logID := createLogID("TestGetLeavesByLeafValueHash") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) data := []byte("some data") - - createFakeLeaf(db, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) + createFakeLeaf(DB, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -590,12 +592,12 @@ func TestGetLeavesByLeafValueHash(t *testing.T) { func TestGetLeavesByIndex(t *testing.T) { // Create fake leaf as if it had been sequenced, read it back and check contents logID := createLogID("TestGetLeavesByIndex") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) - data := []byte("some data") + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) - createFakeLeaf(db, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) + data := []byte("some data") + createFakeLeaf(DB, logID.logID, dummyRawHash, dummyHash, data, someExtraData, sequenceNumber, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -613,25 +615,15 @@ func TestGetLeavesByIndex(t *testing.T) { checkLeafContents(leaves[0], sequenceNumber, dummyRawHash, dummyHash, data, someExtraData, t) } -func openTestDBOrDie() *sql.DB { - db, err := sql.Open("mysql", "test:zaphod@tcp(127.0.0.1:3306)/test") - if err != nil { - panic(err) - } - - return db -} - func TestLatestSignedRootNoneWritten(t *testing.T) { logID := createLogID("TestLatestSignedLogRootNoneWritten") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer tx.Rollback() root, err := tx.LatestSignedLogRoot() - if err != nil { t.Fatalf("Failed to read an empty log root: %v", err) } @@ -643,9 +635,9 @@ func TestLatestSignedRootNoneWritten(t *testing.T) { func TestLatestSignedLogRoot(t *testing.T) { logID := createLogID("TestLatestSignedLogRoot") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer tx.Rollback() @@ -675,9 +667,9 @@ func TestLatestSignedLogRoot(t *testing.T) { func TestGetTreeRevisionAtSize(t *testing.T) { logID := createLogID("TestGetTreeRevisionAtSize") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) { tx := beginLogTx(s, t) @@ -724,9 +716,9 @@ func TestGetTreeRevisionAtSize(t *testing.T) { func TestGetTreeRevisionMultipleSameSize(t *testing.T) { logID := createLogID("TestGetTreeRevisionAtSize") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) { tx := beginLogTx(s, t) @@ -767,9 +759,9 @@ func TestGetTreeRevisionMultipleSameSize(t *testing.T) { func TestDuplicateSignedLogRoot(t *testing.T) { logID := createLogID("TestDuplicateSignedLogRoot") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) defer commit(tx, t) @@ -789,9 +781,9 @@ func TestDuplicateSignedLogRoot(t *testing.T) { func TestLogRootUpdate(t *testing.T) { // Write two roots for a log and make sure the one with the newest timestamp supersedes logID := createLogID("TestLatestSignedLogRoot") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) // TODO: Tidy up the log id as it looks silly chained 3 times like this @@ -824,15 +816,11 @@ func TestLogRootUpdate(t *testing.T) { } func TestGetActiveLogIDs(t *testing.T) { - // Have to wipe everything to ensure we start with zero log trees configured - cleanTestDB() logID := createLogID("TestGetActiveLogIDs") - s := prepareTestLogStorage(logID, t) - + cleanTestDB(DB) // This creates one tree - db := prepareTestLogDB(logID, t) - defer db.Close() - + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) logIDs, err := tx.GetActiveLogIDs() @@ -847,13 +835,10 @@ func TestGetActiveLogIDs(t *testing.T) { } func TestGetActiveLogIDsWithPendingWork(t *testing.T) { - // Have to wipe everything to ensure we start with zero log trees configured - cleanTestDB() logID := createLogID("TestGetActiveLogIDsWithPendingWork") - s := prepareTestLogStorage(logID, t) - db := prepareTestLogDB(logID, t) - defer db.Close() - + cleanTestDB(DB) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) tx := beginLogTx(s, t) logIDs, err := tx.GetActiveLogIDsWithPendingWork() @@ -896,27 +881,21 @@ func TestGetSequencedLeafCount(t *testing.T) { // We'll create leaves for two different trees logID := createLogID("TestGetSequencedLeafCount") logID2 := createLogID("TestGetSequencedLeafCount2") - s1 := prepareTestLogStorage(logID, t) - s2 := prepareTestLogStorage(logID2, t) + cleanTestDB(DB) + s1 := prepareTestLogStorage(DB, logID, t) + s2 := prepareTestLogStorage(DB, logID2, t) { - db := prepareTestLogDB(logID, t) - // Create fake leaf as if it had been sequenced - defer db.Close() - + prepareTestLogDB(DB, logID, t) data := []byte("some data") - - createFakeLeaf(db, logID.logID, dummyHash, dummyRawHash, data, someExtraData, sequenceNumber, t) + createFakeLeaf(DB, logID.logID, dummyHash, dummyRawHash, data, someExtraData, sequenceNumber, t) // Create fake leaves for second tree as if they had been sequenced - db2 := prepareTestLogDB(logID2, t) - defer db2.Close() - + prepareTestLogDB(DB, logID2, t) data2 := []byte("some data 2") data3 := []byte("some data 3") - - createFakeLeaf(db2, logID2.logID, dummyHash2, dummyRawHash, data2, someExtraData, sequenceNumber, t) - createFakeLeaf(db2, logID2.logID, dummyHash3, dummyRawHash, data3, someExtraData, sequenceNumber+1, t) + createFakeLeaf(DB, logID2.logID, dummyHash2, dummyRawHash, data2, someExtraData, sequenceNumber, t) + createFakeLeaf(DB, logID2.logID, dummyHash3, dummyRawHash, data3, someExtraData, sequenceNumber+1, t) } // Read back the leaf counts from both trees @@ -959,71 +938,6 @@ func ensureAllLeavesDistinct(leaves []trillian.LogLeaf, t *testing.T) { } } -func prepareTestLogStorage(logID logIDAndTest, t *testing.T) storage.LogStorage { - if err := DeleteTree(logID.logID, "test:zaphod@tcp(127.0.0.1:3306)/test"); err != nil { - t.Fatalf("Failed to delete log storage: %s", err) - } - if err := CreateTree(logID.logID, "test:zaphod@tcp(127.0.0.1:3306)/test"); err != nil { - t.Fatalf("Failed to create new log storage: %s", err) - } - s, err := NewLogStorage(logID.logID, "test:zaphod@tcp(127.0.0.1:3306)/test") - if err != nil { - t.Fatalf("Failed to open log storage: %s", err) - } - return s -} - -// This removes all database contents for the specified log id so tests run in a -// predictable environment. For obvious reasons this should only be allowed to run -// against test databases. This method panics if any of the deletions fails to make -// sure tests can't inadvertently succeed. -func prepareTestTreeDB(treeID int64, t *testing.T) *sql.DB { - db := openTestDBOrDie() - - // Wipe out anything that was there for this tree id - for _, table := range allTables { - _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE TreeId=?", table), treeID) - - if err != nil { - t.Fatalf("Failed to delete rows in %s for %d: %s", table, treeID, err) - } - } - return db -} - -// This removes all database contents for the specified log id so tests run in a -// predictable environment. For obvious reasons this should only be allowed to run -// against test databases. This method panics if any of the deletions fails to make -// sure tests can't inadvertently succeed. -func prepareTestLogDB(logID logIDAndTest, t *testing.T) *sql.DB { - db := prepareTestTreeDB(logID.logID, t) - - // Now put back the tree row for this log id - _, err := db.Exec(`REPLACE INTO Trees(TreeId, KeyId, TreeType, LeafHasherType, TreeHasherType) - VALUES(?, ?, "LOG", "SHA256", "SHA256")`, logID.logID, logID.logID) - - if err != nil { - t.Fatalf("Failed to create tree entry for test: %v", err) - } - - return db -} - -// This deletes all the entries in the database. Only use this with a test database -func cleanTestDB() { - db := openTestDBOrDie() - defer db.Close() - - // Wipe out anything that was there for this log id - for _, table := range allTables { - _, err := db.Exec(fmt.Sprintf("DELETE FROM %s", table)) - - if err != nil { - panic(fmt.Errorf("Failed to delete rows in %s: %s", table, err)) - } - } -} - // Creates some test leaves with predictable data func createTestLeaves(n, startSeq int64) []trillian.LogLeaf { var leaves []trillian.LogLeaf diff --git a/storage/mysql/map_storage.go b/storage/mysql/map_storage.go index b1ff75d2f4..5ffe62561f 100644 --- a/storage/mysql/map_storage.go +++ b/storage/mysql/map_storage.go @@ -48,10 +48,10 @@ func (m *mySQLMapStorage) MapID() int64 { } // NewMapStorage creates a mySQLMapStorage instance for the specified MySQL URL. -func NewMapStorage(id int64, dbURL string) (storage.MapStorage, error) { +func NewMapStorage(id int64, db *sql.DB) (storage.MapStorage, error) { // TODO(al): pass this through/configure from DB th := merkle.NewRFC6962TreeHasher(crypto.NewSHA256()) - ts, err := newTreeStorage(id, dbURL, th.Size(), defaultMapStrata, cache.PopulateMapSubtreeNodes(th)) + ts, err := newTreeStorage(id, db, th.Size(), defaultMapStrata, cache.PopulateMapSubtreeNodes(th)) if err != nil { glog.Warningf("Couldn't create a new treeStorage: %s", err) return nil, err diff --git a/storage/mysql/map_storage_test.go b/storage/mysql/map_storage_test.go index c8cfa5bcfa..29fd1b72ec 100644 --- a/storage/mysql/map_storage_test.go +++ b/storage/mysql/map_storage_test.go @@ -12,9 +12,9 @@ import ( func TestMapRootUpdate(t *testing.T) { // Write two roots for a map and make sure the one with the newest timestamp supersedes mapID := createMapID("TestLatestSignedMapRoot") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) tx := beginMapTx(s, t) defer tx.Commit() @@ -69,12 +69,10 @@ var mapLeaf = trillian.MapLeaf{ } func TestMapSetGetRoundTrip(t *testing.T) { - cleanTestDB() - mapID := createMapID("TestMapSetGetRoundTrip") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) readRev := int64(1) @@ -109,12 +107,10 @@ func TestMapSetGetRoundTrip(t *testing.T) { } func TestMapSetSameKeyInSameRevisionFails(t *testing.T) { - cleanTestDB() - mapID := createMapID("TestMapSetSameKeyInSameRevisionFails") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) { tx := beginMapTx(s, t) @@ -140,12 +136,10 @@ func TestMapSetSameKeyInSameRevisionFails(t *testing.T) { } func TestMapGetUnknownKey(t *testing.T) { - cleanTestDB() - mapID := createMapID("TestMapGetUnknownKey") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) { tx := beginMapTx(s, t) @@ -164,13 +158,11 @@ func TestMapGetUnknownKey(t *testing.T) { } func TestMapSetGetMultipleRevisions(t *testing.T) { - cleanTestDB() - // Write two roots for a map and make sure the one with the newest timestamp supersedes mapID := createMapID("TestMapSetGetMultipleRevisions") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) tests := []struct { rev int64 @@ -220,9 +212,9 @@ func TestMapSetGetMultipleRevisions(t *testing.T) { func TestLatestSignedMapRootNoneWritten(t *testing.T) { mapID := createMapID("TestLatestSignedMapRootNoneWritten") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) tx := beginMapTx(s, t) defer tx.Rollback() @@ -239,9 +231,9 @@ func TestLatestSignedMapRootNoneWritten(t *testing.T) { func TestLatestSignedMapRoot(t *testing.T) { mapID := createMapID("TestLatestSignedMapRoot") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) tx := beginMapTx(s, t) defer tx.Rollback() @@ -273,9 +265,9 @@ func TestLatestSignedMapRoot(t *testing.T) { func TestDuplicateSignedMapRoot(t *testing.T) { mapID := createMapID("TestDuplicateSignedMapRoot") - db := prepareTestMapDB(mapID, t) - defer db.Close() - s := prepareTestMapStorage(mapID, t) + cleanTestDB(DB) + prepareTestMapDB(DB, mapID, t) + s := prepareTestMapStorage(DB, mapID, t) tx := beginMapTx(s, t) defer tx.Commit() @@ -292,12 +284,11 @@ func TestDuplicateSignedMapRoot(t *testing.T) { } } -func prepareTestMapStorage(mapID mapIDAndTest, t *testing.T) storage.MapStorage { - s, err := NewMapStorage(mapID.mapID, "test:zaphod@tcp(127.0.0.1:3306)/test") +func prepareTestMapStorage(db *sql.DB, mapID mapIDAndTest, t *testing.T) storage.MapStorage { + s, err := NewMapStorage(mapID.mapID, db) if err != nil { t.Fatalf("Failed to open map storage: %s", err) } - return s } @@ -305,18 +296,14 @@ func prepareTestMapStorage(mapID mapIDAndTest, t *testing.T) storage.MapStorage // predictable environment. For obvious reasons this should only be allowed to run // against test databases. This method panics if any of the deletions fails to make // sure tests can't inadvertently succeed. -func prepareTestMapDB(mapID mapIDAndTest, t *testing.T) *sql.DB { - db := prepareTestTreeDB(mapID.mapID, t) - +func prepareTestMapDB(db *sql.DB, mapID mapIDAndTest, t *testing.T) { + prepareTestTreeDB(DB, mapID.mapID, t) // Now put back the tree row for this log id _, err := db.Exec(`REPLACE INTO Trees(TreeId, KeyId, TreeType, LeafHasherType, TreeHasherType) VALUES(?, ?, "LOG", "SHA256", "SHA256")`, mapID.mapID, mapID.mapID) - if err != nil { t.Fatalf("Failed to create tree entry for test: %v", err) } - - return db } func beginMapTx(s storage.MapStorage, t *testing.T) storage.MapTX { diff --git a/storage/mysql/storage_test.go b/storage/mysql/storage_test.go index ae73ac2a48..921e450926 100644 --- a/storage/mysql/storage_test.go +++ b/storage/mysql/storage_test.go @@ -3,6 +3,7 @@ package mysql import ( "bytes" "crypto/sha256" + "database/sql" "flag" "fmt" "os" @@ -51,9 +52,8 @@ func TestNodeIDSerialization(t *testing.T) { func TestNodeRoundTrip(t *testing.T) { logID := createLogID("TestNodeRoundTrip") - db := prepareTestLogDB(logID, t) - defer db.Close() - s := prepareTestLogStorage(logID, t) + prepareTestLogDB(DB, logID, t) + s := prepareTestLogStorage(DB, logID, t) const writeRevision = int64(100) @@ -131,15 +131,6 @@ func createMapID(testName string) mapIDAndTest { } } -func createTestDB() { - db := openTestDBOrDie() - _, err := db.Exec(`REPLACE INTO Trees(TreeId, KeyId, TreeType, LeafHasherType, TreeHasherType) - VALUES(23, "hi", "LOG", "SHA256", "SHA256")`) - if err != nil { - panic(err) - } -} - func createSomeNodes(testName string, treeID int64) []storage.Node { r := make([]storage.Node, 4) for i := range r { @@ -166,9 +157,87 @@ func nodesAreEqual(lhs []storage.Node, rhs []storage.Node) error { return nil } +func openTestDBOrDie() *sql.DB { + db, err := OpenDB("test:zaphod@tcp(127.0.0.1:3306)/test") + if err != nil { + panic(err) + } + return db +} + +// cleanTestDB deletes all the entries in the database. Only use this with a test database +// TODO(gdbelvin): Migrate to testonly/integration +func cleanTestDB(db *sql.DB) { + // Wipe out anything that was there for this log id + for _, table := range allTables { + _, err := db.Exec(fmt.Sprintf("DELETE FROM %s", table)) + + if err != nil { + panic(fmt.Errorf("Failed to delete rows in %s: %s", table, err)) + } + } +} + +// TODO(codingllama): Migrate to Admin API +func createTestDB(db *sql.DB) { + _, err := db.Exec(`REPLACE INTO Trees(TreeId, KeyId, TreeType, LeafHasherType, TreeHasherType) + VALUES(23, "hi", "LOG", "SHA256", "SHA256")`) + if err != nil { + panic(err) + } +} + +// prepareTestTreeDB removes all database contents for the specified log id so tests run in a predictable environment. For obvious reasons this should only be allowed to run against test databases. This method panics if any of the deletions fails to make sure tests can't inadvertently succeed. +// TODO(gdbelvin): Migrate to testonly/integration / create a new DB for freshness +func prepareTestTreeDB(db *sql.DB, treeID int64, t *testing.T) { + // Wipe out anything that was there for this tree id + for _, table := range allTables { + _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE TreeId=?", table), treeID) + + if err != nil { + t.Fatalf("Failed to delete rows in %s for %d: %s", table, treeID, err) + } + } +} + +// prepareTestLogDB removes all database contents for the specified log id so tests run in a predictable environment. For obvious reasons this should only be allowed to run against test databases. This method panics if any of the deletions fails to make sure tests can't inadvertently succeed. +// TODO(codingllama): Migrate to Admin API +func prepareTestLogDB(db *sql.DB, logID logIDAndTest, t *testing.T) { + prepareTestTreeDB(db, logID.logID, t) + + // Now put back the tree row for this log id + _, err := db.Exec(`REPLACE INTO Trees(TreeId, KeyId, TreeType, LeafHasherType, TreeHasherType) + VALUES(?, ?, "LOG", "SHA256", "SHA256")`, logID.logID, logID.logID) + + if err != nil { + t.Fatalf("Failed to create tree entry for test: %v", err) + } + +} + +func prepareTestLogStorage(db *sql.DB, logID logIDAndTest, t *testing.T) storage.LogStorage { + if err := DeleteTree(logID.logID, db); err != nil { + t.Fatalf("Failed to delete log storage: %s", err) + } + if err := CreateTree(logID.logID, db); err != nil { + t.Fatalf("Failed to create new log storage: %s", err) + } + s, err := NewLogStorage(logID.logID, db) + if err != nil { + t.Fatalf("Failed to open log storage: %s", err) + } + return s +} + +// DB is the database used for tests. It's initialized and closed by TestMain(). +var DB *sql.DB + func TestMain(m *testing.M) { flag.Parse() - cleanTestDB() - createTestDB() - os.Exit(m.Run()) + DB = openTestDBOrDie() + defer DB.Close() + cleanTestDB(DB) + createTestDB(DB) + ec := m.Run() + os.Exit(ec) } diff --git a/storage/mysql/tree_storage.go b/storage/mysql/tree_storage.go index 9275c10e24..af6aa7683d 100644 --- a/storage/mysql/tree_storage.go +++ b/storage/mysql/tree_storage.go @@ -55,7 +55,8 @@ type mySQLTreeStorage struct { strataDepths []int } -func openDB(dbURL string) (*sql.DB, error) { +// OpenDB opens a database connection for all MySQL-based storage implementations. +func OpenDB(dbURL string) (*sql.DB, error) { db, err := sql.Open("mysql", dbURL) if err != nil { // Don't log uri as it could contain credentials @@ -71,22 +72,16 @@ func openDB(dbURL string) (*sql.DB, error) { return db, nil } -func newTreeStorage(treeID int64, dbURL string, hashSizeBytes int, strataDepths []int, populateSubtree storage.PopulateSubtreeFunc) (*mySQLTreeStorage, error) { - db, err := openDB(dbURL) - if err != nil { - return &mySQLTreeStorage{}, err - } - - s := mySQLTreeStorage{ +// TODO(codingllama): Remove error return +func newTreeStorage(treeID int64, db *sql.DB, hashSizeBytes int, strataDepths []int, populateSubtree storage.PopulateSubtreeFunc) (*mySQLTreeStorage, error) { + return &mySQLTreeStorage{ treeID: treeID, db: db, hashSizeBytes: hashSizeBytes, populateSubtree: populateSubtree, statements: make(map[string]map[int]*sql.Stmt), strataDepths: strataDepths, - } - - return &s, nil + }, nil } // expandPlaceholderSQL expands an sql statement by adding a specified number of '?' diff --git a/vmap/toy/vmap_toy.go b/vmap/toy/vmap_toy.go index a39c3cfa85..16559a08a9 100644 --- a/vmap/toy/vmap_toy.go +++ b/vmap/toy/vmap_toy.go @@ -23,10 +23,16 @@ var mySQLURIFlag = flag.String("mysql_uri", "test:zaphod@tcp(127.0.0.1:3306)/tes func main() { flag.Parse() glog.Info("Starting...") + + db, err := mysql.OpenDB(*mySQLURIFlag) + if err != nil { + glog.Fatalf("Failed to open DB connection: %v", err) + } + mapID := int64(1) - ms, err := mysql.NewMapStorage(mapID, *mySQLURIFlag) + ms, err := mysql.NewMapStorage(mapID, db) if err != nil { - glog.Fatalf("Failed to open mysql storage: %v", err) + glog.Fatalf("Failed create MapStorage: %v", err) } hasher := merkle.NewMapHasher(merkle.NewRFC6962TreeHasher(crypto.NewSHA256()))