From ed3ecc8d32f5f755cb011e8f79c9393e6dd2bc7f Mon Sep 17 00:00:00 2001 From: Danil Syromolotov Date: Fri, 29 Nov 2024 11:04:21 +0500 Subject: [PATCH] add support for `func`-types to `fx.As()` --- annotated.go | 31 ++++++++++++++++++++++++++----- annotated_test.go | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/annotated.go b/annotated.go index 0e72860e3..4bc20fd39 100644 --- a/annotated.go +++ b/annotated.go @@ -28,6 +28,7 @@ import ( "strings" "go.uber.org/dig" + "go.uber.org/fx/internal/fxreflect" ) @@ -1145,6 +1146,19 @@ var _ Annotation = (*asAnnotation)(nil) // constructor does NOT provide both bytes.Buffer and io.Writer type; it just // provides io.Writer type. // +// Example for function-types: +// +// type domainHandler func(ctx context.Context) error +// +// func anyHandlerProvider() func(ctx context.Context) error { +// ... +// } +// +// fx.Provider( +// anyHandlerProvider(), +// fx.As(new(domainHandler)), +// ) +// // When multiple values are returned by the annotated function, each type // gets mapped to corresponding positional result of the annotated function. // @@ -1211,8 +1225,8 @@ func (at *asAnnotation) apply(ann *annotated) error { continue } t := reflect.TypeOf(typ) - if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Interface { - return fmt.Errorf("fx.As: argument must be a pointer to an interface: got %v", t) + if t.Kind() != reflect.Ptr || !(t.Elem().Kind() == reflect.Interface || t.Elem().Kind() == reflect.Func) { + return fmt.Errorf("fx.As: argument must be a pointer to an interface or function: got %v", t) } t = t.Elem() at.types[i] = asType{typ: t} @@ -1265,8 +1279,11 @@ func (at *asAnnotation) results(ann *annotated) ( continue } - if !t.Implements(at.types[i].typ) { - return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i]) + if !((at.types[i].typ.Kind() == reflect.Interface && t.Implements(at.types[i].typ)) || + t.ConvertibleTo(at.types[i].typ)) { + return nil, + nil, + fmt.Errorf("invalid fx.As: %v does not implement or is not convertible to %v", t, at.types[i]) } field.Type = at.types[i].typ fields = append(fields, field) @@ -1300,7 +1317,11 @@ func (at *asAnnotation) results(ann *annotated) ( newOutResult := reflect.New(resType).Elem() for i := 1; i < resType.NumField(); i++ { - newOutResult.Field(i).Set(getResult(i, results)) + if newOutResult.Field(i).Kind() == reflect.Func { + newOutResult.Field(i).Set(getResult(i, results).Convert(newOutResult.Field(i).Type())) + } else { + newOutResult.Field(i).Set(getResult(i, results)) + } } outResults = append(outResults, newOutResult) diff --git a/annotated_test.go b/annotated_test.go index e7defa76c..449a3c780 100644 --- a/annotated_test.go +++ b/annotated_test.go @@ -33,6 +33,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/fx" "go.uber.org/fx/fxevent" "go.uber.org/fx/fxtest" @@ -442,6 +443,8 @@ func TestAnnotatedAs(t *testing.T) { type myStringer interface { String() string } + type myProvideFunc func() string + type myInvokeFunc func() string newAsStringer := func() *asStringer { return &asStringer{ @@ -477,6 +480,32 @@ func TestAnnotatedAs(t *testing.T) { assert.Equal(t, s.String(), "another stringer") }, }, + { + desc: "value type convertible to target type", + provide: fx.Provide( + fx.Annotate(func() myProvideFunc { + return func() string { + return "provide func example" + } + }, fx.As(new(myInvokeFunc))), + ), + invoke: func(h myInvokeFunc) { + assert.Equal(t, "provide func example", h()) + }, + }, + { + desc: "anonymous value type convertible to target type", + provide: fx.Provide( + fx.Annotate(func() func() string { + return func() string { + return "anonymous func example" + } + }, fx.As(new(myInvokeFunc))), + ), + invoke: func(h myInvokeFunc) { + assert.Equal(t, "anonymous func example", h()) + }, + }, { desc: "provide with multiple types As", provide: fx.Provide(fx.Annotate(func() (*asStringer, *bytes.Buffer) { @@ -1806,9 +1835,9 @@ func TestAnnotateApplySuccess(t *testing.T) { func assertApp( t *testing.T, app interface { - Start(context.Context) error - Stop(context.Context) error - }, + Start(context.Context) error + Stop(context.Context) error +}, started *bool, stopped *bool, invoked *bool,