Skip to content

Commit

Permalink
Add method WatchX509Bundles (#192)
Browse files Browse the repository at this point in the history
Signed-off-by: Max Lambrecht <[email protected]>
  • Loading branch information
maxlambrecht authored Apr 29, 2022
1 parent 23ed83e commit 06f0549
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 29 deletions.
17 changes: 17 additions & 0 deletions v2/bundle/x509bundle/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,23 @@ func Parse(trustDomain spiffeid.TrustDomain, b []byte) (*Bundle, error) {
return bundle, nil
}

// ParseRaw parses a bundle from bytes. The certificate must be ASN.1 DER (concatenated
// with no intermediate padding if there are more than one certificate)
func ParseRaw(trustDomain spiffeid.TrustDomain, b []byte) (*Bundle, error) {
bundle := New(trustDomain)
certs, err := x509.ParseCertificates(b)
if err != nil {
return nil, x509bundleErr.New("cannot parse certificate: %v", err)
}
if len(certs) == 0 {
return nil, x509bundleErr.New("no certificates found")
}
for _, cert := range certs {
bundle.AddX509Authority(cert)
}
return bundle, nil
}

// TrustDomain returns the trust domain that the bundle belongs to.
func (b *Bundle) TrustDomain() spiffeid.TrustDomain {
return b.trustDomain
Expand Down
58 changes: 58 additions & 0 deletions v2/bundle/x509bundle/bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/internal/pemutil"
"github.com/spiffe/go-spiffe/v2/internal/test"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -136,6 +137,49 @@ func TestParse(t *testing.T) {
}
}

func TestParseRaw(t *testing.T) {
tests := []struct {
name string
trustDomain spiffeid.TrustDomain
path string
expNumAuthorities int
expErrContains string
}{
{
name: "Parse multiple certificates should succeed",
path: "testdata/certs.pem",
expNumAuthorities: 2,
},
{
name: "Parse single certificate should succeed",
path: "testdata/cert.pem",
expNumAuthorities: 1,
},
{
name: "Parse should fail if no certificate block is is found",
path: "testdata/key.pem",
expErrContains: "x509bundle: no certificates found",
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
certsBytes := loadRawCertificates(t, test.path)
bundle, err := x509bundle.ParseRaw(td, certsBytes)

if test.expErrContains != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expErrContains)
return
}
require.NoError(t, err)
assert.NotNil(t, bundle)
assert.Len(t, bundle.X509Authorities(), test.expNumAuthorities)
})
}
}

func TestX509AuthorityCRUD(t *testing.T) {
// Load bundle1, which contains a single certificate
bundle1, err := x509bundle.Load(td, "testdata/cert.pem")
Expand Down Expand Up @@ -274,3 +318,17 @@ func TestClone(t *testing.T) {
cloned := original.Clone()
require.True(t, original.Equal(cloned))
}

func loadRawCertificates(t *testing.T, path string) []byte {
certsBytes, err := ioutil.ReadFile(path)
require.NoError(t, err)

certs, err := pemutil.ParseCertificates(certsBytes)
require.NoError(t, err)

var rawBytes []byte
for _, cert := range certs {
rawBytes = append(rawBytes, cert.Raw...)
}
return rawBytes
}
102 changes: 90 additions & 12 deletions v2/internal/test/fakeworkloadapi/workload_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/internal/pemutil"
"github.com/spiffe/go-spiffe/v2/internal/x509util"
"github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
Expand All @@ -29,22 +30,25 @@ import (
var noIdentityError = status.Error(codes.PermissionDenied, "no identity issued")

type WorkloadAPI struct {
tb testing.TB
wg sync.WaitGroup
addr string
server *grpc.Server
mu sync.Mutex
x509Resp *workload.X509SVIDResponse
x509Chans map[chan *workload.X509SVIDResponse]struct{}
jwtResp *workload.JWTSVIDResponse
jwtBundlesResp *workload.JWTBundlesResponse
jwtBundlesChans map[chan *workload.JWTBundlesResponse]struct{}
tb testing.TB
wg sync.WaitGroup
addr string
server *grpc.Server
mu sync.Mutex
x509Resp *workload.X509SVIDResponse
x509Chans map[chan *workload.X509SVIDResponse]struct{}
jwtResp *workload.JWTSVIDResponse
jwtBundlesResp *workload.JWTBundlesResponse
jwtBundlesChans map[chan *workload.JWTBundlesResponse]struct{}
x509BundlesResp *workload.X509BundlesResponse
x509BundlesChans map[chan *workload.X509BundlesResponse]struct{}
}

func New(tb testing.TB) *WorkloadAPI {
w := &WorkloadAPI{
x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
jwtBundlesChans: make(map[chan *workload.JWTBundlesResponse]struct{}),
x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
jwtBundlesChans: make(map[chan *workload.JWTBundlesResponse]struct{}),
x509BundlesChans: make(map[chan *workload.X509BundlesResponse]struct{}),
}

listener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -126,6 +130,38 @@ func (w *WorkloadAPI) SetJWTBundles(jwtBundles ...*jwtbundle.Bundle) {
}
}

func (w *WorkloadAPI) SetX509Bundles(x509Bundles ...*x509bundle.Bundle) {
resp := &workload.X509BundlesResponse{
Bundles: make(map[string][]byte),
}
for _, bundle := range x509Bundles {
bundleBytes, err := bundle.Marshal()
assert.NoError(w.tb, err)
bundlePem, err := pemutil.ParseCertificates(bundleBytes)
assert.NoError(w.tb, err)

var rawBytes []byte
for _, c := range bundlePem {
rawBytes = append(rawBytes, c.Raw...)
}

resp.Bundles[bundle.TrustDomain().String()] = rawBytes
}

w.mu.Lock()
defer w.mu.Unlock()
w.x509BundlesResp = resp

for ch := range w.x509BundlesChans {
select {
case ch <- w.x509BundlesResp:
default:
<-ch
ch <- w.x509BundlesResp
}
}
}

type workloadAPIWrapper struct {
workload.UnimplementedSpiffeWorkloadAPIServer
w *WorkloadAPI
Expand All @@ -135,6 +171,10 @@ func (w *workloadAPIWrapper) FetchX509SVID(req *workload.X509SVIDRequest, stream
return w.w.fetchX509SVID(req, stream)
}

func (w *workloadAPIWrapper) FetchX509Bundles(req *workload.X509BundlesRequest, stream workload.SpiffeWorkloadAPI_FetchX509BundlesServer) error {
return w.w.fetchX509Bundles(req, stream)
}

func (w *workloadAPIWrapper) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) {
return w.w.fetchJWTSVID(ctx, req)
}
Expand Down Expand Up @@ -221,6 +261,44 @@ func (w *WorkloadAPI) fetchX509SVID(_ *workload.X509SVIDRequest, stream workload
}
}

func (w *WorkloadAPI) fetchX509Bundles(_ *workload.X509BundlesRequest, stream workload.SpiffeWorkloadAPI_FetchX509BundlesServer) error {
if err := checkHeader(stream.Context()); err != nil {
return err
}
ch := make(chan *workload.X509BundlesResponse, 1)
w.mu.Lock()
w.x509BundlesChans[ch] = struct{}{}
resp := w.x509BundlesResp
w.mu.Unlock()

defer func() {
w.mu.Lock()
delete(w.x509BundlesChans, ch)
w.mu.Unlock()
}()

sendResp := func(resp *workload.X509BundlesResponse) error {
if resp == nil {
return noIdentityError
}
return stream.Send(resp)
}

if err := sendResp(resp); err != nil {
return err
}
for {
select {
case resp := <-ch:
if err := sendResp(resp); err != nil {
return err
}
case <-stream.Context().Done():
return stream.Context().Err()
}
}
}

func (w *WorkloadAPI) fetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest) (*workload.JWTSVIDResponse, error) {
if err := checkHeader(ctx); err != nil {
return nil, err
Expand Down
75 changes: 73 additions & 2 deletions v2/workloadapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

stream, err := c.wlClient.FetchX509SVID(ctx, &workload.X509SVIDRequest{})
stream, err := c.wlClient.FetchX509Bundles(ctx, &workload.X509BundlesRequest{})
if err != nil {
return nil, err
}
Expand All @@ -113,7 +113,21 @@ func (c *Client) FetchX509Bundles(ctx context.Context) (*x509bundle.Set, error)
return nil, err
}

return parseX509Bundles(resp)
return parseX509BundlesResponse(resp)
}

// WatchX509Bundles watches for changes to the X.509 bundles. The watcher receives
// the updated X.509 bundles.
func (c *Client) WatchX509Bundles(ctx context.Context, watcher X509BundleWatcher) error {
backoff := newBackoff()
for {
err := c.watchX509Bundles(ctx, watcher, backoff)
watcher.OnX509BundlesWatchError(err)
err = c.handleWatchError(ctx, err, backoff)
if err != nil {
return err
}
}
}

// FetchX509Context fetches the X.509 context, which contains both X509-SVIDs
Expand Down Expand Up @@ -321,6 +335,33 @@ func (c *Client) watchJWTBundles(ctx context.Context, watcher JWTBundleWatcher,
}
}

func (c *Client) watchX509Bundles(ctx context.Context, watcher X509BundleWatcher, backoff *backoff) error {
ctx, cancel := context.WithCancel(withHeader(ctx))
defer cancel()

c.config.log.Debugf("Watching X.509 bundles")
stream, err := c.wlClient.FetchX509Bundles(ctx, &workload.X509BundlesRequest{})
if err != nil {
return err
}

for {
resp, err := stream.Recv()
if err != nil {
return err
}

backoff.Reset()
x509bundleSet, err := parseX509BundlesResponse(resp)
if err != nil {
c.config.log.Errorf("Failed to parse X.509 bundle response: %v", err)
watcher.OnX509BundlesWatchError(err)
continue
}
watcher.OnX509BundlesUpdate(x509bundleSet)
}
}

// X509ContextWatcher receives X509Context updates from the Workload API.
type X509ContextWatcher interface {
// OnX509ContextUpdate is called with the latest X.509 context retrieved
Expand All @@ -343,6 +384,17 @@ type JWTBundleWatcher interface {
OnJWTBundlesWatchError(error)
}

// X509BundleWatcher receives X.509 bundle updates from the Workload API.
type X509BundleWatcher interface {
// OnX509BundlesUpdate is called with the latest X.509 bundle set retrieved
// from the Workload API.
OnX509BundlesUpdate(*x509bundle.Set)

// OnX509BundlesWatchError is called when there is a problem establishing
// or maintaining connectivity with the Workload API.
OnX509BundlesWatchError(error)
}

func withHeader(ctx context.Context) context.Context {
header := metadata.Pairs("workload.spiffe.io", "true")
return metadata.NewOutgoingContext(ctx, header)
Expand Down Expand Up @@ -432,6 +484,25 @@ func parseX509Bundle(spiffeID string, bundle []byte) (*x509bundle.Bundle, error)
return x509bundle.FromX509Authorities(td, certs), nil
}

func parseX509BundlesResponse(resp *workload.X509BundlesResponse) (*x509bundle.Set, error) {
bundles := []*x509bundle.Bundle{}

for tdID, b := range resp.Bundles {
td, err := spiffeid.TrustDomainFromString(tdID)
if err != nil {
return nil, err
}

b, err := x509bundle.ParseRaw(td, b)
if err != nil {
return nil, err
}
bundles = append(bundles, b)
}

return x509bundle.NewSet(bundles...), nil
}

// parseJWTSVIDs parses one or all of the SVIDs in the response. If firstOnly
// is true, then only the first SVID in the response is parsed and returned.
// Otherwise all SVIDs are parsed and returned.
Expand Down
Loading

0 comments on commit 06f0549

Please sign in to comment.