-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathctxflag.go
108 lines (92 loc) · 3.44 KB
/
ctxflag.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package web
import (
"context"
"math/rand"
"net/http"
"strconv"
"sync"
"github.com/signalfx/golib/v3/log"
)
// HeaderCtxFlag sets a debug value in the context if HeaderName is not empty, a flag string has
// been set to non empty, and the header HeaderName or query string HeaderName is equal to the set
// flag string
type HeaderCtxFlag struct {
HeaderName string
mu sync.RWMutex
expectedVal string
}
// CreateMiddleware creates a handler that calls next as the next in the chain
func (m *HeaderCtxFlag) CreateMiddleware(next ContextHandler) ContextHandler {
return HandlerFunc(func(ctx context.Context, rw http.ResponseWriter, r *http.Request) {
m.ServeHTTPC(ctx, rw, r, next)
})
}
// SetFlagStr enabled flag setting for HeaderName if it's equal to headerVal
func (m *HeaderCtxFlag) SetFlagStr(headerVal string) {
m.mu.Lock()
m.expectedVal = headerVal
m.mu.Unlock()
}
// WithFlag returns a new Context that has the flag for this context set
func (m *HeaderCtxFlag) WithFlag(ctx context.Context) context.Context {
return context.WithValue(ctx, m, struct{}{})
}
// HasFlag returns true if WithFlag has been set for this context
func (m *HeaderCtxFlag) HasFlag(ctx context.Context) bool {
return ctx.Value(m) != nil
}
// FlagStr returns the currently set flag header
func (m *HeaderCtxFlag) FlagStr() string {
m.mu.RLock()
ret := m.expectedVal
m.mu.RUnlock()
return ret
}
// ServeHTTPC calls next with a context flagged if the headers match. Note it checks both headers and query parameters.
func (m *HeaderCtxFlag) ServeHTTPC(ctx context.Context, rw http.ResponseWriter, r *http.Request, next ContextHandler) {
debugStr := m.FlagStr()
if debugStr != "" && m.HeaderName != "" {
if r.Header.Get(m.HeaderName) == debugStr {
ctx = m.WithFlag(ctx)
} else if r.URL.Query().Get(m.HeaderName) == debugStr {
ctx = m.WithFlag(ctx)
}
}
next.ServeHTTPC(ctx, rw, r)
}
// HeadersInRequest adds headers to any context with a flag set
type HeadersInRequest struct {
Headers map[string]string
}
// ServeHTTPC will add headers to rw if ctx has the flag set
func (m *HeadersInRequest) ServeHTTPC(ctx context.Context, rw http.ResponseWriter, r *http.Request, next ContextHandler) {
for k, v := range m.Headers {
rw.Header().Add(k, v)
}
next.ServeHTTPC(ctx, rw, r)
}
// CreateMiddleware creates a handler that calls next as the next in the chain
func (m *HeadersInRequest) CreateMiddleware(next ContextHandler) ContextHandler {
return HandlerFunc(func(ctx context.Context, rw http.ResponseWriter, r *http.Request) {
m.ServeHTTPC(ctx, rw, r, next)
})
}
// CtxWithFlag adds useful request parameters to the logging context, as well as a random request_id
// to the request
type CtxWithFlag struct {
CtxFlagger *log.CtxDimensions
HeaderName string
}
// ServeHTTPC adds useful request dims to the next context
func (m *CtxWithFlag) ServeHTTPC(ctx context.Context, rw http.ResponseWriter, r *http.Request, next ContextHandler) {
headerid := rand.Int63()
rw.Header().Add(m.HeaderName, strconv.FormatInt(headerid, 10))
ctx = m.CtxFlagger.Append(ctx, "header_id", headerid, "http_remote_addr", r.RemoteAddr, "http_method", r.Method, "http_url", r.URL.String())
next.ServeHTTPC(ctx, rw, r)
}
// CreateMiddleware creates a handler that calls next as the next in the chain
func (m *CtxWithFlag) CreateMiddleware(next ContextHandler) ContextHandler {
return HandlerFunc(func(ctx context.Context, rw http.ResponseWriter, r *http.Request) {
m.ServeHTTPC(ctx, rw, r, next)
})
}