Skip to content

Commit

Permalink
Merge pull request #185 from alexflint/default-value-issue
Browse files Browse the repository at this point in the history
Do not turn values intro strings and then back into values when processing default values
  • Loading branch information
alexflint authored Oct 29, 2022
2 parents dbc2ba5 + 3489ea5 commit 727f853
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 68 deletions.
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@ require (
github.com/stretchr/testify v1.7.0
)

go 1.13
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)

go 1.18
134 changes: 87 additions & 47 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@ func (p path) Child(f reflect.StructField) path {

// spec represents a command line option
type spec struct {
dest path
field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none
cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
required bool // if true, this option must be present on the command line
positional bool // if true, this option will be looked for in the positional flags
separate bool // if true, each slice and map entry will have its own --flag
help string // the help text for this option
env string // the name of the environment variable for this option, or empty for none
defaultVal string // default value for this option
placeholder string // name of the data in help
dest path
field reflect.StructField // the struct field from which this option was created
long string // the --long form for this option, or empty if none
short string // the -s short form for this option, or empty if none
cardinality cardinality // determines how many tokens will be present (possible values: zero, one, multiple)
required bool // if true, this option must be present on the command line
positional bool // if true, this option will be looked for in the positional flags
separate bool // if true, each slice and map entry will have its own --flag
help string // the help text for this option
env string // the name of the environment variable for this option, or empty for none
defaultValue reflect.Value // default value for this option
defaultString string // default value for this option, in string form to be displayed in help text
placeholder string // name of the data in help
}

// command represents a named subcommand, or the top-level command
Expand Down Expand Up @@ -210,18 +211,31 @@ func NewParser(config Config, dests ...interface{}) (*Parser, error) {
return nil, err
}

// add nonzero field values as defaults
// for backwards compatibility, add nonzero field values as defaults
// this applies only to the top-level command, not to subcommands (this inconsistency
// is the reason that this method for setting default values was deprecated)
for _, spec := range cmd.specs {
if v := p.val(spec.dest); v.IsValid() && !isZero(v) {
if defaultVal, ok := v.Interface().(encoding.TextMarshaler); ok {
str, err := defaultVal.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultVal = string(str)
} else {
spec.defaultVal = fmt.Sprintf("%v", v)
// get the value
v := p.val(spec.dest)

// if the value is the "zero value" (e.g. nil pointer, empty struct) then ignore
if isZero(v) {
continue
}

// store as a default
spec.defaultValue = v

// we need a string to display in help text
// if MarshalText is implemented then use that
if m, ok := v.Interface().(encoding.TextMarshaler); ok {
s, err := m.MarshalText()
if err != nil {
return nil, fmt.Errorf("%v: error marshaling default value to string: %v", spec.dest, err)
}
spec.defaultString = string(s)
} else {
spec.defaultString = fmt.Sprintf("%v", v)
}
}

Expand Down Expand Up @@ -293,11 +307,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.help = help
}

defaultVal, hasDefault := field.Tag.Lookup("default")
if hasDefault {
spec.defaultVal = defaultVal
}

// Look at the tag
var isSubcommand bool // tracks whether this field is a subcommand
for _, key := range strings.Split(tag, ",") {
Expand All @@ -324,11 +333,6 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
}
spec.short = key[1:]
case key == "required":
if hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
t.Name(), field.Name))
return false
}
spec.required = true
case key == "positional":
spec.positional = true
Expand Down Expand Up @@ -377,27 +381,60 @@ func cmdFromStruct(name string, dest path, t reflect.Type) (*command, error) {
spec.placeholder = strings.ToUpper(spec.field.Name)
}

// Check whether this field is supported. It's good to do this here rather than
// if this is a subcommand then we've done everything we need to do
if isSubcommand {
return false
}

// check whether this field is supported. It's good to do this here rather than
// wait until ParseValue because it means that a program with invalid argument
// fields will always fail regardless of whether the arguments it received
// exercised those fields.
if !isSubcommand {
cmd.specs = append(cmd.specs, &spec)
var err error
spec.cardinality, err = cardinalityOf(field.Type)
if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
return false
}

var err error
spec.cardinality, err = cardinalityOf(field.Type)
if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: %s fields are not supported",
t.Name(), field.Name, field.Type.String()))
defaultString, hasDefault := field.Tag.Lookup("default")
if hasDefault {
// we do not support default values for maps and slices
if spec.cardinality == multiple {
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",
t.Name(), field.Name))
return false
}
if spec.cardinality == multiple && hasDefault {
errs = append(errs, fmt.Sprintf("%s.%s: default values are not supported for slice or map fields",

// a required field cannot also have a default value
if spec.required {
errs = append(errs, fmt.Sprintf("%s.%s: 'required' cannot be used when a default value is specified",
t.Name(), field.Name))
return false
}

// parse the default value
spec.defaultString = defaultString
if field.Type.Kind() == reflect.Ptr {
// here we have a field of type *T and we create a new T, no need to dereference
// in order for the value to be settable
spec.defaultValue = reflect.New(field.Type.Elem())
} else {
// here we have a field of type T and we create a new T and then dereference it
// so that the resulting value is settable
spec.defaultValue = reflect.New(field.Type).Elem()
}
err := scalar.ParseValue(spec.defaultValue, defaultString)
if err != nil {
errs = append(errs, fmt.Sprintf("%s.%s: error processing default value: %v", t.Name(), field.Name, err))
return false
}
}

// add the spec to the list of specs
cmd.specs = append(cmd.specs, &spec)

// if this was an embedded field then we already returned true up above
return false
})
Expand Down Expand Up @@ -680,11 +717,14 @@ func (p *Parser) process(args []string) error {
}
return errors.New(msg)
}
if !p.config.IgnoreDefault && spec.defaultVal != "" {
err := scalar.ParseValue(p.val(spec.dest), spec.defaultVal)
if err != nil {
return fmt.Errorf("error processing default value for %s: %v", name, err)
}

if spec.defaultValue.IsValid() && !p.config.IgnoreDefault {
// One issue here is that if the user now modifies the value then
// the default value stored in the spec will be corrupted. There
// is no general way to "deep-copy" values in Go, and we still
// support the old-style method for specifying defaults as
// Go values assigned directly to the struct field, so we are stuck.
p.val(spec.dest).Set(spec.defaultValue)
}
}

Expand Down
102 changes: 92 additions & 10 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package arg

import (
"bytes"
"encoding/json"
"fmt"
"net"
"net/mail"
Expand Down Expand Up @@ -1396,13 +1397,21 @@ func TestDefaultOptionValues(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, 123, args.A)
assert.Equal(t, 123, *args.B)
if assert.NotNil(t, args.B) {
assert.Equal(t, 123, *args.B)
}
assert.Equal(t, "xyz", args.C)
assert.Equal(t, "abc", *args.D)
if assert.NotNil(t, args.D) {
assert.Equal(t, "abc", *args.D)
}
assert.Equal(t, 4.56, args.E)
assert.Equal(t, 1.23, *args.F)
assert.True(t, args.G)
if assert.NotNil(t, args.F) {
assert.Equal(t, 1.23, *args.F)
}
assert.True(t, args.G)
if assert.NotNil(t, args.H) {
assert.True(t, *args.H)
}
}

func TestDefaultUnparseable(t *testing.T) {
Expand All @@ -1411,7 +1420,7 @@ func TestDefaultUnparseable(t *testing.T) {
}

err := parse("", &args)
assert.EqualError(t, err, `error processing default value for --a: strconv.ParseInt: parsing "x": invalid syntax`)
assert.EqualError(t, err, `.A: error processing default value: strconv.ParseInt: parsing "x": invalid syntax`)
}

func TestDefaultPositionalValues(t *testing.T) {
Expand All @@ -1430,13 +1439,21 @@ func TestDefaultPositionalValues(t *testing.T) {
require.NoError(t, err)

assert.Equal(t, 456, args.A)
assert.Equal(t, 789, *args.B)
if assert.NotNil(t, args.B) {
assert.Equal(t, 789, *args.B)
}
assert.Equal(t, "abc", args.C)
assert.Equal(t, "abc", *args.D)
if assert.NotNil(t, args.D) {
assert.Equal(t, "abc", *args.D)
}
assert.Equal(t, 1.23, args.E)
assert.Equal(t, 1.23, *args.F)
assert.True(t, args.G)
if assert.NotNil(t, args.F) {
assert.Equal(t, 1.23, *args.F)
}
assert.True(t, args.G)
if assert.NotNil(t, args.H) {
assert.True(t, *args.H)
}
}

func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {
Expand All @@ -1450,7 +1467,7 @@ func TestDefaultValuesNotAllowedWithRequired(t *testing.T) {

func TestDefaultValuesNotAllowedWithSlice(t *testing.T) {
var args struct {
A []int `default:"123"` // required not allowed with default!
A []int `default:"invalid"` // default values not allowed with slices
}

err := parse("", &args)
Expand Down Expand Up @@ -1532,3 +1549,68 @@ func TestMustParsePrintsVersion(t *testing.T) {
assert.Equal(t, 0, *exitCode)
assert.Equal(t, "example 3.2.1\n", b.String())
}

type mapWithUnmarshalText struct {
val map[string]string
}

func (v *mapWithUnmarshalText) UnmarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}

func TestTextUnmarshalerEmpty(t *testing.T) {
// based on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config mapWithUnmarshalText `arg:"--config"`
}

err := parse("", &args)
require.NoError(t, err)
assert.Empty(t, args.Config)
}

func TestTextUnmarshalerEmptyPointer(t *testing.T) {
// a slight variant on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config *mapWithUnmarshalText `arg:"--config"`
}

err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Config)
}

// similar to the above but also implements MarshalText
type mapWithMarshalText struct {
val map[string]string
}

func (v *mapWithMarshalText) MarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}

func (v *mapWithMarshalText) UnmarshalText(data []byte) error {
return json.Unmarshal(data, &v.val)
}

func TestTextMarshalerUnmarshalerEmpty(t *testing.T) {
// based on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config mapWithMarshalText `arg:"--config"`
}

err := parse("", &args)
require.NoError(t, err)
assert.Empty(t, args.Config)
}

func TestTextMarshalerUnmarshalerEmptyPointer(t *testing.T) {
// a slight variant on https://github.com/alexflint/go-arg/issues/184
var args struct {
Config *mapWithMarshalText `arg:"--config"`
}

err := parse("", &args)
require.NoError(t, err)
assert.Nil(t, args.Config)
}
17 changes: 11 additions & 6 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
var textUnmarshalerType = reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()

// cardinality tracks how many tokens are expected for a given spec
// - zero is a boolean, which does to expect any value
// - one is an ordinary option that will be parsed from a single token
// - multiple is a slice or map that can accept zero or more tokens
// - zero is a boolean, which does to expect any value
// - one is an ordinary option that will be parsed from a single token
// - multiple is a slice or map that can accept zero or more tokens
type cardinality int

const (
Expand Down Expand Up @@ -74,10 +74,10 @@ func cardinalityOf(t reflect.Type) (cardinality, error) {
}
}

// isBoolean returns true if the type can be parsed from a single string
// isBoolean returns true if the type is a boolean or a pointer to a boolean
func isBoolean(t reflect.Type) bool {
switch {
case t.Implements(textUnmarshalerType):
case isTextUnmarshaler(t):
return false
case t.Kind() == reflect.Bool:
return true
Expand All @@ -88,6 +88,11 @@ func isBoolean(t reflect.Type) bool {
}
}

// isTextUnmarshaler returns true if the type or its pointer implements encoding.TextUnmarshaler
func isTextUnmarshaler(t reflect.Type) bool {
return t.Implements(textUnmarshalerType) || reflect.PtrTo(t).Implements(textUnmarshalerType)
}

// isExported returns true if the struct field name is exported
func isExported(field string) bool {
r, _ := utf8.DecodeRuneInString(field) // returns RuneError for empty string or invalid UTF8
Expand All @@ -97,7 +102,7 @@ func isExported(field string) bool {
// isZero returns true if v contains the zero value for its type
func isZero(v reflect.Value) bool {
t := v.Type()
if t.Kind() == reflect.Slice || t.Kind() == reflect.Map {
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice || t.Kind() == reflect.Map || t.Kind() == reflect.Chan || t.Kind() == reflect.Interface {
return v.IsNil()
}
if !t.Comparable() {
Expand Down
Loading

0 comments on commit 727f853

Please sign in to comment.