Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token introspection #649

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c0eca01
jwt: added token introspection with caching response for configurable…
Dec 7, 2022
5eff963
renamed interval -> ttl
Dec 21, 2022
b28aac1
added explicit type for TTL: duration
Dec 21, 2022
9549852
non-positive TTL prohibits caching of introspection response
Dec 21, 2022
11191d9
ttl variable only where needed
Jan 3, 2023
48642eb
block access to introspection response cache during introspection req…
Jan 3, 2023
7a3a32f
client authentication methods using JWT: create jti differently
Jan 4, 2023
d85aa75
refactored client authenticator to separate struct
Jan 4, 2023
4236217
use client authenticator in token introspector
Jan 4, 2023
1c7257a
check that token actually arrives at introspection endpoint in form b…
Jan 5, 2023
8ab93b0
try atomic.AddUint32() to fix data race in test
Jan 11, 2023
7d8a716
beta_ prefix token introspection
Feb 24, 2023
835716a
changelog entry
Mar 23, 2023
cffef9f
lock per token
Mar 24, 2023
30af29e
more code documentation; exp() is currently only used internally, so …
Mar 24, 2023
5f1b905
move IntrospectionResponse and related functions
Mar 24, 2023
a482d61
move ClientAuthenticator to separate file
Mar 24, 2023
f227547
enhance function documentation
Mar 24, 2023
de27f2b
passed request can also be introspection request
Mar 24, 2023
396a805
refactored Authenticator
johakoch Jan 31, 2024
c8207c7
refactored Introspector
johakoch Jan 31, 2024
edc5df7
refactor: extracted the alg check functions
johakoch Jan 31, 2024
a856200
refactor method arg validation
malud Feb 1, 2024
964dfac
refactor introspection req
malud Feb 1, 2024
d2c788c
refactor introspection req; cancel on happy path
malud Feb 1, 2024
c21e187
method names
malud Feb 1, 2024
b6783d3
refactored jwt.Validate()
johakoch Feb 3, 2024
2bf8b4d
added jwt token inactive error
johakoch Feb 29, 2024
6c8ce24
Merge branch 'master' into token-introspection
johakoch May 8, 2024
e5ee6af
Merge branch 'master' into token-introspection
malud Jul 8, 2024
75308b5
Merge branch 'master' into token-introspection
johakoch Sep 17, 2024
935f9d4
Merge branch 'master' into token-introspection
johakoch Dec 10, 2024
a1e91fa
Merge branch 'master' into token-introspection
johakoch Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

Unreleased changes are available as `coupergateway/couper:edge` container.

* **Added**
* [`beta_introspection` block](https://docs.couper.io/configuration/block/introspection) in [`jwt` block](https://docs.couper.io/configuration/block/jwt) to facilitate token introspection in order to detect revocated tokens ([#649](https://github.com/avenga/couper/pull/649))

---

## [1.13.0](https://github.com/coupergateway/couper/releases/tag/v1.13.0)
Expand Down
182 changes: 182 additions & 0 deletions accesscontrol/introspection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package accesscontrol

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sync"
"time"

"github.com/hashicorp/hcl/v2"

"github.com/coupergateway/couper/cache"
"github.com/coupergateway/couper/config"
"github.com/coupergateway/couper/config/request"
"github.com/coupergateway/couper/eval"
"github.com/coupergateway/couper/eval/buffer"
"github.com/coupergateway/couper/oauth2"
)

// IntrospectionResponse represents the response body to a token introspection request.
type IntrospectionResponse map[string]interface{}

func NewIntrospectionResponse(res *http.Response) (IntrospectionResponse, error) {
var introspectionData IntrospectionResponse

if res.StatusCode != http.StatusOK {
return nil, fmt.Errorf("introspection response status code %d", res.StatusCode)
}

if !eval.IsJSONMediaType(res.Header.Get("Content-Type")) {
return nil, fmt.Errorf("introspection response is not JSON")
}

err := json.NewDecoder(res.Body).Decode(&introspectionData)
return introspectionData, err
}

// Active returns whether the token is active.
func (ir IntrospectionResponse) Active() bool {
active, _ := ir["active"].(bool)
return active
}

func (ir IntrospectionResponse) exp() int64 {
exp, _ := ir["exp"].(int64)
return exp
}

type lock struct {
mu sync.Mutex
}

// Introspector represents a token introspector.
type Introspector struct {
authenticator oauth2.ClientAuthenticator
conf *config.Introspection
locks sync.Map
memStore *cache.MemoryStore
transport http.RoundTripper
}

// NewIntrospector creates a new token introspector.
func NewIntrospector(evalCtx *hcl.EvalContext, conf *config.Introspection, transport http.RoundTripper, memStore *cache.MemoryStore) (*Introspector, error) {
authenticator, err := oauth2.NewClientAuthenticator(evalCtx, conf.EndpointAuthMethod, "endpoint_auth_method", conf.ClientID, conf.ClientSecret, "", conf.JWTSigningProfile)
if err != nil {
return nil, err
}
return &Introspector{
authenticator: authenticator,
conf: conf,
memStore: memStore,
transport: transport,
}, nil
}

// Introspect retrieves introspection data for the given token using either cached or fresh information.
func (i *Introspector) Introspect(ctx context.Context, token string, exp, nbf int64) (IntrospectionResponse, error) {
var (
introspectionData IntrospectionResponse
key string
)

if i.conf.TTLSeconds > 0 {
// lock per token
entry, _ := i.locks.LoadOrStore(token, &lock{})
l := entry.(*lock)
l.mu.Lock()
defer func() {
i.locks.Delete(token)
l.mu.Unlock()
}()

key = "ir:" + token
cachedIntrospection, _ := i.memStore.Get(key).(IntrospectionResponse)
if cachedIntrospection != nil {
return cachedIntrospection, nil
}
}

introspectionData, err := i.doRequestIntrospection(ctx, token)
if err != nil {
return nil, err
}

if i.conf.TTLSeconds <= 0 {
// do not cache
return introspectionData, nil
}

if exp == 0 {
if isdExp := introspectionData.exp(); isdExp > 0 {
exp = isdExp
}
}

ttl := i.getTTL(exp, nbf, introspectionData.Active())
// cache introspection data
i.memStore.Set(key, introspectionData, ttl)

return introspectionData, nil
}

func (i *Introspector) newIntrospectionRequest(ctx context.Context, token string) (*http.Request, context.CancelFunc, error) {
req, _ := http.NewRequest("POST", i.conf.Endpoint, nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

formParams := &url.Values{}
formParams.Add("token", token)

err := i.authenticator.Authenticate(formParams, req)
if err != nil {
return nil, nil, err
}

eval.SetBody(req, []byte(formParams.Encode()))

outCtx, cancel := context.WithCancel(context.WithValue(ctx, request.RoundTripName, "introspection"))
outCtx = context.WithValue(outCtx, request.BufferOptions, buffer.Response)

return req.WithContext(outCtx), cancel, nil
}

func (i *Introspector) doRequestIntrospection(ctx context.Context, token string) (IntrospectionResponse, error) {
req, cancel, err := i.newIntrospectionRequest(ctx, token)
if err != nil {
return nil, err
}
defer cancel()

response, err := i.transport.RoundTrip(req)
if err != nil {
return nil, fmt.Errorf("introspection response: %s", err)
}
defer response.Body.Close()

introspectionData, err := NewIntrospectionResponse(response)
if err != nil {
return nil, fmt.Errorf("introspection response: %s", err)
}

return introspectionData, nil
}

func (i *Introspector) getTTL(exp, nbf int64, active bool) int64 {
ttl := i.conf.TTLSeconds

if exp > 0 {
now := time.Now().Unix()
maxTTL := exp - now
if !active && (nbf <= 0 || now > nbf) {
// nbf is unknown (token has never been inactive before being active)
// or nbf lies in the past (token has become active after having been inactive):
// token will not become active again, so we can store the response until token expires anyway
ttl = maxTTL
} else if ttl > maxTTL {
ttl = maxTTL
}
}
return ttl
}
65 changes: 50 additions & 15 deletions accesscontrol/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,14 @@ type JWT struct {
rolesMap map[string][]string
permissionsClaim string
permissionsMap map[string][]string
introspector *Introspector
jwks *jwk.JWKS
memStore *cache.MemoryStore
}

// NewJWT parses the key and creates Validation obj which can be referenced in related handlers.
func NewJWT(jwtConf *config.JWT, key []byte, memStore *cache.MemoryStore) (*JWT, error) {
jwtAC, err := newJWT(jwtConf, memStore)
func NewJWT(jwtConf *config.JWT, introspector *Introspector, key []byte, memStore *cache.MemoryStore) (*JWT, error) {
jwtAC, err := newJWT(jwtConf, introspector, memStore)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -115,12 +116,12 @@ func parsePublicPEMKey(key []byte) (pub interface{}, err error) {
return pubKey, nil
}

func NewJWTFromJWKS(jwtConf *config.JWT, jwks *jwk.JWKS, memStore *cache.MemoryStore) (*JWT, error) {
func NewJWTFromJWKS(jwtConf *config.JWT, introspector *Introspector, jwks *jwk.JWKS, memStore *cache.MemoryStore) (*JWT, error) {
if jwks == nil {
return nil, fmt.Errorf("invalid JWKS")
}

jwtAC, err := newJWT(jwtConf, memStore)
jwtAC, err := newJWT(jwtConf, introspector, memStore)
if err != nil {
return nil, err
}
Expand All @@ -143,6 +144,19 @@ type parserConfig struct {
issuer string
}

func newParserConfig(algorithms []string, expectedClaims map[string]interface{}) parserConfig {
parserConfig := parserConfig{
algorithms: algorithms,
}
if aud, ok := expectedClaims["aud"].(string); ok {
parserConfig.audience = aud
}
if iss, ok := expectedClaims["iss"].(string); ok {
parserConfig.issuer = iss
}
return parserConfig
}

func (p parserConfig) key() string {
return fmt.Sprintf("pc:%s:%s:%s", p.algorithms, p.audience, p.issuer)
}
Expand All @@ -162,7 +176,7 @@ func (p parserConfig) newParser() *jwt.Parser {
return jwt.NewParser(options...)
}

func newJWT(jwtConf *config.JWT, memStore *cache.MemoryStore) (*JWT, error) {
func newJWT(jwtConf *config.JWT, introspector *Introspector, memStore *cache.MemoryStore) (*JWT, error) {
source, err := NewTokenSource(jwtConf.Bearer, jwtConf.Dpop, jwtConf.Cookie, jwtConf.Header, jwtConf.TokenValue)
if err != nil {
return nil, err
Expand All @@ -176,6 +190,7 @@ func newJWT(jwtConf *config.JWT, memStore *cache.MemoryStore) (*JWT, error) {
claims: jwtConf.Claims,
claimsRequired: jwtConf.ClaimsRequired,
disablePrivateCaching: jwtConf.DisablePrivateCaching,
introspector: introspector,
memStore: memStore,
name: jwtConf.Name,
rolesClaim: jwtConf.RolesClaim,
Expand Down Expand Up @@ -237,16 +252,8 @@ func (j *JWT) Validate(req *http.Request) error {
}

tokenClaims := jwt.MapClaims{}
_, err = parser.ParseWithClaims(tokenValue, tokenClaims, j.getValidationKey)
if err != nil {
if goerrors.Is(err, jwt.ErrTokenExpired) {
return errors.JwtTokenExpired.With(err)
}
if goerrors.Is(err, jwt.ErrTokenInvalidClaims) {
// TODO throw different error?
return errors.JwtTokenInvalid.With(err)
}
return errors.JwtTokenInvalid.With(err)
if err = j.parse(parser, tokenValue, tokenClaims); err != nil {
return err
}

if err = j.source.ValidateTokenClaims(tokenValue, tokenClaims, req); err != nil {
Expand All @@ -260,6 +267,19 @@ func (j *JWT) Validate(req *http.Request) error {
}

ctx := req.Context()
if j.introspector != nil {
exp, _ := tokenClaims["exp"].(float64)
nbf, _ := tokenClaims["nbf"].(float64)
introspectionResponse, err := j.introspector.Introspect(ctx, tokenValue, int64(exp), int64(nbf))
if err != nil {
return err
}

if !introspectionResponse.Active() {
return errors.JwtTokenInactive.Message("token inactive")
}
}

acMap, ok := ctx.Value(request.AccessControls).(map[string]interface{})
if !ok {
acMap = make(map[string]interface{})
Expand All @@ -282,6 +302,21 @@ func (j *JWT) Validate(req *http.Request) error {
return nil
}

func (j *JWT) parse(parser *jwt.Parser, tokenValue string, tokenClaims jwt.MapClaims) error {
_, err := parser.ParseWithClaims(tokenValue, tokenClaims, j.getValidationKey)
if err != nil {
if goerrors.Is(err, jwt.ErrTokenExpired) {
return errors.JwtTokenExpired.With(err)
}
if goerrors.Is(err, jwt.ErrTokenInvalidClaims) {
// TODO throw different error?
return errors.JwtTokenInvalid.With(err)
}
return errors.JwtTokenInvalid.With(err)
}
return nil
}

func (j *JWT) getValidationKey(token *jwt.Token) (interface{}, error) {
if j.jwks != nil {
return j.jwks.GetSigKeyForToken(token)
Expand Down
12 changes: 6 additions & 6 deletions accesscontrol/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ QolLGgj3tz4NbDEitq+zKMr0uTHvP1Vyu1mXAflcpYcJA4ZmuB3Oj39e0U0gnmr/
ClaimsRequired: tt.fields.claimsRequired,
Name: "test_ac",
SignatureAlgorithm: tt.fields.algorithm,
}, key, memStore)
}, nil, key, memStore)
if jerr != nil {
if tt.wantErr != jerr.Error() {
subT.Errorf("error: %v, want: %v", jerr.Error(), tt.wantErr)
Expand Down Expand Up @@ -333,7 +333,7 @@ func Test_JWT_Validate(t *testing.T) {
Cookie: tt.fields.cookie,
Header: tt.fields.header,
TokenValue: tt.fields.tokenValue,
}, tt.fields.pubKey, memStore)
}, nil, tt.fields.pubKey, memStore)
if err != nil {
subT.Error(err)
return
Expand Down Expand Up @@ -477,7 +477,7 @@ func Test_JWT_Validate_claims(t *testing.T) {
SignatureAlgorithm: "HS256",
Claims: hcl.StaticExpr(cty.ObjectVal(claimValMap), hcl.Range{}),
Bearer: true,
}, key, memStore)
}, nil, key, memStore)
if err != nil {
subT.Error(err)
return
Expand Down Expand Up @@ -554,7 +554,7 @@ mwIDAQAB
jwtAC, err := ac.NewJWT(&config.JWT{
Dpop: true,
SignatureAlgorithm: algo.String(),
}, pubKeyBytes, memStore)
}, nil, pubKeyBytes, memStore)
h.Must(err)

type testCase struct {
Expand Down Expand Up @@ -916,7 +916,7 @@ func Test_JWT_yields_permissions(t *testing.T) {
RolesClaim: tt.rolesClaim,
RolesMap: rolesMap,
SignatureAlgorithm: algo.String(),
}, pubKeyBytes, memStore)
}, nil, pubKeyBytes, memStore)
if err != nil {
subT.Fatal(err)
}
Expand Down Expand Up @@ -1397,7 +1397,7 @@ func Test_JWT_Validate_Concurrency(t *testing.T) {
Claims: hcl.StaticExpr(cty.ObjectVal(claimValMap), hcl.Range{}),
Name: "test_ac",
Bearer: true,
}, key, memStore)
}, nil, key, memStore)
if err != nil {
t.Error(err)
return
Expand Down
Loading