Skip to content

Commit

Permalink
refactor: ResolveTo now can take mult-level pointers
Browse files Browse the repository at this point in the history
If we have T in the resolver. Now ResolveTo() can take *T, **T, ***T
values. It used to only consider *T as valid.
  • Loading branch information
ggicci committed Apr 20, 2024
1 parent 7ebc655 commit 65ec4bd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 36 deletions.
99 changes: 63 additions & 36 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,25 +290,17 @@ func (r *Resolver) Resolve(opts ...Option) (reflect.Value, error) {
// ResolveTo works like Resolve, but it resolves the struct value to the given
// pointer value instead of creating a new value. The pointer value must be
// non-nil and a pointer to the type the resolver holds.
func (r *Resolver) ResolveTo(value any, opts ...Option) error {
if value == nil {
return fmt.Errorf("cannot resolve to nil value")
}
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("cannot resolve to non-pointer value")
}
if rv.IsNil() {
return fmt.Errorf("cannot resolve to nil pointer value")
}
if rv.Type().Elem() != r.Type {
return fmt.Errorf("%w: cannot resolve to value of type %q, expecting type %q",
ErrTypeMismatch, rv.Type().Elem(), r.Type)
func (r *Resolver) ResolveTo(value any, opts ...Option) (err error) {
rv, err := reflectResolveTargetValue(value, r.Type)
if err != nil {
return err
}
ctx := buildContextWithOptionsApplied(context.Background(), opts...)
return r.resolve(ctx, rv)
return r.resolve(ctx, rv.Addr())
}

// resolve runs the directives on the current field and resolves the children fields.
// NOTE: rootValue must be a pointer to a type, i.e. *User, not User.
func (root *Resolver) resolve(ctx context.Context, rootValue reflect.Value) error {
// Run the directives on current field.
if err := root.runDirectives(ctx, rootValue); err != nil {
Expand Down Expand Up @@ -425,9 +417,9 @@ func buildResolver(typ reflect.Type, field reflect.StructField, parent *Resolver
}

if !root.IsRoot() {
directives, err := parseDirectives(field.Tag.Get(Tag()))
directives, err := parseTag(field.Tag.Get(Tag()))
if err != nil {
return nil, fmt.Errorf("parse directives: %w", err)
return nil, fmt.Errorf("parse directives (tag): %w", err)
}
root.Directives = directives
root.Path = append(root.Parent.Path, field.Name)
Expand Down Expand Up @@ -466,6 +458,29 @@ func buildResolver(typ reflect.Type, field reflect.StructField, parent *Resolver
return root, nil
}

// parseTag creates a slice of Directive instances by parsing a struct tag.
func parseTag(tag string) ([]*Directive, error) {
tag = strings.TrimSpace(tag)
var directives []*Directive
existed := make(map[string]bool)
for _, directive := range strings.Split(tag, ";") {
directive = strings.TrimSpace(directive)
if directive == "" {
continue
}
d, err := ParseDirective(directive)
if err != nil {
return nil, err
}
if existed[d.Name] {
return nil, duplicateDirective(d.Name)
}
existed[d.Name] = true
directives = append(directives, d)
}
return directives, nil
}

func reflectStructType(structValue interface{}) (reflect.Type, error) {
typ, ok := structValue.(reflect.Type)
if !ok {
Expand All @@ -487,24 +502,36 @@ func reflectStructType(structValue interface{}) (reflect.Type, error) {
return typ, nil
}

func parseDirectives(tag string) ([]*Directive, error) {
tag = strings.TrimSpace(tag)
var directives []*Directive
existed := make(map[string]bool)
for _, directive := range strings.Split(tag, ";") {
directive = strings.TrimSpace(directive)
if directive == "" {
continue
}
d, err := ParseDirective(directive)
if err != nil {
return nil, err
}
if existed[d.Name] {
return nil, duplicateDirective(d.Name)
}
existed[d.Name] = true
directives = append(directives, d)
func reflectResolveTargetValue(value any, expectedType reflect.Type) (rv reflect.Value, err error) {
if value == nil {
return rv, fmt.Errorf("cannot resolve to nil value")
}
return directives, nil

rv = reflect.ValueOf(value)
if rv.Kind() != reflect.Pointer {
return rv, fmt.Errorf("cannot resolve to non-pointer value")
}

if rv, err = dereference(rv); err != nil {
return rv, fmt.Errorf("cannot resolve to nil pointer value")
}

if rv.Type() != expectedType {
return rv, fmt.Errorf("%w: cannot resolve to value of type %q, expecting type %q",
ErrTypeMismatch, rv.Type(), expectedType)
}

return rv, nil
}

// dereference returns the value that v points to, or an error if v is nil.
// It can be multiple levels deep. e.g. T -> T, *T -> T; **T -> T, etc.
func dereference(v reflect.Value) (reflect.Value, error) {
if v.Kind() != reflect.Pointer {
return v, nil
}
if v.IsNil() {
return v, errors.New("nil pointer")
}
return dereference(v.Elem())
}
32 changes: 32 additions & 0 deletions resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,38 @@ func TestResolveTo_PopulateFieldsOnDemand(t *testing.T) {
os.Clearenv()
}

func TestRevoleTo_MultiLevelPointer(t *testing.T) {
type User struct {
Name string `owl:"env=OWL_TEST_NAME"`
}

ns := owl.NewNamespace()
ns.RegisterDirectiveExecutor("env", owl.DirectiveExecutorFunc(exeEnvReader))
resolver, err := owl.New(User{}, owl.WithNamespace(ns))
assert.NoError(t, err)

var user = new(User)

// *T
os.Setenv("OWL_TEST_NAME", "owl")
assert.NoError(t, resolver.ResolveTo(user))
assert.Equal(t, "owl", user.Name)

// **T
os.Setenv("OWL_TEST_NAME", "golang")
assert.NoError(t, resolver.ResolveTo(&user))
assert.Equal(t, "golang", user.Name)

// ***T
var userPtr = &user
os.Setenv("OWL_TEST_NAME", "world")
assert.NoError(t, resolver.ResolveTo(&userPtr))
assert.Equal(t, "world", user.Name)
assert.Same(t, *userPtr, user)

os.Clearenv()
}

func TestResolveTo_ErrNilValue(t *testing.T) {
resolver, err := owl.New(User{})
assert.NoError(t, err)
Expand Down

0 comments on commit 65ec4bd

Please sign in to comment.