-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcsrf.go
136 lines (117 loc) · 3.38 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package common
import (
"fmt"
"net/http"
"net/url"
"github.com/gin-gonic/gin"
)
const (
badTooken = "CSRF token missing or incorrect."
tookenMissing = "CSRF cookie not set."
noReferer = "Referer checking failed - no Referer."
malformedReferer = "Referer checking failed - Referer is malformed."
insecureReferer = "Referer checking failed - Referer is insecure while host is secure."
// CsrfTokenKey is the cookie name which contains the CSRF token.
CsrfTokenKey = "csrftoken"
// Xcsrf is the header name which contains the CSRF token.
Xcsrf = "X-CSRF-Token"
// Authorization is the header name which contains the token.
Authorization = "Authorization"
)
var ignoreMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
func (e ErrorResponse) Error() string {
return fmt.Sprintf("%d: %s", e.Code, e.Message)
}
// ErrorResponse provides HTTP error response.
type ErrorResponse struct {
Code uint `json:"code" example:"400"`
Message string `json:"message" example:"Bad request"`
}
func getHeader(c *gin.Context) string {
return c.Request.Header.Get(Xcsrf)
}
func isAPIUser(c *gin.Context) bool {
return c.Request.Header.Get(Authorization) != ""
}
func getCookie(c *gin.Context) string {
session, err := c.Request.Cookie(CsrfTokenKey)
if err == nil && session.Value != "" {
return session.Value
}
return ""
}
// CSRF is middleware for handling CSRF protection in gin.
func CSRF(excludePaths []string) gin.HandlerFunc {
return func(c *gin.Context) {
// allow machineuser
if isAPIUser(c) {
c.Next()
return
}
csrfToken := getCookie(c)
// Assume that anything not defined as 'safe' by RFC7231 needs protection
if ContainsString(ignoreMethods, c.Request.Method) || ContainsString(excludePaths, c.Request.URL.Path) {
// set cookie in response if not found
if csrfToken == "" {
val, err := RandomToken()
if err != nil {
c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer})
c.Abort()
return
}
http.SetCookie(c.Writer, &http.Cookie{
Name: CsrfTokenKey,
Value: val,
Path: "/",
Domain: c.Request.URL.Host,
HttpOnly: false,
Secure: IsHTTPS(c.Request),
MaxAge: 12 * 60 * 60,
SameSite: http.SameSiteLaxMode,
})
}
// Set the Vary: Cookie header to protect clients from caching the response.
c.Header("Vary", "Cookie")
c.Next()
return
}
if IsHTTPS(c.Request) {
referer := c.Request.Header.Get("Referer")
if referer == "" {
c.JSON(403, ErrorResponse{Code: 403, Message: noReferer})
c.Abort()
return
}
parsedURL, err := url.Parse(referer)
if err != nil {
c.JSON(403, ErrorResponse{Code: 403, Message: malformedReferer})
c.Abort()
return
}
if parsedURL.Scheme != "https" {
c.JSON(403, ErrorResponse{Code: 403, Message: insecureReferer})
c.Abort()
return
}
if parsedURL.Host != c.Request.Host {
msg := fmt.Sprintf("Referer checking failed - %s does not match any trusted origins.", parsedURL.Host)
c.JSON(403, ErrorResponse{Code: 403, Message: msg})
c.Abort()
return
}
}
requestCSRFToken := getHeader(c)
if csrfToken == "" {
c.JSON(403, ErrorResponse{Code: 403, Message: tookenMissing})
c.Abort()
return
}
if requestCSRFToken != csrfToken {
c.JSON(403, ErrorResponse{Code: 403, Message: badTooken})
c.Abort()
return
}
// process request
c.Next()
}
}