From dd45d0be5b79c05915c293170882a205aef01558 Mon Sep 17 00:00:00 2001 From: Nick Stott Date: Tue, 30 Jan 2024 13:07:34 -0500 Subject: [PATCH] feat: add a WithDefaultJWTSVIDPicker source option Signed-off-by: Nick Stott --- v2/workloadapi/jwtsource.go | 22 ++++++++++++++++++++-- v2/workloadapi/option.go | 28 +++++++++++++++++++++++----- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/v2/workloadapi/jwtsource.go b/v2/workloadapi/jwtsource.go index 47ea83ad..11223539 100644 --- a/v2/workloadapi/jwtsource.go +++ b/v2/workloadapi/jwtsource.go @@ -16,6 +16,7 @@ var jwtsourceErr = errs.Class("jwtsource") // Workload API. type JWTSource struct { watcher *watcher + picker func([]*jwtsvid.SVID) *jwtsvid.SVID mtx sync.RWMutex bundles *jwtbundle.Set @@ -33,7 +34,9 @@ func NewJWTSource(ctx context.Context, options ...JWTSourceOption) (_ *JWTSource option.configureJWTSource(config) } - s := &JWTSource{} + s := &JWTSource{ + picker: config.picker, + } s.watcher, err = newWatcher(ctx, config.watcher, nil, s.setJWTBundles) if err != nil { @@ -61,7 +64,22 @@ func (s *JWTSource) FetchJWTSVID(ctx context.Context, params jwtsvid.Params) (*j if err := s.checkClosed(); err != nil { return nil, err } - return s.watcher.client.FetchJWTSVID(ctx, params) + + var ( + svid *jwtsvid.SVID + err error + ) + if s.picker == nil { + svid, err = s.watcher.client.FetchJWTSVID(ctx, params) + } else { + svids, err := s.watcher.client.FetchJWTSVIDs(ctx, params) + if err != nil { + return svid, err + } + svid = s.picker(svids) + } + + return svid, err } // FetchJWTSVIDs fetches all JWT-SVIDs from the source with the given parameters. diff --git a/v2/workloadapi/option.go b/v2/workloadapi/option.go index 00cab7d1..63fc86dc 100644 --- a/v2/workloadapi/option.go +++ b/v2/workloadapi/option.go @@ -2,6 +2,7 @@ package workloadapi import ( "github.com/spiffe/go-spiffe/v2/logger" + "github.com/spiffe/go-spiffe/v2/svid/jwtsvid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" "google.golang.org/grpc" ) @@ -60,12 +61,12 @@ type X509SourceOption interface { configureX509Source(*x509SourceConfig) } -// WithDefaultX509SVIDPicker provides a function that is used to determine the -// default X509-SVID when more than one is provided by the Workload API. By -// default, the first X509-SVID in the list returned by the Workload API is +// WithDefaultJWTSVIDPicker provides a function that is used to determine the +// default JWT-SVID when more than one is provided by the Workload API. By +// default, the first JWT-SVID in the list returned by the Workload API is // used. -func WithDefaultX509SVIDPicker(picker func([]*x509svid.SVID) *x509svid.SVID) X509SourceOption { - return withDefaultX509SVIDPicker{picker: picker} +func WithDefaultJWTSVIDPicker(picker func([]*jwtsvid.SVID) *jwtsvid.SVID) JWTSourceOption { + return withDefaultJWTSVIDPicker{picker: picker} } // JWTSourceOption is an option for the JWTSource. A SourceOption is also a @@ -74,6 +75,14 @@ type JWTSourceOption interface { configureJWTSource(*jwtSourceConfig) } +// WithDefaultX509SVIDPicker provides a function that is used to determine the +// default X509-SVID when more than one is provided by the Workload API. By +// default, the first X509-SVID in the list returned by the Workload API is +// used. +func WithDefaultX509SVIDPicker(picker func([]*x509svid.SVID) *x509svid.SVID) X509SourceOption { + return withDefaultX509SVIDPicker{picker: picker} +} + // BundleSourceOption is an option for the BundleSource. A SourceOption is also // a BundleSourceOption. type BundleSourceOption interface { @@ -100,6 +109,7 @@ type x509SourceConfig struct { type jwtSourceConfig struct { watcher watcherConfig + picker func([]*jwtsvid.SVID) *jwtsvid.SVID } type bundleSourceConfig struct { @@ -145,3 +155,11 @@ type withDefaultX509SVIDPicker struct { func (o withDefaultX509SVIDPicker) configureX509Source(config *x509SourceConfig) { config.picker = o.picker } + +type withDefaultJWTSVIDPicker struct { + picker func([]*jwtsvid.SVID) *jwtsvid.SVID +} + +func (o withDefaultJWTSVIDPicker) configureJWTSource(config *jwtSourceConfig) { + config.picker = o.picker +}