diff --git a/context.go b/context.go index e76b5b6..cf7fb92 100644 --- a/context.go +++ b/context.go @@ -202,6 +202,19 @@ func (c *Context) QueryString(key string) string { return c.req.query(key) } +// QueryBool returns the value of a given query parameter as a bool. +func (c *Context) QueryBool(key string) (bool, error) { + str := c.req.query(key) + if str == "" { + return false, nil + } + value, err := strconv.ParseBool(str) + if err != nil { + return false, err + } + return value, nil +} + // QueryInt returns the value of a given query parameter as an int. func (c *Context) QueryInt(key string) (int, error) { str := c.req.query(key) diff --git a/context_test.go b/context_test.go index 7655a3e..ad964e4 100644 --- a/context_test.go +++ b/context_test.go @@ -528,6 +528,60 @@ func TestContext_QueryString(t *testing.T) { } } +func TestContext_QueryBool(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=true", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryBool("key") + if err != nil { + t.Errorf("ctx.QueryBool(\"key\") returned an error: %v", err) + } + want := true + if got != want { + t.Errorf("QueryBool() = %v, want %v", got, want) + } +} + +func TestContext_QueryBoolWithException(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=notabool", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + + _, err = ctx.QueryBool("key") + if err == nil { + t.Error("ctx.QueryBool(\"key\") did not return an error") + } +} + +func TestContext_QueryBoolWithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryBool("key") + if err != nil { + t.Errorf("ctx.QueryBool(\"key\") returned an error: %v", err) + } + want := false + if got != want { + t.Errorf("QueryBool() = %v, want %v", got, want) + } +} + func TestContext_QueryInt(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -563,6 +617,25 @@ func TestContext_QueryIntWithException(t *testing.T) { } } +func TestContext_QueryIntWithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryInt("key") + if err != nil { + t.Errorf("ctx.QueryInt(\"key\") returned an error: %v", err) + } + want := 0 + if got != want { + t.Errorf("QueryInt() = %d, want %d", got, want) + } +} + func TestContext_QueryUInt(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -598,6 +671,25 @@ func TestContext_QueryUIntWithException(t *testing.T) { } } +func TestContext_QueryUIntWithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryUInt("key") + if err != nil { + t.Errorf("ctx.QueryUInt(\"key\") returned an error: %v", err) + } + want := uint(0) + if got != want { + t.Errorf("QueryUInt() = %d, want %d", got, want) + } +} + func TestContext_QueryInt8(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -633,6 +725,25 @@ func TestContext_QueryInt8WithException(t *testing.T) { } } +func TestContext_QueryInt8WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryInt8("key") + if err != nil { + t.Errorf("ctx.QueryInt8(\"key\") returned an error: %v", err) + } + want := int8(0) + if got != want { + t.Errorf("QueryInt8() = %d, want %d", got, want) + } +} + func TestContext_QueryUInt8(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -668,6 +779,25 @@ func TestContext_QueryUInt8WithException(t *testing.T) { } } +func TestContext_QueryUInt8WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryUInt8("key") + if err != nil { + t.Errorf("ctx.QueryUInt8(\"key\") returned an error: %v", err) + } + want := uint8(0) + if got != want { + t.Errorf("QueryUInt8() = %d, want %d", got, want) + } +} + func TestContext_QueryInt32(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -703,6 +833,25 @@ func TestContext_QueryInt32WithException(t *testing.T) { } } +func TestContext_QueryInt32WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryInt32("key") + if err != nil { + t.Errorf("ctx.QueryInt32(\"key\") returned an error: %v", err) + } + want := int32(0) + if got != want { + t.Errorf("QueryInt32() = %d, want %d", got, want) + } +} + func TestContext_QueryUInt32(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -738,6 +887,25 @@ func TestContext_QueryUInt32WithException(t *testing.T) { } } +func TestContext_QueryUInt32WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryUInt32("key") + if err != nil { + t.Errorf("ctx.QueryUInt32(\"key\") returned an error: %v", err) + } + want := uint32(0) + if got != want { + t.Errorf("QueryUInt32() = %d, want %d", got, want) + } +} + func TestContext_QueryInt64(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -773,6 +941,25 @@ func TestContext_QueryInt64WithException(t *testing.T) { } } +func TestContext_QueryInt64WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryInt64("key") + if err != nil { + t.Errorf("ctx.QueryInt64(\"key\") returned an error: %v", err) + } + want := int64(0) + if got != want { + t.Errorf("QueryInt64() = %d, want %d", got, want) + } +} + func TestContext_QueryUInt64(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=123", nil) if err != nil { @@ -808,6 +995,25 @@ func TestContext_QueryUInt64WithException(t *testing.T) { } } +func TestContext_QueryUInt64WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryUInt64("key") + if err != nil { + t.Errorf("ctx.QueryUInt64(\"key\") returned an error: %v", err) + } + want := uint64(0) + if got != want { + t.Errorf("QueryUInt64() = %d, want %d", got, want) + } +} + func TestContext_QueryFloat32(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=3.1415", nil) if err != nil { @@ -843,6 +1049,25 @@ func TestContext_QueryFloat32WithException(t *testing.T) { } } +func TestContext_QueryFloat32WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryFloat32("key") + if err != nil { + t.Errorf("ctx.QueryFloat32(\"key\") returned an error: %v", err) + } + want := float32(0) + if got != want { + t.Errorf("QueryFloat32() = %f, want %f", got, want) + } +} + func TestContext_QueryFloat64(t *testing.T) { req, err := http.NewRequest("GET", "/path?key=3.1415", nil) if err != nil { @@ -878,6 +1103,25 @@ func TestContext_QueryFloat64WithException(t *testing.T) { } } +func TestContext_QueryFloat64WithEmptyKey(t *testing.T) { + req, err := http.NewRequest("GET", "/path?key=", nil) + if err != nil { + t.Fatal(err) + } + ctx, err := NewContext(httptest.NewRecorder(), req) + if err != nil { + t.Fatal(err) + } + got, err := ctx.QueryFloat64("key") + if err != nil { + t.Errorf("ctx.QueryFloat64(\"key\") returned an error: %v", err) + } + want := float64(0) + if got != want { + t.Errorf("QueryFloat64() = %f, want %f", got, want) + } +} + func TestContextQueries(t *testing.T) { req, err := http.NewRequest("GET", "/path?foo=bar&baz=qux", nil) if err != nil {