Skip to content

Commit

Permalink
fix(middleware/cors): CORS handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sixcolors committed Mar 26, 2024
1 parent 7ba02c1 commit 83da096
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 43 deletions.
10 changes: 7 additions & 3 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ func New(config ...Config) fiber.Handler {
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))

// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
return c.Next()
}

// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}

Expand Down
43 changes: 3 additions & 40 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand Down Expand Up @@ -104,7 +103,6 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)

require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
Expand Down Expand Up @@ -146,7 +144,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)

Expand Down Expand Up @@ -465,7 +462,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
// Get handler pointer
handler := app.Handler()

t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Without origin", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
Expand All @@ -478,34 +475,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})

t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
Expand All @@ -523,15 +492,14 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})

t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Non-preflight request with origin", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
Expand Down Expand Up @@ -1008,7 +976,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1049,7 +1016,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1083,7 +1049,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1124,7 +1089,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1158,7 +1122,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1199,7 +1162,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)
Expand Down Expand Up @@ -1229,6 +1191,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}

// Preflight request
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
Expand Down

0 comments on commit 83da096

Please sign in to comment.