-
-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathjwks.go
189 lines (166 loc) · 4.47 KB
/
jwks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
package keyfunc
import (
"context"
"encoding/json"
"errors"
"net/http"
"sync"
"time"
)
var (
// ErrKIDNotFound indicates that the given key ID was not found in the JWKS.
ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
// ErrMissingAssets indicates there are required assets are missing to create a public key.
ErrMissingAssets = errors.New("required assets are missing to create a public key")
)
// ErrorHandler is a function signature that consumes an error.
type ErrorHandler func(err error)
// jsonWebKey represents a raw key inside a JWKS.
type jsonWebKey struct {
Curve string `json:"crv"`
Exponent string `json:"e"`
K string `json:"k"`
ID string `json:"kid"`
Modulus string `json:"n"`
Type string `json:"kty"`
X string `json:"x"`
Y string `json:"y"`
}
// JWKS represents a JSON Web Key Set (JWK Set).
type JWKS struct {
cancel context.CancelFunc
client *http.Client
ctx context.Context
raw []byte
givenKeys map[string]GivenKey
givenKIDOverride bool
jwksURL string
keys map[string]interface{}
mux sync.RWMutex
refreshErrorHandler ErrorHandler
refreshInterval time.Duration
refreshRateLimit time.Duration
refreshRequests chan context.CancelFunc
refreshTimeout time.Duration
refreshUnknownKID bool
requestFactory func(ctx context.Context, url string) (*http.Request, error)
responseExtractor func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
}
// rawJWKS represents a JWKS in JSON format.
type rawJWKS struct {
Keys []*jsonWebKey `json:"keys"`
}
// NewJSON creates a new JWKS from a raw JSON message.
func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
var rawKS rawJWKS
err = json.Unmarshal(jwksBytes, &rawKS)
if err != nil {
return nil, err
}
// Iterate through the keys in the raw JWKS. Add them to the JWKS.
jwks = &JWKS{
keys: make(map[string]interface{}, len(rawKS.Keys)),
}
for _, key := range rawKS.Keys {
var keyInter interface{}
switch keyType := key.Type; keyType {
case ktyEC:
keyInter, err = key.ECDSA()
if err != nil {
continue
}
case ktyOKP:
keyInter, err = key.EdDSA()
if err != nil {
continue
}
case ktyOct:
keyInter, err = key.Oct()
if err != nil {
continue
}
case ktyRSA:
keyInter, err = key.RSA()
if err != nil {
continue
}
default:
// Ignore unknown key types silently.
continue
}
jwks.keys[key.ID] = keyInter
}
return jwks, nil
}
// EndBackground ends the background goroutine to update the JWKS. It can only happen once and is only effective if the
// JWKS has a background goroutine refreshing the JWKS keys.
func (j *JWKS) EndBackground() {
if j.cancel != nil {
j.cancel()
}
}
// KIDs returns the key IDs (`kid`) for all keys in the JWKS.
func (j *JWKS) KIDs() (kids []string) {
j.mux.RLock()
defer j.mux.RUnlock()
kids = make([]string, len(j.keys))
index := 0
for kid := range j.keys {
kids[index] = kid
index++
}
return kids
}
// Len returns the number of keys in the JWKS.
func (j *JWKS) Len() int {
j.mux.RLock()
defer j.mux.RUnlock()
return len(j.keys)
}
// RawJWKS returns a copy of the raw JWKS received from the given JWKS URL.
func (j *JWKS) RawJWKS() []byte {
j.mux.RLock()
defer j.mux.RUnlock()
raw := make([]byte, len(j.raw))
copy(raw, j.raw)
return raw
}
// ReadOnlyKeys returns a read-only copy of the mapping of key IDs (`kid`) to cryptographic keys.
func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
keys := make(map[string]interface{})
j.mux.Lock()
for kid, cryptoKey := range j.keys {
keys[kid] = cryptoKey
}
j.mux.Unlock()
return keys
}
// getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to.
func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) {
j.mux.RLock()
jsonKey, ok := j.keys[kid]
j.mux.RUnlock()
if !ok {
if j.refreshUnknownKID {
ctx, cancel := context.WithCancel(j.ctx)
// Refresh the JWKS.
select {
case <-j.ctx.Done():
return
case j.refreshRequests <- cancel:
default:
// If the j.refreshRequests channel is full, return the error early.
return nil, ErrKIDNotFound
}
// Wait for the JWKS refresh to finish.
<-ctx.Done()
j.mux.RLock()
defer j.mux.RUnlock()
if jsonKey, ok = j.keys[kid]; ok {
return jsonKey, nil
}
}
return nil, ErrKIDNotFound
}
return jsonKey, nil
}