diff --git a/internal/json/decode.go b/internal/json/decode.go index 434edf8ea..0b136b7e9 100644 --- a/internal/json/decode.go +++ b/internal/json/decode.go @@ -254,6 +254,7 @@ type decodeState struct { nextscan scanner // for calls to nextValue savedError error useNumber bool + ext *Extension } // errPhase is used for errors that should not happen unless @@ -369,6 +370,9 @@ func (d *decodeState) value(v reflect.Value) { case scanBeginLiteral: d.literal(v) + + case scanBeginFunc: + d.function(v) } } @@ -718,6 +722,213 @@ func (d *decodeState) object(v reflect.Value) { } } +// function consumes a function from d.data[d.off-1:], decoding into the value v. +// the first byte of the function name has been read already. +func (d *decodeState) function(v reflect.Value) { + // Check for unmarshaler. + u, ut, pv := d.indirect(v, false) + if u != nil { + d.off-- + err := u.UnmarshalJSON(d.next()) + if err != nil { + d.error(err) + } + return + } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + v = pv + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + v.Set(reflect.ValueOf(d.functionInterface())) + return + } + + nameStart := d.off - 1 + + if op := d.scanWhile(scanContinue); op != scanFuncArg { + d.error(errPhase) + } + + funcName := string(d.data[nameStart : d.off-1]) + funcData := d.ext.funcs[funcName] + if funcData.key == "" { + d.error(fmt.Errorf("json: unknown function %s", funcName)) + } + + // Check type of target: + // struct or + // map[string]T or map[encoding.TextUnmarshaler]T + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind or be an encoding.TextUnmarshaler. + t := v.Type() + if t.Key().Kind() != reflect.String && + !reflect.PtrTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + + default: + d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)}) + d.off-- + d.next() // skip over { } in input + return + } + + // TODO Fix case of func field as map. + //topv := v + + // Figure out field corresponding to function. + key := []byte(funcData.key) + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + v = reflect.New(elemType).Elem() + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, key) { + f = ff + break + } + if f == nil && ff.equalFold(ff.nameBytes, key) { + f = ff + } + } + if f != nil { + for _, i := range f.index { + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + v = v.Field(i) + } + if v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + } + + var mapElem reflect.Value + + // Parse function arguments. + for i := 0; ; i++ { + // closing ) - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndFunc { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + if i >= len(funcData.args) { + d.error(fmt.Errorf("json: too many arguments for function %s", funcName)) + } + key := []byte(funcData.args[i]) + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := v.Type().Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.Set(reflect.Zero(elemType)) + } + subv = mapElem + } else { + var f *field + fields := cachedTypeFields(v.Type()) + for i := range fields { + ff := &fields[i] + if bytes.Equal(ff.nameBytes, key) { + f = ff + break + } + if f == nil && ff.equalFold(ff.nameBytes, key) { + f = ff + } + } + if f != nil { + subv = v + destring = f.quoted + for _, i := range f.index { + if subv.Kind() == reflect.Ptr { + if subv.IsNil() { + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + subv = subv.Field(i) + } + } + } + + // Read value. + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + d.literalStore(nullLiteral, subv, false) + case string: + d.literalStore([]byte(qv), subv, true) + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + d.value(subv) + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := v.Type().Key() + var kv reflect.Value + switch { + case kt.Kind() == reflect.String: + kv = reflect.ValueOf(key).Convert(v.Type().Key()) + case reflect.PtrTo(kt).Implements(textUnmarshalerType): + kv = reflect.New(v.Type().Key()) + d.literalStore(key, kv, true) + kv = kv.Elem() + default: + panic("json: Unexpected key type") // should never occur + } + v.SetMapIndex(kv, subv) + } + + // Next token must be , or ). + op = d.scanWhile(scanSkipSpace) + if op == scanEndFunc { + break + } + if op != scanFuncArg { + d.error(errPhase) + } + } +} + // literal consumes a literal from d.data[d.off-1:], decoding into the value v. // The first byte of the literal has been read already // (that's how the caller knows it's a literal). @@ -933,6 +1144,8 @@ func (d *decodeState) valueInterface() interface{} { return d.objectInterface() case scanBeginLiteral: return d.literalInterface() + case scanBeginFunc: + return d.functionInterface() } } @@ -1047,6 +1260,49 @@ func (d *decodeState) literalInterface() interface{} { } } +// functionInterface is like function but returns map[string]interface{}. +func (d *decodeState) functionInterface() map[string]interface{} { + nameStart := d.off - 1 + + if op := d.scanWhile(scanContinue); op != scanFuncArg { + d.error(errPhase) + } + + funcName := string(d.data[nameStart : d.off-1]) + funcData := d.ext.funcs[funcName] + if funcData.key == "" { + d.error(fmt.Errorf("json: unknown function %s", funcName)) + } + + m := make(map[string]interface{}) + for i := 0; ; i++ { + // Look ahead for ) - can only happen on first iteration. + op := d.scanWhile(scanSkipSpace) + if op == scanEndFunc { + break + } + + // Back up so d.value can have the byte we just read. + d.off-- + d.scan.undo(op) + + if i >= len(funcData.args) { + d.error(fmt.Errorf("json: too many arguments for function %s", funcName)) + } + m[funcData.args[i]] = d.valueInterface() + + // Next token must be , or ). + op = d.scanWhile(scanSkipSpace) + if op == scanEndFunc { + break + } + if op != scanFuncArg { + d.error(errPhase) + } + } + return map[string]interface{}{funcData.key: m} +} + // getu4 decodes \uXXXX from the beginning of s, returning the hex value, // or it returns -1. func getu4(s []byte) rune { diff --git a/internal/json/extension.go b/internal/json/extension.go new file mode 100644 index 000000000..c2b7e72af --- /dev/null +++ b/internal/json/extension.go @@ -0,0 +1,36 @@ +package json + +// Extension holds a set of additional rules to be used when unmarshaling +// strict JSON or JSON-like content. +type Extension struct { + funcs map[string]funcExt + keyed map[string]func() interface{} +} + +type funcExt struct { + key string + args []string +} + +// Extend changes the decoder behavior to consider the provided extension. +func (dec *Decoder) Extend(ext *Extension) { dec.d.ext = ext } + +// Func defines a function call that may be observed inside JSON content. +// A function with the provided name will be unmarshaled as the document +// {key: {args[0]: ..., args[N]: ...}}. +func (e *Extension) Func(name string, key string, args ...string) { + if e.funcs == nil { + e.funcs = make(map[string]funcExt) + } + e.funcs[name] = funcExt{key, args} +} + +// KeyedDoc defines a key that when observed as the first element inside a +// JSON document or sub-document triggers the parsing of that document as +// the value returned by the provided function. +func (e *Extension) KeyedDoc(key string, new func() interface{}) { + if e.keyed == nil { + e.keyed = make(map[string]func() interface{}) + } + e.keyed[key] = new +} diff --git a/internal/json/extension_test.go b/internal/json/extension_test.go new file mode 100644 index 000000000..17f3b907b --- /dev/null +++ b/internal/json/extension_test.go @@ -0,0 +1,91 @@ +package json + +import ( + "bytes" + "fmt" + "reflect" + "testing" +) + +type extensionTest struct { + in string + ptr interface{} + out interface{} + err error +} + +var extensionTests = []extensionTest{ + {in: `Func1()`, ptr: new(interface{}), out: map[string]interface{}{ + "$func1": map[string]interface{}{}, + }}, + {in: `Func2(1)`, ptr: new(interface{}), out: map[string]interface{}{ + "$func2": map[string]interface{}{"arg1": float64(1)}, + }}, + {in: `Func2(1, 2)`, ptr: new(interface{}), out: map[string]interface{}{ + "$func2": map[string]interface{}{"arg1": float64(1), "arg2": float64(2)}, + }}, + {in: `Func2(Func1())`, ptr: new(interface{}), out: map[string]interface{}{ + "$func2": map[string]interface{}{"arg1": map[string]interface{}{"$func1": map[string]interface{}{}}}, + }}, + {in: `Func2(1, 2, 3)`, ptr: new(interface{}), err: fmt.Errorf("json: too many arguments for function Func2")}, + {in: `Func3()`, ptr: new(interface{}), err: fmt.Errorf("json: unknown function Func3")}, + + {in: `Func1()`, ptr: new(funcs), out: funcs{Func1: &funcN{}}}, + {in: `Func2(1)`, ptr: new(funcs), out: funcs{Func2: &funcN{Arg1: 1}}}, + {in: `Func2(1, 2)`, ptr: new(funcs), out: funcs{Func2: &funcN{Arg1: 1, Arg2: 2}}}, + + {in: `Func2(1, 2, 3)`, ptr: new(funcs), err: fmt.Errorf("json: too many arguments for function Func2")}, + {in: `Func3()`, ptr: new(funcs), err: fmt.Errorf("json: unknown function Func3")}, +} + +type funcN struct { + Arg1 int `json:"arg1"` + Arg2 int `json:"arg2"` +} + +type funcs struct { + Func2 *funcN `json:"$func2"` + Func1 *funcN `json:"$func1"` +} + +var ext Extension + +func init() { + ext.Func("Func1", "$func1") + ext.Func("Func2", "$func2", "arg1", "arg2") +} + +func TestExtensions(t *testing.T) { + for i, tt := range extensionTests { + var scan scanner + in := []byte(tt.in) + if err := checkValid(in, &scan); err != nil { + if !reflect.DeepEqual(err, tt.err) { + t.Errorf("#%d: checkValid: %#v", i, err) + continue + } + } + if tt.ptr == nil { + continue + } + + // v = new(right-type) + v := reflect.New(reflect.TypeOf(tt.ptr).Elem()) + dec := NewDecoder(bytes.NewReader(in)) + dec.Extend(&ext) + if err := dec.Decode(v.Interface()); !reflect.DeepEqual(err, tt.err) { + t.Errorf("#%d: %v, want %v", i, err, tt.err) + continue + } else if err != nil { + continue + } + if !reflect.DeepEqual(v.Elem().Interface(), tt.out) { + t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), tt.out) + data, _ := Marshal(v.Elem().Interface()) + t.Logf("%s", string(data)) + data, _ = Marshal(tt.out) + t.Logf("%s", string(data)) + continue + } + } +} diff --git a/internal/json/scanner.go b/internal/json/scanner.go index a6d8706c7..8dccde40b 100644 --- a/internal/json/scanner.go +++ b/internal/json/scanner.go @@ -125,6 +125,10 @@ const ( scanEndArray // end array (implies scanArrayValue if possible) scanSkipSpace // space byte; can skip; known to be last "continue" result + scanBeginFunc // begin function call + scanFuncArg // begin function argument + scanEndFunc // end function call + // Stop. scanEnd // top-level value ended *before* this byte; known to be first "stop" result scanError // hit an error, scanner.err. @@ -138,6 +142,8 @@ const ( parseObjectKey = iota // parsing object key (before colon) parseObjectValue // parsing object value (after colon) parseArrayValue // parsing array value + parseFuncName // parsing function name + parseFuncArg // parsing function argument value ) // reset prepares the scanner for use. @@ -240,6 +246,10 @@ func stateBeginValue(s *scanner, c byte) int { s.step = state1 return scanBeginLiteral } + if c == '$' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' { + s.step = stateFuncName + return scanBeginFunc + } return s.error(c, "looking for beginning of value") } @@ -312,6 +322,16 @@ func stateEndValue(s *scanner, c byte) int { return scanEndArray } return s.error(c, "after array element") + case parseFuncArg: + if c == ',' { + s.step = stateBeginValue + return scanFuncArg + } + if c == ')' { + s.popParseState() + return scanEndFunc + } + return s.error(c, "after array element") } return s.error(c, "") } @@ -485,6 +505,30 @@ func stateE0(s *scanner, c byte) int { return stateEndValue(s, c) } +// stateFuncName is the state while reading an unquoted function name. +func stateFuncName(s *scanner, c byte) int { + if c == '$' || 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { + return scanContinue + } + if c == '(' { + s.step = stateFuncArgOrEmpty + s.pushParseState(parseFuncArg) + return scanFuncArg + } + return stateEndValue(s, c) +} + +// stateFuncArgOrEmpty is the state after reading `[`. +func stateFuncArgOrEmpty(s *scanner, c byte) int { + if c <= ' ' && isSpace(c) { + return scanSkipSpace + } + if c == ')' { + return stateEndValue(s, c) + } + return stateBeginValue(s, c) +} + // stateT is the state after reading `t`. func stateT(s *scanner, c byte) int { if c == 'r' {