-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathrequests.go
91 lines (76 loc) · 1.97 KB
/
requests.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
package logkeeper
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"time"
)
var ErrReadSizeLimitExceeded = errors.New("read size limit exceeded")
// A LimitedReader reads from R but limits the amount of
// data returned to just N bytes. Each call to Read
// updates N to reflect the new amount remaining.
// Note: this is identical to io.LimitedReader, but returns ErrReadSizeLimitExceeded
// so it can be distinguished from a normal EOF.
type LimitedReader struct {
R io.Reader // underlying reader
N int // max bytes remaining
}
// Read returns an error if the bytes in the reader exceed the maximum size
// threshold for the reader, but fail to
func (l *LimitedReader) Read(p []byte) (n int, err error) {
if l.N <= 0 {
return 0, ErrReadSizeLimitExceeded
}
if len(p) > l.N {
p = p[0:l.N]
}
n, err = l.R.Read(p)
l.N -= n
return
}
func readJSON(body io.Reader, maxSize int, out interface{}) *apiError {
decoder := json.NewDecoder(&LimitedReader{body, maxSize})
err := decoder.Decode(out)
if errors.Is(err, ErrReadSizeLimitExceeded) {
return &apiError{
Err: err.Error(),
MaxSize: maxSize,
code: http.StatusRequestEntityTooLarge,
}
} else if err != nil {
return &apiError{
Err: err.Error(),
code: http.StatusBadRequest,
}
}
return nil
}
type ctxKey int
const (
requestIDKey ctxKey = iota
startAtKey
)
func setCtxRequestId(reqID int, r *http.Request) *http.Request {
return r.WithContext(context.WithValue(r.Context(), requestIDKey, reqID))
}
func getCtxRequestId(ctx context.Context) int {
if val := ctx.Value(requestIDKey); val != nil {
if id, ok := val.(int); ok {
return id
}
}
return 0
}
func setStartAtTime(r *http.Request, startAt time.Time) *http.Request {
return r.WithContext(context.WithValue(r.Context(), startAtKey, startAt))
}
func getRequestStartAt(ctx context.Context) time.Time {
if rv := ctx.Value(startAtKey); rv != nil {
if t, ok := rv.(time.Time); ok {
return t
}
}
return time.Time{}
}