-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into options_for_cors
- Loading branch information
Showing
2 changed files
with
236 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,130 @@ | ||
require "../../spec_helper" | ||
|
||
module Amber | ||
module Pipe | ||
describe CORS do | ||
context "allowed headers" do | ||
# Pipeline with default settings | ||
pipeline = Pipeline.new | ||
pipeline.build :cors do | ||
plug CORS.new | ||
end | ||
pipeline.prepare_pipelines | ||
|
||
# Pipeline with custom settings | ||
pipeline_custom = Pipeline.new | ||
pipeline_custom.build :cors do | ||
plug CORS.new(allow_headers: "max-age") | ||
end | ||
pipeline_custom.prepare_pipelines | ||
|
||
Amber::Server.router.draw :cors do | ||
options "/test", HelloController, :world | ||
end | ||
|
||
it "should allow a case-insensitive header values" do | ||
request = HTTP::Request.new("OPTIONS", "/test") | ||
request.headers["Access-Control-Request-Method"] = "OPTIONS" | ||
request.headers["Access-Control-Request-Headers"] = "cOnTeNt-TyPe" | ||
response = create_request_and_return_io(pipeline, request) | ||
|
||
response.status_code.should eq 200 | ||
end | ||
|
||
it "allows headers 'accept, content-type' by default" do | ||
request = HTTP::Request.new("OPTIONS", "/test") | ||
request.headers["Access-Control-Request-Method"] = "OPTIONS" | ||
request.headers["Access-Control-Request-Headers"] = "accept" | ||
response = create_request_and_return_io(pipeline, request) | ||
|
||
response.status_code.should eq 200 | ||
response.headers["Access-Control-Allow-Headers"].should eq "accept, content-type" | ||
end | ||
|
||
it "can override settings at initialization" do | ||
request = HTTP::Request.new("OPTIONS", "/test") | ||
request.headers["Access-Control-Request-Method"] = "OPTIONS" | ||
request.headers["Access-Control-Request-Headers"] = "max-age" | ||
response = create_request_and_return_io(pipeline_custom, request) | ||
|
||
response.status_code.should eq 200 | ||
response.headers["Access-Control-Allow-Headers"].should eq "max-age" | ||
end | ||
require "../../../spec_helper" | ||
|
||
module Amber::Pipe | ||
describe CORS do | ||
it "supports simple CORS requests" do | ||
context = cors_context("GET", "Origin": "http://localhost:3000") | ||
CORS.new.call(context) | ||
assert_cors_success(context) | ||
end | ||
|
||
it "does not return CORS headers if Origin header not present" do | ||
context = cors_context("GET") | ||
CORS.new.call(context) | ||
assert_cors_failure context | ||
end | ||
|
||
it "supports OPTIONS request" do | ||
context = cors_context("OPTIONS", "Origin": "example.com") | ||
CORS.new.call(context) | ||
assert_cors_success context | ||
end | ||
|
||
it "matches regex :origin settings" do | ||
context = cors_context("GET", "Origin": "http://192.168.0.1:3000") | ||
origins = CORS::OriginType.new | ||
origins << %r(192\.168\.0\.1) | ||
CORS.new(origins: origins).call(context) | ||
assert_cors_success(context) | ||
end | ||
|
||
it "does not return CORS headers if origins is empty" do | ||
context = cors_context("GET", "Origin": "http://localhost:3000") | ||
CORS.new(origins: CORS::OriginType.new).call(context) | ||
assert_cors_failure context | ||
end | ||
|
||
it "supports alternative X-Origin header" do | ||
context = cors_context("GET", "X-Origin": "http://localhost:3000") | ||
CORS.new.call(context) | ||
assert_cors_success(context) | ||
end | ||
|
||
it "supports expose header configuration" do | ||
expose_header = %w(X-Expose) | ||
context = cors_context("GET", "X-Origin": "http://localhost:3000") | ||
CORS.new(expose_headers: expose_header).call(context) | ||
context.response.headers[Amber::Pipe::Headers::ALLOW_EXPOSE].should eq expose_header.join(",") | ||
end | ||
|
||
it "supports expose multiple header configuration" do | ||
expose_header = %w(X-Example X-Another) | ||
context = cors_context("GET", "X-Origin": "http://localhost:3000") | ||
CORS.new(expose_headers: expose_header).call(context) | ||
context.response.headers[Amber::Pipe::Headers::ALLOW_EXPOSE].should eq expose_header.join(",") | ||
end | ||
|
||
it "adds vary header when origin is other than (*)" do | ||
domain = "example.com" | ||
origins = CORS::OriginType.new | ||
origins << domain | ||
context = cors_context("GET", "Origin": domain) | ||
CORS.new(origins: origins).call(context) | ||
context.response.headers[Amber::Pipe::Headers::VARY].should eq "Origin" | ||
end | ||
|
||
it "does not add vary header when origin is (*)" do | ||
origins = CORS::OriginType.new | ||
origins << "*" | ||
context = cors_context("GET", "Origin": "*") | ||
CORS.new(origins: origins).call(context) | ||
context.response.headers[Amber::Pipe::Headers::VARY]?.should be_nil | ||
end | ||
|
||
it "adds Vary header based on :vary option" do | ||
domain = "example.com" | ||
origins = CORS::OriginType.new | ||
origins << domain | ||
context = cors_context("GET", "Origin": domain) | ||
CORS.new(origins: origins, vary: "Other").call(context) | ||
context.response.headers[Amber::Pipe::Headers::VARY].should eq "Origin,Other" | ||
end | ||
|
||
it "sets allow credential headers if credential settings is true" do | ||
domain = "example.com" | ||
origins = CORS::OriginType.new | ||
origins << domain | ||
context = cors_context("GET", "Origin": domain) | ||
CORS.new(credentials: true, origins: origins, vary: "Other").call(context) | ||
context.response.headers[Amber::Pipe::Headers::ALLOW_CREDENTIALS].should eq "true" | ||
end | ||
|
||
context "when preflight request" do | ||
it "process valid preflight request" do | ||
domain = "example.com" | ||
origins = CORS::OriginType.new | ||
origins << domain | ||
context = cors_context( | ||
"OPTIONS", | ||
"Origin": domain, | ||
"Access-Control-Request-Method": "PUT", | ||
"Access-Control-Request-Headers": "Accept" | ||
) | ||
CORS.new(origins: origins).call(context) | ||
|
||
context.response.status_code = 200 | ||
context.response.headers["Content-Length"].should eq "0" | ||
end | ||
end | ||
end | ||
end | ||
|
||
def cors_context(method = "GET", **args) | ||
headers = HTTP::Headers.new | ||
args.each do |k, v| | ||
headers[k.to_s] = v | ||
end | ||
request = HTTP::Request.new(method, "/", headers) | ||
create_context(request) | ||
end | ||
|
||
def assert_cors_success(context) | ||
origin_header = context.response.headers["Access-Control-Allow-Origin"]? | ||
origin_header.should_not be_nil | ||
end | ||
|
||
def assert_cors_failure(context) | ||
origin_header = context.response.headers["Access-Control-Allow-Origin"]? | ||
context.response.status_code.should eq 403 | ||
origin_header.should be_nil | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,134 @@ | ||
require "./base" | ||
|
||
module Amber | ||
module Pipe | ||
# The CORS Handler adds support for Cross Origin Resource Sharing. | ||
module Headers | ||
VARY = "Vary" | ||
ORIGIN = "Origin" | ||
X_ORIGIN = "X-Origin" | ||
REQUEST_METHOD = "Access-Control-Request-Method" | ||
REQUEST_HEADERS = "Access-Control-Request-Headers" | ||
ALLOW_EXPOSE = "Access-Control-Expose-Headers" | ||
ALLOW_ORIGIN = "Access-Control-Allow-Origin" | ||
ALLOW_METHOD = "Access-Control-Allow-Method" | ||
ALLOW_HEADERS = "Access-Control-Allow-Headers" | ||
ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials" | ||
ALLOW_MAX_AGE = "Access-Control-Max-Age" | ||
end | ||
|
||
class CORS < Base | ||
property allow_origin, allow_headers, allow_methods, allow_credentials, | ||
max_age | ||
alias OriginType = Array(String | Regex) | ||
FORBIDDEN = "Forbidden for invalid origins, methods or headers" | ||
ALLOW_METHODS = %w(PUT PATCH DELETE) | ||
ALLOW_HEADERS = %w(Accept Content-type) | ||
|
||
ALLOW_METHODS = %w(GET HEAD POST DELETE OPTIONS PUT PATCH) | ||
ALLOW_HEADERS = %w(accept content-type) | ||
property origins, headers, methods, credentials, max_age | ||
@origin : Origin | ||
|
||
def initialize( | ||
@allow_origin = "*", | ||
@allow_methods = ALLOW_METHODS, | ||
@allow_headers = ALLOW_HEADERS, | ||
@allow_credentials = false, | ||
@max_age = 0 | ||
@origins : OriginType = ["*", %r()], | ||
@methods = ALLOW_METHODS, | ||
@headers = ALLOW_HEADERS, | ||
@credentials = false, | ||
@max_age : Int32? = 0, | ||
@expose_headers : Array(String)? = nil, | ||
@vary : String? = nil | ||
) | ||
@origin = Origin.new(origins) | ||
end | ||
|
||
def initialize( | ||
@allow_origin = "*", | ||
allow_methods : String = ALLOW_METHODS.join(", "), | ||
allow_headers : String = ALLOW_HEADERS.join(", "), | ||
@allow_credentials = false, | ||
@max_age = 0 | ||
) | ||
@allow_methods = allow_methods.strip.split /[\s,]+/ | ||
@allow_headers = allow_headers.strip.split /[\s,]+/ | ||
def call(context : HTTP::Server::Context) | ||
if @origin.match?(context.request) | ||
put_expose_header(context.response) | ||
Preflight.request?(context, self) | ||
put_response_headers(context.response) | ||
else | ||
return forbidden(context) | ||
end | ||
|
||
call_next(context) | ||
end | ||
|
||
def call(context : HTTP::Server::Context) | ||
context.response.headers["Access-Control-Allow-Origin"] = allow_origin | ||
def forbidden(context) | ||
context.response.headers["Content-Type"] = "text/plain" | ||
context.response.respond_with_error FORBIDDEN, 403 | ||
end | ||
|
||
# TODO: verify the actual origin matches allowed origins. | ||
# if requested_origin = context.request.headers["Origin"] | ||
# if allow_origins.includes? requested_origin | ||
# end | ||
# end | ||
private def put_expose_header(response) | ||
response.headers[Headers::ALLOW_EXPOSE] = @expose_headers.as(Array).join(",") if @expose_headers | ||
end | ||
|
||
if allow_credentials | ||
context.response.headers["Access-Control-Allow-credentials"] = "true" | ||
end | ||
private def put_response_headers(response) | ||
response.headers[Headers::ALLOW_CREDENTIALS] = @credentials.to_s if @credentials | ||
response.headers[Headers::ALLOW_ORIGIN] = @origin.request_origin.not_nil! | ||
response.headers[Headers::VARY] = vary unless @origin.any? | ||
end | ||
|
||
if max_age > 0 | ||
context.response.headers["Access-Control-Max-Age"] = max_age.to_s | ||
private def vary | ||
String.build do |str| | ||
str << Headers::ORIGIN | ||
str << "," << @vary if @vary | ||
end | ||
end | ||
end | ||
|
||
# if asking permission for request method or request headers | ||
if context.request.method.downcase == "options" | ||
context.response.status_code = 200 | ||
response = "" | ||
|
||
if requested_method = context.request.headers["Access-Control-Request-Method"] | ||
if allow_methods.includes? requested_method.strip | ||
context.response.headers["Access-Control-Allow-Methods"] = allow_methods.join(", ") | ||
else | ||
context.response.status_code = 403 | ||
response = "Method #{requested_method} not allowed." | ||
end | ||
end | ||
module Preflight | ||
extend self | ||
|
||
if requested_headers = context.request.headers["Access-Control-Request-Headers"] | ||
requested_headers.split(",").each do |requested_header| | ||
if allow_headers.includes? requested_header.strip.downcase | ||
context.response.headers["Access-Control-Allow-Headers"] = allow_headers.join(", ") | ||
else | ||
context.response.status_code = 403 | ||
response = "Headers #{requested_headers} not allowed." | ||
end | ||
end | ||
def request?(context, cors) | ||
if context.request.method == "OPTIONS" | ||
if valid_method?(context.request, cors.methods) && | ||
valid_headers?(context.request, cors.headers) | ||
put_preflight_headers(context.request, context.response, cors.max_age) | ||
else | ||
cors.forbidden(context) | ||
end | ||
end | ||
end | ||
|
||
context.response.content_type = "text/html; charset=utf-8" | ||
context.response.print(response) | ||
else | ||
call_next(context) | ||
def put_preflight_headers(request, response, max_age) | ||
response.headers[Headers::ALLOW_METHOD] = request.headers[Headers::REQUEST_METHOD] | ||
response.headers[Headers::ALLOW_HEADERS] = request.headers[Headers::REQUEST_HEADERS] | ||
response.headers[Headers::ALLOW_MAX_AGE] = max_age.to_s if max_age | ||
response.content_length = 0 | ||
response.flush | ||
end | ||
|
||
def valid_method?(request, methods) | ||
methods.includes? request.headers[Headers::REQUEST_METHOD]? | ||
end | ||
|
||
def valid_headers?(request, headers) | ||
!(headers & request.headers[Headers::REQUEST_HEADERS].split(',')).empty? | ||
end | ||
end | ||
|
||
struct Origin | ||
getter request_origin : String? | ||
|
||
def initialize(@origins : CORS::OriginType) | ||
end | ||
|
||
def match?(request) | ||
return false if @origins.empty? | ||
return false unless origin_header?(request) | ||
return true if any? | ||
|
||
@origins.any? do |origin| | ||
case origin | ||
when String then origin == request_origin | ||
when Regex then origin =~ request_origin | ||
end | ||
end | ||
end | ||
|
||
def any? | ||
@origins.includes? "*" | ||
end | ||
|
||
private def origin_header?(request) | ||
@request_origin ||= request.headers[Headers::ORIGIN]? || request.headers[Headers::X_ORIGIN]? | ||
end | ||
end | ||
end | ||
end |