Skip to content

Commit

Permalink
admin.Server returns the correct error codes (#455)
Browse files Browse the repository at this point in the history
This PR wires trillian/errors to the most common error sources on admin.Server.
It also adds a method to read the errors.Code from generic errors.

Errors originating from the admin storage layer are still unmmapped, therefore
their code is implied to be Unknown. While other codes would be more appropriate
(Unavailable for Begin errors, Internal for enum-mapping errors, so on), that
should be acceptable for now, as those errors are relatively uncommon.
  • Loading branch information
codingllama authored Mar 21, 2017
1 parent 0b20b2f commit 920fb88
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 53 deletions.
10 changes: 10 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ func (e *trillianError) Code() Code {
return e.code
}

// ErrorCode returns the assigned Code if err is a TrillianError, Unknown
// otherwise.
func ErrorCode(err error) Code {
terr, ok := err.(TrillianError)
if ok {
return terr.Code()
}
return Unknown
}

// Errorf creates a TrillianError from the specified code and message.
func Errorf(code Code, format string, a ...interface{}) error {
return &trillianError{code, fmt.Sprintf(format, a...)}
Expand Down
17 changes: 17 additions & 0 deletions errors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package errors

import (
"errors"
"testing"

"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -49,6 +50,22 @@ func TestCodes(t *testing.T) {
}
}

func TestErrorCode(t *testing.T) {
tests := []struct {
err error
wantCode Code
}{
{err: Errorf(InvalidArgument, "invalid argument error"), wantCode: InvalidArgument},
{err: Errorf(NotFound, "not found error"), wantCode: NotFound},
{err: errors.New("generic error"), wantCode: Unknown},
}
for _, test := range tests {
if got := ErrorCode(test.err); got != test.wantCode {
t.Errorf("err = %v, wantCode = %v", test.err, test.wantCode)
}
}
}

func TestErrorf(t *testing.T) {
tests := []struct {
code Code
Expand Down
47 changes: 22 additions & 25 deletions integration/admin/admin_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,9 @@ func TestAdminServer_CreateTree(t *testing.T) {
invalidTree.TreeState = trillian.TreeState_HARD_DELETED

tests := []struct {
desc string
req *trillian.CreateTreeRequest
// TODO(codingllama): Check correctness of returned codes.Code
wantErr bool
desc string
req *trillian.CreateTreeRequest
wantCode codes.Code
}{
{
desc: "validTree",
Expand All @@ -93,26 +92,26 @@ func TestAdminServer_CreateTree(t *testing.T) {
},
},
{
desc: "nilTree",
req: &trillian.CreateTreeRequest{},
wantErr: true,
desc: "nilTree",
req: &trillian.CreateTreeRequest{},
wantCode: codes.InvalidArgument,
},
{
desc: "invalidTree",
req: &trillian.CreateTreeRequest{
Tree: &invalidTree,
},
wantErr: true,
wantCode: codes.InvalidArgument,
},
}

ctx := context.Background()
for _, test := range tests {
createdTree, err := client.CreateTree(ctx, test.req)
if hasErr := err != nil; hasErr != test.wantErr {
t.Errorf("%v: CreateTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
if grpc.Code(err) != test.wantCode {
t.Errorf("%v: CreateTree() = (_, %v), wantCode = %v", test.desc, err, test.wantCode)
continue
} else if hasErr {
} else if err != nil {
continue
}
storedTree, err := client.GetTree(ctx, &trillian.GetTreeRequest{TreeId: createdTree.TreeId})
Expand All @@ -134,31 +133,29 @@ func TestAdminServer_GetTree(t *testing.T) {
defer closeFn()

tests := []struct {
desc string
treeID int64
wantErr bool
desc string
treeID int64
wantCode codes.Code
}{
{
desc: "negativeTreeID",
treeID: -1,
// TODO(codingllama): Check correctness of returned codes.Code
wantErr: true,
desc: "negativeTreeID",
treeID: -1,
wantCode: codes.NotFound,
},
{
desc: "notFound",
treeID: 12345,
wantErr: true,
desc: "notFound",
treeID: 12345,
wantCode: codes.NotFound,
},
}

ctx := context.Background()
for _, test := range tests {
_, err = client.GetTree(ctx, &trillian.GetTreeRequest{TreeId: test.treeID})
if hasErr := err != nil; hasErr != test.wantErr {
t.Errorf("%v: GetTree() = (_, %v), wantErr = %v", test.desc, err, test.wantErr)
if grpc.Code(err) != test.wantCode {
t.Errorf("%v: GetTree() = (_, %v), wantCode = %v", test.desc, err, test.wantCode)
}
// Success of GetTree is part of TestAdminServer_CreateTree, so it's not asserted
// here.
// Success of GetTree is part of TestAdminServer_CreateTree, so it's not asserted here.
}
}

Expand Down
17 changes: 17 additions & 0 deletions server/admin/admin_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/trillian"
"github.com/google/trillian/extension"
"github.com/google/trillian/server/errors"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -42,6 +43,14 @@ func (s *Server) ListTrees(context.Context, *trillian.ListTreesRequest) (*trilli

// GetTree implements trillian.TrillianAdminServer.GetTree.
func (s *Server) GetTree(ctx context.Context, request *trillian.GetTreeRequest) (*trillian.Tree, error) {
tree, err := s.getTreeImpl(ctx, request)
if err != nil {
return nil, errors.WrapError(err)
}
return tree, nil
}

func (s *Server) getTreeImpl(ctx context.Context, request *trillian.GetTreeRequest) (*trillian.Tree, error) {
tx, err := s.registry.AdminStorage.Snapshot(ctx)
if err != nil {
return nil, err
Expand All @@ -60,6 +69,14 @@ func (s *Server) GetTree(ctx context.Context, request *trillian.GetTreeRequest)

// CreateTree implements trillian.TrillianAdminServer.CreateTree.
func (s *Server) CreateTree(ctx context.Context, request *trillian.CreateTreeRequest) (*trillian.Tree, error) {
tree, err := s.createTreeImpl(ctx, request)
if err != nil {
return nil, errors.WrapError(err)
}
return tree, err
}

func (s *Server) createTreeImpl(ctx context.Context, request *trillian.CreateTreeRequest) (*trillian.Tree, error) {
tx, err := s.registry.AdminStorage.Begin(ctx)
if err != nil {
return nil, err
Expand Down
11 changes: 9 additions & 2 deletions server/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@
package errors

import (
"database/sql"

te "github.com/google/trillian/errors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)

// WrapError wraps err as a gRPC error if err is a TrillianError, else err is
// returned unmodified.
// WrapError wraps err as a gRPC error if err is a TrillianError or a well-known
// error instance (such as canonical sql errors), else err is returned
// unmodified.
func WrapError(err error) error {
if err == sql.ErrNoRows {
return grpc.Errorf(codes.NotFound, err.Error())
}

switch err := err.(type) {
case te.TrillianError:
return grpc.Errorf(codes.Code(err.Code()), err.Error())
Expand Down
5 changes: 5 additions & 0 deletions server/errors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package errors

import (
"database/sql"
"errors"
"testing"

Expand Down Expand Up @@ -45,6 +46,10 @@ func TestWrapError(t *testing.T) {
err: err,
wantErr: err,
},
{
err: sql.ErrNoRows,
wantErr: grpc.Errorf(codes.NotFound, sql.ErrNoRows.Error()),
},
}
for _, test := range tests {
// We can't use == for rpcErrors because grpc.Errorf returns *rpcError.
Expand Down
46 changes: 22 additions & 24 deletions storage/tree_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
package storage

import (
"errors"
"fmt"

"github.com/golang/protobuf/ptypes"
"github.com/google/trillian"
"github.com/google/trillian/crypto/sigpb"
"github.com/google/trillian/errors"
)

const (
Expand All @@ -35,29 +33,29 @@ const (
func ValidateTreeForCreation(tree *trillian.Tree) error {
switch {
case tree == nil:
return errors.New("a tree is required")
return errors.New(errors.InvalidArgument, "a tree is required")
case tree.TreeState != trillian.TreeState_ACTIVE:
return fmt.Errorf("invalid tree_state: %s", tree.TreeState)
return errors.Errorf(errors.InvalidArgument, "invalid tree_state: %s", tree.TreeState)
case tree.TreeType == trillian.TreeType_UNKNOWN_TREE_TYPE:
return fmt.Errorf("invalid tree_type: %s", tree.TreeType)
return errors.Errorf(errors.InvalidArgument, "invalid tree_type: %s", tree.TreeType)
case tree.HashStrategy == trillian.HashStrategy_UNKNOWN_HASH_STRATEGY:
return fmt.Errorf("invalid hash_strategy: %s", tree.HashStrategy)
return errors.Errorf(errors.InvalidArgument, "invalid hash_strategy: %s", tree.HashStrategy)
case tree.HashAlgorithm == sigpb.DigitallySigned_NONE:
return fmt.Errorf("invalid hash_algorithm: %s", tree.HashAlgorithm)
return errors.Errorf(errors.InvalidArgument, "invalid hash_algorithm: %s", tree.HashAlgorithm)
case tree.SignatureAlgorithm == sigpb.DigitallySigned_ANONYMOUS:
return fmt.Errorf("invalid signature_algorithm: %s", tree.SignatureAlgorithm)
return errors.Errorf(errors.InvalidArgument, "invalid signature_algorithm: %s", tree.SignatureAlgorithm)
case tree.DuplicatePolicy == trillian.DuplicatePolicy_UNKNOWN_DUPLICATE_POLICY:
return fmt.Errorf("invalid duplicate_policy: %s", tree.DuplicatePolicy)
return errors.Errorf(errors.InvalidArgument, "invalid duplicate_policy: %s", tree.DuplicatePolicy)
case tree.PrivateKey == nil:
return errors.New("a private_key is required")
return errors.New(errors.InvalidArgument, "a private_key is required")
}

// Check that the private_key proto contains a valid serialized proto.
// TODO(robpercival): Could we attempt to produce an STH at this point,
// to verify that the key works?
var privateKey ptypes.DynamicAny
if err := ptypes.UnmarshalAny(tree.PrivateKey, &privateKey); err != nil {
return fmt.Errorf("invalid private_key: %v", err)
return errors.Errorf(errors.InvalidArgument, "invalid private_key: %v", err)
}

return validateMutableTreeFields(tree)
Expand All @@ -74,35 +72,35 @@ func ValidateTreeForUpdate(storedTree, newTree *trillian.Tree) error {
// Check that readonly fields didn't change
switch {
case storedTree.TreeId != newTree.TreeId:
return errors.New("readonly field changed: tree_id")
return errors.New(errors.InvalidArgument, "readonly field changed: tree_id")
case storedTree.TreeType != newTree.TreeType:
return errors.New("readonly field changed: tree_type")
return errors.New(errors.InvalidArgument, "readonly field changed: tree_type")
case storedTree.HashStrategy != newTree.HashStrategy:
return errors.New("readonly field changed: hash_strategy")
return errors.New(errors.InvalidArgument, "readonly field changed: hash_strategy")
case storedTree.HashAlgorithm != newTree.HashAlgorithm:
return errors.New("readonly field changed: hash_algorithm")
return errors.New(errors.InvalidArgument, "readonly field changed: hash_algorithm")
case storedTree.SignatureAlgorithm != newTree.SignatureAlgorithm:
return errors.New("readonly field changed: signature_algorithm")
return errors.New(errors.InvalidArgument, "readonly field changed: signature_algorithm")
case storedTree.DuplicatePolicy != newTree.DuplicatePolicy:
return errors.New("readonly field changed: duplicate_policy")
return errors.New(errors.InvalidArgument, "readonly field changed: duplicate_policy")
case storedTree.CreateTimeMillisSinceEpoch != newTree.CreateTimeMillisSinceEpoch:
return errors.New("readonly field changed: create_time")
return errors.New(errors.InvalidArgument, "readonly field changed: create_time")
case storedTree.UpdateTimeMillisSinceEpoch != newTree.UpdateTimeMillisSinceEpoch:
return errors.New("readonly field changed: update_time")
return errors.New(errors.InvalidArgument, "readonly field changed: update_time")
case storedTree.PrivateKey != newTree.PrivateKey:
return errors.New("readonly field changed: private_key")
return errors.New(errors.InvalidArgument, "readonly field changed: private_key")
}
return validateMutableTreeFields(newTree)
}

func validateMutableTreeFields(tree *trillian.Tree) error {
switch {
case tree.TreeState == trillian.TreeState_UNKNOWN_TREE_STATE:
return fmt.Errorf("invalid tree_state: %v", tree.TreeState)
return errors.Errorf(errors.InvalidArgument, "invalid tree_state: %v", tree.TreeState)
case len(tree.DisplayName) > maxDisplayNameLength:
return fmt.Errorf("display_name too big, max length is %v: %v", maxDisplayNameLength, tree.DisplayName)
return errors.Errorf(errors.InvalidArgument, "display_name too big, max length is %v: %v", maxDisplayNameLength, tree.DisplayName)
case len(tree.Description) > maxDescriptionLength:
return fmt.Errorf("description too big, max length is %v: %v", maxDescriptionLength, tree.Description)
return errors.Errorf(errors.InvalidArgument, "description too big, max length is %v: %v", maxDescriptionLength, tree.Description)
}
return nil
}
11 changes: 9 additions & 2 deletions storage/tree_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/google/trillian"
"github.com/google/trillian/crypto/sigpb"
"github.com/google/trillian/errors"
)

func TestValidateTreeForCreation(t *testing.T) {
Expand Down Expand Up @@ -162,8 +163,11 @@ func TestValidateTreeForCreation(t *testing.T) {
}
for i, test := range tests {
err := ValidateTreeForCreation(test.tree)
if hasErr := err != nil; hasErr != test.wantErr {
switch hasErr := err != nil; {
case hasErr != test.wantErr:
t.Errorf("%v: ValidateTreeForCreation() = %v, wantErr = %v", i, err, test.wantErr)
case hasErr && errors.ErrorCode(err) != errors.InvalidArgument:
t.Errorf("%v: ValidateTreeForCreation() = %v, wantCode = %v", i, err, errors.InvalidArgument)
}
}
}
Expand Down Expand Up @@ -263,8 +267,11 @@ func TestValidateTreeForUpdate(t *testing.T) {
test.updatefn(tree)

err := ValidateTreeForUpdate(&baseTree, tree)
if hasErr := err != nil; hasErr != test.wantErr {
switch hasErr := err != nil; {
case hasErr != test.wantErr:
t.Errorf("%v: ValidateTreeForUpdate() = %v, wantErr = %v", test.desc, err, test.wantErr)
case hasErr && errors.ErrorCode(err) != errors.InvalidArgument:
t.Errorf("%v: ValidateTreeForUpdate() = %v, wantCode = %d", test.desc, err, errors.InvalidArgument)
}
}
}
Expand Down

0 comments on commit 920fb88

Please sign in to comment.