diff --git a/internal/util/statuses.go b/internal/util/statuses.go index 9eeae6e8cc..0613ca037f 100644 --- a/internal/util/statuses.go +++ b/internal/util/statuses.go @@ -179,7 +179,6 @@ func (s *NiceStatus) Error() string { // SanitizingInterceptor sanitized error statuses which do not conform to NiceStatus, ensuring // that we don't accidentally leak implementation details over gRPC. func SanitizingInterceptor() grpc.UnaryServerInterceptor { - // TODO: this has no test coverage! return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ret, err := handler(ctx, req) if err != nil { diff --git a/internal/util/statuses_test.go b/internal/util/statuses_test.go index 7b2e7453a7..be5ce0ee50 100644 --- a/internal/util/statuses_test.go +++ b/internal/util/statuses_test.go @@ -16,11 +16,14 @@ package util_test import ( + "context" "fmt" "testing" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/stacklok/minder/internal/util" ) @@ -37,3 +40,72 @@ func TestNiceStatusCreation(t *testing.T) { expected := "Code: 0\nName: OK\nDescription: OK\nDetails: OK is returned on success." require.Equal(t, expected, fmt.Sprint(s)) } + +func TestSanitizingInterceptor(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + handler grpc.UnaryHandler + wantErr bool + errIsNice bool + }{ + { + name: "success", + handler: func(ctx context.Context, req interface{}) (interface{}, error) { + return "success", nil + }, + wantErr: false, + }, + { + name: "some error", + handler: func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.Internal, "some error") + }, + wantErr: true, + }, + { + name: "nice error", + handler: func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, util.UserVisibleError(codes.Internal, "some error") + }, + wantErr: true, + errIsNice: true, + }, + } + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + i := util.SanitizingInterceptor() + ret, err := i(ctx, nil, nil, tt.handler) + if tt.wantErr { + require.Error(t, err) + require.Nil(t, ret) + + if tt.errIsNice { + require.IsType(t, &util.NiceStatus{}, err) + require.Contains(t, err.Error(), "Code: 13\nName: INTERNAL\nDescription: Server error\nDetails: some error") + + st := status.Convert(err) + require.Equal(t, codes.Internal, st.Code()) + } + + return + } + + require.NoError(t, err) + require.NotNil(t, ret) + + // test nil error + st := status.Convert(err) + nicest := util.FromRpcError(st) + require.Equal(t, codes.OK, nicest.Code) + require.Equal(t, "OK", nicest.Name) + }) + } +}