Skip to content

Commit

Permalink
feat(depinject): key resolvers for interface types (#12103)
Browse files Browse the repository at this point in the history
* Rough draft of key resolvers for interface types

* Add unit test and empty key guard in getResolver

* clean up empty key check in getResolvers
  • Loading branch information
kocubinski authored Jun 1, 2022
1 parent 5e50def commit c4934b7
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 11 deletions.
39 changes: 28 additions & 11 deletions depinject/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import (
type container struct {
*debugConfig

resolvers map[reflect.Type]resolver
resolvers map[reflect.Type]resolver
keyedResolvers map[string]resolver

moduleKeys map[string]*moduleKey

Expand All @@ -29,11 +30,12 @@ type resolveFrame struct {

func newContainer(cfg *debugConfig) *container {
return &container{
debugConfig: cfg,
resolvers: map[reflect.Type]resolver{},
moduleKeys: map[string]*moduleKey{},
callerStack: nil,
callerMap: map[Location]bool{},
debugConfig: cfg,
resolvers: map[reflect.Type]resolver{},
keyedResolvers: map[string]resolver{},
moduleKeys: map[string]*moduleKey{},
callerStack: nil,
callerMap: map[Location]bool{},
}
}

Expand Down Expand Up @@ -76,7 +78,13 @@ func (c *container) call(provider *ProviderDescriptor, moduleKey *moduleKey) ([]
return out, nil
}

func (c *container) getResolver(typ reflect.Type) (resolver, error) {
func (c *container) getResolver(typ reflect.Type, key string) (resolver, error) {
if key != "" {
if vr, ok := c.keyedResolvers[key]; ok {
return vr, nil
}
}

if vr, ok := c.resolvers[typ]; ok {
return vr, nil
}
Expand Down Expand Up @@ -147,7 +155,7 @@ func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (inter
return nil, fmt.Errorf("one-per-module type %v can't be used as an input parameter", typ)
}

vr, err := c.getResolver(typ)
vr, err := c.getResolver(typ, in.Key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -189,7 +197,7 @@ func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (inter
typ = typ.Elem()
}

vr, err := c.getResolver(typ)
vr, err := c.getResolver(typ, out.Key)
if err != nil {
return nil, err
}
Expand All @@ -211,6 +219,10 @@ func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (inter
idxInValues: i,
}
c.resolvers[typ] = vr

if out.Key != "" {
c.keyedResolvers[out.Key] = vr
}
}

c.addGraphEdge(providerGraphNode, vr.typeGraphNode())
Expand Down Expand Up @@ -245,13 +257,18 @@ func (c *container) addNode(provider *ProviderDescriptor, key *moduleKey) (inter
}

typeGraphNode := c.typeGraphNode(typ)
c.resolvers[typ] = &moduleDepResolver{
mdr := &moduleDepResolver{
typ: typ,
idxInValues: i,
node: node,
valueMap: map[*moduleKey]reflect.Value{},
graphNode: typeGraphNode,
}
c.resolvers[typ] = mdr

if out.Key != "" {
c.keyedResolvers[out.Key] = mdr
}

c.addGraphEdge(providerGraphNode, typeGraphNode)
}
Expand Down Expand Up @@ -304,7 +321,7 @@ func (c *container) resolve(in ProviderInput, moduleKey *moduleKey, caller Locat
return reflect.ValueOf(OwnModuleKey{moduleKey}), nil
}

vr, err := c.getResolver(in.Type)
vr, err := c.getResolver(in.Type, in.Key)
if err != nil {
return reflect.Value{}, err
}
Expand Down
42 changes: 42 additions & 0 deletions depinject/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,45 @@ func TestConditionalDebugging(t *testing.T) {
require.Empty(t, logs)
require.True(t, success)
}

type Duck interface {
quack()
}

type AlsoDuck interface {
quack()
}

type Mallard struct{}

func (duck Mallard) quack() {}

type KeyedOutput struct {
depinject.Out
Duck Duck `key:"foo"`
}

type KeyedInput struct {
depinject.In
AlsoDuck AlsoDuck `key:"foo"`
}

type Pond struct {
Duck AlsoDuck
}

func TestKeyedInputOutput(t *testing.T) {
var pond Pond

require.NoError(t,
depinject.Inject(
depinject.Provide(
func() KeyedOutput { return KeyedOutput{Duck: Mallard{}} },
func(in KeyedInput) Pond {
require.NotNil(t, in.AlsoDuck)
return Pond{Duck: in.AlsoDuck}
}),
&pond))

require.NotNil(t, pond)
}
2 changes: 2 additions & 0 deletions depinject/provider_desc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ type ProviderDescriptor struct {
type ProviderInput struct {
Type reflect.Type
Optional bool
Key string
}

type ProviderOutput struct {
Type reflect.Type
Key string
}

func ExtractProviderDescriptor(provider interface{}) (ProviderDescriptor, error) {
Expand Down
24 changes: 24 additions & 0 deletions depinject/provider_desc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ type StructOut struct {
Y []byte
}

type KeyedIn struct {
depinject.In
X string `key:"theKey"`
}

type KeyedOut struct {
depinject.Out
X string `key:"theKey"`
}

func TestExtractProviderDescriptor(t *testing.T) {
var (
intType = reflect.TypeOf(0)
Expand Down Expand Up @@ -87,6 +97,20 @@ func TestExtractProviderDescriptor(t *testing.T) {
nil,
true,
},
{
name: "keyed input",
ctr: func(_ KeyedIn) int { return 0 },
wantIn: []depinject.ProviderInput{{Type: stringType, Key: "theKey"}},
wantOut: []depinject.ProviderOutput{{Type: intType}},
wantErr: false,
},
{
name: "keyed output",
ctr: func(s string) KeyedOut { return KeyedOut{X: "foo"} },
wantIn: []depinject.ProviderInput{{Type: stringType}},
wantOut: []depinject.ProviderOutput{{Type: stringType, Key: "theKey"}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
14 changes: 14 additions & 0 deletions depinject/struct_args.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,16 @@ func structArgsInTypes(typ reflect.Type) ([]ProviderInput, error) {
}
}

var key string
keyTag, keyFound := f.Tag.Lookup("key")
if keyFound {
key = keyTag
}

res = append(res, ProviderInput{
Type: f.Type,
Optional: optional,
Key: key,
})
}
return res, nil
Expand Down Expand Up @@ -151,8 +158,15 @@ func structArgsOutTypes(typ reflect.Type) []ProviderOutput {
continue
}

var key string
keyTag, keyFound := f.Tag.Lookup("key")
if keyFound {
key = keyTag
}

res = append(res, ProviderOutput{
Type: f.Type,
Key: key,
})
}
return res
Expand Down

0 comments on commit c4934b7

Please sign in to comment.