Skip to content

Commit

Permalink
Code review changes, remove code panic and refactor test
Browse files Browse the repository at this point in the history
  • Loading branch information
vishalkuo committed Nov 19, 2018
1 parent 86c0e9a commit 77d4644
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 23 deletions.
2 changes: 1 addition & 1 deletion storage/postgres/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestMain(m *testing.M) {
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*30))
defer cancel()
db = testdb.OpenTestDBOrDie(ctx)
db = testdb.NewTrillianDBOrDie(ctx)
defer db.Close()
ec = m.Run()
}
23 changes: 6 additions & 17 deletions storage/postgres/testdb/testdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,17 @@ func newEmptyDB(ctx context.Context) (*sql.DB, error) {
return db, db.Ping()
}

// NewTrillianDB creates an empty database with the Trillian schema. The database name is randomly
// NewTrillianDBOrDie creates an empty database with the Trillian schema. The database name is randomly
// generated.
// NewTrillianDB is equivalent to Default().NewTrillianDB(ctx).
func NewTrillianDB(ctx context.Context) (*sql.DB, error) {
func NewTrillianDBOrDie(ctx context.Context) *sql.DB {
db, err := newEmptyDB(ctx)
if err != nil {
return nil, err
panic(err)
}

sqlBytes, err := ioutil.ReadFile(trillianSQL)
if err != nil {
return nil, err
panic(err)
}

for _, stmt := range strings.Split(sanitize(string(sqlBytes)), ";") {
Expand All @@ -96,10 +95,10 @@ func NewTrillianDB(ctx context.Context) (*sql.DB, error) {
continue
}
if _, err := db.ExecContext(ctx, stmt); err != nil {
return nil, fmt.Errorf("error running statement %q: %v", stmt, err)
panic(fmt.Errorf("error running statement %q: %v", stmt, err))
}
}
return db, nil
return db
}

// sanitize tries to remove empty lines and comments from a sql script
Expand All @@ -120,13 +119,3 @@ func sanitize(script string) string {
func getConnStr(name string) string {
return fmt.Sprintf("database=%s %s", name, *pgOpts)
}

// OpenTestDBOrDie attempts to return a connection to a new postgres
// test database and fails if unable to do so.
func OpenTestDBOrDie(ctx context.Context) *sql.DB {
db, err := NewTrillianDB(ctx)
if err != nil {
panic(err)
}
return db
}
13 changes: 9 additions & 4 deletions storage/postgres/tree_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ type statementSkeleton struct {

// expandPlaceholderSQL expands an sql statement by adding a specified number of '%s'
// placeholder slots. At most one placeholder will be expanded.
func expandPlaceholderSQL(skeleton *statementSkeleton) string {
func expandPlaceholderSQL(skeleton *statementSkeleton) (string, error) {
if skeleton.num <= 0 {
panic(fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", skeleton.sql))
return "", fmt.Errorf("trying to expand SQL placeholder with <= 0 parameters: %s", skeleton.sql)
}

restCount := skeleton.num - 1
Expand All @@ -120,7 +120,7 @@ func expandPlaceholderSQL(skeleton *statementSkeleton) string {
remainingInsertion := strings.Repeat(","+skeleton.restInsertion, restCount)
toInsertBuilder.WriteString(fmt.Sprintf(remainingInsertion, totalArray[skeleton.firstPlaceholders:]...))

return strings.Replace(skeleton.sql, placeholderSQL, toInsertBuilder.String(), 1)
return strings.Replace(skeleton.sql, placeholderSQL, toInsertBuilder.String(), 1), nil
}

// getStmt creates and caches sql.Stmt structs based on the passed in statement
Expand All @@ -137,7 +137,12 @@ func (p *pgTreeStorage) getStmt(ctx context.Context, skeleton *statementSkeleton
p.statements[skeleton.sql] = make(map[int]*sql.Stmt)
}

s, err := p.db.PrepareContext(ctx, expandPlaceholderSQL(skeleton))
statement, err := expandPlaceholderSQL(skeleton)
if err != nil {
glog.Warningf("Failed to expand placeholder sql: %v", skeleton)
return nil, err
}
s, err := p.db.PrepareContext(ctx, statement)

if err != nil {
glog.Warningf("Failed to prepare statement %d: %s", skeleton.num, err)
Expand Down
6 changes: 5 additions & 1 deletion storage/postgres/tree_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ func TestExpandPlaceholderSQL(t *testing.T) {
}

for _, tc := range testCases {
if res := expandPlaceholderSQL(tc.input); res != tc.expected {
res, err := expandPlaceholderSQL(tc.input)
if err != nil {
t.Fatalf("Error while expanding placeholder sql: %v", err)
}
if tc.expected != res {
t.Fatalf("Expected %v but got %v", tc.expected, res)
}
}
Expand Down

0 comments on commit 77d4644

Please sign in to comment.