diff --git a/docs/_docs/customizingyourgateway.md b/docs/_docs/customizingyourgateway.md index 86e0f8d06fa..41fc86f428a 100644 --- a/docs/_docs/customizingyourgateway.md +++ b/docs/_docs/customizingyourgateway.md @@ -234,91 +234,19 @@ if err := pb.RegisterMyServiceHandlerFromEndpoint(ctx, mux, serviceEndpoint, opt ``` ## Error handler -The gateway uses two different error handlers for non-streaming requests: - - * `runtime.HTTPError` is called for errors from backend calls - * `runtime.OtherErrorHandler` is called for errors from parsing and routing client requests - -To override all error handling for a `*runtime.ServeMux`, use the -`runtime.WithProtoErrorHandler` serve option. - -Alternatively, you can override the global default `HTTPError` handling by -setting `runtime.GlobalHTTPErrorHandler` to a custom function, and override -the global default `OtherErrorHandler` by setting `runtime.OtherErrorHandler` -to a custom function. - -You should not set `runtime.HTTPError` directly, because that might break -any `ServeMux` set up with the `WithProtoErrorHandler` option. +To override error handling for a `*runtime.ServeMux`, use the +`runtime.WithErrorHandler` option. This will configure all unary error +responses to pass through this error handler. See https://mycodesmells.com/post/grpc-gateway-error-handler for an example -of writing a custom error handler function. +of writing a custom error handler function. Note that this post targets +the v1 release of the gateway, and you no longer assign to `HTTPError` to +configure an error handler. ## Stream Error Handler The error handler described in the previous section applies only -to RPC methods that have a unary response. - -When the method has a streaming response, grpc-gateway handles -that by emitting a newline-separated stream of "chunks". Each -chunk is an envelope that can container either a response message -or an error. Only the last chunk will include an error, and only -when the RPC handler ends abnormally (i.e. with an error code). - -Because of the way the errors are included in the response body, -the other error handler signature is insufficient. So for server -streams, you must install a _different_ error handler: - -```go -mux := runtime.NewServeMux( - runtime.WithStreamErrorHandler(handleStreamError)) -``` - -The signature of the handler is much more rigid because we need -to know the structure of the error payload in order to properly -encode the "chunk" schema into a Swagger/OpenAPI spec. - -So the function must return a `*runtime.StreamError`. The handler -can choose to omit some fields and can filter/transform the original -error, such as stripping stack traces from error messages. - -Here's an example custom handler: -```go -// handleStreamError overrides default behavior for computing an error -// message for a server stream. -// -// It uses a default "502 Bad Gateway" HTTP code; only emits "safe" -// messages; and does not set gRPC code or details fields (so they will -// be omitted from the resulting JSON object that is sent to client). -func handleStreamError(ctx context.Context, err error) *runtime.StreamError { - code := http.StatusBadGateway - msg := "unexpected error" - if s, ok := status.FromError(err); ok { - code = runtime.HTTPStatusFromCode(s.Code()) - // default message, based on the name of the gRPC code - msg = code.String() - // see if error details include "safe" message to send - // to external callers - for _, msg := s.Details() { - if safe, ok := msg.(*SafeMessage); ok { - msg = safe.Text - break - } - } - } - return &runtime.StreamError{ - HttpCode: int32(code), - HttpStatus: http.StatusText(code), - Message: msg, - } -} -``` - -If no custom handler is provided, the default stream error handler -will include any gRPC error attributes (code, message, detail messages), -if the error being reported includes them. If the error does not have -these attributes, a gRPC code of `Unknown` (2) is reported. The default -handler will also include an HTTP code and status, which is derived -from the gRPC code (or set to `"500 Internal Server Error"` when -the source error has no gRPC attributes). +to RPC methods that have a unary response. It is currently +not possible to configure the stream error handler. ## Replace a response forwarder per method You might want to keep the behavior of the current marshaler but change only a message forwarding of a certain API method. diff --git a/examples/internal/integration/BUILD.bazel b/examples/internal/integration/BUILD.bazel index bcc837b56f1..9755e1b7f7c 100644 --- a/examples/internal/integration/BUILD.bazel +++ b/examples/internal/integration/BUILD.bazel @@ -7,7 +7,6 @@ go_test( "fieldmask_test.go", "integration_test.go", "main_test.go", - "proto_error_test.go", ], deps = [ "//examples/internal/clients/abe:go_default_library", @@ -23,9 +22,7 @@ go_test( "@com_github_golang_protobuf//descriptor:go_default_library_gen", "@com_github_golang_protobuf//jsonpb:go_default_library_gen", "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", "@com_github_google_go_cmp//cmp:go_default_library", - "@go_googleapis//google/rpc:errdetails_go_proto", "@go_googleapis//google/rpc:status_go_proto", "@io_bazel_rules_go//proto/wkt:empty_go_proto", "@io_bazel_rules_go//proto/wkt:field_mask_go_proto", diff --git a/examples/internal/integration/integration_test.go b/examples/internal/integration/integration_test.go index f23f6e7e133..59a320a18d2 100644 --- a/examples/internal/integration/integration_test.go +++ b/examples/internal/integration/integration_test.go @@ -1345,13 +1345,13 @@ func TestUnknownPath(t *testing.T) { return } - if got, want := resp.StatusCode, http.StatusNotFound; got != want { + if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { t.Errorf("resp.StatusCode = %d; want %d", got, want) t.Logf("%s", buf) } } -func TestMethodNotAllowed(t *testing.T) { +func TestIncorrectMethod(t *testing.T) { if testing.Short() { t.Skip() return @@ -1370,7 +1370,7 @@ func TestMethodNotAllowed(t *testing.T) { return } - if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { + if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { t.Errorf("resp.StatusCode = %d; want %d", got, want) t.Logf("%s", buf) } diff --git a/examples/internal/integration/proto_error_test.go b/examples/internal/integration/proto_error_test.go deleted file mode 100644 index 12a636c1a62..00000000000 --- a/examples/internal/integration/proto_error_test.go +++ /dev/null @@ -1,281 +0,0 @@ -package integration_test - -import ( - "context" - "fmt" - "io/ioutil" - "net/http" - "strings" - "testing" - "time" - - "github.com/golang/protobuf/jsonpb" - "github.com/golang/protobuf/ptypes" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "google.golang.org/genproto/googleapis/rpc/errdetails" - spb "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc/codes" -) - -func runServer(ctx context.Context, t *testing.T, port uint16) { - opt := runtime.WithProtoErrorHandler(runtime.DefaultHTTPProtoErrorHandler) - if err := runGateway(ctx, fmt.Sprintf(":%d", port), opt); err != nil { - t.Errorf("runGateway() failed with %v; want success", err) - } -} - -func TestWithProtoErrorHandler(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8082 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8082); err != nil { - t.Errorf("waitForGateway(ctx, 8082) failed with %v; want success", err) - } - testEcho(t, port, "v1", "application/json") - testEchoBody(t, port, "v1") -} - -func TestABEWithProtoErrorHandler(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8083 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8083); err != nil { - t.Errorf("waitForGateway(ctx, 8083) failed with %v; want success", err) - } - - testABECreate(t, port) - testABECreateBody(t, port) - testABEBulkCreate(t, port) - testABELookup(t, port) - testABELookupNotFoundWithProtoError(t, port) - testABELookupNotFoundWithProtoErrorIncludingDetails(t, port) - testABEList(t, port) - testABEBulkEcho(t, port) - testABEBulkEchoZeroLength(t, port) - testAdditionalBindings(t, port) -} - -func testABELookupNotFoundWithProtoError(t *testing.T, port uint16) { - url := fmt.Sprintf("http://localhost:%d/v1/example/a_bit_of_everything", port) - uuid := "not_exist" - url = fmt.Sprintf("%s/%s", url, uuid) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Get(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotFound; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - return - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.NotFound); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if got, want := msg.Message, "not found"; got != want { - t.Errorf("msg.Message = %s; want %s", got, want) - return - } - - if got, want := resp.Header.Get("Grpc-Metadata-Uuid"), uuid; got != want { - t.Errorf("Grpc-Metadata-Uuid was %s, wanted %s", got, want) - } - if got, want := resp.Trailer.Get("Grpc-Trailer-Foo"), "foo2"; got != want { - t.Errorf("Grpc-Trailer-Foo was %q, wanted %q", got, want) - } - if got, want := resp.Trailer.Get("Grpc-Trailer-Bar"), "bar2"; got != want { - t.Errorf("Grpc-Trailer-Bar was %q, wanted %q", got, want) - } -} - -func testABELookupNotFoundWithProtoErrorIncludingDetails(t *testing.T, port uint16) { - uuid := "errorwithdetails" - url := fmt.Sprintf("http://localhost:%d/v2/example/%s", port, uuid) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Get(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusInternalServerError; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - return - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unknown); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if got, want := msg.Message, "with details"; got != want { - t.Errorf("msg.Message = %s; want %s", got, want) - return - } - - details := msg.Details - if got, want := len(details), 1; got != want { - t.Fatalf("got %q details, wanted %q", got, want) - } - - detail := errdetails.DebugInfo{} - if got, want := ptypes.UnmarshalAny(msg.Details[0], &detail), error(nil); got != want { - t.Errorf("unmarshaling any: got %q, wanted %q", got, want) - } - - if got, want := len(detail.StackEntries), 1; got != want { - t.Fatalf("got %d stack entries, expected %d", got, want) - } - if got, want := detail.StackEntries[0], "foo:1"; got != want { - t.Errorf("StackEntries[0]: got %q; want %q", got, want) - } -} - -func TestUnknownPathWithProtoError(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8084 - go runServer(ctx, t, port) - if err := waitForGateway(ctx, 8084); err != nil { - t.Errorf("waitForGateway(ctx, 8084) failed with %v; want success", err) - } - - url := fmt.Sprintf("http://localhost:%d", port) - resp, err := http.Post(url, "application/json", strings.NewReader("{}")) - if err != nil { - t.Errorf("http.Post(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unimplemented); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if msg.Message == "" { - t.Errorf("msg.Message should not be empty") - return - } -} - -func TestMethodNotAllowedWithProtoError(t *testing.T) { - if testing.Short() { - t.Skip() - return - } - - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const port = 8085 - go runServer(ctx, t, port) - - // Waiting for the server's getting available. - // TODO(yugui) find a better way to wait - time.Sleep(100 * time.Millisecond) - - url := fmt.Sprintf("http://localhost:%d/v1/example/echo/myid", port) - resp, err := http.Get(url) - if err != nil { - t.Errorf("http.Post(%q) failed with %v; want success", url, err) - return - } - defer resp.Body.Close() - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) - return - } - - if got, want := resp.StatusCode, http.StatusNotImplemented; got != want { - t.Errorf("resp.StatusCode = %d; want %d", got, want) - t.Logf("%s", buf) - } - - var msg spb.Status - if err := jsonpb.UnmarshalString(string(buf), &msg); err != nil { - t.Errorf("jsonpb.UnmarshalString(%s, &msg) failed with %v; want success", buf, err) - return - } - - if got, want := msg.Code, int32(codes.Unimplemented); got != want { - t.Errorf("msg.Code = %d; want %d", got, want) - return - } - - if msg.Message == "" { - t.Errorf("msg.Message should not be empty") - return - } -} diff --git a/runtime/BUILD.bazel b/runtime/BUILD.bazel index 745fa692547..b5295bd120a 100644 --- a/runtime/BUILD.bazel +++ b/runtime/BUILD.bazel @@ -20,7 +20,6 @@ go_library( "mux.go", "pattern.go", "proto2_convert.go", - "proto_errors.go", "query.go", ], importpath = "github.com/grpc-ecosystem/grpc-gateway/v2/runtime", diff --git a/runtime/errors.go b/runtime/errors.go index 58c80eec857..9434992ed24 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -10,6 +10,9 @@ import ( "google.golang.org/grpc/status" ) +// ErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request. +type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error) + // HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status. // See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto func HTTPStatusFromCode(code codes.Code) int { @@ -55,61 +58,19 @@ func HTTPStatusFromCode(code codes.Code) int { return http.StatusInternalServerError } -var ( - // HTTPError replies to the request with an error. - // - // HTTPError is called: - // - From generated per-endpoint gateway handler code, when calling the backend results in an error. - // - From gateway runtime code, when forwarding the response message results in an error. - // - // The default value for HTTPError calls the custom error handler configured on the ServeMux via the - // WithProtoErrorHandler serve option if that option was used, calling GlobalHTTPErrorHandler otherwise. - // - // To customize the error handling of a particular ServeMux instance, use the WithProtoErrorHandler - // serve option. - // - // To customize the error format for all ServeMux instances not using the WithProtoErrorHandler serve - // option, set GlobalHTTPErrorHandler to a custom function. - // - // Setting this variable directly to customize error format is deprecated. - HTTPError = MuxOrGlobalHTTPError - - // GlobalHTTPErrorHandler is the HTTPError handler for all ServeMux instances not using the - // WithProtoErrorHandler serve option. - // - // You can set a custom function to this variable to customize error format. - GlobalHTTPErrorHandler = DefaultHTTPError - - // OtherErrorHandler handles gateway errors from parsing and routing client requests for all - // ServeMux instances not using the WithProtoErrorHandler serve option. - // - // It returns the following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest - // - // To customize parsing and routing error handling of a particular ServeMux instance, use the - // WithProtoErrorHandler serve option. - // - // To customize parsing and routing error handling of all ServeMux instances not using the - // WithProtoErrorHandler serve option, set a custom function to this variable. - OtherErrorHandler = DefaultOtherErrorHandler -) - -// MuxOrGlobalHTTPError uses the mux-configured error handler, falling back to GlobalErrorHandler. -func MuxOrGlobalHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { - if mux.protoErrorHandler != nil { - mux.protoErrorHandler(ctx, mux, marshaler, w, r, err) - } else { - GlobalHTTPErrorHandler(ctx, mux, marshaler, w, r, err) - } +// HTTPError uses the mux-configured error handler. +func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) { + mux.errorHandler(ctx, mux, marshaler, w, r, err) } -// DefaultHTTPError is the default implementation of HTTPError. -// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. +// defaultHTTPErrorHandler is the default error handler. +// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode. // If otherwise, it replies with http.StatusInternalServerError. // -// The response body returned by this function is a JSON object, -// which contains a member whose key is "message" and whose value is err.Error(). -func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { - const fallback = `{"error": "failed to marshal error message"}` +// The response body written by this function is a Status message marshaled by the Marshaler. +func defaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { + // return Internal when Marshal failed + const fallback = `{"code": 13, "message": "failed to marshal error message"}` s := status.Convert(err) pb := s.Proto() @@ -117,7 +78,7 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w w.Header().Del("Trailer") contentType := marshaler.ContentType() - // Check marshaler on run time in order to keep backwards compatability + // Check marshaler at runtime in order to keep backwards compatibility. // An interface param needs to be added to the ContentType() function on // the Marshal interface to be able to remove this check if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { @@ -150,9 +111,3 @@ func DefaultHTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w handleForwardResponseTrailer(w, md) } - -// DefaultOtherErrorHandler is the default implementation of OtherErrorHandler. -// It simply writes a string representation of the given error into "w". -func DefaultOtherErrorHandler(w http.ResponseWriter, _ *http.Request, msg string, code int) { - http.Error(w, msg, code) -} diff --git a/runtime/errors_test.go b/runtime/errors_test.go index 83ec1666a03..5ba54eb97db 100644 --- a/runtime/errors_test.go +++ b/runtime/errors_test.go @@ -62,7 +62,8 @@ func TestDefaultHTTPError(t *testing.T) { } { w := httptest.NewRecorder() req, _ := http.NewRequest("", "", nil) // Pass in an empty request to match the signature - runtime.DefaultHTTPError(ctx, &runtime.ServeMux{}, &runtime.JSONPb{}, w, req, spec.err) + mux := runtime.NewServeMux() + runtime.HTTPError(ctx, mux, &runtime.JSONPb{}, w, req, spec.err) if got, want := w.Header().Get("Content-Type"), "application/json"; got != want { t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err) diff --git a/runtime/mux.go b/runtime/mux.go index eb7e435365e..ddb857f3da1 100644 --- a/runtime/mux.go +++ b/runtime/mux.go @@ -16,9 +16,8 @@ import ( // A HandlerFunc handles a specific pair of path pattern and HTTP method. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) -// ErrUnknownURI is the error supplied to a custom ProtoErrorHandlerFunc when -// a request is received with a URI path that does not match any registered -// service method. +// ErrUnknownURI is the error returned when a request is received with a URI path +// or HTTP Method that does not match any registered service method. // // Since gRPC servers return an "Unimplemented" code for requests with an // unrecognized URI path, this error also has a gRPC "Unimplemented" code. @@ -34,7 +33,7 @@ type ServeMux struct { incomingHeaderMatcher HeaderMatcherFunc outgoingHeaderMatcher HeaderMatcherFunc metadataAnnotators []func(context.Context, *http.Request) metadata.MD - protoErrorHandler ProtoErrorHandlerFunc + errorHandler ErrorHandlerFunc disablePathLengthFallback bool lastMatchWins bool } @@ -110,14 +109,12 @@ func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) Se } } -// WithProtoErrorHandler returns a ServeMuxOption for configuring a custom error handler. +// WithErrorHandler returns a ServeMuxOption for configuring a custom error handler. // -// This can be used to handle an error as general proto message defined by gRPC. -// When this option is used, the mux uses the configured error handler instead of HTTPError and -// OtherErrorHandler. -func WithProtoErrorHandler(fn ProtoErrorHandlerFunc) ServeMuxOption { +// This can be used to configure a custom error response. +func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption { return func(serveMux *ServeMux) { - serveMux.protoErrorHandler = fn + serveMux.errorHandler = fn } } @@ -143,6 +140,7 @@ func NewServeMux(opts ...ServeMuxOption) *ServeMux { handlers: make(map[string][]handler), forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), marshalers: makeMarshalerMIMERegistry(), + errorHandler: defaultHTTPErrorHandler, } for _, opt := range opts { @@ -177,28 +175,22 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path if !strings.HasPrefix(path, "/") { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest)) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, http.StatusText(http.StatusBadRequest)) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } components := strings.Split(path[1:], "/") l := len(components) var verb string - if idx := strings.LastIndex(components[l-1], ":"); idx == 0 { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) - } + idx := strings.LastIndex(components[l-1], ":") + if idx == 0 { + _, outboundMarshaler := MarshalerForRequest(s, r) + s.errorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) return - } else if idx > 0 { + } + if idx > 0 { c := components[l-1] components[l-1], verb = c[:idx], c[idx+1:] } @@ -206,13 +198,9 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { r.Method = strings.ToUpper(override) if err := r.ParseForm(); err != nil { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, err.Error()) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, err.Error()) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } } @@ -225,8 +213,7 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // lookup other methods to handle fallback from GET to POST and - // to determine if it is MethodNotAllowed or NotFound. + // lookup other methods to handle fallback from GET to POST. for m, handlers := range s.handlers { if m == r.Method { continue @@ -239,34 +226,22 @@ func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // X-HTTP-Method-Override is optional. Always allow fallback to POST. if s.isPathLengthFallback(r) { if err := r.ParseForm(); err != nil { - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - sterr := status.Error(codes.InvalidArgument, err.Error()) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, sterr) - } else { - OtherErrorHandler(w, r, err.Error(), http.StatusBadRequest) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + sterr := status.Error(codes.InvalidArgument, err.Error()) + s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } h.h(w, r, pathParams) return } - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + s.errorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) return } } - if s.protoErrorHandler != nil { - _, outboundMarshaler := MarshalerForRequest(s, r) - s.protoErrorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) - } else { - OtherErrorHandler(w, r, http.StatusText(http.StatusNotFound), http.StatusNotFound) - } + _, outboundMarshaler := MarshalerForRequest(s, r) + s.errorHandler(ctx, s, outboundMarshaler, w, r, ErrUnknownURI) } // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. diff --git a/runtime/mux_test.go b/runtime/mux_test.go index 21ea5d143f5..3da476cf536 100644 --- a/runtime/mux_test.go +++ b/runtime/mux_test.go @@ -2,7 +2,6 @@ package runtime_test import ( "bytes" - "context" "fmt" "net/http" "net/http/httptest" @@ -22,6 +21,8 @@ func TestMuxServeHTTP(t *testing.T) { verb string } for _, spec := range []struct { + name string + patterns []stubPattern patternOpts []runtime.PatternOpt @@ -33,16 +34,17 @@ func TestMuxServeHTTP(t *testing.T) { respContent string disablePathLengthFallback bool - errHandler runtime.ProtoErrorHandlerFunc muxOpts []runtime.ServeMuxOption }{ { + name: "GET to unregistered path with no registered paths should return 501 Not Implemented", patterns: nil, reqMethod: "GET", reqPath: "/", - respStatus: http.StatusNotFound, + respStatus: http.StatusNotImplemented, }, { + name: "GET to registered path should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -56,6 +58,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo", }, { + name: "GET to unregistered path with should return 501 Not Implemented", patterns: []stubPattern{ { method: "GET", @@ -65,9 +68,10 @@ func TestMuxServeHTTP(t *testing.T) { }, reqMethod: "GET", reqPath: "/bar", - respStatus: http.StatusNotFound, + respStatus: http.StatusNotImplemented, }, { + name: "GET to registered path with two registered paths should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -85,6 +89,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo", }, { + name: "POST to registered path with GET also registered should return POST 200 OK", patterns: []stubPattern{ { method: "GET", @@ -103,6 +108,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "POST /foo", }, { + name: "DELETE to path with GET registered should return 501 NotImplemented", patterns: []stubPattern{ { method: "GET", @@ -112,9 +118,10 @@ func TestMuxServeHTTP(t *testing.T) { }, reqMethod: "DELETE", reqPath: "/foo", - respStatus: http.StatusMethodNotAllowed, + respStatus: http.StatusNotImplemented, }, { + name: "POST with path length fallback to registered path with GET should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -131,6 +138,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo", }, { + name: "POST with path length fallback to registered path with GET with path length fallback disabled should return 501 Not Implemented", patterns: []stubPattern{ { method: "GET", @@ -143,11 +151,11 @@ func TestMuxServeHTTP(t *testing.T) { headers: map[string]string{ "Content-Type": "application/x-www-form-urlencoded", }, - respStatus: http.StatusMethodNotAllowed, - respContent: "Method Not Allowed\n", + respStatus: http.StatusNotImplemented, disablePathLengthFallback: true, }, { + name: "POST with path length fallback to registered path with POST with path length fallback disabled should return POST 200 OK", patterns: []stubPattern{ { method: "GET", @@ -170,6 +178,7 @@ func TestMuxServeHTTP(t *testing.T) { disablePathLengthFallback: true, }, { + name: "POST with path length fallback and method override to registered path with GET should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -192,6 +201,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo", }, { + name: "POST to registered path with GET should return 501 NotImplemented", patterns: []stubPattern{ { method: "GET", @@ -204,9 +214,10 @@ func TestMuxServeHTTP(t *testing.T) { headers: map[string]string{ "Content-Type": "application/json", }, - respStatus: http.StatusMethodNotAllowed, + respStatus: http.StatusNotImplemented, }, { + name: "POST to registered path with verb should return POST 200 OK", patterns: []stubPattern{ { method: "POST", @@ -224,6 +235,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "POST /foo:bar", }, { + name: "POST to registered path with verb and non-verb should return POST 200 OK", patterns: []stubPattern{ { method: "GET", @@ -246,38 +258,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo/{id=*}:verb", }, { - // mux identifying invalid path results in 'Not Found' status - // (with custom handler looking for ErrUnknownURI) - patterns: []stubPattern{ - { - method: "GET", - ops: []int{int(utilities.OpLitPush), 0}, - pool: []string{"unimplemented"}, - }, - }, - reqMethod: "GET", - reqPath: "/foobar", - respStatus: http.StatusNotFound, - respContent: "GET /foobar", - errHandler: unknownPathIs404, - }, - { - // server returning unimplemented results in 'Not Implemented' code - // even when using custom error handler - patterns: []stubPattern{ - { - method: "GET", - ops: []int{int(utilities.OpLitPush), 0}, - pool: []string{"unimplemented"}, - }, - }, - reqMethod: "GET", - reqPath: "/unimplemented", - respStatus: http.StatusNotImplemented, - respContent: `GET /unimplemented`, - errHandler: unknownPathIs404, - }, - { + name: "GET to registered path without AssumeColonVerb should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -295,6 +276,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo/{id=*}", }, { + name: "GET to registered path without AssumeColonVerb with colon verb should return GET 200 OK", patterns: []stubPattern{ { method: "GET", @@ -312,6 +294,7 @@ func TestMuxServeHTTP(t *testing.T) { respContent: "GET /foo/{id=*}", }, { + name: "POST to registered path without AssumeColonVerb with LastMatchWins should match the correctly", patterns: []stubPattern{ { method: "POST", @@ -336,64 +319,53 @@ func TestMuxServeHTTP(t *testing.T) { muxOpts: []runtime.ServeMuxOption{runtime.WithLastMatchWins()}, }, } { - opts := spec.muxOpts - if spec.disablePathLengthFallback { - opts = append(opts, runtime.WithDisablePathLengthFallback()) - } - if spec.errHandler != nil { - opts = append(opts, runtime.WithProtoErrorHandler(spec.errHandler)) - } - mux := runtime.NewServeMux(opts...) - for _, p := range spec.patterns { - func(p stubPattern) { - pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb, spec.patternOpts...) - if err != nil { - t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) - } - mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { - if r.URL.Path == "/unimplemented" { - // simulate method returning "unimplemented" error - _, m := runtime.MarshalerForRequest(mux, r) - runtime.HTTPError(r.Context(), mux, m, w, r, status.Error(codes.Unimplemented, http.StatusText(http.StatusNotImplemented))) - w.WriteHeader(http.StatusNotImplemented) - return + t.Run(spec.name, func(t *testing.T) { + opts := spec.muxOpts + if spec.disablePathLengthFallback { + opts = append(opts, runtime.WithDisablePathLengthFallback()) + } + mux := runtime.NewServeMux(opts...) + for _, p := range spec.patterns { + func(p stubPattern) { + pat, err := runtime.NewPattern(1, p.ops, p.pool, p.verb, spec.patternOpts...) + if err != nil { + t.Fatalf("runtime.NewPattern(1, %#v, %#v, %q) failed with %v; want success", p.ops, p.pool, p.verb, err) } - fmt.Fprintf(w, "%s %s", p.method, pat.String()) - }) - }(p) - } - - url := fmt.Sprintf("http://host.example%s", spec.reqPath) - r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) - if err != nil { - t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) - } - for name, value := range spec.headers { - r.Header.Set(name, value) - } - w := httptest.NewRecorder() - mux.ServeHTTP(w, r) + mux.Handle(p.method, pat, func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + if r.URL.Path == "/unimplemented" { + // simulate method returning "unimplemented" error + _, m := runtime.MarshalerForRequest(mux, r) + st := status.New(codes.Unimplemented, http.StatusText(http.StatusNotImplemented)) + m.NewEncoder(w).Encode(st.Proto()) + w.WriteHeader(http.StatusNotImplemented) + return + } + fmt.Fprintf(w, "%s %s", p.method, pat.String()) + }) + }(p) + } - if got, want := w.Code, spec.respStatus; got != want { - t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) - } - if spec.respContent != "" { - if got, want := w.Body.String(), spec.respContent; got != want { - t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + url := fmt.Sprintf("http://host.example%s", spec.reqPath) + r, err := http.NewRequest(spec.reqMethod, url, bytes.NewReader(nil)) + if err != nil { + t.Fatalf("http.NewRequest(%q, %q, nil) failed with %v; want success", spec.reqMethod, url, err) } - } - } -} + for name, value := range spec.headers { + r.Header.Set(name, value) + } + w := httptest.NewRecorder() + mux.ServeHTTP(w, r) -func unknownPathIs404(ctx context.Context, mux *runtime.ServeMux, m runtime.Marshaler, w http.ResponseWriter, r *http.Request, err error) { - if err == runtime.ErrUnknownURI { - w.WriteHeader(http.StatusNotFound) - } else { - c := status.Convert(err).Code() - w.WriteHeader(runtime.HTTPStatusFromCode(c)) + if got, want := w.Code, spec.respStatus; got != want { + t.Errorf("w.Code = %d; want %d; patterns=%v; req=%v", got, want, spec.patterns, r) + } + if spec.respContent != "" { + if got, want := w.Body.String(), spec.respContent; got != want { + t.Errorf("w.Body = %q; want %q; patterns=%v; req=%v", got, want, spec.patterns, r) + } + } + }) } - - fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) } var defaultHeaderMatcherTests = []struct { diff --git a/runtime/pattern.go b/runtime/pattern.go index c2e4bf956b0..a1bd2496fc9 100644 --- a/runtime/pattern.go +++ b/runtime/pattern.go @@ -21,7 +21,8 @@ type op struct { operand int } -// Pattern is a template pattern of http request paths defined in github.com/googleapis/googleapis/google/api/http.proto. +// Pattern is a template pattern of http request paths defined in +// https://github.com/googleapis/googleapis/blob/master/google/api/http.proto type Pattern struct { // ops is a list of operations ops []op diff --git a/runtime/proto_errors.go b/runtime/proto_errors.go deleted file mode 100644 index b0cf0d0bb3f..00000000000 --- a/runtime/proto_errors.go +++ /dev/null @@ -1,70 +0,0 @@ -package runtime - -import ( - "context" - "io" - "net/http" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/status" -) - -// ProtoErrorHandlerFunc handles the error as a gRPC error generated via status package and replies to the request. -type ProtoErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error) - -var _ ProtoErrorHandlerFunc = DefaultHTTPProtoErrorHandler - -// DefaultHTTPProtoErrorHandler is an implementation of HTTPError. -// If "err" is an error from gRPC system, the function replies with the status code mapped by HTTPStatusFromCode. -// If otherwise, it replies with http.StatusInternalServerError. -// -// The response body returned by this function is a Status message marshaled by a Marshaler. -// -// Do not set this function to HTTPError variable directly, use WithProtoErrorHandler option instead. -func DefaultHTTPProtoErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, _ *http.Request, err error) { - // return Internal when Marshal failed - const fallback = `{"code": 13, "message": "failed to marshal error message"}` - - s, ok := status.FromError(err) - if !ok { - s = status.New(codes.Unknown, err.Error()) - } - - w.Header().Del("Trailer") - - contentType := marshaler.ContentType() - // Check marshaler on run time in order to keep backwards compatability - // An interface param needs to be added to the ContentType() function on - // the Marshal interface to be able to remove this check - if typeMarshaler, ok := marshaler.(contentTypeMarshaler); ok { - pb := s.Proto() - contentType = typeMarshaler.ContentTypeFromMessage(pb) - } - w.Header().Set("Content-Type", contentType) - - buf, merr := marshaler.Marshal(s.Proto()) - if merr != nil { - grpclog.Infof("Failed to marshal error message %q: %v", s.Proto(), merr) - w.WriteHeader(http.StatusInternalServerError) - if _, err := io.WriteString(w, fallback); err != nil { - grpclog.Infof("Failed to write response: %v", err) - } - return - } - - md, ok := ServerMetadataFromContext(ctx) - if !ok { - grpclog.Infof("Failed to extract ServerMetadata from context") - } - - handleForwardResponseServerMetadata(w, mux, md) - handleForwardResponseTrailerHeader(w, md) - st := HTTPStatusFromCode(s.Code()) - w.WriteHeader(st) - if _, err := w.Write(buf); err != nil { - grpclog.Infof("Failed to write response: %v", err) - } - - handleForwardResponseTrailer(w, md) -}