From 72ff71ddc7a412d30897a9e54a8bca83357eec3f Mon Sep 17 00:00:00 2001 From: Akshay Shah Date: Tue, 8 Aug 2023 21:31:12 -0700 Subject: [PATCH] Generate aliases for connect.Request/Response Reduce Connect's generics-induced wordiness by generating type aliases for `connect.Request` and `connect.Response`. For an actual net reduction in wordiness, we can't generate long identifiers. That reduces our ability to manage name collisions, so we only generate aliases for messages that are declared in the same file and used exclusively as requests or responses (but not both). Notably, we don't attempt to generate aliases for the stream types - they end up even wordier than the generic types, and they end up very confusingly named. --- README.md | 5 +- client_ext_test.go | 4 +- cmd/protoc-gen-connect-go/main.go | 259 ++++++++++++++---- connect_ext_test.go | 49 ++-- error_not_modified_example_test.go | 5 +- handler_example_test.go | 5 +- handler_ext_test.go | 5 +- .../v1/collidev1connect/collide.connect.go | 13 +- .../ping/v1/pingv1connect/ping.connect.go | 37 ++- recover_ext_test.go | 7 +- 10 files changed, 270 insertions(+), 119 deletions(-) diff --git a/README.md b/README.md index 266ce081..048784c8 100644 --- a/README.md +++ b/README.md @@ -73,10 +73,7 @@ type PingServer struct { pingv1connect.UnimplementedPingServiceHandler // returns errors from all methods } -func (ps *PingServer) Ping( - ctx context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (ps *PingServer) Ping(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { // connect.Request and connect.Response give you direct access to headers and // trailers. No context-based nonsense! log.Println(req.Header().Get("Some-Header")) diff --git a/client_ext_test.go b/client_ext_test.go index 40cbabed..5bf4050b 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -195,9 +195,7 @@ type notModifiedPingServer struct { etag string } -func (s *notModifiedPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (s *notModifiedPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { if req.HTTPMethod() == http.MethodGet && req.Header().Get("If-None-Match") == s.etag { return nil, connect.NewNotModifiedError(http.Header{"Etag": []string{s.etag}}) } diff --git a/cmd/protoc-gen-connect-go/main.go b/cmd/protoc-gen-connect-go/main.go index 61aeedcd..5fed8fe5 100644 --- a/cmd/protoc-gen-connect-go/main.go +++ b/cmd/protoc-gen-connect-go/main.go @@ -135,8 +135,12 @@ func generate(plugin *protogen.Plugin, file *protogen.File) { generatedFile.Import(file.GoImportPath) generatePreamble(generatedFile, file) generateServiceNameConstants(generatedFile, file.Services) + + paramNames := newParameterNames(generatedFile, file.Services) + generateTypeAliases(generatedFile, paramNames) + for _, service := range file.Services { - generateService(generatedFile, service) + generateService(generatedFile, service, paramNames) } } @@ -219,16 +223,16 @@ func generateServiceNameConstants(g *protogen.GeneratedFile, services []*protoge g.P() } -func generateService(g *protogen.GeneratedFile, service *protogen.Service) { +func generateService(g *protogen.GeneratedFile, service *protogen.Service, paramNames *parameterNames) { names := newNames(service) - generateClientInterface(g, service, names) - generateClientImplementation(g, service, names) - generateServerInterface(g, service, names) + generateClientInterface(g, service, names, paramNames) + generateClientImplementation(g, service, names, paramNames) + generateServerInterface(g, service, names, paramNames) generateServerConstructor(g, service, names) - generateUnimplementedServerImplementation(g, service, names) + generateUnimplementedServerImplementation(g, service, names, paramNames) } -func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.Client, " is a client for the ", service.Desc.FullName(), " service.") if isDeprecatedService(service) { g.P("//") @@ -243,13 +247,13 @@ func generateClientInterface(g *protogen.GeneratedFile, service *protogen.Servic method.Comments.Leading, isDeprecatedMethod(method), ) - g.P(clientSignature(g, method, false /* named */)) + g.P(clientSignature(g, method, paramNames.Get(method), false /* named */)) } g.P("}") g.P() } -func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { clientOption := connectPackage.Ident("ClientOption") // Client constructor. @@ -304,11 +308,11 @@ func generateClientImplementation(g *protogen.GeneratedFile, service *protogen.S g.P("}") g.P() for _, method := range service.Methods { - generateClientMethod(g, method, names) + generateClientMethod(g, method, names, paramNames.Get(method)) } } -func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, names names) { +func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, names names, paramNames *methodParameterNames) { receiver := names.ClientImpl isStreamingClient := method.Desc.IsStreamingClient() isStreamingServer := method.Desc.IsStreamingServer() @@ -317,7 +321,7 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P("//") deprecated(g) } - g.P("func (c *", receiver, ") ", clientSignature(g, method, true /* named */), " {") + g.P("func (c *", receiver, ") ", clientSignature(g, method, paramNames, true /* named */), " {") switch { case isStreamingClient && !isStreamingServer: @@ -333,37 +337,31 @@ func generateClientMethod(g *protogen.GeneratedFile, method *protogen.Method, na g.P() } -func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, named bool) string { +func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames, named bool) string { reqName := "req" ctxName := "ctx" if !named { reqName, ctxName = "", "" } + ctxType := g.QualifiedGoIdent(contextPackage.Ident("Context")) if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + - "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + return method.GoName + "(" + ctxName + " " + ctxType + ") *" + paramNames.ClientOutput.Name() } if method.Desc.IsStreamingClient() { // client streaming - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ") " + - "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + return method.GoName + "(" + ctxName + " " + ctxType + ") *" + paramNames.ClientOutput.Name() } if method.Desc.IsStreamingServer() { - return method.GoName + "(" + ctxName + " " + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + " *" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "]) " + - "(*" + g.QualifiedGoIdent(connectPackage.Ident("ServerStreamForClient")) + - "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ", error)" + return method.GoName + "(" + ctxName + " " + ctxType + + ", " + reqName + " *" + paramNames.ClientInput.Name() + ") " + + "(*" + paramNames.ClientOutput.Name() + ", error)" } // unary; symmetric so we can re-use server templating - return method.GoName + serverSignatureParams(g, method, named) + return method.GoName + serverSignatureParams(g, method, paramNames, named) } -func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.Server, " is an implementation of the ", service.Desc.FullName(), " service.") if isDeprecatedService(service) { g.P("//") @@ -378,7 +376,7 @@ func generateServerInterface(g *protogen.GeneratedFile, service *protogen.Servic isDeprecatedMethod(method), ) g.AnnotateSymbol(names.Server+"."+method.GoName, protogen.Annotation{Location: method.Location}) - g.P(serverSignature(g, method)) + g.P(serverSignature(g, method, paramNames.Get(method))) } g.P("}") g.P() @@ -439,12 +437,12 @@ func generateServerConstructor(g *protogen.GeneratedFile, service *protogen.Serv g.P() } -func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names) { +func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, service *protogen.Service, names names, paramNames *parameterNames) { wrapComments(g, names.UnimplementedServer, " returns CodeUnimplemented from all methods.") g.P("type ", names.UnimplementedServer, " struct {}") g.P() for _, method := range service.Methods { - g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method), "{") + g.P("func (", names.UnimplementedServer, ") ", serverSignature(g, method, paramNames.Get(method)), "{") if method.Desc.IsStreamingServer() { g.P("return ", connectPackage.Ident("NewError"), "(", connectPackage.Ident("CodeUnimplemented"), ", ", errorsPackage.Ident("New"), @@ -460,46 +458,47 @@ func generateUnimplementedServerImplementation(g *protogen.GeneratedFile, servic g.P() } -func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { - return method.GoName + serverSignatureParams(g, method, false /* named */) +func generateTypeAliases(g *protogen.GeneratedFile, paramNames *parameterNames) { + if len(paramNames.Aliases) == 0 { + return + } + g.P("type (") + for _, alias := range paramNames.Aliases { + g.P(alias[0], " = ", alias[1]) + } + g.P(")") + g.P() } -func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, named bool) string { +func serverSignature(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames) string { + return method.GoName + serverSignatureParams(g, method, paramNames, false /* named */) +} + +func serverSignatureParams(g *protogen.GeneratedFile, method *protogen.Method, paramNames *methodParameterNames, named bool) string { ctxName := "ctx " reqName := "req " streamName := "stream " if !named { ctxName, reqName, streamName = "", "", "" } + ctxType := g.QualifiedGoIdent(contextPackage.Ident("Context")) if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { // bidi streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("BidiStream")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ") error" + return "(" + ctxName + ctxType + ", " + streamName + "*" + paramNames.HandlerInput.Name() + ") error" } if method.Desc.IsStreamingClient() { // client streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + ", " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + - "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + - ") (*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "] ,error)" + return "(" + ctxName + ctxType + ", " + streamName + "*" + paramNames.HandlerInput.Name() + + ") (*" + paramNames.HandlerOutput.Name() + " ,error)" } if method.Desc.IsStreamingServer() { // server streaming - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "], " + - streamName + "*" + g.QualifiedGoIdent(connectPackage.Ident("ServerStream")) + - "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + - ") error" + return "(" + ctxName + ctxType + ", " + reqName + "*" + paramNames.HandlerInput.Name() + ", " + + streamName + "*" + paramNames.HandlerOutput.Name() + ") error" } // unary - return "(" + ctxName + g.QualifiedGoIdent(contextPackage.Ident("Context")) + - ", " + reqName + "*" + g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + - g.QualifiedGoIdent(method.Input.GoIdent) + "]) " + - "(*" + g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + - g.QualifiedGoIdent(method.Output.GoIdent) + "], error)" + return "(" + ctxName + ctxType + ", " + reqName + "*" + paramNames.HandlerInput.Name() + ") " + + "(*" + paramNames.HandlerOutput.Name() + ", error)" } func procedureConstName(m *protogen.Method) string { @@ -628,3 +627,159 @@ func newNames(service *protogen.Service) names { UnimplementedServer: fmt.Sprintf("Unimplemented%sHandler", base), } } + +type parameterNames struct { + Aliases [][2]string + Methods map[protoreflect.FullName]*methodParameterNames +} + +func newParameterNames(g *protogen.GeneratedFile, services []*protogen.Service) *parameterNames { //nolint:gocyclo + // First, make one pass to find alias-able request and response types. We're + // trying to shorten user-visible type names, so there's no point in + // producing aliases that are just as long as the spelled-out generic types. + // + // To safely produce short aliases, we're only aliasing messages that are: + // - used as a connect.Request or connect.Response, but not both. + // - from the same protobuf package and file as the service. + // Ideally we'd allow aliases for types from different files in the same + // package, but the plugin contract doesn't allow us to inspect services in + // files other than the ones we're generating code for. + // + // Notably, we're not generating aliases for Connect's stream types: useful + // aliases for them are just as wordy as the generic types, so the extra + // indirection isn't worth it. + const ( + asRequest = 0b01 + asResponse = 0b10 + ) + aliasable := make(map[protoreflect.FullName]uint8) + for _, service := range services { + pkg := service.Desc.ParentFile().Package() + path := service.Desc.ParentFile().Path() + for _, method := range service.Methods { + if method.Desc.IsStreamingClient() && method.Desc.IsStreamingServer() { + continue + } + if method.Input.Desc.ParentFile().Package() == pkg && method.Input.Desc.ParentFile().Path() == path { + aliasable[method.Input.Desc.FullName()] |= asRequest + } + if method.Output.Desc.ParentFile().Package() == pkg && method.Input.Desc.ParentFile().Path() == path { + aliasable[method.Output.Desc.FullName()] |= asResponse + } + } + } + for fqn, usage := range aliasable { + if usage == asRequest&asResponse { + delete(aliasable, fqn) + } + } + // Now, make another pass to choose names. + params := ¶meterNames{Methods: make(map[protoreflect.FullName]*methodParameterNames)} + for _, service := range services { + for _, method := range service.Methods { + isStreamingClient := method.Desc.IsStreamingClient() + isStreamingServer := method.Desc.IsStreamingServer() + methodParams := &methodParameterNames{} + switch { + case isStreamingClient && isStreamingServer: + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("BidiStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("BidiStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + case isStreamingClient: + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ClientStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + + ", " + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ClientStream")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.HandlerOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Response")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + methodParams.HandlerOutput.Alias = method.Output.GoIdent.GoName + } + case isStreamingServer: + methodParams.ClientInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ServerStreamForClient")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + + "[" + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.HandlerOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("ServerStream")) + + "[" + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + methodParams.ClientInput.Alias = method.Input.GoIdent.GoName + methodParams.HandlerInput.Alias = methodParams.ClientInput.Alias + } + default: + methodParams.ClientInput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Request")) + "[" + + g.QualifiedGoIdent(method.Input.GoIdent) + "]" + methodParams.ClientOutput.Generic = g.QualifiedGoIdent(connectPackage.Ident("Response")) + "[" + + g.QualifiedGoIdent(method.Output.GoIdent) + "]" + methodParams.HandlerInput.Generic = methodParams.ClientInput.Generic + methodParams.HandlerOutput.Generic = methodParams.ClientOutput.Generic + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + methodParams.ClientInput.Alias = method.Input.GoIdent.GoName + methodParams.HandlerInput.Alias = methodParams.ClientInput.Alias + } + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + methodParams.ClientOutput.Alias = method.Output.GoIdent.GoName + methodParams.HandlerOutput.Alias = methodParams.ClientOutput.Alias + } + } + params.Methods[method.Desc.FullName()] = methodParams + } + } + // Finally, another pass to prepare the actual alias declarations. We need to + // deduplicate (in case the same message is used in multiple RPCs), and we'd + // like the aliases to appear in the same order as they're used in the RPC + // definitions. + for _, service := range services { + for _, method := range service.Methods { + methodParams := params.Get(method) + if _, ok := aliasable[method.Input.Desc.FullName()]; ok { + if methodParams.ClientInput.Alias != "" { + params.Aliases = append(params.Aliases, [2]string{methodParams.ClientInput.Alias, methodParams.ClientInput.Generic}) + } + if methodParams.HandlerInput.Alias != "" && methodParams.HandlerInput.Alias != methodParams.ClientInput.Alias { + params.Aliases = append(params.Aliases, [2]string{methodParams.HandlerInput.Alias, methodParams.HandlerInput.Generic}) + } + delete(aliasable, method.Input.Desc.FullName()) + } + if _, ok := aliasable[method.Output.Desc.FullName()]; ok { + if methodParams.ClientOutput.Alias != "" { + params.Aliases = append(params.Aliases, [2]string{methodParams.ClientOutput.Alias, methodParams.ClientOutput.Generic}) + } + if methodParams.HandlerOutput.Alias != "" && methodParams.HandlerOutput.Alias != methodParams.ClientOutput.Alias { + params.Aliases = append(params.Aliases, [2]string{methodParams.HandlerOutput.Alias, methodParams.HandlerOutput.Generic}) + } + delete(aliasable, method.Output.Desc.FullName()) + } + } + } + return params +} + +func (pn *parameterNames) Get(method *protogen.Method) *methodParameterNames { + return pn.Methods[method.Desc.FullName()] +} + +type methodParameterNames struct { + ClientInput aliasedTypeName + ClientOutput aliasedTypeName + HandlerInput aliasedTypeName + HandlerOutput aliasedTypeName +} + +type aliasedTypeName struct { + Generic string + Alias string +} + +func (n aliasedTypeName) Name() string { + if n.Alias != "" { + return n.Alias + } + return n.Generic +} diff --git a/connect_ext_test.go b/connect_ext_test.go index 4545071c..d75a2996 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -501,7 +501,7 @@ func TestHeaderBasic(t *testing.T) { ) pingServer := &pluggablePingServer{ - ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { assert.Equal(t, request.Header().Get(key), cval) response := connect.NewResponse(&pingv1.PingResponse{}) response.Header().Set(key, hval) @@ -529,7 +529,7 @@ func TestHeaderHost(t *testing.T) { ) pingServer := &pluggablePingServer{ - ping: func(_ context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { assert.Equal(t, request.Header().Get(key), cval) response := connect.NewResponse(&pingv1.PingResponse{}) return response, nil @@ -583,7 +583,7 @@ func TestTimeoutParsing(t *testing.T) { t.Parallel() const timeout = 10 * time.Minute pingServer := &pluggablePingServer{ - ping: func(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { deadline, ok := ctx.Deadline() assert.True(t, ok) remaining := time.Until(deadline) @@ -1597,7 +1597,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Equal(t, stream.Conn().Spec().StreamType, connect.StreamTypeServer) assert.Equal(t, stream.Conn().Spec().Procedure, pingv1connect.PingServiceCountUpProcedure) assert.False(t, stream.Conn().Spec().IsClient) @@ -1614,7 +1614,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream-send", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { assert.Nil(t, stream.Send(&pingv1.CountUpResponse{Number: 1})) return nil }, @@ -1631,7 +1631,7 @@ func TestStreamForServer(t *testing.T) { t.Run("server-stream-send-nil", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - countUp: func(ctx context.Context, req *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + countUp: func(ctx context.Context, req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse]) error { stream.ResponseHeader().Set("foo", "bar") stream.ResponseTrailer().Set("bas", "blah") assert.Nil(t, stream.Send(nil)) @@ -1653,7 +1653,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.Equal(t, stream.Spec().StreamType, connect.StreamTypeClient) assert.Equal(t, stream.Spec().Procedure, pingv1connect.PingServiceSumProcedure) assert.False(t, stream.Spec().IsClient) @@ -1675,7 +1675,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream-conn", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.NotNil(t, stream.Conn().Send("not-proto")) return connect.NewResponse(&pingv1.SumResponse{}), nil }, @@ -1690,7 +1690,7 @@ func TestStreamForServer(t *testing.T) { t.Run("client-stream-send-msg", func(t *testing.T) { t.Parallel() client, server := newPingServer(&pluggablePingServer{ - sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { assert.Nil(t, stream.Conn().Send(&pingv1.SumResponse{Sum: 2})) return connect.NewResponse(&pingv1.SumResponse{}), nil }, @@ -1711,7 +1711,7 @@ func TestConnectHTTPErrorCodes(t *testing.T) { t.Helper() mux := http.NewServeMux() pluggableServer := &pluggablePingServer{ - ping: func(_ context.Context, _ *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, _ *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return nil, connect.NewError(connectCode, errors.New("error")) }, } @@ -1993,7 +1993,7 @@ func TestAllowCustomUserAgent(t *testing.T) { const customAgent = "custom" mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ - ping: func(_ context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { agent := req.Header().Get("User-Agent") assert.Equal(t, agent, customAgent) return connect.NewResponse(&pingv1.PingResponse{Number: req.Msg.Number}), nil @@ -2063,10 +2063,10 @@ func TestHandlerReturnsNilResponse(t *testing.T) { mux := http.NewServeMux() mux.Handle(pingv1connect.NewPingServiceHandler(&pluggablePingServer{ - ping: func(ctx context.Context, req *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { + ping: func(ctx context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return nil, nil //nolint: nilnil }, - sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + sum: func(ctx context.Context, req *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) { return nil, nil //nolint: nilnil }, }, connect.WithRecover(recoverPanic))) @@ -2353,29 +2353,26 @@ func (c failCodec) Unmarshal(data []byte, message any) error { type pluggablePingServer struct { pingv1connect.UnimplementedPingServiceHandler - ping func(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) - sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) - countUp func(context.Context, *connect.Request[pingv1.CountUpRequest], *connect.ServerStream[pingv1.CountUpResponse]) error + ping func(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) + sum func(context.Context, *connect.ClientStream[pingv1.SumRequest]) (*pingv1connect.SumResponse, error) + countUp func(context.Context, *pingv1connect.CountUpRequest, *connect.ServerStream[pingv1.CountUpResponse]) error cumSum func(context.Context, *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error } -func (p *pluggablePingServer) Ping( - ctx context.Context, - request *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (p *pluggablePingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return p.ping(ctx, request) } func (p *pluggablePingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], -) (*connect.Response[pingv1.SumResponse], error) { +) (*pingv1connect.SumResponse, error) { return p.sum(ctx, stream) } func (p *pluggablePingServer) CountUp( ctx context.Context, - req *connect.Request[pingv1.CountUpRequest], + req *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { return p.countUp(ctx, req, stream) @@ -2431,7 +2428,7 @@ type pingServer struct { checkMetadata bool } -func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { +func (p pingServer) Ping(ctx context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2452,7 +2449,7 @@ func (p pingServer) Ping(ctx context.Context, request *connect.Request[pingv1.Pi return response, nil } -func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.FailRequest]) (*connect.Response[pingv1.FailResponse], error) { +func (p pingServer) Fail(ctx context.Context, request *pingv1connect.FailRequest) (*pingv1connect.FailResponse, error) { if err := expectClientHeader(p.checkMetadata, request); err != nil { return nil, err } @@ -2471,7 +2468,7 @@ func (p pingServer) Fail(ctx context.Context, request *connect.Request[pingv1.Fa func (p pingServer) Sum( ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest], -) (*connect.Response[pingv1.SumResponse], error) { +) (*pingv1connect.SumResponse, error) { if p.checkMetadata { if err := expectMetadata(stream.RequestHeader(), "header", clientHeader, headerValue); err != nil { return nil, err @@ -2498,7 +2495,7 @@ func (p pingServer) Sum( func (p pingServer) CountUp( ctx context.Context, - request *connect.Request[pingv1.CountUpRequest], + request *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { if err := expectClientHeader(p.checkMetadata, request); err != nil { diff --git a/error_not_modified_example_test.go b/error_not_modified_example_test.go index 881a4b45..b1f05f84 100644 --- a/error_not_modified_example_test.go +++ b/error_not_modified_example_test.go @@ -34,10 +34,7 @@ type ExampleCachingPingServer struct { // Ping is idempotent and free of side effects (and the Protobuf schema // indicates this), so clients using the Connect protocol may call it with HTTP // GET requests. This implementation uses Etags to manage client-side caching. -func (*ExampleCachingPingServer) Ping( - _ context.Context, - req *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (*ExampleCachingPingServer) Ping(_ context.Context, req *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { resp := connect.NewResponse(&pingv1.PingResponse{ Number: req.Msg.Number, }) diff --git a/handler_example_test.go b/handler_example_test.go index f5c2c0da..2d0666d6 100644 --- a/handler_example_test.go +++ b/handler_example_test.go @@ -30,10 +30,7 @@ type ExamplePingServer struct { } // Ping implements pingv1connect.PingServiceHandler. -func (*ExamplePingServer) Ping( - _ context.Context, - request *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (*ExamplePingServer) Ping(_ context.Context, request *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { return connect.NewResponse( &pingv1.PingResponse{ Number: request.Msg.Number, diff --git a/handler_ext_test.go b/handler_ext_test.go index ca71712e..ea3052b0 100644 --- a/handler_ext_test.go +++ b/handler_ext_test.go @@ -24,7 +24,6 @@ import ( connect "connectrpc.com/connect" "connectrpc.com/connect/internal/assert" - pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" "connectrpc.com/connect/internal/gen/connect/ping/v1/pingv1connect" ) @@ -213,6 +212,6 @@ type successPingServer struct { pingv1connect.UnimplementedPingServiceHandler } -func (successPingServer) Ping(context.Context, *connect.Request[pingv1.PingRequest]) (*connect.Response[pingv1.PingResponse], error) { - return &connect.Response[pingv1.PingResponse]{}, nil +func (successPingServer) Ping(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { + return &pingv1connect.PingResponse{}, nil } diff --git a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go index e0ae47ae..483e1c4b 100644 --- a/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go +++ b/internal/gen/connect/collide/v1/collidev1connect/collide.connect.go @@ -51,9 +51,14 @@ const ( CollideServiceImportProcedure = "/connect.collide.v1.CollideService/Import" ) +type ( + ImportRequest = connect.Request[v1.ImportRequest] + ImportResponse = connect.Response[v1.ImportResponse] +) + // CollideServiceClient is a client for the connect.collide.v1.CollideService service. type CollideServiceClient interface { - Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) + Import(context.Context, *ImportRequest) (*ImportResponse, error) } // NewCollideServiceClient constructs a client for the connect.collide.v1.CollideService service. By @@ -80,13 +85,13 @@ type collideServiceClient struct { } // Import calls connect.collide.v1.CollideService.Import. -func (c *collideServiceClient) Import(ctx context.Context, req *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) { +func (c *collideServiceClient) Import(ctx context.Context, req *ImportRequest) (*ImportResponse, error) { return c._import.CallUnary(ctx, req) } // CollideServiceHandler is an implementation of the connect.collide.v1.CollideService service. type CollideServiceHandler interface { - Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) + Import(context.Context, *ImportRequest) (*ImportResponse, error) } // NewCollideServiceHandler builds an HTTP handler from the service implementation. It returns the @@ -113,6 +118,6 @@ func NewCollideServiceHandler(svc CollideServiceHandler, opts ...connect.Handler // UnimplementedCollideServiceHandler returns CodeUnimplemented from all methods. type UnimplementedCollideServiceHandler struct{} -func (UnimplementedCollideServiceHandler) Import(context.Context, *connect.Request[v1.ImportRequest]) (*connect.Response[v1.ImportResponse], error) { +func (UnimplementedCollideServiceHandler) Import(context.Context, *ImportRequest) (*ImportResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.collide.v1.CollideService.Import is not implemented")) } diff --git a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go index 7a99236c..5fc78432 100644 --- a/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go +++ b/internal/gen/connect/ping/v1/pingv1connect/ping.connect.go @@ -64,16 +64,25 @@ const ( PingServiceCumSumProcedure = "/connect.ping.v1.PingService/CumSum" ) +type ( + PingRequest = connect.Request[v1.PingRequest] + PingResponse = connect.Response[v1.PingResponse] + FailRequest = connect.Request[v1.FailRequest] + FailResponse = connect.Response[v1.FailResponse] + SumResponse = connect.Response[v1.SumResponse] + CountUpRequest = connect.Request[v1.CountUpRequest] +) + // PingServiceClient is a client for the connect.ping.v1.PingService service. type PingServiceClient interface { // Ping sends a ping to the server to determine if it's reachable. - Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) + Ping(context.Context, *PingRequest) (*PingResponse, error) // Fail always fails. - Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) + Fail(context.Context, *FailRequest) (*FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. Sum(context.Context) *connect.ClientStreamForClient[v1.SumRequest, v1.SumResponse] // CountUp returns a stream of the numbers up to the given request. - CountUp(context.Context, *connect.Request[v1.CountUpRequest]) (*connect.ServerStreamForClient[v1.CountUpResponse], error) + CountUp(context.Context, *CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) // CumSum determines the cumulative sum of all the numbers sent on the stream. CumSum(context.Context) *connect.BidiStreamForClient[v1.CumSumRequest, v1.CumSumResponse] } @@ -127,12 +136,12 @@ type pingServiceClient struct { } // Ping calls connect.ping.v1.PingService.Ping. -func (c *pingServiceClient) Ping(ctx context.Context, req *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) { +func (c *pingServiceClient) Ping(ctx context.Context, req *PingRequest) (*PingResponse, error) { return c.ping.CallUnary(ctx, req) } // Fail calls connect.ping.v1.PingService.Fail. -func (c *pingServiceClient) Fail(ctx context.Context, req *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) { +func (c *pingServiceClient) Fail(ctx context.Context, req *FailRequest) (*FailResponse, error) { return c.fail.CallUnary(ctx, req) } @@ -142,7 +151,7 @@ func (c *pingServiceClient) Sum(ctx context.Context) *connect.ClientStreamForCli } // CountUp calls connect.ping.v1.PingService.CountUp. -func (c *pingServiceClient) CountUp(ctx context.Context, req *connect.Request[v1.CountUpRequest]) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { +func (c *pingServiceClient) CountUp(ctx context.Context, req *CountUpRequest) (*connect.ServerStreamForClient[v1.CountUpResponse], error) { return c.countUp.CallServerStream(ctx, req) } @@ -154,13 +163,13 @@ func (c *pingServiceClient) CumSum(ctx context.Context) *connect.BidiStreamForCl // PingServiceHandler is an implementation of the connect.ping.v1.PingService service. type PingServiceHandler interface { // Ping sends a ping to the server to determine if it's reachable. - Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) + Ping(context.Context, *PingRequest) (*PingResponse, error) // Fail always fails. - Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) + Fail(context.Context, *FailRequest) (*FailResponse, error) // Sum calculates the sum of the numbers sent on the stream. - Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) + Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*SumResponse, error) // CountUp returns a stream of the numbers up to the given request. - CountUp(context.Context, *connect.Request[v1.CountUpRequest], *connect.ServerStream[v1.CountUpResponse]) error + CountUp(context.Context, *CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error // CumSum determines the cumulative sum of all the numbers sent on the stream. CumSum(context.Context, *connect.BidiStream[v1.CumSumRequest, v1.CumSumResponse]) error } @@ -218,19 +227,19 @@ func NewPingServiceHandler(svc PingServiceHandler, opts ...connect.HandlerOption // UnimplementedPingServiceHandler returns CodeUnimplemented from all methods. type UnimplementedPingServiceHandler struct{} -func (UnimplementedPingServiceHandler) Ping(context.Context, *connect.Request[v1.PingRequest]) (*connect.Response[v1.PingResponse], error) { +func (UnimplementedPingServiceHandler) Ping(context.Context, *PingRequest) (*PingResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Ping is not implemented")) } -func (UnimplementedPingServiceHandler) Fail(context.Context, *connect.Request[v1.FailRequest]) (*connect.Response[v1.FailResponse], error) { +func (UnimplementedPingServiceHandler) Fail(context.Context, *FailRequest) (*FailResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Fail is not implemented")) } -func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*connect.Response[v1.SumResponse], error) { +func (UnimplementedPingServiceHandler) Sum(context.Context, *connect.ClientStream[v1.SumRequest]) (*SumResponse, error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.Sum is not implemented")) } -func (UnimplementedPingServiceHandler) CountUp(context.Context, *connect.Request[v1.CountUpRequest], *connect.ServerStream[v1.CountUpResponse]) error { +func (UnimplementedPingServiceHandler) CountUp(context.Context, *CountUpRequest, *connect.ServerStream[v1.CountUpResponse]) error { return connect.NewError(connect.CodeUnimplemented, errors.New("connect.ping.v1.PingService.CountUp is not implemented")) } diff --git a/recover_ext_test.go b/recover_ext_test.go index 99df97c9..54b37319 100644 --- a/recover_ext_test.go +++ b/recover_ext_test.go @@ -33,16 +33,13 @@ type panicPingServer struct { panicWith any } -func (s *panicPingServer) Ping( - context.Context, - *connect.Request[pingv1.PingRequest], -) (*connect.Response[pingv1.PingResponse], error) { +func (s *panicPingServer) Ping(context.Context, *pingv1connect.PingRequest) (*pingv1connect.PingResponse, error) { panic(s.panicWith) //nolint:forbidigo } func (s *panicPingServer) CountUp( _ context.Context, - _ *connect.Request[pingv1.CountUpRequest], + _ *pingv1connect.CountUpRequest, stream *connect.ServerStream[pingv1.CountUpResponse], ) error { if err := stream.Send(&pingv1.CountUpResponse{}); err != nil {