Skip to content

Commit

Permalink
Normalize allowed request headers and store them in a sorted set (fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jub0bs authored Apr 24, 2024
1 parent 8d33ca4 commit 4c32059
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 192 deletions.
18 changes: 17 additions & 1 deletion bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"net/http"
"strings"
"testing"
)

Expand Down Expand Up @@ -87,7 +88,22 @@ func BenchmarkPreflightHeader(b *testing.B) {
req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil)
req.Header.Add(headerOrigin, dummyOrigin)
req.Header.Add(headerACRM, http.MethodGet)
req.Header.Add(headerACRH, "Accept")
req.Header.Add(headerACRH, "accept")
handler := Default().Handler(testHandler)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(resps[i], req)
}
}

func BenchmarkPreflightAdversarialACRH(b *testing.B) {
resps := makeFakeResponses(b.N)
req, _ := http.NewRequest(http.MethodOptions, dummyEndpoint, nil)
req.Header.Add(headerOrigin, dummyOrigin)
req.Header.Add(headerACRM, http.MethodGet)
req.Header.Add(headerACRH, strings.Repeat(",", 1024))
handler := Default().Handler(testHandler)

b.ReportAllocs()
Expand Down
57 changes: 22 additions & 35 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"os"
"strconv"
"strings"

"github.com/rs/cors/internal"
)

var headerVaryOrigin = []string{"Origin"}
Expand Down Expand Up @@ -111,7 +113,11 @@ type Cors struct {
// Optional origin validator function
allowOriginFunc func(r *http.Request, origin string) (bool, []string)
// Normalized list of allowed headers
allowedHeaders []string
// Note: the Fetch standard guarantees that CORS-unsafe request-header names
// (i.e. the values listed in the Access-Control-Request-Headers header)
// are unique and sorted;
// see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names.
allowedHeaders internal.SortedSet
// Normalized list of allowed methods
allowedMethods []string
// Pre-computed normalized list of exposed headers
Expand Down Expand Up @@ -183,15 +189,19 @@ func New(options Options) *Cors {
}

// Allowed Headers
// Note: the Fetch standard guarantees that CORS-unsafe request-header names
// (i.e. the values listed in the Access-Control-Request-Headers header)
// are lowercase; see https://fetch.spec.whatwg.org/#cors-unsafe-request-header-names.
if len(options.AllowedHeaders) == 0 {
// Use sensible defaults
c.allowedHeaders = []string{"Accept", "Content-Type", "X-Requested-With"}
c.allowedHeaders = internal.NewSortedSet("accept", "content-type", "x-requested-with")
} else {
c.allowedHeaders = convert(options.AllowedHeaders, http.CanonicalHeaderKey)
normalized := convert(options.AllowedHeaders, strings.ToLower)
c.allowedHeaders = internal.NewSortedSet(normalized...)
for _, h := range options.AllowedHeaders {
if h == "*" {
c.allowedHeadersAll = true
c.allowedHeaders = nil
c.allowedHeaders = internal.SortedSet{}
break
}
}
Expand Down Expand Up @@ -351,10 +361,12 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
c.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
return
}
reqHeadersRaw := r.Header["Access-Control-Request-Headers"]
reqHeaders, reqHeadersEdited := convertDidCopy(splitHeaderValues(reqHeadersRaw), http.CanonicalHeaderKey)
if !c.areHeadersAllowed(reqHeaders) {
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
// Note: the Fetch standard guarantees that at most one
// Access-Control-Request-Headers header is present in the preflight request;
// see step 5.2 in https://fetch.spec.whatwg.org/#cors-preflight-fetch-0.
reqHeaders, found := first(r.Header, "Access-Control-Request-Headers")
if found && !c.allowedHeadersAll && !c.allowedHeaders.Subsumes(reqHeaders[0]) {
c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders[0])
return
}
if c.allowedOriginsAll {
Expand All @@ -365,14 +377,10 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers["Access-Control-Allow-Methods"] = r.Header["Access-Control-Request-Method"]
if len(reqHeaders) > 0 {
if found && len(reqHeaders[0]) > 0 {
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) {
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
} else {
headers["Access-Control-Allow-Headers"] = reqHeadersRaw
}
headers["Access-Control-Allow-Headers"] = reqHeaders
}
if c.allowCredentials {
headers["Access-Control-Allow-Credentials"] = headerTrue
Expand Down Expand Up @@ -492,24 +500,3 @@ func (c *Cors) isMethodAllowed(method string) bool {
}
return false
}

// areHeadersAllowed checks if a given list of headers are allowed to used within
// a cross-domain request.
func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
if c.allowedHeadersAll || len(requestedHeaders) == 0 {
return true
}
for _, header := range requestedHeaders {
found := false
for _, h := range c.allowedHeaders {
if h == header {
found = true
break
}
}
if !found {
return false
}
}
return true
}
78 changes: 10 additions & 68 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,19 @@ func TestSpec(t *testing.T) {
"AllowedHeaders",
Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedHeaders: []string{"X-Header-1", "x-header-2"},
AllowedHeaders: []string{"X-Header-1", "x-header-2", "X-HEADER-3"},
},
"OPTIONS",
map[string]string{
"Origin": "http://foobar.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
"Access-Control-Request-Headers": "x-header-1,x-header-2",
},
map[string]string{
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
"Access-Control-Allow-Headers": "x-header-1,x-header-2",
},
true,
},
Expand All @@ -329,13 +329,13 @@ func TestSpec(t *testing.T) {
map[string]string{
"Origin": "http://foobar.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Requested-With",
"Access-Control-Request-Headers": "x-requested-with",
},
map[string]string{
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "X-Requested-With",
"Access-Control-Allow-Headers": "x-requested-with",
},
true,
},
Expand All @@ -349,13 +349,13 @@ func TestSpec(t *testing.T) {
map[string]string{
"Origin": "http://foobar.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Header-2, X-HEADER-1",
"Access-Control-Request-Headers": "x-header-1,x-header-2",
},
map[string]string{
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
"Access-Control-Allow-Headers": "x-header-1,x-header-2",
},
true,
},
Expand All @@ -369,7 +369,7 @@ func TestSpec(t *testing.T) {
map[string]string{
"Origin": "http://foobar.com",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Header-3, X-Header-1",
"Access-Control-Request-Headers": "x-header-1,x-header-3",
},
map[string]string{
"Vary": "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
Expand Down Expand Up @@ -577,8 +577,8 @@ func TestDefault(t *testing.T) {
if !s.allowedOriginsAll {
t.Error("c.allowedOriginsAll should be true when Default")
}
if s.allowedHeaders == nil {
t.Error("c.allowedHeaders should be nil when Default")
if s.allowedHeaders.Size() == 0 {
t.Error("c.allowedHeaders should be empty when Default")
}
if s.allowedMethods == nil {
t.Error("c.allowedMethods should be nil when Default")
Expand Down Expand Up @@ -712,64 +712,6 @@ func TestOptionsSuccessStatusCodeOverride(t *testing.T) {
})
}

func TestCorsAreHeadersAllowed(t *testing.T) {
cases := []struct {
name string
allowedHeaders []string
requestedHeaders []string
want bool
}{
{
name: "nil allowedHeaders",
allowedHeaders: nil,
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: false,
},
{
name: "star allowedHeaders",
allowedHeaders: []string{"*"},
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: true,
},
{
name: "empty reqHeader",
allowedHeaders: nil,
requestedHeaders: []string{},
want: true,
},
{
name: "match allowedHeaders",
allowedHeaders: []string{"Content-Type", "X-PINGOTHER", "X-APP-KEY"},
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: true,
},
{
name: "not matched allowedHeaders",
allowedHeaders: []string{"X-PINGOTHER"},
requestedHeaders: []string{"X-API-KEY, Content-Type"},
want: false,
},
{
name: "allowedHeaders should be a superset of requestedHeaders",
allowedHeaders: []string{"X-PINGOTHER"},
requestedHeaders: []string{"X-PINGOTHER, Content-Type"},
want: false,
},
}

for _, tt := range cases {
tt := tt

t.Run(tt.name, func(t *testing.T) {
c := New(Options{AllowedHeaders: tt.allowedHeaders})
have := c.areHeadersAllowed(convert(splitHeaderValues(tt.requestedHeaders), http.CanonicalHeaderKey))
if have != tt.want {
t.Errorf("Cors.areHeadersAllowed() have: %t want: %t", have, tt.want)
}
})
}
}

func TestAccessControlExposeHeadersPresence(t *testing.T) {
cases := []struct {
name string
Expand Down
113 changes: 113 additions & 0 deletions internal/sortedset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// adapted from github.com/jub0bs/cors
package internal

import (
"sort"
"strings"
)

// A SortedSet represents a mathematical set of strings sorted in
// lexicographical order.
// Each element has a unique position ranging from 0 (inclusive)
// to the set's cardinality (exclusive).
// The zero value represents an empty set.
type SortedSet struct {
m map[string]int
maxLen int
}

// NewSortedSet returns a SortedSet that contains all of elems,
// but no other elements.
func NewSortedSet(elems ...string) SortedSet {
sort.Strings(elems)
m := make(map[string]int)
var maxLen int
i := 0
for _, s := range elems {
if _, exists := m[s]; exists {
continue
}
m[s] = i
i++
maxLen = max(maxLen, len(s))
}
return SortedSet{
m: m,
maxLen: maxLen,
}
}

// Size returns the cardinality of set.
func (set SortedSet) Size() int {
return len(set.m)
}

// String sorts joins the elements of set (in lexicographical order)
// with a comma and returns the resulting string.
func (set SortedSet) String() string {
elems := make([]string, len(set.m))
for elem, i := range set.m {
elems[i] = elem // safe indexing, by construction of SortedSet
}
return strings.Join(elems, ",")
}

// Subsumes reports whether csv is a sequence of comma-separated names that are
// - all elements of set,
// - sorted in lexicographically order,
// - unique.
func (set SortedSet) Subsumes(csv string) bool {
if csv == "" {
return true
}
posOfLastNameSeen := -1
chunkSize := set.maxLen + 1 // (to accommodate for at least one comma)
for {
// As a defense against maliciously long names in csv,
// we only process at most chunkSize bytes per iteration.
end := min(len(csv), chunkSize)
comma := strings.IndexByte(csv[:end], ',')
var name string
if comma == -1 {
name = csv
} else {
name = csv[:comma]
}
pos, found := set.m[name]
if !found {
return false
}
// The names in csv are expected to be sorted in lexicographical order
// and appear at most once in csv.
// Therefore, the positions (in set) of the names that
// appear in csv should form a strictly increasing sequence.
// If that's not actually the case, bail out.
if pos <= posOfLastNameSeen {
return false
}
posOfLastNameSeen = pos
if comma < 0 { // We've now processed all the names in csv.
break
}
csv = csv[comma+1:]
}
return true
}

// TODO: when updating go directive to 1.21 or later,
// use min builtin instead.
func min(a, b int) int {
if a < b {
return a
}
return b
}

// TODO: when updating go directive to 1.21 or later,
// use max builtin instead.
func max(a, b int) int {
if a > b {
return a
}
return b
}
Loading

0 comments on commit 4c32059

Please sign in to comment.