diff --git a/storage/postgres/storage_test.go b/storage/postgres/storage_test.go index eddaa843c7..68ee4f8ac4 100644 --- a/storage/postgres/storage_test.go +++ b/storage/postgres/storage_test.go @@ -39,7 +39,7 @@ func TestMain(m *testing.M) { } ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*30)) defer cancel() - db = testdb.OpenTestDBOrDie(ctx) + db = testdb.NewTrillianDBOrDie(ctx) defer db.Close() ec = m.Run() } diff --git a/storage/postgres/testdb/testdb.go b/storage/postgres/testdb/testdb.go index 70db31ea43..69d0579e76 100644 --- a/storage/postgres/testdb/testdb.go +++ b/storage/postgres/testdb/testdb.go @@ -76,18 +76,17 @@ func newEmptyDB(ctx context.Context) (*sql.DB, error) { return db, db.Ping() } -// NewTrillianDB creates an empty database with the Trillian schema. The database name is randomly +// NewTrillianDBOrDie creates an empty database with the Trillian schema. The database name is randomly // generated. -// NewTrillianDB is equivalent to Default().NewTrillianDB(ctx). -func NewTrillianDB(ctx context.Context) (*sql.DB, error) { +func NewTrillianDBOrDie(ctx context.Context) *sql.DB { db, err := newEmptyDB(ctx) if err != nil { - return nil, err + panic(err) } sqlBytes, err := ioutil.ReadFile(trillianSQL) if err != nil { - return nil, err + panic(err) } for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";") { @@ -96,10 +95,10 @@ func NewTrillianDB(ctx context.Context) (*sql.DB, error) { continue } if _, err := db.ExecContext(ctx, stmt); err != nil { - return nil, fmt.Errorf("error running statement %q: %v", stmt, err) + panic(fmt.Errorf("error running statement %q: %v", stmt, err)) } } - return db, nil + return db } // sanitize tries to remove empty lines and comments from a sql script @@ -120,13 +119,3 @@ func sanitize(script string) string { func getConnStr(name string) string { return fmt.Sprintf("database=%s %s", name, *pgOpts) } - -// OpenTestDBOrDie attempts to return a connection to a new postgres -// test database and fails if unable to do so. -func OpenTestDBOrDie(ctx context.Context) *sql.DB { - db, err := NewTrillianDB(ctx) - if err != nil { - panic(err) - } - return db -} diff --git a/storage/postgres/tree_storage.go b/storage/postgres/tree_storage.go index c0ed071b1b..97c0252bb5 100644 --- a/storage/postgres/tree_storage.go +++ b/storage/postgres/tree_storage.go @@ -103,9 +103,9 @@ type statementSkeleton struct { // expandPlaceholderSQL expands an sql statement by adding a specified number of '%s' // placeholder slots. At most one placeholder will be expanded. -func expandPlaceholderSQL(skeleton *statementSkeleton) string { +func expandPlaceholderSQL(skeleton *statementSkeleton) (string, error) { if skeleton.num <= 0 { - panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", skeleton.sql)) + return "", fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", skeleton.sql) } restCount := skeleton.num - 1 @@ -120,7 +120,7 @@ func expandPlaceholderSQL(skeleton *statementSkeleton) string { remainingInsertion := strings.Repeat(","+skeleton.restInsertion, restCount) toInsertBuilder.WriteString(fmt.Sprintf(remainingInsertion, totalArray[skeleton.firstPlaceholders:]...)) - return strings.Replace(skeleton.sql, placeholderSQL, toInsertBuilder.String(), 1) + return strings.Replace(skeleton.sql, placeholderSQL, toInsertBuilder.String(), 1), nil } // getStmt creates and caches sql.Stmt structs based on the passed in statement @@ -137,7 +137,12 @@ func (p *pgTreeStorage) getStmt(ctx context.Context, skeleton *statementSkeleton p.statements[skeleton.sql] = make(map[int]*sql.Stmt) } - s, err := p.db.PrepareContext(ctx, expandPlaceholderSQL(skeleton)) + statement, err := expandPlaceholderSQL(skeleton) + if err != nil { + glog.Warningf("Failed to expand placeholder sql: %v", skeleton) + return nil, err + } + s, err := p.db.PrepareContext(ctx, statement) if err != nil { glog.Warningf("Failed to prepare statement %d: %s", skeleton.num, err) diff --git a/storage/postgres/tree_storage_test.go b/storage/postgres/tree_storage_test.go index 6d62051d9f..9baa7b329e 100644 --- a/storage/postgres/tree_storage_test.go +++ b/storage/postgres/tree_storage_test.go @@ -94,7 +94,11 @@ func TestExpandPlaceholderSQL(t *testing.T) { } for _, tc := range testCases { - if res := expandPlaceholderSQL(tc.input); res != tc.expected { + res, err := expandPlaceholderSQL(tc.input) + if err != nil { + t.Fatalf("Error while expanding placeholder sql: %v", err) + } + if tc.expected != res { t.Fatalf("Expected %v but got %v", tc.expected, res) } }