Skip to content

Commit

Permalink
add form decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
muir committed Oct 25, 2024
1 parent be8a4f4 commit 4789489
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 51 deletions.
86 changes: 61 additions & 25 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"encoding"
"encoding/json"
"encoding/xml"
"io/ioutil"
"io"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
Expand All @@ -29,8 +30,8 @@ var ReadBody = nject.Provide("read-body", readBody)
func readBody(r *http.Request) (Body, nject.TerminalError) {
// nolint:errcheck
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
r.Body = ioutil.NopCloser(bytes.NewReader(body))
body, err := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewReader(body))
return Body(body), err
}

Expand Down Expand Up @@ -195,7 +196,9 @@ var deepObjectRE = regexp.MustCompile(`^([^\[]+)\[([^\]]+)\]$`) // id[name]
// allowReserved=false # default
// allowReserved=true # query parameters only
// form=false # default
// form=true # cookies only
// form=true # query paramters only, may extract value from application/x-www-form-urlencoded POST content
// formOnly=false # default
// formOnly=true # query paramters only, extract value from application/x-www-form-urlencoded POST content only
// content=application/json # specifies that the value should be decoded with JSON
// content=application/xml # specifies that the value should be decoded with XML
// content=application/yaml # specifies that the value should be decoded with YAML
Expand All @@ -216,7 +219,7 @@ var deepObjectRE = regexp.MustCompile(`^([^\[]+)\[([^\]]+)\]$`) // id[name]
// optional. Tag them with their name or with "-" if you do not want
// them filled.
//
// type Fillme struct {
// type Fillme struct {
// Embedded struct {
// IntValue int // will get filled by key "IntValue"
// FloatValue float64 `nvelope:"-"` // will not get filled
Expand Down Expand Up @@ -274,7 +277,9 @@ func GenerateDecoder(
var cookieFillers []func(model reflect.Value, r *http.Request) error
var bodyFillers []func(model reflect.Value, body []byte, r *http.Request) error
queryFillers := make(map[string]func(reflect.Value, []string) error)
queryFillersForm := make(map[string]func(reflect.Value, []string) error)
deepObjectFillers := make(map[string]func(reflect.Value, map[string][]string) error)
deepObjectFillersForm := make(map[string]func(reflect.Value, map[string][]string) error)
var returnError error
reflectutils.WalkStructElements(nonPointer, func(field reflect.StructField) bool {
tag, ok := reflectutils.LookupTag(field.Tag, options.tag)
Expand Down Expand Up @@ -376,6 +381,19 @@ func GenerateDecoder(
name, field.Name)
}
}
if tags.Form || tags.FormOnly {
if unpacker.deepObject != nil {
deepObjectFillersForm[name] = deepObjectFillers[name]
if tags.FormOnly {
delete(deepObjectFillers, name)
}
} else {
queryFillersForm[name] = queryFillers[name]
if tags.FormOnly {
delete(queryFillers, name)
}
}
}
case "cookie":
cookieFillers = append(cookieFillers, func(model reflect.Value, r *http.Request) error {
f := model.FieldByIndex(field.Index)
Expand All @@ -402,14 +420,16 @@ func GenerateDecoder(
len(headerFillers) == 0 &&
len(cookieFillers) == 0 &&
len(queryFillers) == 0 &&
len(queryFillersForm) == 0 &&
len(bodyFillers) == 0 &&
len(deepObjectFillers) == 0 {
len(deepObjectFillers) == 0 &&
len(deepObjectFillersForm) == 0 {
continue
}

outputs := []reflect.Type{returnType, terminalErrorType}
inputs := []reflect.Type{httpRequestType}
if len(bodyFillers) != 0 {
if len(bodyFillers) != 0 || len(queryFillersForm) != 0 || len(deepObjectFillersForm) != 0 {
inputs = append(inputs, bodyType)
}

Expand Down Expand Up @@ -461,27 +481,42 @@ func GenerateDecoder(
setError(hf(model, r.Header))
}
var deepObjects map[string]map[string][]string
for key, vals := range r.URL.Query() {
if qf, ok := queryFillers[key]; ok {
setError(qf(model, vals))
continue
}
if len(deepObjectFillers) != 0 {
if m := deepObjectRE.FindStringSubmatch(key); len(m) == 3 {
if _, ok := deepObjectFillers[m[1]]; ok {
if deepObjects == nil {
deepObjects = make(map[string]map[string][]string)
}
if deepObjects[m[1]] == nil {
deepObjects[m[1]] = make(map[string][]string)
handleQueryParams := func(values url.Values, queryFillers map[string]func(reflect.Value, []string) error, deepObjectFillers map[string]func(reflect.Value, map[string][]string) error) {
for key, vals := range values {
if qf, ok := queryFillers[key]; ok {
setError(qf(model, vals))
continue
}
if len(deepObjectFillers) != 0 {
if m := deepObjectRE.FindStringSubmatch(key); len(m) == 3 {
if _, ok := deepObjectFillers[m[1]]; ok {
if deepObjects == nil {
deepObjects = make(map[string]map[string][]string)
}
if deepObjects[m[1]] == nil {
deepObjects[m[1]] = make(map[string][]string)
}
deepObjects[m[1]][m[2]] = vals
continue
}
deepObjects[m[1]][m[2]] = vals
continue
}
}
if options.rejectUnknownQueryParameters {
setError(errors.Errorf("query parameter '%s' not supported", key))
}
}
if options.rejectUnknownQueryParameters {
setError(errors.Errorf("query parameter '%s' not supported", key))
}
handleQueryParams(r.URL.Query(), queryFillers, deepObjectFillers)
if len(queryFillersForm) != 0 || len(deepObjectFillersForm) != 0 {
body := []byte(in[1].Interface().(Body))
ct := r.Header.Get("Content-Type")
if ct == "application/x-www-form-urlencoded" {
values, err := url.ParseQuery(string(body))
if err != nil {
setError(errors.Wrap(err, "could not parse application/x-www-form-urlencoded data"))
} else {
handleQueryParams(values, queryFillersForm, deepObjectFillersForm)
}
}
}
for dofKey, values := range deepObjects {
Expand Down Expand Up @@ -707,7 +742,7 @@ func getUnpacker(
},
}, nil
}
if reflect.PtrTo(fieldType).AssignableTo(textUnmarshallerType) {
if reflect.PointerTo(fieldType).AssignableTo(textUnmarshallerType) {
return unpack{
createMe: true,
single: func(from string, target reflect.Value, value string) error {
Expand Down Expand Up @@ -1007,6 +1042,7 @@ type tags struct {
Delimiter string `pt:"delimiter"`
AllowReserved bool `pt:"allowReserved"`
Form bool `pt:"form"`
FormOnly bool `pt:"formOnly"`
Content string `pt:"content"`
DeepObject bool `pt:"deepObject"`
}
Expand Down
43 changes: 34 additions & 9 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func TestDecodeQuerySimpleParameters(t *testing.T) {
Complex64 *Complex64 `json:",omitempty" nvelope:"query,name=complex64"`
Complex128 *Complex128 `json:",omitempty" nvelope:"query,name=complex128"`
BoolP *bool `json:",omitempty" nvelope:"query,name=boolp"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"Int":135}`, do("/x?int=135"))
Expand Down Expand Up @@ -111,7 +112,8 @@ func TestDecodeQueryComplexParameters(t *testing.T) {
Int16 int16 `json:",omitempty" nvelope:"eint16"`
String string `json:",omitempty"`
} `json:",omitempty" nvelope:"query,name=emb2,deepObject=true"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"IntSlice":[1,7]}`, do("/x?intslice=1,7"))
Expand Down Expand Up @@ -141,7 +143,8 @@ func TestDecodeQueryJSONParameters(t *testing.T) {
S1 string `json:",omitempty" nvelope:"query,name=s1,content=application/json"`
S2 *string `json:",omitempty" nvelope:"query,name=s2,content=application/json"`
S3 **string `json:",omitempty" nvelope:"query,name=s3,content=application/json"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"Foo":"~bar~"}`, do("/x?foo=bar"))
Expand All @@ -159,7 +162,8 @@ func TestDecodeQueryHeaderParameters(t *testing.T) {
A1 []string `json:",omitempty" nvelope:"header,name=A1"`
A2 []string `json:",omitempty" nvelope:"header,name=A2"`
A3 []string `json:",omitempty" nvelope:"header,explode=false,name=A3"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"S":"yip"}`, do("/x", header("S", "yip")))
Expand All @@ -173,7 +177,8 @@ func TestDecodeQueryCookieParameters(t *testing.T) {
S string `json:",omitempty" nvelope:"cookie,name=S"`
A1 []string `json:",omitempty" nvelope:"cookie,name=A1"`
A3 []string `json:",omitempty" nvelope:"cookie,explode=false,name=A3"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"S":"yip"}`, do("/x", cookie("S", "yip")))
Expand All @@ -186,7 +191,8 @@ func TestDecodeQueryPathParameters(t *testing.T) {
A string `json:",omitempty" nvelope:"path,name=a"`
B *int `json:",omitempty" nvelope:"path,name=b"`
C Foo `json:",omitempty" nvelope:"path,name=c"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"A":"foobar","B":38,"C":"~john~"}`, do("/x/foobar/38/john"))
Expand All @@ -196,7 +202,8 @@ func TestDecodeQueryExplode(t *testing.T) {
do := captureOutput("/x", func(s struct {
M map[string]int `json:",omitempty" nvelope:"query,name=m,explode=true"`
S []string `json:",omitempty" nvelope:"query,name=s,explode=true"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})
assert.Equal(t, `200->{"M":{"a":7,"b":8}}`, do("/x?m=a%3D7&m=b%3D8"))
Expand All @@ -218,7 +225,8 @@ func TestDecodeQueryContentExplode(t *testing.T) {
SE []int `json:",omitempty" nvelope:"query,name=se,explode=true"`
MA map[int]thing `json:",omitempty" nvelope:"query,name=ma,explode=false,content=application/json"`
SA []thing `json:",omitempty" nvelope:"query,name=sa,explode=false,content=application/json"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})

Expand All @@ -234,7 +242,8 @@ func TestDecodeQueryOtherEncoders(t *testing.T) {
do := captureOutput("/x", func(s struct {
XML *thing `json:",omitempty" nvelope:"query,name=xml,explode=false,content=application/xml"`
YAML *thing `json:",omitempty" nvelope:"query,name=yaml,explode=false,content=text/yaml"`
}) (nvelope.Response, error) {
},
) (nvelope.Response, error) {
return s, nil
})

Expand All @@ -252,3 +261,19 @@ func TestDecodeQueryOtherEncoders(t *testing.T) {
assert.Equal(t, `200->{"XML":{"I":3,"F":6.2}}`, do("/x?xml="+xmle(thing{I: 3, F: 6.2})))
assert.Equal(t, `200->{"YAML":{"I":8,"F":2.2}}`, do("/x?yaml="+yamle(thing{I: 8, F: 2.2})))
}

func TestDecodeFormValues(t *testing.T) {
do := captureOutput("/x", func(s struct {
A int `json:",omitempty" nvelope:"query,name=a"`
B int `json:",omitempty" nvelope:"query,form,name=b"`
C int `json:",omitempty" nvelope:"query,formOnly,name=c"`
D int `json:",omitempty" nvelope:"query,formOnly,name=d"`
},
) (nvelope.Response, error) {
return s, nil
})

assert.Equal(t, `200->{"A":7,"B":8,"C":9}`, do("/x?a=7&b=8", header("Content-type", "application/x-www-form-urlencoded"), body(`c=9`)))
assert.Equal(t, `200->{"A":7,"B":8}`, do("/x?a=7&b=8", header("Content-type", "application/json"), body(`{}`)))
assert.Equal(t, `200->{"A":7,"B":8,"C":9,"D":2}`, do("/x?a=7", header("Content-type", "application/x-www-form-urlencoded"), body(`c=9&b=8&d=2`)))
}
2 changes: 0 additions & 2 deletions doc.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Stuff

/*
Package nvelope provides injection handlers that make building
HTTP endpoints simple. In combination with npoint and nject it
provides a API endpoint framework.
Expand All @@ -25,6 +24,5 @@ an error return to cause a specific HTTP error code to be sent.
CatchPanic makes it easy to turn panics into error returns.
The provided example puts it all together.
*/
package nvelope
4 changes: 2 additions & 2 deletions example_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nvelope_test
import (
"context"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"time"
Expand Down Expand Up @@ -92,7 +92,7 @@ func ExampleServiceWithMiddleware() {
fmt.Println("response error:", err)
return
}
b, err := ioutil.ReadAll(res.Body)
b, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("read error:", err)
return
Expand Down
4 changes: 2 additions & 2 deletions example_minimalerror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nvelope_test
import (
"context"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"

Expand Down Expand Up @@ -38,7 +38,7 @@ func ExampleMinimalErrorHandler() {
fmt.Println("response error:", err)
return
}
b, err := ioutil.ReadAll(res.Body)
b, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("read error:", err)
return
Expand Down
4 changes: 2 additions & 2 deletions example_mwhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nvelope_test
import (
"context"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/http/httptest"
"time"
Expand Down Expand Up @@ -92,7 +92,7 @@ func ExampleServiceWithMiddlewareHandler() {
fmt.Println("response error:", err)
return
}
b, err := ioutil.ReadAll(res.Body)
b, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("read error:", err)
return
Expand Down
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package nvelope_test
import (
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -91,7 +91,7 @@ func Example() {
fmt.Println("response error:", err)
return
}
b, err := ioutil.ReadAll(res.Body)
b, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("read error:", err)
return
Expand Down
Loading

0 comments on commit 4789489

Please sign in to comment.