diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index e27e74cba8..9159a1340c 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -172,8 +172,9 @@ func New(config ...Config) fiber.Handler { // Get originHeader header originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin)) - // If the request does not have an Origin header, the request is outside the scope of CORS - if originHeader == "" { + // If the request does not have Origin and Access-Control-Request-Method + // headers, the request is outside the scope of CORS + if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" { return c.Next() } @@ -211,8 +212,9 @@ func New(config ...Config) fiber.Handler { } // Simple request + // Ommit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { - setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg) return c.Next() } @@ -233,14 +235,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' - if allowOrigin != "*" && allowOrigin != "" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } else if allowOrigin == "*" { + if allowOrigin == "*" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") + } else if allowOrigin != "" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") } - } else if len(allowOrigin) > 0 { + } else if allowOrigin != "" { // For non-credential requests, it's safe to set to '*' or specific origins c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index ff5cdd7c25..c56d3c503b 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -35,6 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) @@ -49,6 +50,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -59,6 +61,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { // Test default OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) @@ -87,6 +90,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -101,6 +105,7 @@ func Test_CORS_Wildcard(t *testing.T) { ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) handler(ctx) utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) @@ -128,6 +133,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) @@ -141,6 +147,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) @@ -226,6 +233,7 @@ func Test_CORS_Subdomain(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -240,6 +248,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with domain only (disallowed) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) @@ -252,6 +261,7 @@ func Test_CORS_Subdomain(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com") handler(ctx) @@ -366,6 +376,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin) handler(ctx) @@ -422,6 +433,90 @@ func Test_CORS_Next(t *testing.T) { utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) } +// go test -run Test_CORS_Headers_BasedOnRequestType +func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Use(New(Config{})) + + methods := []string{ + fiber.MethodGet, + fiber.MethodPost, + fiber.MethodPut, + fiber.MethodDelete, + fiber.MethodPatch, + fiber.MethodHead, + } + + // Get handler pointer + handler := app.Handler() + + t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make request without origin header, and without Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) { + // Make request with origin header, but without Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) { + // Make request without origin header, but with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + handler(ctx) + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") + } + }) + + t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make preflight request with origin header and with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.SetRequestURI("https://example.com/") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) + handler(ctx) + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") + utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)") + } + }) + + t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { + // Make non-preflight request with origin header and with Access-Control-Request-Method + for _, method := range methods { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.SetMethod(method) + ctx.Request.SetRequestURI("https://example.com/api/action") + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) + handler(ctx) + utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)") + utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)") + } + }) +} + func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { t.Parallel() // New fiber instance @@ -440,6 +535,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request @@ -454,6 +550,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com") handler(ctx) @@ -466,6 +563,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -505,6 +603,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) { // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) @@ -652,6 +751,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) @@ -742,6 +842,7 @@ func Test_CORS_AllowCredentials(t *testing.T) { ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) + ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx)