From 1d9159ea396bc944c285102afcc4e11d3bf4dfb0 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Mon, 23 Dec 2024 00:56:20 +0800 Subject: [PATCH] feat: support form array in three notations (#4498) Signed-off-by: kevin --- core/mapping/unmarshaler.go | 145 ++++++++++++++++++--------- core/mapping/unmarshaler_test.go | 167 ++++++++++++++++++++++++++++--- rest/httpx/requests_test.go | 112 ++++++++++++++++++++- rest/httpx/util.go | 24 ++++- rest/httpx/util_test.go | 22 ++++ rest/router/patrouter_test.go | 63 ++++++++---- 6 files changed, 452 insertions(+), 81 deletions(-) diff --git a/core/mapping/unmarshaler.go b/core/mapping/unmarshaler.go index b4fb356ee319..f68748c68172 100644 --- a/core/mapping/unmarshaler.go +++ b/core/mapping/unmarshaler.go @@ -18,6 +18,7 @@ import ( ) const ( + comma = "," defaultKeyName = "key" delimiter = '.' ignoreKey = "-" @@ -36,6 +37,7 @@ var ( defaultCacheLock sync.Mutex emptyMap = map[string]any{} emptyValue = reflect.ValueOf(lang.Placeholder) + stringSliceType = reflect.TypeOf([]string{}) ) type ( @@ -80,40 +82,11 @@ func (u *Unmarshaler) Unmarshal(i, v any) error { return u.unmarshal(i, v, "") } -func (u *Unmarshaler) unmarshal(i, v any, fullName string) error { - valueType := reflect.TypeOf(v) - if valueType.Kind() != reflect.Ptr { - return errValueNotSettable - } - - elemType := Deref(valueType) - switch iv := i.(type) { - case map[string]any: - if elemType.Kind() != reflect.Struct { - return errTypeMismatch - } - - return u.unmarshalValuer(mapValuer(iv), v, fullName) - case []any: - if elemType.Kind() != reflect.Slice { - return errTypeMismatch - } - - return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName) - default: - return errUnsupportedType - } -} - // UnmarshalValuer unmarshals m into v. func (u *Unmarshaler) UnmarshalValuer(m Valuer, v any) error { return u.unmarshalValuer(simpleValuer{current: m}, v, "") } -func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error { - return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName) -} - func (u *Unmarshaler) fillMap(fieldType reflect.Type, value reflect.Value, mapValue any, fullName string) error { if !value.CanSet() { @@ -173,13 +146,18 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, baseType := fieldType.Elem() dereffedBaseType := Deref(baseType) dereffedBaseKind := dereffedBaseType.Kind() - conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) if refValue.Len() == 0 { - value.Set(conv) + value.Set(reflect.MakeSlice(reflect.SliceOf(baseType), 0, 0)) return nil } + if u.opts.fromArray { + refValue = makeStringSlice(refValue) + } + var valid bool + conv := reflect.MakeSlice(reflect.SliceOf(baseType), refValue.Len(), refValue.Cap()) + for i := 0; i < refValue.Len(); i++ { ithValue := refValue.Index(i).Interface() if ithValue == nil { @@ -191,17 +169,9 @@ func (u *Unmarshaler) fillSlice(fieldType reflect.Type, value reflect.Value, switch dereffedBaseKind { case reflect.Struct: - target := reflect.New(dereffedBaseType) - val, ok := ithValue.(map[string]any) - if !ok { - return errTypeMismatch - } - - if err := u.unmarshal(val, target.Interface(), sliceFullName); err != nil { + if err := u.fillStructElement(baseType, conv.Index(i), ithValue, sliceFullName); err != nil { return err } - - SetValue(fieldType.Elem(), conv.Index(i), target.Elem()) case reflect.Slice: if err := u.fillSlice(dereffedBaseType, conv.Index(i), ithValue, sliceFullName); err != nil { return err @@ -236,7 +206,7 @@ func (u *Unmarshaler) fillSliceFromString(fieldType reflect.Type, value reflect. return errUnsupportedType } - baseFieldType := Deref(fieldType.Elem()) + baseFieldType := fieldType.Elem() baseFieldKind := baseFieldType.Kind() conv := reflect.MakeSlice(reflect.SliceOf(baseFieldType), len(slice), cap(slice)) @@ -257,29 +227,39 @@ func (u *Unmarshaler) fillSliceValue(slice reflect.Value, index int, } ithVal := slice.Index(index) + ithValType := ithVal.Type() + switch v := value.(type) { case fmt.Stringer: return setValueFromString(baseKind, ithVal, v.String()) case string: return setValueFromString(baseKind, ithVal, v) case map[string]any: - return u.fillMap(ithVal.Type(), ithVal, value, fullName) + // deref to handle both pointer and non-pointer types. + switch Deref(ithValType).Kind() { + case reflect.Struct: + return u.fillStructElement(ithValType, ithVal, v, fullName) + case reflect.Map: + return u.fillMap(ithValType, ithVal, value, fullName) + default: + return errTypeMismatch + } default: // don't need to consider the difference between int, int8, int16, int32, int64, // uint, uint8, uint16, uint32, uint64, because they're handled as json.Number. if ithVal.Kind() == reflect.Ptr { - baseType := Deref(ithVal.Type()) + baseType := Deref(ithValType) if !reflect.TypeOf(value).AssignableTo(baseType) { return errTypeMismatch } target := reflect.New(baseType).Elem() target.Set(reflect.ValueOf(value)) - SetValue(ithVal.Type(), ithVal, target) + SetValue(ithValType, ithVal, target) return nil } - if !reflect.TypeOf(value).AssignableTo(ithVal.Type()) { + if !reflect.TypeOf(value).AssignableTo(ithValType) { return errTypeMismatch } @@ -310,6 +290,23 @@ func (u *Unmarshaler) fillSliceWithDefault(derefedType reflect.Type, value refle return u.fillSlice(derefedType, value, slice, fullName) } +func (u *Unmarshaler) fillStructElement(baseType reflect.Type, target reflect.Value, + value any, fullName string) error { + val, ok := value.(map[string]any) + if !ok { + return errTypeMismatch + } + + // use Deref(baseType) to get the base type in case the type is a pointer type. + ptr := reflect.New(Deref(baseType)) + if err := u.unmarshal(val, ptr.Interface(), fullName); err != nil { + return err + } + + SetValue(baseType, target, ptr.Elem()) + return nil +} + func (u *Unmarshaler) fillUnmarshalerStruct(fieldType reflect.Type, value reflect.Value, targetValue string) error { if !value.CanSet() { @@ -952,6 +949,35 @@ func (u *Unmarshaler) processNamedFieldWithoutValue(fieldType reflect.Type, valu return nil } +func (u *Unmarshaler) unmarshal(i, v any, fullName string) error { + valueType := reflect.TypeOf(v) + if valueType.Kind() != reflect.Ptr { + return errValueNotSettable + } + + elemType := Deref(valueType) + switch iv := i.(type) { + case map[string]any: + if elemType.Kind() != reflect.Struct { + return errTypeMismatch + } + + return u.unmarshalValuer(mapValuer(iv), v, fullName) + case []any: + if elemType.Kind() != reflect.Slice { + return errTypeMismatch + } + + return u.fillSlice(elemType, reflect.ValueOf(v).Elem(), iv, fullName) + default: + return errUnsupportedType + } +} + +func (u *Unmarshaler) unmarshalValuer(m Valuer, v any, fullName string) error { + return u.unmarshalWithFullName(simpleValuer{current: m}, v, fullName) +} + func (u *Unmarshaler) unmarshalWithFullName(m valuerWithParent, v any, fullName string) error { rv := reflect.ValueOf(v) if err := ValidatePtr(rv); err != nil { @@ -1146,6 +1172,35 @@ func join(elem ...string) string { return builder.String() } +func makeStringSlice(refValue reflect.Value) reflect.Value { + if refValue.Len() != 1 { + return refValue + } + + element := refValue.Index(0) + if element.Kind() != reflect.String { + return refValue + } + + val, ok := element.Interface().(string) + if !ok { + return refValue + } + + splits := strings.Split(val, comma) + if len(splits) <= 1 { + return refValue + } + + slice := reflect.MakeSlice(stringSliceType, len(splits), len(splits)) + for i, split := range splits { + // allow empty strings + slice.Index(i).Set(reflect.ValueOf(split)) + } + + return slice +} + func newInitError(name string) error { return fmt.Errorf("field %q is not set", name) } diff --git a/core/mapping/unmarshaler_test.go b/core/mapping/unmarshaler_test.go index 3270632dc447..ae2aba0edc86 100644 --- a/core/mapping/unmarshaler_test.go +++ b/core/mapping/unmarshaler_test.go @@ -351,7 +351,7 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) { assert.Error(t, UnmarshalKey(m, &in)) }) - t.Run("int slice with nil", func(t *testing.T) { + t.Run("int slice with nil element", func(t *testing.T) { type inner struct { Ints []int `key:"ints"` } @@ -365,6 +365,21 @@ func TestUnmarshalIntSliceOfPtr(t *testing.T) { assert.Empty(t, in.Ints) } }) + + t.Run("int slice with nil", func(t *testing.T) { + type inner struct { + Ints []int `key:"ints"` + } + + m := map[string]any{ + "ints": []any(nil), + } + + var in inner + if assert.NoError(t, UnmarshalKey(m, &in)) { + assert.Empty(t, in.Ints) + } + }) } func TestUnmarshalIntWithDefault(t *testing.T) { @@ -1374,20 +1389,82 @@ func TestUnmarshalWithFloatPtr(t *testing.T) { } func TestUnmarshalIntSlice(t *testing.T) { - var v struct { - Ages []int `key:"ages"` - Slice []int `key:"slice"` - } - m := map[string]any{ - "ages": []int{1, 2}, - "slice": []any{}, - } + t.Run("int slice from int", func(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + Slice []int `key:"slice"` + } + m := map[string]any{ + "ages": []int{1, 2}, + "slice": []any{}, + } - ast := assert.New(t) - if ast.NoError(UnmarshalKey(m, &v)) { - ast.ElementsMatch([]int{1, 2}, v.Ages) - ast.Equal([]int{}, v.Slice) - } + ast := assert.New(t) + if ast.NoError(UnmarshalKey(m, &v)) { + ast.ElementsMatch([]int{1, 2}, v.Ages) + ast.Equal([]int{}, v.Slice) + } + }) + + t.Run("int slice from one int", func(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + } + m := map[string]any{ + "ages": []int{2}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]int{2}, v.Ages) + } + }) + + t.Run("int slice from one int string", func(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + } + m := map[string]any{ + "ages": []string{"2"}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]int{2}, v.Ages) + } + }) + + t.Run("int slice from one json.Number", func(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + } + m := map[string]any{ + "ages": []json.Number{"2"}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]int{2}, v.Ages) + } + }) + + t.Run("int slice from one int strings", func(t *testing.T) { + var v struct { + Ages []int `key:"ages"` + } + m := map[string]any{ + "ages": []string{"1,2"}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]int{1, 2}, v.Ages) + } + }) } func TestUnmarshalString(t *testing.T) { @@ -1442,6 +1519,36 @@ func TestUnmarshalStringSliceFromString(t *testing.T) { } }) + t.Run("slice from empty string", func(t *testing.T) { + var v struct { + Names []string `key:"names"` + } + m := map[string]any{ + "names": []string{""}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]string{""}, v.Names) + } + }) + + t.Run("slice from empty and valid string", func(t *testing.T) { + var v struct { + Names []string `key:"names"` + } + m := map[string]any{ + "names": []string{","}, + } + + ast := assert.New(t) + unmarshaler := NewUnmarshaler(defaultKeyName, WithFromArray()) + if ast.NoError(unmarshaler.Unmarshal(m, &v)) { + ast.ElementsMatch([]string{"", ""}, v.Names) + } + }) + t.Run("slice from string with slice error", func(t *testing.T) { var v struct { Names []int `key:"names"` @@ -5862,6 +5969,38 @@ func TestUnmarshal_Unmarshaler(t *testing.T) { }) } +func TestParseJsonStringValue(t *testing.T) { + t.Run("string", func(t *testing.T) { + type GoodsInfo struct { + Sku int64 `json:"sku,optional"` + } + + type GetReq struct { + GoodsList []*GoodsInfo `json:"goods_list"` + } + + input := map[string]any{"goods_list": "[{\"sku\":11},{\"sku\":22}]"} + var v GetReq + assert.NotPanics(t, func() { + assert.NoError(t, UnmarshalJsonMap(input, &v)) + assert.Equal(t, 2, len(v.GoodsList)) + assert.ElementsMatch(t, []int64{11, 22}, []int64{v.GoodsList[0].Sku, v.GoodsList[1].Sku}) + }) + }) + + t.Run("string with invalid type", func(t *testing.T) { + type GetReq struct { + GoodsList []*int `json:"goods_list"` + } + + input := map[string]any{"goods_list": "[{\"sku\":11},{\"sku\":22}]"} + var v GetReq + assert.NotPanics(t, func() { + assert.Error(t, UnmarshalJsonMap(input, &v)) + }) + }) +} + func BenchmarkDefaultValue(b *testing.B) { for i := 0; i < b.N; i++ { var a struct { diff --git a/rest/httpx/requests_test.go b/rest/httpx/requests_test.go index 437b9a136fb5..fd7fb3a5ac5e 100644 --- a/rest/httpx/requests_test.go +++ b/rest/httpx/requests_test.go @@ -88,6 +88,36 @@ func TestParseFormArray(t *testing.T) { } }) + t.Run("slice with empty", func(t *testing.T) { + var v struct { + Name []string `form:"name,optional"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{}, v.Name) + } + }) + + t.Run("slice with empty", func(t *testing.T) { + var v struct { + Name []string `form:"name,optional"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?name=", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{""}, v.Name) + } + }) + t.Run("slice with empty and non-empty", func(t *testing.T) { var v struct { Name []string `form:"name"` @@ -99,7 +129,67 @@ func TestParseFormArray(t *testing.T) { http.NoBody) assert.NoError(t, err) if assert.NoError(t, Parse(r, &v)) { - assert.ElementsMatch(t, []string{"1"}, v.Name) + assert.ElementsMatch(t, []string{"", "1"}, v.Name) + } + }) + + t.Run("slice with one value on array format", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names=1,2,3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names) + } + }) + + t.Run("slice with one value on combined array format", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names=[1,2,3]&names=4", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"[1,2,3]", "4"}, v.Names) + } + }) + + t.Run("slice with one value on integer array format", func(t *testing.T) { + var v struct { + Numbers []int `form:"numbers"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?numbers=1,2,3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []int{1, 2, 3}, v.Numbers) + } + }) + + t.Run("slice with one value on array format brackets", func(t *testing.T) { + var v struct { + Names []string `form:"names"` + } + + r, err := http.NewRequest( + http.MethodGet, + "/a?names[]=1&names[]=2&names[]=3", + http.NoBody) + assert.NoError(t, err) + if assert.NoError(t, Parse(r, &v)) { + assert.ElementsMatch(t, []string{"1", "2", "3"}, v.Names) } }) } @@ -528,6 +618,26 @@ func TestCustomUnmarshalerStructRequest(t *testing.T) { assert.Equal(t, "hello", v.Foo.Name) } +func TestParseJsonStringRequest(t *testing.T) { + type GoodsInfo struct { + Sku int64 `json:"sku,optional"` + } + + type GetReq struct { + GoodsList []*GoodsInfo `json:"goods_list"` + } + + input := `{"goods_list":"[{\"sku\":11},{\"sku\":22}]"}` + r := httptest.NewRequest(http.MethodPost, "/a", strings.NewReader(input)) + r.Header.Set(ContentType, JsonContentType) + var v GetReq + assert.NotPanics(t, func() { + assert.NoError(t, Parse(r, &v)) + assert.Equal(t, 2, len(v.GoodsList)) + assert.ElementsMatch(t, []int64{11, 22}, []int64{v.GoodsList[0].Sku, v.GoodsList[1].Sku}) + }) +} + func BenchmarkParseRaw(b *testing.B) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/a?name=hello&age=18&percent=3.4", http.NoBody) if err != nil { diff --git a/rest/httpx/util.go b/rest/httpx/util.go index 19248ae74bb3..c22ad8e0d28b 100644 --- a/rest/httpx/util.go +++ b/rest/httpx/util.go @@ -2,12 +2,23 @@ package httpx import ( "errors" + "fmt" "net/http" + "strings" ) -const xForwardedFor = "X-Forwarded-For" +const ( + xForwardedFor = "X-Forwarded-For" + arraySuffix = "[]" + // most servers and clients have a limit of 8192 bytes (8 KB) + // one parameter at least take 4 chars, for example `?a=b&c=d` + maxFormParamCount = 2048 +) -// GetFormValues returns the form values. +// GetFormValues returns the form values supporting three array notation formats: +// 1. Standard notation: /api?names=alice&names=bob +// 2. Comma notation: /api?names=alice,bob +// 3. Bracket notation: /api?names[]=alice&names[]=bob func GetFormValues(r *http.Request) (map[string]any, error) { if err := r.ParseForm(); err != nil { return nil, err @@ -19,16 +30,23 @@ func GetFormValues(r *http.Request) (map[string]any, error) { } } + var n int params := make(map[string]any, len(r.Form)) for name, values := range r.Form { filtered := make([]string, 0, len(values)) for _, v := range values { - if len(v) > 0 { + if n < maxFormParamCount { filtered = append(filtered, v) + n++ + } else { + return nil, fmt.Errorf("too many form values, error: %s", r.Form.Encode()) } } if len(filtered) > 0 { + if strings.HasSuffix(name, arraySuffix) { + name = name[:len(name)-2] + } params[name] = filtered } } diff --git a/rest/httpx/util_test.go b/rest/httpx/util_test.go index 8e804cbf78ea..19725d47b1ab 100644 --- a/rest/httpx/util_test.go +++ b/rest/httpx/util_test.go @@ -1,7 +1,9 @@ package httpx import ( + "fmt" "net/http" + "net/url" "strings" "testing" @@ -23,3 +25,23 @@ func TestGetRemoteAddrNoHeader(t *testing.T) { assert.True(t, len(GetRemoteAddr(r)) == 0) } + +func TestGetFormValues_TooManyValues(t *testing.T) { + form := url.Values{} + + // Add more values than the limit + for i := 0; i < maxFormParamCount+10; i++ { + form.Add("param", fmt.Sprintf("value%d", i)) + } + + // Create a new request with the form data + req, err := http.NewRequest("POST", "/test", strings.NewReader(form.Encode())) + assert.NoError(t, err) + + // Set the content type for form data + req.Header.Set(ContentType, "application/x-www-form-urlencoded") + + _, err = GetFormValues(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too many form values") +} diff --git a/rest/router/patrouter_test.go b/rest/router/patrouter_test.go index dca589e9fc00..02f21ece3d93 100644 --- a/rest/router/patrouter_test.go +++ b/rest/router/patrouter_test.go @@ -516,28 +516,55 @@ func TestParsePtrInRequestEmpty(t *testing.T) { } func TestParseQueryOptional(t *testing.T) { - r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) - assert.Nil(t, err) + t.Run("optional with string", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) + assert.Nil(t, err) - router := NewRouter() - err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - v := struct { - Nickname string `form:"nickname"` - Zipcode int64 `form:"zipcode,optional"` - }{} + router := NewRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode string `form:"zipcode,optional"` + }{} + + err = httpx.Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) - err = httpx.Parse(r, &v) - assert.Nil(t, err) - _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) - assert.Nil(t, err) - })) - assert.Nil(t, err) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, r) + assert.Equal(t, "whatever:", rr.Body.String()) + }) - assert.Equal(t, "whatever:0", rr.Body.String()) + t.Run("optional with int", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever", nil) + assert.Nil(t, err) + + router := NewRouter() + err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + v := struct { + Nickname string `form:"nickname"` + Zipcode int `form:"zipcode,optional"` + }{} + + err = httpx.Parse(r, &v) + assert.Nil(t, err) + _, err = io.WriteString(w, fmt.Sprintf("%s:%d", v.Nickname, v.Zipcode)) + assert.Nil(t, err) + })) + assert.Nil(t, err) + + rr := httptest.NewRecorder() + router.ServeHTTP(rr, r) + + assert.Equal(t, "whatever:0", rr.Body.String()) + }) } func TestParse(t *testing.T) {