diff --git a/cmd/spire-agent/cli/api/common.go b/cmd/spire-agent/cli/api/common.go index 6789b0b621..6b319a6799 100644 --- a/cmd/spire-agent/cli/api/common.go +++ b/cmd/spire-agent/cli/api/common.go @@ -13,6 +13,8 @@ import ( "google.golang.org/grpc/metadata" ) +const commandTimeout = 5 * time.Second + type workloadClient struct { workload.SpiffeWorkloadAPIClient timeout time.Duration @@ -71,7 +73,7 @@ func adaptCommand(env *cli.Env, clientsMaker workloadClientMaker, cmd command) * clientsMaker: clientsMaker, cmd: cmd, env: env, - timeout: cli.DurationFlag(time.Second), + timeout: cli.DurationFlag(commandTimeout), } fs := flag.NewFlagSet(cmd.name(), flag.ContinueOnError) diff --git a/cmd/spire-agent/cli/run/run.go b/cmd/spire-agent/cli/run/run.go index 7cf33896c0..92d69efbca 100644 --- a/cmd/spire-agent/cli/run/run.go +++ b/cmd/spire-agent/cli/run/run.go @@ -106,7 +106,8 @@ type experimentalConfig struct { Flags fflag.RawConfig `hcl:"feature_flags"` - UnusedKeys []string `hcl:",unusedKeys"` + UnusedKeys []string `hcl:",unusedKeys"` + X509SVIDCacheMaxSize int `hcl:"x509_svid_cache_max_size"` } type Command struct { @@ -394,6 +395,11 @@ func NewAgentConfig(c *Config, logOptions []log.Option, allowUnknownConfig bool) } } + if c.Agent.Experimental.X509SVIDCacheMaxSize < 0 { + return nil, errors.New("x509_svid_cache_max_size should not be negative") + } + ac.X509SVIDCacheMaxSize = c.Agent.Experimental.X509SVIDCacheMaxSize + serverHostPort := net.JoinHostPort(c.Agent.ServerAddress, strconv.Itoa(c.Agent.ServerPort)) ac.ServerAddress = fmt.Sprintf("dns:///%s", serverHostPort) diff --git a/cmd/spire-agent/cli/run/run_test.go b/cmd/spire-agent/cli/run/run_test.go index 1eb8c3c2dc..c54cdda620 100644 --- a/cmd/spire-agent/cli/run/run_test.go +++ b/cmd/spire-agent/cli/run/run_test.go @@ -727,6 +727,42 @@ func TestNewAgentConfig(t *testing.T) { require.Nil(t, c) }, }, + { + msg: "x509_svid_cache_max_size is set", + input: func(c *Config) { + c.Agent.Experimental.X509SVIDCacheMaxSize = 100 + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 100, c.X509SVIDCacheMaxSize) + }, + }, + { + msg: "x509_svid_cache_max_size is not set", + input: func(c *Config) { + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) + }, + }, + { + msg: "x509_svid_cache_max_size is zero", + input: func(c *Config) { + c.Agent.Experimental.X509SVIDCacheMaxSize = 0 + }, + test: func(t *testing.T, c *agent.Config) { + require.EqualValues(t, 0, c.X509SVIDCacheMaxSize) + }, + }, + { + msg: "x509_svid_cache_max_size is negative", + expectError: true, + input: func(c *Config) { + c.Agent.Experimental.X509SVIDCacheMaxSize = -10 + }, + test: func(t *testing.T, c *agent.Config) { + require.Nil(t, c) + }, + }, { msg: "allowed_foreign_jwt_claims provided", input: func(c *Config) { diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 7becbcb595..a44d8fb38b 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -212,20 +212,20 @@ func (a *Agent) attest(ctx context.Context, sto storage.Storage, cat catalog.Cat func (a *Agent) newManager(ctx context.Context, sto storage.Storage, cat catalog.Catalog, metrics telemetry.Metrics, as *node_attestor.AttestationResult, cache *storecache.Cache, na nodeattestor.NodeAttestor) (manager.Manager, error) { config := &manager.Config{ - SVID: as.SVID, - SVIDKey: as.Key, - Bundle: as.Bundle, - Reattestable: as.Reattestable, - Catalog: cat, - TrustDomain: a.c.TrustDomain, - ServerAddr: a.c.ServerAddress, - Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), - Metrics: metrics, - WorkloadKeyType: a.c.WorkloadKeyType, - Storage: sto, - SyncInterval: a.c.SyncInterval, - SVIDStoreCache: cache, - NodeAttestor: na, + SVID: as.SVID, + SVIDKey: as.Key, + Bundle: as.Bundle, + Catalog: cat, + TrustDomain: a.c.TrustDomain, + ServerAddr: a.c.ServerAddress, + Log: a.c.Log.WithField(telemetry.SubsystemName, telemetry.Manager), + Metrics: metrics, + WorkloadKeyType: a.c.WorkloadKeyType, + Storage: sto, + SyncInterval: a.c.SyncInterval, + SVIDCacheMaxSize: a.c.X509SVIDCacheMaxSize, + SVIDStoreCache: cache, + NodeAttestor: na, } mgr := manager.New(config) diff --git a/pkg/agent/api/delegatedidentity/v1/service.go b/pkg/agent/api/delegatedidentity/v1/service.go index 87a2020988..56a94dbae9 100644 --- a/pkg/agent/api/delegatedidentity/v1/service.go +++ b/pkg/agent/api/delegatedidentity/v1/service.go @@ -82,24 +82,24 @@ func (s *Service) isCallerAuthorized(ctx context.Context, log logrus.FieldLogger } } - identities := s.manager.MatchingIdentities(callerSelectors) - numRegisteredIDs := len(identities) + entries := s.manager.MatchingRegistrationEntries(callerSelectors) + numRegisteredEntries := len(entries) - if numRegisteredIDs == 0 { + if numRegisteredEntries == 0 { log.Error("no identity issued") return nil, status.Error(codes.PermissionDenied, "no identity issued") } - for _, identity := range identities { - if _, ok := s.authorizedDelegates[identity.Entry.SpiffeId]; ok { + for _, entry := range entries { + if _, ok := s.authorizedDelegates[entry.SpiffeId]; ok { return callerSelectors, nil } } // caller has identity associeted with but none is authorized log.WithFields(logrus.Fields{ - "num_registered_ids": numRegisteredIDs, - "default_id": identities[0].Entry.SpiffeId, + "num_registered_entries": numRegisteredEntries, + "default_id": entries[0].SpiffeId, }).Error("Permission denied; caller not configured as an authorized delegate.") return nil, status.Error(codes.PermissionDenied, "caller not configured as an authorized delegate") @@ -120,7 +120,11 @@ func (s *Service) SubscribeToX509SVIDs(req *delegatedidentityv1.SubscribeToX509S return status.Error(codes.InvalidArgument, "could not parse provided selectors") } - subscriber := s.manager.SubscribeToCacheChanges(selectors) + subscriber, err := s.manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() for { @@ -268,11 +272,11 @@ func (s *Service) FetchJWTSVIDs(ctx context.Context, req *delegatedidentityv1.Fe } var spiffeIDs []spiffeid.ID - identities := s.manager.MatchingIdentities(selectors) - for _, identity := range identities { - spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId) + entries := s.manager.MatchingRegistrationEntries(selectors) + for _, entry := range entries { + spiffeID, err := spiffeid.FromString(entry.SpiffeId) if err != nil { - log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") + log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err) } diff --git a/pkg/agent/api/delegatedidentity/v1/service_test.go b/pkg/agent/api/delegatedidentity/v1/service_test.go index 96035d18a0..8baf8f20b1 100644 --- a/pkg/agent/api/delegatedidentity/v1/service_test.go +++ b/pkg/agent/api/delegatedidentity/v1/service_test.go @@ -87,6 +87,16 @@ func TestSubscribeToX509SVIDs(t *testing.T) { expectCode: codes.PermissionDenied, expectMsg: "caller not configured as an authorized delegate", }, + { + testName: "subscribe to cache changes error", + authSpiffeID: []string{"spiffe://example.org/one"}, + identities: []cache.Identity{ + identityFromX509SVID(x509SVID1), + }, + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + }, { testName: "workload update with one identity", authSpiffeID: []string{"spiffe://example.org/one"}, @@ -653,10 +663,6 @@ func (fa FakeAttestor) Attest(ctx context.Context) ([]*common.Selector, error) { return fa.selectors, fa.err } -func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.identities -} - type FakeManager struct { manager.Manager @@ -677,9 +683,12 @@ func (m *FakeManager) subscriberDone() { atomic.AddInt32(&m.subscribers, -1) } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } atomic.AddInt32(&m.subscribers, 1) - return newFakeSubscriber(m, m.updates) + return newFakeSubscriber(m, m.updates), nil } func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) { @@ -692,6 +701,14 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au }, nil } +func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + out := make([]*common.RegistrationEntry, 0, len(m.identities)) + for _, identity := range m.identities { + out = append(out, identity.Entry) + } + return out +} + type fakeSubscriber struct { m *FakeManager ch chan *cache.WorkloadUpdate diff --git a/pkg/agent/config.go b/pkg/agent/config.go index a3f10e9fae..06d60a81e8 100644 --- a/pkg/agent/config.go +++ b/pkg/agent/config.go @@ -59,6 +59,9 @@ type Config struct { // SyncInterval controls how often the agent sync synchronizer waits SyncInterval time.Duration + // X509SVIDCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + X509SVIDCacheMaxSize int + // Trust domain and associated CA bundle TrustDomain spiffeid.TrustDomain TrustBundle []*x509.Certificate diff --git a/pkg/agent/endpoints/sdsv2/handler.go b/pkg/agent/endpoints/sdsv2/handler.go index dfff824140..3d96f6ccc1 100644 --- a/pkg/agent/endpoints/sdsv2/handler.go +++ b/pkg/agent/endpoints/sdsv2/handler.go @@ -31,7 +31,7 @@ type Attestor interface { } type Manager interface { - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate } @@ -64,7 +64,11 @@ func (h *Handler) StreamSecrets(stream discovery_v2.SecretDiscoveryService_Strea return err } - sub := h.c.Manager.SubscribeToCacheChanges(selectors) + sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer sub.Finish() updch := sub.Updates() diff --git a/pkg/agent/endpoints/sdsv2/handler_test.go b/pkg/agent/endpoints/sdsv2/handler_test.go index 4ac1a68712..be9bdb96f6 100644 --- a/pkg/agent/endpoints/sdsv2/handler_test.go +++ b/pkg/agent/endpoints/sdsv2/handler_test.go @@ -552,7 +552,7 @@ func NewFakeManager(t *testing.T) *FakeManager { } } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { require.Equal(m.t, workloadSelectors, selectors) updch := make(chan *cache.WorkloadUpdate, 1) @@ -568,7 +568,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S return NewFakeSubscriber(updch, func() { delete(m.subs, key) close(updch) - }) + }), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/endpoints/sdsv3/handler.go b/pkg/agent/endpoints/sdsv3/handler.go index 07da317be9..0e98bad562 100644 --- a/pkg/agent/endpoints/sdsv3/handler.go +++ b/pkg/agent/endpoints/sdsv3/handler.go @@ -39,7 +39,7 @@ type Attestor interface { } type Manager interface { - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate } @@ -74,7 +74,11 @@ func (h *Handler) StreamSecrets(stream secret_v3.SecretDiscoveryService_StreamSe return err } - sub := h.c.Manager.SubscribeToCacheChanges(selectors) + sub, err := h.c.Manager.SubscribeToCacheChanges(stream.Context(), selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer sub.Finish() updch := sub.Updates() diff --git a/pkg/agent/endpoints/sdsv3/handler_test.go b/pkg/agent/endpoints/sdsv3/handler_test.go index 937b307bc6..1fd9455b3c 100644 --- a/pkg/agent/endpoints/sdsv3/handler_test.go +++ b/pkg/agent/endpoints/sdsv3/handler_test.go @@ -831,6 +831,21 @@ func TestStreamSecretsBadNonce(t *testing.T) { requireSecrets(t, resp, workloadTLSCertificate2) } +func TestStreamSecretsErrInSubscribeToCacheChanges(t *testing.T) { + test := setupErrTest(t) + defer test.server.Stop() + + stream, err := test.handler.StreamSecrets(context.Background()) + require.NoError(t, err) + defer func() { + require.NoError(t, stream.CloseSend()) + }() + + resp, err := stream.Recv() + require.Error(t, err) + require.Nil(t, resp) +} + func TestFetchSecrets(t *testing.T) { for _, tt := range []struct { name string @@ -1174,11 +1189,16 @@ func DeltaSecretsTest(t *testing.T) { } func setupTest(t *testing.T) *handlerTest { - return setupTestWithConfig(t, Config{}) + return setupTestWithManager(t, Config{}, NewFakeManager(t)) } -func setupTestWithConfig(t *testing.T, c Config) *handlerTest { +func setupErrTest(t *testing.T) *handlerTest { manager := NewFakeManager(t) + manager.err = errors.New("bad-error") + return setupTestWithManager(t, Config{}, manager) +} + +func setupTestWithManager(t *testing.T, c Config, manager *FakeManager) *handlerTest { defaultConfig := Config{ Manager: manager, Attestor: FakeAttestor(workloadSelectors), @@ -1220,6 +1240,11 @@ func setupTestWithConfig(t *testing.T, c Config) *handlerTest { return test } +func setupTestWithConfig(t *testing.T, c Config) *handlerTest { + manager := NewFakeManager(t) + return setupTestWithManager(t, c, manager) +} + type handlerTest struct { t *testing.T @@ -1279,6 +1304,7 @@ type FakeManager struct { upd *cache.WorkloadUpdate next int subs map[int]chan *cache.WorkloadUpdate + err error } func NewFakeManager(t *testing.T) *FakeManager { @@ -1288,7 +1314,10 @@ func NewFakeManager(t *testing.T) *FakeManager { } } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } require.Equal(m.t, workloadSelectors, selectors) updch := make(chan *cache.WorkloadUpdate, 1) @@ -1304,7 +1333,7 @@ func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.S return NewFakeSubscriber(updch, func() { delete(m.subs, key) close(updch) - }) + }), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/endpoints/workload/handler.go b/pkg/agent/endpoints/workload/handler.go index b1074f95ef..81ab3e5570 100644 --- a/pkg/agent/endpoints/workload/handler.go +++ b/pkg/agent/endpoints/workload/handler.go @@ -30,8 +30,8 @@ import ( ) type Manager interface { - SubscribeToCacheChanges(cache.Selectors) cache.Subscriber - MatchingIdentities([]*common.Selector) []cache.Identity + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) FetchWorkloadUpdate([]*common.Selector) *cache.WorkloadUpdate } @@ -85,15 +85,15 @@ func (h *Handler) FetchJWTSVID(ctx context.Context, req *workload.JWTSVIDRequest log = log.WithField(telemetry.Registered, true) - identities := h.c.Manager.MatchingIdentities(selectors) - for _, identity := range identities { - if req.SpiffeId != "" && identity.Entry.SpiffeId != req.SpiffeId { + entries := h.c.Manager.MatchingRegistrationEntries(selectors) + for _, entry := range entries { + if req.SpiffeId != "" && entry.SpiffeId != req.SpiffeId { continue } - spiffeID, err := spiffeid.FromString(identity.Entry.SpiffeId) + spiffeID, err := spiffeid.FromString(entry.SpiffeId) if err != nil { - log.WithField(telemetry.SPIFFEID, identity.Entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") + log.WithField(telemetry.SPIFFEID, entry.SpiffeId).WithError(err).Error("Invalid requested SPIFFE ID") return nil, status.Errorf(codes.InvalidArgument, "invalid requested SPIFFE ID: %v", err) } @@ -138,7 +138,11 @@ func (h *Handler) FetchJWTBundles(req *workload.JWTBundlesRequest, stream worklo return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() var previousResp *workload.JWTBundlesResponse @@ -224,7 +228,11 @@ func (h *Handler) FetchX509SVID(_ *workload.X509SVIDRequest, stream workload.Spi return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() for { @@ -250,7 +258,11 @@ func (h *Handler) FetchX509Bundles(_ *workload.X509BundlesRequest, stream worklo return err } - subscriber := h.c.Manager.SubscribeToCacheChanges(selectors) + subscriber, err := h.c.Manager.SubscribeToCacheChanges(ctx, selectors) + if err != nil { + log.WithError(err).Error("Subscribe to cache changes failed") + return err + } defer subscriber.Finish() var previousResp *workload.X509BundlesResponse diff --git a/pkg/agent/endpoints/workload/handler_test.go b/pkg/agent/endpoints/workload/handler_test.go index 0165deebfe..92cabb973b 100644 --- a/pkg/agent/endpoints/workload/handler_test.go +++ b/pkg/agent/endpoints/workload/handler_test.go @@ -56,6 +56,7 @@ func TestFetchX509SVID(t *testing.T) { name string updates []*cache.WorkloadUpdate attestErr error + managerErr error asPID int expectCode codes.Code expectMsg string @@ -103,6 +104,23 @@ func TestFetchX509SVID(t *testing.T) { }, }, }, + { + name: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchX509SVID", + logrus.ErrorKey: "err", + }, + }, + }, + }, { name: "with identity and federated bundles", updates: []*cache.WorkloadUpdate{{ @@ -167,6 +185,7 @@ func TestFetchX509SVID(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AsPID: tt.asPID, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -195,6 +214,7 @@ func TestFetchX509Bundles(t *testing.T) { testName string updates []*cache.WorkloadUpdate attestErr error + managerErr error expectCode codes.Code expectMsg string expectResp *workloadPB.X509BundlesResponse @@ -235,6 +255,23 @@ func TestFetchX509Bundles(t *testing.T) { }, }, }, + { + testName: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchX509Bundles", + logrus.ErrorKey: "err", + }, + }, + }, + }, { testName: "cache update unexpectedly missing bundle", updates: []*cache.WorkloadUpdate{ @@ -307,6 +344,7 @@ func TestFetchX509Bundles(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AllowUnauthenticatedVerifiers: tt.allowUnauthenticatedVerifiers, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -665,6 +703,7 @@ func TestFetchJWTBundles(t *testing.T) { name string updates []*cache.WorkloadUpdate attestErr error + managerErr error expectCode codes.Code expectMsg string expectResp *workloadPB.JWTBundlesResponse @@ -705,6 +744,23 @@ func TestFetchJWTBundles(t *testing.T) { }, }, }, + { + name: "subscribe to cache changes error", + managerErr: errors.New("err"), + expectCode: codes.Unknown, + expectMsg: "err", + expectLogs: []spiretest.LogEntry{ + { + Level: logrus.ErrorLevel, + Message: "Subscribe to cache changes failed", + Data: logrus.Fields{ + "service": "WorkloadAPI", + "method": "FetchJWTBundles", + logrus.ErrorKey: "err", + }, + }, + }, + }, { name: "cache update unexpectedly missing bundle", updates: []*cache.WorkloadUpdate{ @@ -777,6 +833,7 @@ func TestFetchJWTBundles(t *testing.T) { AttestErr: tt.attestErr, ExpectLogs: tt.expectLogs, AllowUnauthenticatedVerifiers: tt.allowUnauthenticatedVerifiers, + ManagerErr: tt.managerErr, } runTest(t, params, func(ctx context.Context, client workloadPB.SpiffeWorkloadAPIClient) { @@ -1300,8 +1357,12 @@ type FakeManager struct { err error } -func (m *FakeManager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.identities +func (m *FakeManager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + out := make([]*common.RegistrationEntry, 0, len(m.identities)) + for _, identity := range m.identities { + out = append(out, identity.Entry) + } + return out } func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audience []string) (*client.JWTSVID, error) { @@ -1314,9 +1375,12 @@ func (m *FakeManager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, au }, nil } -func (m *FakeManager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { +func (m *FakeManager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + if m.err != nil { + return nil, m.err + } atomic.AddInt32(&m.subscribers, 1) - return newFakeSubscriber(m, m.updates) + return newFakeSubscriber(m, m.updates), nil } func (m *FakeManager) FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate { diff --git a/pkg/agent/manager/cache/cache.go b/pkg/agent/manager/cache/cache.go index 7aa980b2de..7f98a7396b 100644 --- a/pkg/agent/manager/cache/cache.go +++ b/pkg/agent/manager/cache/cache.go @@ -1,6 +1,7 @@ package cache import ( + "context" "crypto" "crypto/x509" "sort" @@ -200,16 +201,8 @@ func (c *Cache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdat return c.buildWorkloadUpdate(set) } -func (c *Cache) SubscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { - c.mu.Lock() - defer c.mu.Unlock() - - sub := newSubscriber(c, selectors) - for s := range sub.set { - c.addSelectorIndexSub(s, sub) - } - c.notify(sub) - return sub +func (c *Cache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(selectors), nil } // UpdateEntries updates the cache with the provided registration entries and bundles and @@ -446,6 +439,55 @@ func (c *Cache) GetStaleEntries() []*StaleEntry { return staleEntries } +func (c *Cache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() + + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) + for record := range records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *Cache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *Cache) SyncSVIDsWithSubscribers() { + c.log.Error("SyncSVIDsWithSubscribers method is not implemented") +} + +func (c *Cache) subscribeToWorkloadUpdates(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() + + sub := newSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) + } + c.notify(sub) + return sub +} + func (c *Cache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*cacheRecord, *common.RegistrationEntry) { var existingEntry *common.RegistrationEntry record, recordExists := c.records[newEntry.EntryId] diff --git a/pkg/agent/manager/cache/cache_test.go b/pkg/agent/manager/cache/cache_test.go index e7101b260d..8f8372842e 100644 --- a/pkg/agent/manager/cache/cache_test.go +++ b/pkg/agent/manager/cache/cache_test.go @@ -137,11 +137,11 @@ func TestAllSubscribersNotifiedOnBundleChange(t *testing.T) { cache := newTestCache() // create some subscribers and assert they get the initial bundle - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) @@ -168,11 +168,11 @@ func TestSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { }) // subscribe to A and B and assert initial updates are received. - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -231,11 +231,11 @@ func TestSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { cache := newTestCache() // create subscribers for each combination of selectors - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() - subAB := cache.SubscribeToWorkloadUpdates(makeSelectors("A", "B")) + subAB := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer subAB.Finish() // assert all subscribers get the initial update @@ -288,7 +288,7 @@ func TestSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -314,7 +314,7 @@ func TestSubscriberNotifiedOnSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -343,7 +343,7 @@ func TestSubcriberNotificationsOnSelectorChanges(t *testing.T) { }) // create subscribers for A and make sure the initial update has FOO - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ Bundle: bundleV1, @@ -388,13 +388,13 @@ func newTestCache() *Cache { func TestSubcriberNotifiedWhenEntryDropped(t *testing.T) { cache := newTestCache() - subA := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + subA := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer subA.Finish() assertAnyWorkloadUpdate(t, subA) // subB's job here is to just make sure we don't notify unrelated // subscribers when dropping registration entries - subB := cache.SubscribeToWorkloadUpdates(makeSelectors("B")) + subB := cache.subscribeToWorkloadUpdates(makeSelectors("B")) defer subB.Finish() assertAnyWorkloadUpdate(t, subB) @@ -438,7 +438,7 @@ func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { } cache.UpdateEntries(updateEntries, nil) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() // workload update does not include the identity because it has no SVID. @@ -459,7 +459,7 @@ func TestSubcriberOnlyGetsEntriesWithSVID(t *testing.T) { func TestSubscribersDoNotBlockNotifications(t *testing.T) { cache := newTestCache() - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() cache.UpdateEntries(&UpdateEntries{ @@ -607,7 +607,7 @@ func TestSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -632,7 +632,7 @@ func TestSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { X509SVIDs: makeX509SVIDs(foo, bar), }) - sub := cache.SubscribeToWorkloadUpdates(makeSelectors("A", "B")) + sub := cache.subscribeToWorkloadUpdates(makeSelectors("A", "B")) defer sub.Finish() assertAnyWorkloadUpdate(t, sub) @@ -672,7 +672,7 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { cache.UpdateEntries(updateEntries, nil) for i := 0; i < numWorkloads; i++ { selectors := distinctSelectors(i, selectorsPerWorkload) - cache.SubscribeToWorkloadUpdates(selectors) + cache.subscribeToWorkloadUpdates(selectors) } runtime.GC() @@ -689,6 +689,47 @@ func BenchmarkCacheGlobalNotification(b *testing.B) { } } +func TestMatchingRegistrationEntries(t *testing.T) { + cache := newTestCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + + // check empty result + assert.Equal(t, []*common.RegistrationEntry{}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) +} + +func TestEntries(t *testing.T) { + cache := newTestCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, cache.Entries()) +} + func distinctSelectors(id, n int) []*common.Selector { out := make([]*common.Selector, 0, n) for i := 0; i < n; i++ { diff --git a/pkg/agent/manager/cache/lru_cache.go b/pkg/agent/manager/cache/lru_cache.go new file mode 100644 index 0000000000..ad09719ac0 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache.go @@ -0,0 +1,939 @@ +package cache + +import ( + "context" + "sort" + "sync" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/agent/common/backoff" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" +) + +const ( + DefaultSVIDCacheMaxSize = 1000 + SVIDSyncInterval = 500 * time.Millisecond +) + +// Cache caches each registration entry, bundles, and JWT SVIDs for the agent. +// The signed X509-SVIDs for those entries are stored in LRU-like cache. +// It allows subscriptions by (workload) selector sets and notifies subscribers when: +// +// 1) a registration entry related to the selectors: +// - is modified +// - has a new X509-SVID signed for it +// - federates with a federated bundle that is updated +// +// 2) the trust bundle for the agent trust domain is updated +// +// When notified, the subscriber is given a WorkloadUpdate containing +// related identities and trust bundles. +// +// The cache does this efficiently by building an index for each unique +// selector it encounters. Each selector index tracks the subscribers (i.e +// workloads) and registration entries that have that selector. +// +// The LRU-like SVID cache has configurable size limit and expiry period. +// 1. Size limit of SVID cache is a soft limit. If SVID has a subscriber present then +// that SVID is never removed from cache. +// 2. Least recently used SVIDs are removed from cache only after the cache expiry period has passed. +// This is done to reduce the overall cache churn. +// 3. Last access timestamp for SVID cache entry is updated when a new subscriber is created +// 4. When a new subscriber is created and there is a cache miss +// then subscriber needs to wait for next SVID sync event to receive WorkloadUpdate with newly minted SVID +// +// The advantage of above approach is that if agent has entry count less than cache size +// then all SVIDs are cached at all times. If agent has entry count greater than cache size then +// subscribers will continue to get SVID updates (potential delay for first WorkloadUpdate if cache miss) +// and least used SVIDs will be removed from cache which will save memory usage. +// This allows agent to support environments where the active simultaneous workload count +// is a small percentage of the large number of registrations assigned to the agent. +// +// When registration entries are added/updated/removed, the set of relevant +// selectors are gathered and the indexes for those selectors are combed for +// all relevant subscribers. +// +// For each relevant subscriber, the selector index for each selector of the +// subscriber is combed for registration whose selectors are a subset of the +// subscriber selector set. Identities for those entries are added to the +// workload update returned to the subscriber. +// +// NOTE: The cache is intended to be able to handle thousands of workload +// subscriptions, which can involve thousands of certificates, keys, bundles, +// and registration entries, etc. The selector index itself is intended to be +// scalable, but the objects themselves can take a considerable amount of +// memory. For maximal safety, the objects should be cloned both coming in and +// leaving the cache. However, during global updates (e.g. trust bundle is +// updated for the agent trust domain) in particular, cloning all of the +// relevant objects for each subscriber causes HUGE amounts of memory pressure +// which adds non-trivial amounts of latency and causes a giant memory spike +// that could OOM the agent on smaller VMs. For this reason, the cache is +// presumed to own ALL data passing in and out of the cache. Producers and +// consumers MUST NOT mutate the data. +type LRUCache struct { + *BundleCache + *JWTSVIDCache + + log logrus.FieldLogger + trustDomain spiffeid.TrustDomain + clk clock.Clock + + metrics telemetry.Metrics + + mu sync.RWMutex + + // records holds the records for registration entries, keyed by registration entry ID + records map[string]*lruCacheRecord + + // selectors holds the selector indices, keyed by a selector key + selectors map[selector]*selectorsMapIndex + + // staleEntries holds stale or new registration entries which require new SVID to be stored in cache + staleEntries map[string]bool + + // bundles holds the trust bundles, keyed by trust domain id (i.e. "spiffe://domain.test") + bundles map[spiffeid.TrustDomain]*bundleutil.Bundle + + // svids are stored by entry IDs + svids map[string]*X509SVID + + // svidCacheMaxSize is a soft limit of max number of SVIDs that would be stored in cache + svidCacheMaxSize int + subscribeBackoffFn func() backoff.BackOff +} + +func NewLRUCache(log logrus.FieldLogger, trustDomain spiffeid.TrustDomain, bundle *Bundle, metrics telemetry.Metrics, + svidCacheMaxSize int, clk clock.Clock) *LRUCache { + if svidCacheMaxSize <= 0 { + svidCacheMaxSize = DefaultSVIDCacheMaxSize + } + + return &LRUCache{ + BundleCache: NewBundleCache(trustDomain, bundle), + JWTSVIDCache: NewJWTSVIDCache(), + + log: log, + metrics: metrics, + trustDomain: trustDomain, + records: make(map[string]*lruCacheRecord), + selectors: make(map[selector]*selectorsMapIndex), + staleEntries: make(map[string]bool), + bundles: map[spiffeid.TrustDomain]*bundleutil.Bundle{ + trustDomain: bundle, + }, + svids: make(map[string]*X509SVID), + svidCacheMaxSize: svidCacheMaxSize, + clk: clk, + subscribeBackoffFn: func() backoff.BackOff { + return backoff.NewBackoff(clk, SVIDSyncInterval) + }, + } +} + +// Identities is only used by manager tests +// TODO: We should remove this and find a better way +func (c *LRUCache) Identities() []Identity { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]Identity, 0, len(c.records)) + for _, record := range c.records { + svid, ok := c.svids[record.entry.EntryId] + if !ok { + // The record does not have an SVID yet and should not be returned + // from the cache. + continue + } + out = append(out, makeNewIdentity(record, svid)) + } + sortIdentities(out) + return out +} + +func (c *LRUCache) Entries() []*common.RegistrationEntry { + c.mu.RLock() + defer c.mu.RUnlock() + + out := make([]*common.RegistrationEntry, 0, len(c.records)) + for _, record := range c.records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) CountSVIDs() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.svids) +} + +func (c *LRUCache) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.matchingEntries(set) +} + +func (c *LRUCache) FetchWorkloadUpdate(selectors []*common.Selector) *WorkloadUpdate { + set, setDone := allocSelectorSet(selectors...) + defer setDone() + + c.mu.RLock() + defer c.mu.RUnlock() + return c.buildWorkloadUpdate(set) +} + +// NewSubscriber creates a subscriber for given selector set. +// Separately call Notify for the first time after this method is invoked to receive latest updates. +func (c *LRUCache) NewSubscriber(selectors []*common.Selector) Subscriber { + c.mu.Lock() + defer c.mu.Unlock() + + sub := newLRUCacheSubscriber(c, selectors) + for s := range sub.set { + c.addSelectorIndexSub(s, sub) + } + // update lastAccessTimestamp of records containing provided selectors + c.updateLastAccessTimestamp(selectors) + return sub +} + +// UpdateEntries updates the cache with the provided registration entries and bundles and +// notifies impacted subscribers. The checkSVID callback, if provided, is used to determine +// if the SVID for the entry is stale, or otherwise in need of rotation. Entries marked stale +// through the checkSVID callback are returned from GetStaleEntries() until the SVID is +// updated through a call to UpdateSVIDs. +func (c *LRUCache) UpdateEntries(update *UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *X509SVID) bool) { + c.mu.Lock() + defer c.mu.Unlock() + + // Remove bundles that no longer exist. The bundle for the agent trust + // domain should NOT be removed even if not present (which should only be + // the case if there is a bug on the server) since it is necessary to + // authenticate the server. + bundleRemoved := false + for id := range c.bundles { + if _, ok := update.Bundles[id]; !ok && id != c.trustDomain { + bundleRemoved = true + // bundle no longer exists. + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle removed") + delete(c.bundles, id) + } + } + + // Update bundles with changes, populating a "changed" set that we can + // check when processing registration entries to know if they need to spawn + // a notification. + bundleChanged := make(map[spiffeid.TrustDomain]bool) + for id, bundle := range update.Bundles { + existing, ok := c.bundles[id] + if !(ok && existing.EqualTo(bundle)) { + if !ok { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle added") + } else { + c.log.WithField(telemetry.TrustDomainID, id).Debug("Bundle updated") + } + bundleChanged[id] = true + c.bundles[id] = bundle + } + } + trustDomainBundleChanged := bundleChanged[c.trustDomain] + + // Allocate sets from the pool to track changes to selectors and + // federatesWith declarations. These sets must be cleared after EACH use + // and returned to their respective pools when done processing the + // updates. + notifySets := make([]selectorSet, 0) + selAdd, selAddDone := allocSelectorSet() + defer selAddDone() + selRem, selRemDone := allocSelectorSet() + defer selRemDone() + fedAdd, fedAddDone := allocStringSet() + defer fedAddDone() + fedRem, fedRemDone := allocStringSet() + defer fedRemDone() + + // Remove records for registration entries that no longer exist + for id, record := range c.records { + if _, ok := update.RegistrationEntries[id]; !ok { + c.log.WithFields(logrus.Fields{ + telemetry.Entry: id, + telemetry.SPIFFEID: record.entry.SpiffeId, + }).Debug("Entry removed") + + // built a set of selectors for the record being removed, drop the + // record for each selector index, and add the entry selectors to + // the notify set. + clearSelectorSet(selRem) + selRem.Merge(record.entry.Selectors...) + c.delSelectorIndicesRecord(selRem, record) + notifySets = append(notifySets, selRem) + delete(c.records, id) + delete(c.svids, id) + // Remove stale entry since, registration entry is no longer on cache. + delete(c.staleEntries, id) + } + } + + outdatedEntries := make(map[string]struct{}) + + // Add/update records for registration entries in the update + for _, newEntry := range update.RegistrationEntries { + clearSelectorSet(selAdd) + clearSelectorSet(selRem) + clearStringSet(fedAdd) + clearStringSet(fedRem) + + record, existingEntry := c.updateOrCreateRecord(newEntry) + + // Calculate the difference in selectors, add/remove the record + // from impacted selector indices, and add the selector diff to the + // notify set. + c.diffSelectors(existingEntry, newEntry, selAdd, selRem) + selectorsChanged := len(selAdd) > 0 || len(selRem) > 0 + c.addSelectorIndicesRecord(selAdd, record) + c.delSelectorIndicesRecord(selRem, record) + + // Determine if there were changes to FederatesWith declarations or + // if any federated bundles related to the entry were updated. + c.diffFederatesWith(existingEntry, newEntry, fedAdd, fedRem) + federatedBundlesChanged := len(fedAdd) > 0 || len(fedRem) > 0 + if !federatedBundlesChanged { + for _, id := range newEntry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(id) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: id, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if bundleChanged[td] { + federatedBundlesChanged = true + break + } + } + } + + // If any selectors or federated bundles were changed, then make + // sure subscribers for the new and extisting entry selector sets + // are notified. + if selectorsChanged { + if existingEntry != nil { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(existingEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + } + + if federatedBundlesChanged || selectorsChanged { + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + notifySet.Merge(newEntry.Selectors...) + notifySets = append(notifySets, notifySet) + } + + // Identify stale/outdated entries + if existingEntry != nil && existingEntry.RevisionNumber != newEntry.RevisionNumber { + outdatedEntries[newEntry.EntryId] = struct{}{} + } + + // Log all the details of the update to the DEBUG log + if federatedBundlesChanged || selectorsChanged { + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: newEntry.EntryId, + telemetry.SPIFFEID: newEntry.SpiffeId, + }) + if len(selAdd) > 0 { + log = log.WithField(telemetry.SelectorsAdded, len(selAdd)) + } + if len(selRem) > 0 { + log = log.WithField(telemetry.SelectorsRemoved, len(selRem)) + } + if len(fedAdd) > 0 { + log = log.WithField(telemetry.FederatedAdded, len(fedAdd)) + } + if len(fedRem) > 0 { + log = log.WithField(telemetry.FederatedRemoved, len(fedRem)) + } + if existingEntry != nil { + log.Debug("Entry updated") + } else { + log.Debug("Entry created") + } + } + } + + // entries with active subscribers which are not cached will be put in staleEntries map; + // irrespective of what svid cache size as we cannot deny identity to a subscriber + activeSubsByEntryID, recordsWithLastAccessTime := c.syncSVIDsWithSubscribers() + extraSize := len(c.svids) - c.svidCacheMaxSize + + // delete svids without subscribers and which have not been accessed since svidCacheExpiryTime + if extraSize > 0 { + // sort recordsWithLastAccessTime + sortByTimestamps(recordsWithLastAccessTime) + + for _, record := range recordsWithLastAccessTime { + if extraSize <= 0 { + // no need to delete SVIDs any further as cache size <= svidCacheMaxSize + break + } + if _, ok := c.svids[record.id]; ok { + if _, exists := activeSubsByEntryID[record.id]; !exists { + // remove svid + c.log.WithField("record_id", record.id). + WithField("record_timestamp", record.timestamp). + Debug("Removing SVID record") + delete(c.svids, record.id) + extraSize-- + } + } + } + } + + // Update all stale svids or svids whose registration entry is outdated + for id, svid := range c.svids { + if _, ok := outdatedEntries[id]; ok || (checkSVID != nil && checkSVID(nil, c.records[id].entry, svid)) { + c.staleEntries[id] = true + } + } + c.log.WithField(telemetry.OutdatedSVIDs, len(outdatedEntries)). + Debug("Updating SVIDs with outdated attributes in cache") + + if bundleRemoved || len(bundleChanged) > 0 { + c.BundleCache.Update(c.bundles) + } + + if trustDomainBundleChanged { + c.notifyAll() + } else { + c.notifyBySelectorSet(notifySets...) + } +} + +func (c *LRUCache) UpdateSVIDs(update *UpdateSVIDs) { + c.mu.Lock() + defer c.mu.Unlock() + + // Allocate a set of selectors that + notifySet, selSetDone := allocSelectorSet() + defer selSetDone() + + // Add/update records for registration entries in the update + for entryID, svid := range update.X509SVIDs { + record, existingEntry := c.records[entryID] + if !existingEntry { + c.log.WithField(telemetry.RegistrationID, entryID).Error("Entry not found") + continue + } + + c.svids[entryID] = svid + notifySet.Merge(record.entry.Selectors...) + log := c.log.WithFields(logrus.Fields{ + telemetry.Entry: record.entry.EntryId, + telemetry.SPIFFEID: record.entry.SpiffeId, + }) + log.Debug("SVID updated") + + // Registration entry is updated, remove it from stale map + delete(c.staleEntries, entryID) + c.notifyBySelectorSet(notifySet) + clearSelectorSet(notifySet) + } +} + +// GetStaleEntries obtains a list of stale entries +func (c *LRUCache) GetStaleEntries() []*StaleEntry { + c.mu.Lock() + defer c.mu.Unlock() + + var staleEntries []*StaleEntry + for entryID := range c.staleEntries { + cachedEntry, ok := c.records[entryID] + if !ok { + c.log.WithField(telemetry.RegistrationID, entryID).Debug("Stale marker found for unknown entry. Please fill a bug") + delete(c.staleEntries, entryID) + continue + } + + var expiresAt time.Time + if cachedSvid, ok := c.svids[entryID]; ok { + expiresAt = cachedSvid.Chain[0].NotAfter + } + + staleEntries = append(staleEntries, &StaleEntry{ + Entry: cachedEntry.entry, + ExpiresAt: expiresAt, + }) + } + + return staleEntries +} + +// SyncSVIDsWithSubscribers will sync svid cache: +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) SyncSVIDsWithSubscribers() { + c.mu.Lock() + defer c.mu.Unlock() + + c.syncSVIDsWithSubscribers() +} + +// Notify subscribers of selector set only if all SVIDs for corresponding selector set are cached +// It returns whether all SVIDs are cached or not. +// This method should be retried with backoff to avoid lock contention. +func (c *LRUCache) Notify(selectors []*common.Selector) bool { + c.mu.RLock() + defer c.mu.RUnlock() + set, setFree := allocSelectorSet(selectors...) + defer setFree() + if !c.missingSVIDRecords(set) { + c.notifyBySelectorSet(set) + return true + } + return false +} + +func (c *LRUCache) SubscribeToWorkloadUpdates(ctx context.Context, selectors Selectors) (Subscriber, error) { + return c.subscribeToWorkloadUpdates(ctx, selectors, nil) +} + +func (c *LRUCache) subscribeToWorkloadUpdates(ctx context.Context, selectors Selectors, notifyCallbackFn func()) (Subscriber, error) { + subscriber := c.NewSubscriber(selectors) + bo := c.subscribeBackoffFn() + // block until all svids are cached and subscriber is notified + for { + // notifyCallbackFn is used for testing + if c.Notify(selectors) { + if notifyCallbackFn != nil { + notifyCallbackFn() + } + return subscriber, nil + } + c.log.WithField(telemetry.Selectors, selectors).Info("Waiting for SVID to get cached") + // used for testing + if notifyCallbackFn != nil { + notifyCallbackFn() + } + + select { + case <-ctx.Done(): + subscriber.Finish() + return nil, ctx.Err() + case <-c.clk.After(bo.NextBackOff()): + } + } +} + +func (c *LRUCache) missingSVIDRecords(set selectorSet) bool { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + for record := range records { + if _, exists := c.svids[record.entry.EntryId]; !exists { + return true + } + } + return false +} + +func (c *LRUCache) updateLastAccessTimestamp(selectors []*common.Selector) { + set, setFree := allocSelectorSet(selectors...) + defer setFree() + + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + now := c.clk.Now().UnixMilli() + for record := range records { + // Set lastAccessTimestamp so that svid LRU cache can be cleaned based on this timestamp + record.lastAccessTimestamp = now + } +} + +// entries with active subscribers which are not cached will be put in staleEntries map +// records which are not cached for remainder of max cache size will also be put in staleEntries map +func (c *LRUCache) syncSVIDsWithSubscribers() (map[string]struct{}, []recordAccessEvent) { + activeSubsByEntryID := make(map[string]struct{}) + lastAccessTimestamps := make([]recordAccessEvent, 0, len(c.records)) + + // iterate over all selectors from cached entries and obtain: + // 1. entries that have active subscribers + // 1.1 if those entries don't have corresponding SVID cached then put them in staleEntries + // so that SVID will be cached in next sync + // 2. get lastAccessTimestamp of each entry + for id, record := range c.records { + for _, sel := range record.entry.Selectors { + if index, ok := c.selectors[makeSelector(sel)]; ok && index != nil { + if len(index.subs) > 0 { + if _, ok := c.svids[record.entry.EntryId]; !ok { + c.staleEntries[id] = true + } + activeSubsByEntryID[id] = struct{}{} + break + } + } + } + lastAccessTimestamps = append(lastAccessTimestamps, newRecordAccessEvent(record.lastAccessTimestamp, id)) + } + + remainderSize := c.svidCacheMaxSize - len(c.svids) + // add records which are not cached for remainder of cache size + for id := range c.records { + if len(c.staleEntries) >= remainderSize { + break + } + if _, svidCached := c.svids[id]; !svidCached { + if _, ok := c.staleEntries[id]; !ok { + c.staleEntries[id] = true + } + } + } + + return activeSubsByEntryID, lastAccessTimestamps +} + +func (c *LRUCache) updateOrCreateRecord(newEntry *common.RegistrationEntry) (*lruCacheRecord, *common.RegistrationEntry) { + var existingEntry *common.RegistrationEntry + record, recordExists := c.records[newEntry.EntryId] + if !recordExists { + record = newLRUCacheRecord() + c.records[newEntry.EntryId] = record + } else { + existingEntry = record.entry + } + record.entry = newEntry + return record, existingEntry +} + +func (c *LRUCache) diffSelectors(existingEntry, newEntry *common.RegistrationEntry, added, removed selectorSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.Selectors...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, selector := range existingEntry.Selectors { + s := makeSelector(selector) + if _, ok := added[s]; ok { + // selector already exists in entry + delete(added, s) + } else { + // selector has been removed from entry + removed[s] = struct{}{} + } + } + } +} + +func (c *LRUCache) diffFederatesWith(existingEntry, newEntry *common.RegistrationEntry, added, removed stringSet) { + // Make a set of all the selectors being added + if newEntry != nil { + added.Merge(newEntry.FederatesWith...) + } + + // Make a set of all the selectors that are being removed + if existingEntry != nil { + for _, id := range existingEntry.FederatesWith { + if _, ok := added[id]; ok { + // Bundle already exists in entry + delete(added, id) + } else { + // Bundle has been removed from entry + removed[id] = struct{}{} + } + } + } +} + +func (c *LRUCache) addSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.addSelectorIndexRecord(selector, record) + } +} + +func (c *LRUCache) addSelectorIndexRecord(s selector, record *lruCacheRecord) { + index := c.getSelectorIndexForWrite(s) + index.records[record] = struct{}{} +} + +func (c *LRUCache) delSelectorIndicesRecord(selectors selectorSet, record *lruCacheRecord) { + for selector := range selectors { + c.delSelectorIndexRecord(selector, record) + } +} + +// delSelectorIndexRecord removes the record from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexRecord(s selector, record *lruCacheRecord) { + index, ok := c.selectors[s] + if ok { + delete(index.records, record) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) addSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index := c.getSelectorIndexForWrite(s) + index.subs[sub] = struct{}{} +} + +// delSelectorIndexSub removes the subscription from the selector index. If +// the selector index is empty afterwards, it is also removed. +func (c *LRUCache) delSelectorIndexSub(s selector, sub *lruCacheSubscriber) { + index, ok := c.selectors[s] + if ok { + delete(index.subs, sub) + if index.isEmpty() { + delete(c.selectors, s) + } + } +} + +func (c *LRUCache) unsubscribe(sub *lruCacheSubscriber) { + c.mu.Lock() + defer c.mu.Unlock() + for selector := range sub.set { + c.delSelectorIndexSub(selector, sub) + } +} + +func (c *LRUCache) notifyAll() { + subs, subsDone := c.allSubscribers() + defer subsDone() + for sub := range subs { + c.notify(sub) + } +} + +func (c *LRUCache) notifyBySelectorSet(sets ...selectorSet) { + notifiedSubs, notifiedSubsDone := allocLRUCacheSubscriberSet() + defer notifiedSubsDone() + for _, set := range sets { + subs, subsDone := c.getSubscribers(set) + defer subsDone() + for sub := range subs { + if _, notified := notifiedSubs[sub]; !notified && sub.set.SuperSetOf(set) { + c.notify(sub) + notifiedSubs[sub] = struct{}{} + } + } + } +} + +func (c *LRUCache) notify(sub *lruCacheSubscriber) { + update := c.buildWorkloadUpdate(sub.set) + sub.notify(update) +} + +func (c *LRUCache) allSubscribers() (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for _, index := range c.selectors { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + return subs, subsDone +} + +func (c *LRUCache) getSubscribers(set selectorSet) (lruCacheSubscriberSet, func()) { + subs, subsDone := allocLRUCacheSubscriberSet() + for s := range set { + if index := c.getSelectorIndexForRead(s); index != nil { + for sub := range index.subs { + subs[sub] = struct{}{} + } + } + } + return subs, subsDone +} + +func (c *LRUCache) matchingIdentities(set selectorSet) []Identity { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]Identity, 0, len(records)) + for record := range records { + if svid, ok := c.svids[record.entry.EntryId]; ok { + out = append(out, makeNewIdentity(record, svid)) + } + } + sortIdentities(out) + return out +} + +func (c *LRUCache) matchingEntries(set selectorSet) []*common.RegistrationEntry { + records, recordsDone := c.getRecordsForSelectors(set) + defer recordsDone() + + if len(records) == 0 { + return nil + } + + // Return identities in ascending "entry id" order to maintain a consistent + // ordering. + // TODO: figure out how to determine the "default" identity + out := make([]*common.RegistrationEntry, 0, len(records)) + for record := range records { + out = append(out, record.entry) + } + sortEntriesByID(out) + return out +} + +func (c *LRUCache) buildWorkloadUpdate(set selectorSet) *WorkloadUpdate { + w := &WorkloadUpdate{ + Bundle: c.bundles[c.trustDomain], + FederatedBundles: make(map[spiffeid.TrustDomain]*bundleutil.Bundle), + Identities: c.matchingIdentities(set), + } + + // Add in the bundles the workload is federated with. + for _, identity := range w.Identities { + for _, federatesWith := range identity.Entry.FederatesWith { + td, err := spiffeid.TrustDomainFromString(federatesWith) + if err != nil { + c.log.WithFields(logrus.Fields{ + telemetry.TrustDomainID: federatesWith, + logrus.ErrorKey: err, + }).Warn("Invalid federated trust domain") + continue + } + if federatedBundle := c.bundles[td]; federatedBundle != nil { + w.FederatedBundles[td] = federatedBundle + } else { + c.log.WithFields(logrus.Fields{ + telemetry.RegistrationID: identity.Entry.EntryId, + telemetry.SPIFFEID: identity.Entry.SpiffeId, + telemetry.FederatedBundle: federatesWith, + }).Warn("Federated bundle contents missing") + } + } + } + + return w +} + +func (c *LRUCache) getRecordsForSelectors(set selectorSet) (lruCacheRecordSet, func()) { + // Build and dedup a list of candidate entries. Don't check for selector set inclusion yet, since + // that is a more expensive operation and we could easily have duplicate + // entries to check. + records, recordsDone := allocLRUCacheRecordSet() + for selector := range set { + if index := c.getSelectorIndexForRead(selector); index != nil { + for record := range index.records { + records[record] = struct{}{} + } + } + } + + // Filter out records whose registration entry selectors are not within + // inside the selector set. + for record := range records { + for _, s := range record.entry.Selectors { + if !set.In(s) { + delete(records, record) + } + } + } + return records, recordsDone +} + +// getSelectorIndexForWrite gets the selector index for the selector. If one +// doesn't exist, it is created. Callers must hold the write lock. If the index +// is only being read, then getSelectorIndexForRead should be used instead. +func (c *LRUCache) getSelectorIndexForWrite(s selector) *selectorsMapIndex { + index, ok := c.selectors[s] + if !ok { + index = newSelectorsMapIndex() + c.selectors[s] = index + } + return index +} + +// getSelectorIndexForRead gets the selector index for the selector. If one +// doesn't exist, nil is returned. Callers should hold the read or write lock. +// If the index is being modified, callers should use getSelectorIndexForWrite +// instead. +func (c *LRUCache) getSelectorIndexForRead(s selector) *selectorsMapIndex { + if index, ok := c.selectors[s]; ok { + return index + } + return nil +} + +type lruCacheRecord struct { + entry *common.RegistrationEntry + subs map[*lruCacheSubscriber]struct{} + lastAccessTimestamp int64 +} + +func newLRUCacheRecord() *lruCacheRecord { + return &lruCacheRecord{ + subs: make(map[*lruCacheSubscriber]struct{}), + } +} + +type selectorsMapIndex struct { + // subs holds the subscriptions related to this selector + subs map[*lruCacheSubscriber]struct{} + + // records holds the cache records related to this selector + records map[*lruCacheRecord]struct{} +} + +func (x *selectorsMapIndex) isEmpty() bool { + return len(x.subs) == 0 && len(x.records) == 0 +} + +func newSelectorsMapIndex() *selectorsMapIndex { + return &selectorsMapIndex{ + subs: make(map[*lruCacheSubscriber]struct{}), + records: make(map[*lruCacheRecord]struct{}), + } +} + +func sortByTimestamps(records []recordAccessEvent) { + sort.Slice(records, func(a, b int) bool { + return records[a].timestamp < records[b].timestamp + }) +} + +func makeNewIdentity(record *lruCacheRecord, svid *X509SVID) Identity { + return Identity{ + Entry: record.entry, + SVID: svid.Chain, + PrivateKey: svid.PrivateKey, + } +} + +type recordAccessEvent struct { + timestamp int64 + id string +} + +func newRecordAccessEvent(timestamp int64, id string) recordAccessEvent { + return recordAccessEvent{timestamp: timestamp, id: id} +} diff --git a/pkg/agent/manager/cache/lru_cache_subscriber.go b/pkg/agent/manager/cache/lru_cache_subscriber.go new file mode 100644 index 0000000000..00556f89a9 --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_subscriber.go @@ -0,0 +1,60 @@ +package cache + +import ( + "sync" + + "github.com/spiffe/spire/proto/spire/common" +) + +type lruCacheSubscriber struct { + cache *LRUCache + set selectorSet + setFree func() + + mu sync.Mutex + c chan *WorkloadUpdate + done bool +} + +func newLRUCacheSubscriber(cache *LRUCache, selectors []*common.Selector) *lruCacheSubscriber { + set, setFree := allocSelectorSet(selectors...) + return &lruCacheSubscriber{ + cache: cache, + set: set, + setFree: setFree, + c: make(chan *WorkloadUpdate, 1), + } +} + +func (s *lruCacheSubscriber) Updates() <-chan *WorkloadUpdate { + return s.c +} + +func (s *lruCacheSubscriber) Finish() { + s.mu.Lock() + done := s.done + if !done { + s.done = true + close(s.c) + } + s.mu.Unlock() + if !done { + s.cache.unsubscribe(s) + s.setFree() + s.set = nil + } +} + +func (s *lruCacheSubscriber) notify(update *WorkloadUpdate) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + + select { + case <-s.c: + default: + } + s.c <- update +} diff --git a/pkg/agent/manager/cache/lru_cache_test.go b/pkg/agent/manager/cache/lru_cache_test.go new file mode 100644 index 0000000000..8fd5ea2bce --- /dev/null +++ b/pkg/agent/manager/cache/lru_cache_test.go @@ -0,0 +1,954 @@ +package cache + +import ( + "context" + "crypto/x509" + "fmt" + "runtime" + "testing" + "time" + + "github.com/andres-erbsen/clock" + "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/bundleutil" + "github.com/spiffe/spire/pkg/common/telemetry" + "github.com/spiffe/spire/proto/spire/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUCacheFetchWorkloadUpdate(t *testing.T) { + cache := newTestLRUCache() + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + bar.FederatesWith = makeFederatesWith(otherBundleV1) + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + workloadUpdate := cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Len(t, workloadUpdate.Identities, 0, "identities should not be returned that don't have SVIDs") + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + + workloadUpdate = cache.FetchWorkloadUpdate(makeSelectors("A", "B")) + assert.Equal(t, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{ + {Entry: bar}, + {Entry: foo}, + }, + }, workloadUpdate) +} + +func TestLRUCacheMatchingRegistrationIdentities(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Update SVIDs and MatchingRegistrationEntries should return both entries + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + } + cache.UpdateSVIDs(updateSVIDs) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) + + // Remove SVIDs and MatchingRegistrationEntries should still return both entries + cache.UpdateSVIDs(&UpdateSVIDs{}) + assert.Equal(t, []*common.RegistrationEntry{bar, foo}, + cache.MatchingRegistrationEntries(makeSelectors("A", "B"))) +} + +func TestLRUCacheCountSVIDs(t *testing.T) { + cache := newTestLRUCache() + + // populate the cache with FOO and BAR without SVIDS + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + } + cache.UpdateEntries(updateEntries, nil) + + // No SVIDs expected + require.Equal(t, 0, cache.CountSVIDs()) + + updateSVIDs := &UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + } + cache.UpdateSVIDs(updateSVIDs) + + // Only one SVID expected + require.Equal(t, 1, cache.CountSVIDs()) +} + +func TestLRUCacheBundleChanges(t *testing.T) { + cache := newTestLRUCache() + + bundleStream := cache.SubscribeToBundleChanges() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + }, nil) + if assert.True(t, bundleStream.HasNext(), "has new bundle value after adding bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1, otherBundleV1), bundleStream.Value()) + } + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + }, nil) + + if assert.True(t, bundleStream.HasNext(), "has new bundle value after removing bundle") { + bundleStream.Next() + assert.Equal(t, makeBundles(bundleV1), bundleStream.Value()) + } +} + +func TestLRUCacheAllSubscribersNotifiedOnBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // create some subscribers and assert they get the initial bundle + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV1}) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV1}) + + // update the bundle and assert all subscribers gets the updated bundle + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{Bundle: bundleV2}) + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{Bundle: bundleV2}) +} + +func TestLRUCacheSomeSubscribersNotifiedOnFederatedBundleChange(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with an entry FOO that has a valid SVID and + // selector "A" + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // subscribe to A and B and assert initial updates are received. + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + // add the federated bundle with no registration entries federating with + // it and make sure nobody is notified. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertNoWorkloadUpdate(t, subA) + assertNoWorkloadUpdate(t, subB) + + // update FOO to federate with otherdomain.test and make sure subA is + // notified but not subB. + foo = makeRegistrationEntry("FOO", "A") + foo.FederatesWith = makeFederatesWith(otherBundleV1) + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV1), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now change the federated bundle and make sure subA gets notified, but + // again, not subB. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + FederatedBundles: makeBundles(otherBundleV2), + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + // now drop the federation and make sure subA is again notified and no + // longer has the federated bundle. + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1, otherBundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscribersGetEntriesWithSelectorSubsets(t *testing.T) { + cache := newTestLRUCache() + + // create subscribers for each combination of selectors + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + subAB := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer subAB.Finish() + + // assert all subscribers get the initial update + initialUpdate := &WorkloadUpdate{Bundle: bundleV1} + assertWorkloadUpdateEqual(t, subA, initialUpdate) + assertWorkloadUpdateEqual(t, subB, initialUpdate) + assertWorkloadUpdateEqual(t, subAB, initialUpdate) + + // create entry FOO that will target any subscriber with containing (A) + foo := makeRegistrationEntry("FOO", "A") + + // create entry BAR that will target any subscriber with containing (A,C) + bar := makeRegistrationEntry("BAR", "A", "C") + + // update the cache with foo and bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + // subA selector set contains (A), but not (A, C), so it should only get FOO + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // subB selector set does not contain either (A) or (A,C) so it isn't even + // notified. + assertNoWorkloadUpdate(t, subB) + + // subAB selector set contains (A) but not (A, C), so it should get FOO + assertWorkloadUpdateEqual(t, subAB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberIsNotNotifiedIfNothingChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Second update is the same (other than X509SVIDs, which, when set, + // always constitute a "change" for the impacted registration entries. + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotifiedOnSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotificationsOnSelectorChanges(t *testing.T) { + cache := newTestLRUCache() + + // initialize the cache with a FOO entry with selector A and an SVID + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + // create subscribers for A and make sure the initial update has FOO + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + + // update FOO to have selectors (A,B) and make sure the subscriber loses + // FOO, since (A,B) is not a subset of the subscriber set (A). + foo = makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + }) + + // update FOO to drop B and make sure the subscriber regains FOO + foo = makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscriberNotifiedWhenEntryDropped(t *testing.T) { + cache := newTestLRUCache() + + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer subA.Finish() + assertAnyWorkloadUpdate(t, subA) + + // subB's job here is to just make sure we don't notify unrelated + // subscribers when dropping registration entries + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertAnyWorkloadUpdate(t, subB) + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + // make sure subA gets notified with FOO but not subB + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + assertNoWorkloadUpdate(t, subB) + + updateEntries.RegistrationEntries = nil + cache.UpdateEntries(updateEntries, nil) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + }) + assertNoWorkloadUpdate(t, subB) + + // Make sure trying to update SVIDs of removed entry does not notify + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertNoWorkloadUpdate(t, subB) +} + +func TestLRUCacheSubscriberOnlyGetsEntriesWithSVID(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + updateEntries := &UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + } + cache.UpdateEntries(updateEntries, nil) + + sub := cache.NewSubscriber(makeSelectors("A")) + defer sub.Finish() + assertNoWorkloadUpdate(t, sub) + + // update to include the SVID and now we should get the update + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) +} + +func TestLRUCacheSubscribersDoNotBlockNotifications(t *testing.T) { + cache := newTestLRUCache() + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, nil) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV3), + }, nil) + + assertWorkloadUpdateEqual(t, sub, &WorkloadUpdate{ + Bundle: bundleV3, + }) +} + +func TestLRUCacheCheckSVIDCallback(t *testing.T) { + cache := newTestLRUCache() + + // no calls because there are no registration entries + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + assert.Fail(t, "should not be called if there are no registration entries") + + return false + }) + + foo := makeRegistrationEntryWithTTL("FOO", 60) + + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + // should not get invoked + assert.Fail(t, "should not be called as no SVIDs are cached yet") + return false + }) + + // called once for FOO with new SVID + svids := makeX509SVIDs(foo) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + + // called once for FOO with existing SVID + callCount := 0 + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(foo), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + callCount++ + assert.Equal(t, "FOO", newEntry.EntryId) + if assert.NotNil(t, svid) { + assert.Exactly(t, svids["FOO"], svid) + } + + return true + }) + assert.Equal(t, 1, callCount) + assert.Equal(t, map[string]bool{foo.EntryId: true}, cache.staleEntries) +} + +func TestLRUCacheGetStaleEntries(t *testing.T) { + cache := newTestLRUCache() + + bar := makeRegistrationEntryWithTTL("BAR", 120, "B") + + // Create entry but don't mark it stale from checkSVID method; + // it will be marked stale cause it does not have SVID cached + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(bar), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return false + }) + + // Assert that the entry is returned as stale. The `ExpiresAt` field should be unset since there is no SVID. + expectedEntries := []*StaleEntry{{Entry: cache.records[bar.EntryId].entry}} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Update the SVID for the stale entry + svids := make(map[string]*X509SVID) + expiredAt := time.Now() + svids[bar.EntryId] = &X509SVID{ + Chain: []*x509.Certificate{{NotAfter: expiredAt}}, + } + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: svids, + }) + // Assert that updating the SVID removes stale marker from entry + assert.Empty(t, cache.GetStaleEntries()) + + // Update entry again and mark it as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + RegistrationEntries: makeRegistrationEntries(bar), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + + // Assert that the entry again returns as stale. This time the `ExpiresAt` field should be populated with the expiration of the SVID. + expectedEntries = []*StaleEntry{{ + Entry: cache.records[bar.EntryId].entry, + ExpiresAt: expiredAt, + }} + assert.Equal(t, expectedEntries, cache.GetStaleEntries()) + + // Remove registration entry and assert that it is no longer returned as stale + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV2), + }, func(existingEntry, newEntry *common.RegistrationEntry, svid *X509SVID) bool { + return true + }) + assert.Empty(t, cache.GetStaleEntries()) +} + +func TestLRUCacheSubscriberNotNotifiedOnDifferentSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSubscriberNotNotifiedOnOverlappingSVIDChanges(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A", "C") + bar := makeRegistrationEntry("FOO", "A", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + + sub := subscribeToWorkloadUpdates(t, cache, makeSelectors("A", "B")) + defer sub.Finish() + assertAnyWorkloadUpdate(t, sub) + + // Update SVID + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + + assertNoWorkloadUpdate(t, sub) +} + +func TestLRUCacheSVIDCacheExpiry(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + clk.Add(1 * time.Second) + foo := makeRegistrationEntry("FOO", "A") + // validate workload update for foo + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + subA := subscribeToWorkloadUpdates(t, cache, makeSelectors("A")) + assertWorkloadUpdateEqual(t, subA, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{{Entry: foo}}, + }) + subA.Finish() + + // move clk by 1 sec so that SVID access time will be different + clk.Add(1 * time.Second) + bar := makeRegistrationEntry("BAR", "B") + // validate workload update for bar + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(bar), + }) + + // not closing subscriber immediately + subB := subscribeToWorkloadUpdates(t, cache, makeSelectors("B")) + defer subB.Finish() + assertWorkloadUpdateEqual(t, subB, &WorkloadUpdate{ + Bundle: bundleV1, + Identities: []Identity{ + {Entry: bar}, + }, + }) + + // Move clk by 2 seconds + clk.Add(2 * time.Second) + // update total of 12 entries + updateEntries := createUpdateEntries(10, makeBundles(bundleV1)) + updateEntries.RegistrationEntries[foo.EntryId] = foo + updateEntries.RegistrationEntries[bar.EntryId] = bar + + cache.UpdateEntries(updateEntries, nil) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromMap(updateEntries.RegistrationEntries), + }) + + for id, entry := range updateEntries.RegistrationEntries { + // create and close subscribers for remaining entries so that svid cache is full + if id != foo.EntryId && id != bar.EntryId { + sub := cache.NewSubscriber(entry.Selectors) + sub.Finish() + } + } + assert.Equal(t, 12, cache.CountSVIDs()) + + cache.UpdateEntries(updateEntries, nil) + assert.Equal(t, 10, cache.CountSVIDs()) + + // foo SVID should be removed from cache as it does not have active subscriber + assert.False(t, cache.Notify(makeSelectors("A"))) + // bar SVID should be cached as it has active subscriber + assert.True(t, cache.Notify(makeSelectors("B"))) + + subA = cache.NewSubscriber(makeSelectors("A")) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + + // Make sure foo is marked as stale entry which does not have svid cached + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, foo, cache.GetStaleEntries()[0].Entry) + + assert.Equal(t, 10, cache.CountSVIDs()) +} + +func TestLRUCacheMaxSVIDCacheSize(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(10, clk) + + // create entries more than maxSvidCacheSize + updateEntries := createUpdateEntries(12, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + + require.Len(t, cache.GetStaleEntries(), 10) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + require.Len(t, cache.GetStaleEntries(), 0) + assert.Equal(t, 10, cache.CountSVIDs()) + + // Validate that active subscriber will still get SVID even if SVID count is at maxSvidCacheSize + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + + cache.UpdateEntries(updateEntries, nil) + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, 10, cache.CountSVIDs()) + + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.Equal(t, 11, cache.CountSVIDs()) + require.Len(t, cache.GetStaleEntries(), 0) +} + +func TestSyncSVIDsWithSubscribers(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(5, clk) + + updateEntries := createUpdateEntries(5, makeBundles(bundleV1)) + cache.UpdateEntries(updateEntries, nil) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDsFromStaleEntries(cache.GetStaleEntries()), + }) + assert.Equal(t, 5, cache.CountSVIDs()) + + // Update foo but its SVID is not yet cached + foo := makeRegistrationEntry("FOO", "A") + updateEntries.RegistrationEntries[foo.EntryId] = foo + + cache.UpdateEntries(updateEntries, nil) + + // Create a subscriber for foo + subA := cache.NewSubscriber(foo.Selectors) + defer subA.Finish() + require.Len(t, cache.GetStaleEntries(), 0) + + // After SyncSVIDsWithSubscribers foo should be marked as stale, requiring signing + cache.SyncSVIDsWithSubscribers() + require.Len(t, cache.GetStaleEntries(), 1) + assert.Equal(t, []*StaleEntry{{Entry: cache.records[foo.EntryId].entry}}, cache.GetStaleEntries()) + + assert.Equal(t, 5, cache.CountSVIDs()) +} + +func TestNotify(t *testing.T) { + cache := newTestLRUCache() + + foo := makeRegistrationEntry("FOO", "A") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo), + }, nil) + + assert.False(t, cache.Notify(makeSelectors("A"))) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo), + }) + assert.True(t, cache.Notify(makeSelectors("A"))) +} + +func TestSubscribeToLRUCacheChanges(t *testing.T) { + clk := clock.NewMock() + cache := newTestLRUCacheWithConfig(1, clk) + + foo := makeRegistrationEntry("FOO", "A") + bar := makeRegistrationEntry("BAR", "B") + cache.UpdateEntries(&UpdateEntries{ + Bundles: makeBundles(bundleV1), + RegistrationEntries: makeRegistrationEntries(foo, bar), + }, nil) + + sub1WaitCh := make(chan struct{}, 1) + sub1ErrCh := make(chan error, 1) + go func() { + sub1, err := cache.subscribeToWorkloadUpdates(context.Background(), foo.Selectors, func() { + sub1WaitCh <- struct{}{} + }) + if err != nil { + sub1ErrCh <- err + return + } + + defer sub1.Finish() + u1 := <-sub1.Updates() + if len(u1.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u1.Identities)) + return + } + sub1ErrCh <- nil + }() + + sub2WaitCh := make(chan struct{}, 1) + sub2ErrCh := make(chan error, 1) + go func() { + sub2, err := cache.subscribeToWorkloadUpdates(context.Background(), bar.Selectors, func() { + sub2WaitCh <- struct{}{} + }) + if err != nil { + sub2ErrCh <- err + return + } + + defer sub2.Finish() + u2 := <-sub2.Updates() + if len(u2.Identities) != 1 { + sub1ErrCh <- fmt.Errorf("expected 1 SVID, got: %d", len(u2.Identities)) + return + } + sub2ErrCh <- nil + }() + + <-sub1WaitCh + <-sub2WaitCh + cache.SyncSVIDsWithSubscribers() + + assert.Len(t, cache.GetStaleEntries(), 2) + cache.UpdateSVIDs(&UpdateSVIDs{ + X509SVIDs: makeX509SVIDs(foo, bar), + }) + assert.Equal(t, 2, cache.CountSVIDs()) + + clk.Add(SVIDSyncInterval * 2) + + sub1Err := <-sub1ErrCh + assert.NoError(t, sub1Err, "subscriber 1 error") + + sub2Err := <-sub2ErrCh + assert.NoError(t, sub2Err, "subscriber 2 error") +} + +func TestNewLRUCache(t *testing.T) { + // negative value + cache := newTestLRUCacheWithConfig(-5, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) + + // zero value + cache = newTestLRUCacheWithConfig(0, clock.NewMock()) + require.Equal(t, DefaultSVIDCacheMaxSize, cache.svidCacheMaxSize) +} + +func BenchmarkLRUCacheGlobalNotification(b *testing.B) { + cache := newTestLRUCache() + + const numEntries = 1000 + const numWorkloads = 1000 + const selectorsPerEntry = 3 + const selectorsPerWorkload = 10 + + // build a set of 1000 registration entries with distinct selectors + bundlesV1 := makeBundles(bundleV1) + bundlesV2 := makeBundles(bundleV2) + updateEntries := &UpdateEntries{ + Bundles: bundlesV1, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, selectorsPerEntry), + } + } + + cache.UpdateEntries(updateEntries, nil) + for i := 0; i < numWorkloads; i++ { + selectors := distinctSelectors(i, selectorsPerWorkload) + cache.NewSubscriber(selectors) + } + + runtime.GC() + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if i%2 == 0 { + updateEntries.Bundles = bundlesV2 + } else { + updateEntries.Bundles = bundlesV1 + } + cache.UpdateEntries(updateEntries, nil) + } +} + +func newTestLRUCache() *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, + telemetry.Blackhole{}, 0, clock.NewMock()) +} + +func newTestLRUCacheWithConfig(svidCacheMaxSize int, clk clock.Clock) *LRUCache { + log, _ := test.NewNullLogger() + return NewLRUCache(log, spiffeid.RequireTrustDomainFromString("domain.test"), bundleV1, telemetry.Blackhole{}, + svidCacheMaxSize, clk) +} + +// numEntries should not be more than 12 digits +func createUpdateEntries(numEntries int, bundles map[spiffeid.TrustDomain]*bundleutil.Bundle) *UpdateEntries { + updateEntries := &UpdateEntries{ + Bundles: bundles, + RegistrationEntries: make(map[string]*common.RegistrationEntry, numEntries), + } + + for i := 0; i < numEntries; i++ { + entryID := fmt.Sprintf("00000000-0000-0000-0000-%012d", i) + updateEntries.RegistrationEntries[entryID] = &common.RegistrationEntry{ + EntryId: entryID, + ParentId: "spiffe://domain.test/node", + SpiffeId: fmt.Sprintf("spiffe://domain.test/workload-%d", i), + Selectors: distinctSelectors(i, 1), + } + } + return updateEntries +} + +func makeX509SVIDsFromMap(entries map[string]*common.RegistrationEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.EntryId] = &X509SVID{} + } + return out +} + +func makeX509SVIDsFromStaleEntries(entries []*StaleEntry) map[string]*X509SVID { + out := make(map[string]*X509SVID) + for _, entry := range entries { + out[entry.Entry.EntryId] = &X509SVID{} + } + return out +} + +func subscribeToWorkloadUpdates(t *testing.T, cache *LRUCache, selectors []*common.Selector) Subscriber { + subscriber, err := cache.subscribeToWorkloadUpdates(context.Background(), selectors, nil) + assert.NoError(t, err) + return subscriber +} diff --git a/pkg/agent/manager/cache/sets.go b/pkg/agent/manager/cache/sets.go index 6c9e1701bb..98baf01682 100644 --- a/pkg/agent/manager/cache/sets.go +++ b/pkg/agent/manager/cache/sets.go @@ -30,14 +30,25 @@ var ( return make(recordSet) }, } + + lruCacheRecordSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheRecordSet) + }, + } + + lruCacheSubscriberSetPool = sync.Pool{ + New: func() interface{} { + return make(lruCacheSubscriberSet) + }, + } ) // unique set of strings, allocated from a pool type stringSet map[string]struct{} -func allocStringSet(ss ...string) (stringSet, func()) { +func allocStringSet() (stringSet, func()) { set := stringSetPool.Get().(stringSet) - set.Merge(ss...) return set, func() { clearStringSet(set) stringSetPool.Put(set) @@ -149,3 +160,37 @@ func clearRecordSet(set recordSet) { delete(set, k) } } + +// unique set of LRU cache records, allocated from a pool +type lruCacheRecordSet map[*lruCacheRecord]struct{} + +func allocLRUCacheRecordSet() (lruCacheRecordSet, func()) { + set := lruCacheRecordSetPool.Get().(lruCacheRecordSet) + return set, func() { + clearLRUCacheRecordSet(set) + lruCacheRecordSetPool.Put(set) + } +} + +func clearLRUCacheRecordSet(set lruCacheRecordSet) { + for k := range set { + delete(set, k) + } +} + +// unique set of LRU cache subscribers, allocated from a pool +type lruCacheSubscriberSet map[*lruCacheSubscriber]struct{} + +func allocLRUCacheSubscriberSet() (lruCacheSubscriberSet, func()) { + set := lruCacheSubscriberSetPool.Get().(lruCacheSubscriberSet) + return set, func() { + clearLRUCacheSubscriberSet(set) + lruCacheSubscriberSetPool.Put(set) + } +} + +func clearLRUCacheSubscriberSet(set lruCacheSubscriberSet) { + for k := range set { + delete(set, k) + } +} diff --git a/pkg/agent/manager/cache/util.go b/pkg/agent/manager/cache/util.go new file mode 100644 index 0000000000..ab365514fd --- /dev/null +++ b/pkg/agent/manager/cache/util.go @@ -0,0 +1,13 @@ +package cache + +import ( + "sort" + + "github.com/spiffe/spire/proto/spire/common" +) + +func sortEntriesByID(entries []*common.RegistrationEntry) { + sort.Slice(entries, func(a, b int) bool { + return entries[a].EntryId < entries[b].EntryId + }) +} diff --git a/pkg/agent/manager/config.go b/pkg/agent/manager/config.go index 34fe82fe5a..3a3fe11eee 100644 --- a/pkg/agent/manager/config.go +++ b/pkg/agent/manager/config.go @@ -9,7 +9,7 @@ import ( "github.com/sirupsen/logrus" "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/catalog" - "github.com/spiffe/spire/pkg/agent/manager/cache" + managerCache "github.com/spiffe/spire/pkg/agent/manager/cache" "github.com/spiffe/spire/pkg/agent/manager/storecache" "github.com/spiffe/spire/pkg/agent/plugin/keymanager" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" @@ -24,7 +24,7 @@ type Config struct { // Agent SVID and key resulting from successful attestation. SVID []*x509.Certificate SVIDKey keymanager.Key - Bundle *cache.Bundle + Bundle *managerCache.Bundle Reattestable bool Catalog catalog.Catalog TrustDomain spiffeid.TrustDomain @@ -36,6 +36,7 @@ type Config struct { SyncInterval time.Duration RotationInterval time.Duration SVIDStoreCache *storecache.Cache + SVIDCacheMaxSize int NodeAttestor nodeattestor.NodeAttestor // Clk is the clock the manager will use to get time @@ -60,7 +61,15 @@ func newManager(c *Config) *manager { c.Clk = clock.New() } - cache := cache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, c.Metrics) + var cache Cache + if c.SVIDCacheMaxSize > 0 { + // use LRU cache implementation + cache = managerCache.NewLRUCache(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics, c.SVIDCacheMaxSize, c.Clk) + } else { + cache = managerCache.New(c.Log.WithField(telemetry.SubsystemName, telemetry.CacheManager), c.TrustDomain, c.Bundle, + c.Metrics) + } rotCfg := &svid.RotatorConfig{ SVIDKeyManager: keymanager.ForSVID(c.Catalog.GetKeyManager()), diff --git a/pkg/agent/manager/manager.go b/pkg/agent/manager/manager.go index 83771c05b0..c82aba287d 100644 --- a/pkg/agent/manager/manager.go +++ b/pkg/agent/manager/manager.go @@ -35,7 +35,7 @@ type Manager interface { // SubscribeToCacheChanges returns a Subscriber on which cache entry updates are sent // for a particular set of selectors. - SubscribeToCacheChanges(key cache.Selectors) cache.Subscriber + SubscribeToCacheChanges(ctx context.Context, key cache.Selectors) (cache.Subscriber, error) // SubscribeToSVIDChanges returns a new observer.Stream on which svid.State instances are received // each time an SVID rotation finishes. @@ -55,9 +55,9 @@ type Manager interface { // SetRotationFinishedHook sets a hook that will be called when a rotation finished SetRotationFinishedHook(func()) - // MatchingIdentities returns all of the cached identities whose - // registration entry selectors are a subset of the passed selectors. - MatchingIdentities(selectors []*common.Selector) []cache.Identity + // MatchingRegistrationEntries returns all of the cached registration entries whose + // selectors are a subset of the passed selectors. + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry // FetchWorkloadUpdates gets the latest workload update for the selectors FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate @@ -76,20 +76,62 @@ type Manager interface { GetBundle() *cache.Bundle } +// Cache stores each registration entry, signed X509-SVIDs for those entries, +// bundles, and JWT SVIDs for the agent. +type Cache interface { + SVIDCache + + // Bundle gets latest cached bundle + Bundle() *bundleutil.Bundle + + // SyncSVIDsWithSubscribers syncs SVID cache + SyncSVIDsWithSubscribers() + + // SubscribeToWorkloadUpdates creates a subscriber for given selector set. + SubscribeToWorkloadUpdates(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) + + // SubscribeToBundleChanges creates a stream for providing bundle changes + SubscribeToBundleChanges() *cache.BundleStream + + // MatchingRegistrationEntries with given selectors + MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry + + // CountSVIDs in cache stored + CountSVIDs() int + + // FetchWorkloadUpdate for given selectors + FetchWorkloadUpdate(selectors []*common.Selector) *cache.WorkloadUpdate + + // GetJWTSVID provides JWT-SVID + GetJWTSVID(id spiffeid.ID, audience []string) (*client.JWTSVID, bool) + + // SetJWTSVID adds JWT-SVID to cache + SetJWTSVID(id spiffeid.ID, audience []string, svid *client.JWTSVID) + + // Entries get all registration entries + Entries() []*common.RegistrationEntry + + // Identities get all identities in cache + Identities() []cache.Identity +} + type manager struct { c *Config // Fields protected by mtx mutex. mtx *sync.RWMutex + // Protects multiple goroutines from requesting SVID signings at the same time + updateSVIDMu sync.RWMutex - cache *cache.Cache + cache Cache svid svid.Rotator storage storage.Storage - // backoff calculator for fetch interval, backing off if error is returned on + // synchronizeBackoff calculator for fetch interval, backing off if error is returned on // fetch attempt - backoff backoff.BackOff + synchronizeBackoff backoff.BackOff + svidSyncBackoff backoff.BackOff client client.Client @@ -106,7 +148,8 @@ func (m *manager) Initialize(ctx context.Context) error { m.storeSVID(m.svid.State().SVID, m.svid.State().Reattestable) m.storeBundle(m.cache.Bundle()) - m.backoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) + m.synchronizeBackoff = backoff.NewBackoff(m.clk, m.c.SyncInterval) + m.svidSyncBackoff = backoff.NewBackoff(m.clk, cache.SVIDSyncInterval) err := m.synchronize(ctx) if nodeutil.ShouldAgentReattest(err) { @@ -125,6 +168,7 @@ func (m *manager) Run(ctx context.Context) error { err := util.RunTasks(ctx, m.runSynchronizer, + m.runSyncSVIDs, m.runSVIDObserver, m.runBundleObserver, m.svid.Run) @@ -147,8 +191,8 @@ func (m *manager) Run(ctx context.Context) error { } } -func (m *manager) SubscribeToCacheChanges(selectors cache.Selectors) cache.Subscriber { - return m.cache.SubscribeToWorkloadUpdates(selectors) +func (m *manager) SubscribeToCacheChanges(ctx context.Context, selectors cache.Selectors) (cache.Subscriber, error) { + return m.cache.SubscribeToWorkloadUpdates(ctx, selectors) } func (m *manager) SubscribeToSVIDChanges() observer.Stream { @@ -171,8 +215,8 @@ func (m *manager) SetRotationFinishedHook(f func()) { m.svid.SetRotationFinishedHook(f) } -func (m *manager) MatchingIdentities(selectors []*common.Selector) []cache.Identity { - return m.cache.MatchingIdentities(selectors) +func (m *manager) MatchingRegistrationEntries(selectors []*common.Selector) []*common.RegistrationEntry { + return m.cache.MatchingRegistrationEntries(selectors) } func (m *manager) CountSVIDs() int { @@ -214,9 +258,9 @@ func (m *manager) FetchJWTSVID(ctx context.Context, spiffeID spiffeid.ID, audien } func (m *manager) getEntryID(spiffeID string) string { - for _, identity := range m.cache.Identities() { - if identity.Entry.SpiffeId == spiffeID { - return identity.Entry.EntryId + for _, entry := range m.cache.Entries() { + if entry.SpiffeId == spiffeID { + return entry.EntryId } } return "" @@ -225,7 +269,7 @@ func (m *manager) getEntryID(spiffeID string) string { func (m *manager) runSynchronizer(ctx context.Context) error { for { select { - case <-m.clk.After(m.backoff.NextBackOff()): + case <-m.clk.After(m.synchronizeBackoff.NextBackOff()): case <-ctx.Done(): return nil } @@ -242,7 +286,26 @@ func (m *manager) runSynchronizer(ctx context.Context) error { // Just log the error and wait for next synchronization m.c.Log.WithError(err).Error("Synchronize failed") default: - m.backoff.Reset() + m.synchronizeBackoff.Reset() + } + } +} + +func (m *manager) runSyncSVIDs(ctx context.Context) error { + for { + select { + case <-m.clk.After(m.svidSyncBackoff.NextBackOff()): + case <-ctx.Done(): + return nil + } + + err := m.syncSVIDs(ctx) + switch { + case err != nil: + // Just log the error and wait for next synchronization + m.c.Log.WithError(err).Error("SVID sync failed") + default: + m.svidSyncBackoff.Reset() } } } diff --git a/pkg/agent/manager/manager_test.go b/pkg/agent/manager/manager_test.go index 58c463b897..73eef1edce 100644 --- a/pkg/agent/manager/manager_test.go +++ b/pkg/agent/manager/manager_test.go @@ -230,9 +230,9 @@ func TestHappyPathWithoutSyncNorRotation(t *testing.T) { t.Fatal("PrivateKey is not equals to configured one") } - matches := m.MatchingIdentities(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + matches := m.MatchingRegistrationEntries(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) if len(matches) != 2 { - t.Fatal("expected 2 identities") + t.Fatal("expected 2 registration entries") } // Verify bundle @@ -246,10 +246,11 @@ func TestHappyPathWithoutSyncNorRotation(t *testing.T) { compareRegistrationEntries(t, regEntriesMap["resp2"], - []*common.RegistrationEntry{matches[0].Entry, matches[1].Entry}) + []*common.RegistrationEntry{matches[0], matches[1]}) util.RunWithTimeout(t, 5*time.Second, func() { - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) u := <-sub.Updates() if len(u.Identities) != 2 { @@ -320,9 +321,9 @@ func TestRotationWithRSAKey(t *testing.T) { t.Fatal("PrivateKey is not equals to configured one") } - matches := m.MatchingIdentities(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + matches := m.MatchingRegistrationEntries(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) if len(matches) != 2 { - t.Fatal("expected 2 identities") + t.Fatal("expected 2 registration entries") } // Verify bundle @@ -336,10 +337,11 @@ func TestRotationWithRSAKey(t *testing.T) { compareRegistrationEntries(t, regEntriesMap["resp2"], - []*common.RegistrationEntry{matches[0].Entry, matches[1].Entry}) + []*common.RegistrationEntry{matches[0], matches[1]}) util.RunWithTimeout(t, 5*time.Second, func() { - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) u := <-sub.Updates() if len(u.Identities) != 2 { @@ -516,10 +518,11 @@ func TestSynchronization(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{ + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ {Type: "unix", Value: "uid:1111"}, {Type: "spiffe_id", Value: joinTokenID.String()}, }) + require.NoError(t, err) defer sub.Finish() if err := m.Initialize(context.Background()); err != nil { @@ -675,7 +678,7 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { // entries. compareRegistrationEntries(t, append(regEntriesMap["resp1"], regEntriesMap["resp2"]...), - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) // manually synchronize again if err := m.synchronize(context.Background()); err != nil { @@ -685,7 +688,7 @@ func TestSynchronizationClearsStaleCacheEntries(t *testing.T) { // now the cache should have entries from resp2 removed compareRegistrationEntries(t, regEntriesMap["resp1"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) } func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { @@ -747,7 +750,7 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { // after initialization, the cache should contain resp2 entries compareRegistrationEntries(t, regEntriesMap["resp2"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) // manually synchronize again if err := m.synchronize(context.Background()); err != nil { @@ -757,7 +760,7 @@ func TestSynchronizationUpdatesRegistrationEntries(t *testing.T) { // now the cache should have the updated entries from resp3 compareRegistrationEntries(t, regEntriesMap["resp3"], - regEntriesFromIdentities(m.cache.Identities())) + m.cache.Entries()) } func TestSubscribersGetUpToDateBundle(t *testing.T) { @@ -801,9 +804,9 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) - defer initializeAndRunManager(t, m)() + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) util.RunWithTimeout(t, 1*time.Second, func() { // Update should contain a new bundle. @@ -817,6 +820,248 @@ func TestSubscribersGetUpToDateBundle(t *testing.T) { }) } +func TestSynchronizationWithLRUCache(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + ttl := 3 + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(*mockAPI, int32, *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + return makeGetAuthorizedEntriesResponse(t, "resp1", "resp2"), nil + }, + batchNewX509SVIDEntries: func(*mockAPI, int32) []*common.RegistrationEntry { + return makeBatchNewX509SVIDEntries("resp1", "resp2") + }, + svidTTL: ttl, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + RotationInterval: time.Hour, + SyncInterval: time.Hour, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 10, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } + + m := newManager(c) + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + require.Equal(t, clk.Now(), m.GetLastSync()) + + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + {Type: "spiffe_id", Value: joinTokenID.String()}, + }) + require.NoError(t, err) + defer sub.Finish() + + // Before synchronization + identitiesBefore := identitiesByEntryID(m.cache.Identities()) + if len(identitiesBefore) != 3 { + t.Fatalf("3 cached identities were expected; got %d", len(identitiesBefore)) + } + + // This is the initial update based on the selector set + u := <-sub.Updates() + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + eb, ok := identitiesBefore[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) + } + require.Equal(t, eb, eu, "identity received does not match identity on cache") + } + + require.Equal(t, clk.Now(), m.GetLastSync()) + + // SVIDs expire after 3 seconds, so we shouldn't expect any updates after + // 1 second has elapsed. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case <-sub.Updates(): + t.Fatal("update unexpected after 1 second") + default: + } + + // After advancing another second, the SVIDs should have been refreshed, + // since the half-time has been exceeded. + clk.Add(time.Second) + require.NoError(t, m.synchronize(context.Background())) + select { + case u = <-sub.Updates(): + default: + t.Fatal("update expected after 2 seconds") + } + + // Make sure the update contains the updated entries and that the cache + // has a consistent view. + identitiesAfter := identitiesByEntryID(m.cache.Identities()) + if len(identitiesAfter) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(identitiesAfter)) + } + + for key, eb := range identitiesBefore { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("expected identity with EntryId=%v after synchronization", key) + } + require.NotEqual(t, eb, ea, "there is at least one identity that was not refreshed: %v", ea) + } + + if len(u.Identities) != 3 { + t.Fatalf("expected 3 identities, got: %d", len(u.Identities)) + } + + if len(u.Bundle.RootCAs()) != 1 { + t.Fatal("expected 1 bundle root CA") + } + + if !u.Bundle.EqualTo(api.bundle) { + t.Fatal("received bundle should be equals to the server bundle") + } + + for key, eu := range identitiesByEntryID(u.Identities) { + ea, ok := identitiesAfter[key] + if !ok { + t.Fatalf("an update was received for an inexistent entry on the cache with EntryId=%v", key) + } + require.Equal(t, eu, ea, "entry received does not match entry on cache") + } + + require.Equal(t, clk.Now(), m.GetLastSync()) +} + +func TestSyncSVIDsWithLRUCache(t *testing.T) { + dir := spiretest.TempDir(t) + km := fakeagentkeymanager.New(t, dir) + + clk := clock.NewMock(t) + api := newMockAPI(t, &mockAPIConfig{ + km: km, + getAuthorizedEntries: func(h *mockAPI, count int32, _ *entryv1.GetAuthorizedEntriesRequest) (*entryv1.GetAuthorizedEntriesResponse, error) { + switch count { + case 1: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + case 2: + return makeGetAuthorizedEntriesResponse(t, "resp2"), nil + default: + return nil, fmt.Errorf("unexpected getAuthorizedEntries call count: %d", count) + } + }, + batchNewX509SVIDEntries: func(h *mockAPI, count int32) []*common.RegistrationEntry { + switch count { + case 1: + return makeBatchNewX509SVIDEntries("resp2") + case 2: + return makeBatchNewX509SVIDEntries("resp2") + default: + return nil + } + }, + svidTTL: 3, + clk: clk, + }) + + baseSVID, baseSVIDKey := api.newSVID(joinTokenID, 1*time.Hour) + cat := fakeagentcatalog.New() + cat.SetKeyManager(km) + + c := &Config{ + ServerAddr: api.addr, + SVID: baseSVID, + SVIDKey: baseSVIDKey, + Log: testLogger, + TrustDomain: trustDomain, + Storage: openStorage(t, dir), + Bundle: api.bundle, + Metrics: &telemetry.Blackhole{}, + Clk: clk, + Catalog: cat, + WorkloadKeyType: workloadkey.ECP256, + SVIDCacheMaxSize: 1, + SVIDStoreCache: storecache.New(&storecache.Config{TrustDomain: trustDomain, Log: testLogger}), + } + + m := newManager(c) + + if err := m.Initialize(context.Background()); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + subErrCh := make(chan error, 1) + go func(ctx context.Context) { + sub, err := m.SubscribeToCacheChanges(ctx, cache.Selectors{ + {Type: "unix", Value: "uid:1111"}, + }) + if err != nil { + subErrCh <- err + return + } + defer sub.Finish() + subErrCh <- nil + }(ctx) + + syncErrCh := make(chan error, 1) + // run svid sync + go func(ctx context.Context) { + syncErrCh <- m.runSyncSVIDs(ctx) + }(ctx) + + // keep clk moving so that subscriber keeps looking for svid + go func(ctx context.Context) { + for { + clk.Add(cache.SVIDSyncInterval) + if ctx.Err() != nil { + return + } + } + }(ctx) + + subErr := <-subErrCh + assert.NoError(t, subErr, "subscriber error") + + // ensure 2 SVIDs corresponding to selectors are cached. + assert.Equal(t, 2, m.cache.CountSVIDs()) + + // cancel the ctx to stop go routines + cancel() + + syncErr := <-syncErrCh + assert.NoError(t, syncErr, "svid sync error") +} + func TestSurvivesCARotation(t *testing.T) { dir := spiretest.TempDir(t) km := fakeagentkeymanager.New(t, dir) @@ -863,7 +1108,8 @@ func TestSurvivesCARotation(t *testing.T) { m := newManager(c) - sub := m.SubscribeToCacheChanges(cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + sub, err := m.SubscribeToCacheChanges(context.Background(), cache.Selectors{{Type: "unix", Value: "uid:1111"}}) + require.NoError(t, err) // This should be the update received when Subscribe function was called. updates := sub.Updates() initialUpdate := <-updates @@ -1127,13 +1373,6 @@ func identitiesByEntryID(ces []cache.Identity) (result map[string]cache.Identity return result } -func regEntriesFromIdentities(ces []cache.Identity) (result []*common.RegistrationEntry) { - for _, ce := range ces { - result = append(result, ce.Entry) - } - return result -} - func compareRegistrationEntries(t *testing.T, expected, actual []*common.RegistrationEntry) { if len(expected) != len(actual) { t.Fatalf("entries count doesn't match, expected: %d, got: %d", len(expected), len(actual)) diff --git a/pkg/agent/manager/sync.go b/pkg/agent/manager/sync.go index 0512f99b73..25b1dfb402 100644 --- a/pkg/agent/manager/sync.go +++ b/pkg/agent/manager/sync.go @@ -25,7 +25,7 @@ type csrRequest struct { CurrentSVIDExpiresAt time.Time } -type Cache interface { +type SVIDCache interface { // UpdateEntries updates entries on cache UpdateEntries(update *cache.UpdateEntries, checkSVID func(*common.RegistrationEntry, *common.RegistrationEntry, *cache.X509SVID) bool) @@ -36,6 +36,15 @@ type Cache interface { GetStaleEntries() []*cache.StaleEntry } +func (m *manager) syncSVIDs(ctx context.Context) (err error) { + // perform syncSVIDs only if using LRU cache + if m.c.SVIDCacheMaxSize > 0 { + m.cache.SyncSVIDsWithSubscribers() + return m.updateSVIDs(ctx, m.c.Log.WithField(telemetry.CacheType, "workload"), m.cache) + } + return nil +} + // synchronize fetches the authorized entries from the server, updates the // cache, and fetches missing/expiring SVIDs. func (m *manager) synchronize(ctx context.Context) (err error) { @@ -57,12 +66,11 @@ func (m *manager) synchronize(ctx context.Context) (err error) { return nil } -func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c Cache) error { +func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log logrus.FieldLogger, cacheType string, c SVIDCache) error { // update the cache and build a list of CSRs that need to be processed // in this interval. // // the values in `update` now belong to the cache. DO NOT MODIFY. - var csrs []csrRequest var expiring int var outdated int c.UpdateEntries(update, func(existingEntry, newEntry *common.RegistrationEntry, svid *cache.X509SVID) bool { @@ -98,22 +106,31 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, log.WithField(telemetry.OutdatedSVIDs, outdated).Debug("Updating SVIDs with outdated attributes in cache") } + return m.updateSVIDs(ctx, log, c) +} + +func (m *manager) updateSVIDs(ctx context.Context, log logrus.FieldLogger, c SVIDCache) error { + m.updateSVIDMu.Lock() + defer m.updateSVIDMu.Unlock() + staleEntries := c.GetStaleEntries() if len(staleEntries) > 0 { + var csrs []csrRequest log.WithFields(logrus.Fields{ telemetry.Count: len(staleEntries), telemetry.Limit: limits.SignLimitPerIP, }).Debug("Renewing stale entries") - for _, staleEntry := range staleEntries { + + for _, entry := range staleEntries { // we've exceeded the CSR limit, don't make any more CSRs if len(csrs) >= limits.SignLimitPerIP { break } csrs = append(csrs, csrRequest{ - EntryID: staleEntry.Entry.EntryId, - SpiffeID: staleEntry.Entry.SpiffeId, - CurrentSVIDExpiresAt: staleEntry.ExpiresAt, + EntryID: entry.Entry.EntryId, + SpiffeID: entry.Entry.SpiffeId, + CurrentSVIDExpiresAt: entry.ExpiresAt, }) } @@ -124,7 +141,6 @@ func (m *manager) updateCache(ctx context.Context, update *cache.UpdateEntries, // the values in `update` now belong to the cache. DO NOT MODIFY. c.UpdateSVIDs(update) } - return nil } diff --git a/test/integration/common b/test/integration/common index 89c7bb2f4c..54e7a7651c 100644 --- a/test/integration/common +++ b/test/integration/common @@ -83,6 +83,26 @@ check-synced-entry() { fail-now "timed out waiting for agent to sync down entry" } +check-x509-svid-count() { + MAXCHECKS=50 + CHECKINTERVAL=1 + + for ((i=1;i<=MAXCHECKS;i++)); do + log-info "check X.509-SVID count on agent debug endpoint ($(($i)) of $MAXCHECKS max)..." + COUNT=$(docker-compose exec -T $1 /opt/spire/conf/agent/debugclient -testCase "printDebugPage" | jq '.svidsCount') + log-info "X.509-SVID Count: ${COUNT}" + if [ "$COUNT" -eq "$2" ]; then + log-info "X.509-SVID count of $COUNT from cache matches the expected count of $2" + break + fi + sleep "${CHECKINTERVAL}" + done + + if (( $i>$MAXCHECKS )); then + fail-now "X.509-SVID count validation failed" + fi +} + build-mashup-image() { ENVOY_VERSION=$1 ENVOY_IMAGE_TAG="${ENVOY_VERSION}-latest" diff --git a/test/integration/setup/debugagent/main.go b/test/integration/setup/debugagent/main.go index bff310d966..aa420040e0 100644 --- a/test/integration/setup/debugagent/main.go +++ b/test/integration/setup/debugagent/main.go @@ -38,6 +38,8 @@ func run() error { var err error switch *testCaseFlag { + case "printDebugPage": + err = printDebugPage(ctx) case "agentEndpoints": err = agentEndpoints(ctx) case "serverWithWorkload": @@ -52,26 +54,43 @@ func run() error { } func agentEndpoints(ctx context.Context) error { + s, err := retrieveDebugPage(ctx) + if err != nil { + return err + } + log.Printf("Debug info: %s", s) + return nil +} + +// printDebugPage allows integration tests to easily parse debug page with jq +func printDebugPage(ctx context.Context) error { + s, err := retrieveDebugPage(ctx) + if err != nil { + return err + } + fmt.Println(s) + return nil +} + +func retrieveDebugPage(ctx context.Context) (string, error) { conn, err := grpc.Dial(*socketPathFlag, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return fmt.Errorf("failed to connect server: %w", err) + return "", fmt.Errorf("failed to connect server: %w", err) } defer conn.Close() client := agent_debugv1.NewDebugClient(conn) resp, err := client.GetInfo(ctx, &agent_debugv1.GetInfoRequest{}) if err != nil { - return fmt.Errorf("failed to get info: %w", err) + return "", fmt.Errorf("failed to get info: %w", err) } m := protojson.MarshalOptions{Indent: " "} s, err := m.Marshal(resp) if err != nil { - return fmt.Errorf("failed to parse proto: %w", err) + return "", fmt.Errorf("failed to parse proto: %w", err) } - - log.Printf("Debug info: %s", string(s)) - return nil + return string(s), nil } func serverWithWorkload(ctx context.Context) error { diff --git a/test/integration/suites/fetch-x509-svids/00-setup b/test/integration/suites/fetch-x509-svids/00-setup new file mode 100755 index 0000000000..c1fb18218e --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/00-setup @@ -0,0 +1,6 @@ +#!/bin/bash + +"${ROOTDIR}/setup/x509pop/setup.sh" conf/server conf/agent + +"${ROOTDIR}/setup/debugserver/build.sh" "${RUNDIR}/conf/server/debugclient" +"${ROOTDIR}/setup/debugagent/build.sh" "${RUNDIR}/conf/agent/debugclient" diff --git a/test/integration/suites/fetch-x509-svids/01-start-server b/test/integration/suites/fetch-x509-svids/01-start-server new file mode 100755 index 0000000000..a3e999b264 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/01-start-server @@ -0,0 +1,3 @@ +#!/bin/bash + +docker-up spire-server diff --git a/test/integration/suites/fetch-x509-svids/02-bootstrap-agent b/test/integration/suites/fetch-x509-svids/02-bootstrap-agent new file mode 100755 index 0000000000..405147f2fd --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/02-bootstrap-agent @@ -0,0 +1,5 @@ +#!/bin/bash + +log-debug "bootstrapping agent..." +docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server bundle show > conf/agent/bootstrap.crt diff --git a/test/integration/suites/fetch-x509-svids/03-start-agent b/test/integration/suites/fetch-x509-svids/03-start-agent new file mode 100755 index 0000000000..ac36d05f0d --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/03-start-agent @@ -0,0 +1,3 @@ +#!/bin/bash + +docker-up spire-agent diff --git a/test/integration/suites/fetch-x509-svids/04-create-registration-entries b/test/integration/suites/fetch-x509-svids/04-create-registration-entries new file mode 100755 index 0000000000..1866777122 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/04-create-registration-entries @@ -0,0 +1,18 @@ +#!/bin/bash + +SIZE=10 + +# Create entries for uid 1001 +for ((m=1;m<=$SIZE;m++)); do + log-debug "creating registration entry: $m" + docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server entry create \ + -parentID "spiffe://domain.test/spire/agent/x509pop/$(fingerprint conf/agent/agent.crt.pem)" \ + -spiffeID "spiffe://domain.test/workload-$m" \ + -selector "unix:uid:1001" \ + -ttl 0 & +done + +for ((m=1;m<=$SIZE;m++)); do + check-synced-entry "spire-agent" "spiffe://domain.test/workload-$m" +done diff --git a/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids new file mode 100755 index 0000000000..4bb53c55df --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/05-fetch-x509-svids @@ -0,0 +1,17 @@ +#!/bin/bash + +ENTRYCOUNT=10 +CACHESIZE=8 + +X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") + +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; +else + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; +fi + +# Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up +check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-x509-svids/06-create-registration-entries b/test/integration/suites/fetch-x509-svids/06-create-registration-entries new file mode 100755 index 0000000000..f93ae19418 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/06-create-registration-entries @@ -0,0 +1,18 @@ +#!/bin/bash + +SIZE=10 + +# Create entries for uid 1002 +for ((m=1;m<=$SIZE;m++)); do + log-debug "creating registration entry...($m)" + docker-compose exec -T spire-server \ + /opt/spire/bin/spire-server entry create \ + -parentID "spiffe://domain.test/spire/agent/x509pop/$(fingerprint conf/agent/agent.crt.pem)" \ + -spiffeID "spiffe://domain.test/workload/$m" \ + -selector "unix:uid:1002" \ + -ttl 0 & +done + +for ((m=1;m<=$SIZE;m++)); do + check-synced-entry "spire-agent" "spiffe://domain.test/workload/$m" +done diff --git a/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids new file mode 100755 index 0000000000..9a46e29602 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/07-fetch-x509-svids @@ -0,0 +1,27 @@ +#!/bin/bash + +CACHESIZE=8 +ENTRYCOUNT=10 + +X509SVIDCOUNT=$(docker-compose exec -u 1002 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") + +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1002"; +else + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1002"; +fi + +X509SVIDCOUNT=$(docker-compose exec -u 1001 -T spire-agent \ + /opt/spire/bin/spire-agent api fetch x509 \ + -socketPath /opt/spire/sockets/workload_api.sock | grep -i "spiffe://domain.test" | wc -l || fail-now "X.509-SVID check failed") + +if [ "$X509SVIDCOUNT" -ne "$ENTRYCOUNT" ]; then + fail-now "X.509-SVID check failed. Expected $ENTRYCOUNT X.509-SVIDs but received $X509SVIDCOUNT for uid 1001"; +else + log-info "Expected $ENTRYCOUNT X.509-SVIDs and received $X509SVIDCOUNT for uid 1001"; +fi + +# Call agent debug endpoints and check if extra X.509-SVIDs from cache are cleaned up +check-x509-svid-count "spire-agent" $CACHESIZE diff --git a/test/integration/suites/fetch-x509-svids/README.md b/test/integration/suites/fetch-x509-svids/README.md new file mode 100644 index 0000000000..896ed8deeb --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/README.md @@ -0,0 +1,5 @@ +# Fetch x509-SVID Suite + +## Description + +This suite validates X.509-SVID cache operations. diff --git a/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf b/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf new file mode 100644 index 0000000000..bdbc803a95 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/conf/agent/agent.conf @@ -0,0 +1,31 @@ +agent { + data_dir = "/opt/spire/data/agent" + log_level = "DEBUG" + server_address = "spire-server" + server_port = "8081" + socket_path = "/opt/spire/sockets/workload_api.sock" + trust_bundle_path = "/opt/spire/conf/agent/bootstrap.crt" + trust_domain = "domain.test" + admin_socket_path = "/opt/debug.sock" + experimental { + x509_svid_cache_max_size = 8 + } +} + +plugins { + NodeAttestor "x509pop" { + plugin_data { + private_key_path = "/opt/spire/conf/agent/agent.key.pem" + certificate_path = "/opt/spire/conf/agent/agent.crt.pem" + } + } + KeyManager "disk" { + plugin_data { + directory = "/opt/spire/data/agent" + } + } + WorkloadAttestor "unix" { + plugin_data { + } + } +} diff --git a/test/integration/suites/fetch-x509-svids/conf/server/server.conf b/test/integration/suites/fetch-x509-svids/conf/server/server.conf new file mode 100644 index 0000000000..a8f18c0680 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/conf/server/server.conf @@ -0,0 +1,26 @@ +server { + bind_address = "0.0.0.0" + bind_port = "8081" + trust_domain = "domain.test" + data_dir = "/opt/spire/data/server" + log_level = "DEBUG" + ca_ttl = "1h" + default_svid_ttl = "10m" +} + +plugins { + DataStore "sql" { + plugin_data { + database_type = "sqlite3" + connection_string = "/opt/spire/data/server/datastore.sqlite3" + } + } + NodeAttestor "x509pop" { + plugin_data { + ca_bundle_path = "/opt/spire/conf/server/agent-cacert.pem" + } + } + KeyManager "memory" { + plugin_data = {} + } +} diff --git a/test/integration/suites/fetch-x509-svids/docker-compose.yaml b/test/integration/suites/fetch-x509-svids/docker-compose.yaml new file mode 100644 index 0000000000..0e67183c23 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/docker-compose.yaml @@ -0,0 +1,15 @@ +version: '3' +services: + spire-server: + image: spire-server:latest-local + hostname: spire-server + volumes: + - ./conf/server:/opt/spire/conf/server + command: ["-config", "/opt/spire/conf/server/server.conf"] + spire-agent: + image: spire-agent:latest-local + hostname: spire-agent + depends_on: ["spire-server"] + volumes: + - ./conf/agent:/opt/spire/conf/agent + command: ["-config", "/opt/spire/conf/agent/agent.conf"] diff --git a/test/integration/suites/fetch-x509-svids/teardown b/test/integration/suites/fetch-x509-svids/teardown new file mode 100755 index 0000000000..9953dcd3f9 --- /dev/null +++ b/test/integration/suites/fetch-x509-svids/teardown @@ -0,0 +1,6 @@ +#!/bin/bash + +if [ -z "$SUCCESS" ]; then + docker-compose logs +fi +docker-down