diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index e90d0d320bdb4..8359fc7c532cb 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -363,9 +363,10 @@ func (m *monitoredDatabases) setCloud(databases types.Databases) { m.cloud = databases } -func (m *monitoredDatabases) isCloud(database types.Database) bool { - m.mu.RLock() - defer m.mu.RUnlock() +// isCloud_Locked returns whether a database was discovered by the cloud +// watchers, aka legacy database discovery done by the db service. +// The lock must be held when calling this function. +func (m *monitoredDatabases) isCloud_Locked(database types.Database) bool { for i := range m.cloud { if m.cloud[i] == database { return true @@ -374,13 +375,17 @@ func (m *monitoredDatabases) isCloud(database types.Database) bool { return false } -func (m *monitoredDatabases) isDiscoveryResource(database types.Database) bool { - return database.Origin() == types.OriginCloud && m.isResource(database) +// isDiscoveryResource_Locked returns whether a database was discovered by the +// discovery service. +// The lock must be held when calling this function. +func (m *monitoredDatabases) isDiscoveryResource_Locked(database types.Database) bool { + return database.Origin() == types.OriginCloud && m.isResource_Locked(database) } -func (m *monitoredDatabases) isResource(database types.Database) bool { - m.mu.RLock() - defer m.mu.RUnlock() +// isResource_Locked returns whether a database is a dynamic database, aka a db +// object. +// The lock must be held when calling this function. +func (m *monitoredDatabases) isResource_Locked(database types.Database) bool { for i := range m.resources { if m.resources[i] == database { return true @@ -389,9 +394,9 @@ func (m *monitoredDatabases) isResource(database types.Database) bool { return false } -func (m *monitoredDatabases) get() map[string]types.Database { - m.mu.RLock() - defer m.mu.RUnlock() +// getLocked returns a slice containing all of the monitored databases. +// The lock must be held when calling this function. +func (m *monitoredDatabases) getLocked() map[string]types.Database { return utils.FromSlice(append(append(m.static, m.resources...), m.cloud...), types.Database.GetName) } diff --git a/lib/srv/db/watcher.go b/lib/srv/db/watcher.go index 575fd2413ffcf..9f61a6a8fae4c 100644 --- a/lib/srv/db/watcher.go +++ b/lib/srv/db/watcher.go @@ -39,7 +39,7 @@ func (s *Server) startReconciler(ctx context.Context) error { reconciler, err := services.NewReconciler(services.ReconcilerConfig[types.Database]{ Matcher: s.matcher, GetCurrentResources: s.getResources, - GetNewResources: s.monitoredDatabases.get, + GetNewResources: s.monitoredDatabases.getLocked, OnCreate: s.onCreate, OnUpdate: s.onUpdate, OnDelete: s.onDelete, @@ -52,12 +52,15 @@ func (s *Server) startReconciler(ctx context.Context) error { for { select { case <-s.reconcileCh: + // don't let monitored dbs change during reconciliation + s.monitoredDatabases.mu.RLock() if err := reconciler.Reconcile(ctx); err != nil { s.log.ErrorContext(ctx, "Failed to reconcile.", "error", err) } if s.cfg.OnReconcile != nil { s.cfg.OnReconcile(s.getProxiedDatabases()) } + s.monitoredDatabases.mu.RUnlock() case <-ctx.Done(): s.log.DebugContext(ctx, "Reconciler done.") return @@ -167,11 +170,15 @@ func (s *Server) onCreate(ctx context.Context, database types.Database) error { // copy here so that any attribute changes to the proxied database will not // affect database objects tracked in s.monitoredDatabases. databaseCopy := database.Copy() - applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers) + + // only apply resource matcher settings to dynamic resources. + if s.monitoredDatabases.isResource_Locked(database) { + s.applyAWSResourceMatcherSettings(databaseCopy) + } // Run DiscoveryResourceChecker after resource matchers are applied to make // sure the correct AssumeRoleARN is used. - if s.monitoredDatabases.isDiscoveryResource(database) { + if s.monitoredDatabases.isDiscoveryResource_Locked(database) { if err := s.cfg.discoveryResourceChecker.Check(ctx, databaseCopy); err != nil { return trace.Wrap(err) } @@ -185,7 +192,11 @@ func (s *Server) onUpdate(ctx context.Context, database, _ types.Database) error // copy here so that any attribute changes to the proxied database will not // affect database objects tracked in s.monitoredDatabases. databaseCopy := database.Copy() - applyResourceMatchersToDatabase(databaseCopy, s.cfg.ResourceMatchers) + + // only apply resource matcher settings to dynamic resources. + if s.monitoredDatabases.isResource_Locked(database) { + s.applyAWSResourceMatcherSettings(databaseCopy) + } return s.updateDatabase(ctx, databaseCopy) } @@ -198,7 +209,7 @@ func (s *Server) onDelete(ctx context.Context, database types.Database) error { func (s *Server) matcher(database types.Database) bool { // In the case of databases discovered by this database server, matchers // should be skipped. - if s.monitoredDatabases.isCloud(database) { + if s.monitoredDatabases.isCloud_Locked(database) { return true // Cloud fetchers return only matching databases. } @@ -207,12 +218,18 @@ func (s *Server) matcher(database types.Database) bool { return services.MatchResourceLabels(s.cfg.ResourceMatchers, database.GetAllLabels()) } -func applyResourceMatchersToDatabase(database types.Database, resourceMatchers []services.ResourceMatcher) { - for _, matcher := range resourceMatchers { +func (s *Server) applyAWSResourceMatcherSettings(database types.Database) { + if !database.IsAWSHosted() { + // dynamic matchers only apply AWS settings (for now), so skip non-AWS + // databases. + return + } + dbLabels := database.GetAllLabels() + for _, matcher := range s.cfg.ResourceMatchers { if len(matcher.Labels) == 0 || matcher.AWS.AssumeRoleARN == "" { continue } - if match, _, _ := services.MatchLabels(matcher.Labels, database.GetAllLabels()); !match { + if match, _, _ := services.MatchLabels(matcher.Labels, dbLabels); !match { continue } diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go index 1843816442253..aa6635420589d 100644 --- a/lib/srv/db/watcher_test.go +++ b/lib/srv/db/watcher_test.go @@ -21,6 +21,7 @@ package db import ( "context" "fmt" + "maps" "sort" "testing" "time" @@ -60,11 +61,13 @@ func TestWatcher(t *testing.T) { // watches for databases with label group=a. testCtx.setupDatabaseServer(ctx, t, agentParams{ Databases: []types.Database{db0}, - ResourceMatchers: []services.ResourceMatcher{ - {Labels: types.Labels{ + ResourceMatchers: []services.ResourceMatcher{{ + Labels: types.Labels{ "group": []string{"a"}, - }}, - }, + }, + // these should not be applied to non-AWS databases. + AWS: services.ResourceMatcherAWS{AssumeRoleARN: "some-role", ExternalID: "some-externalid"}, + }}, OnReconcile: func(d types.Databases) { reconcileCh <- d }, @@ -137,7 +140,7 @@ func TestWatcher(t *testing.T) { // ResourceMatchers should be always evaluated for the dynamic registered // resources. func TestWatcherDynamicResource(t *testing.T) { - var db1, db2, db3, db4, db5 *types.DatabaseV3 + var db1, db2, db3, db4, db5, db6 *types.DatabaseV3 ctx := context.Background() testCtx := setupTestContext(ctx, t) @@ -247,6 +250,7 @@ func TestWatcherDynamicResource(t *testing.T) { // ResourceMatchers and has AssumeRoleARN set by the discovery service. discoveredDB5, err := makeDiscoveryDatabase("db5", map[string]string{"group": "b"}, withRDSURL, withDiscoveryAssumeRoleARN) require.NoError(t, err) + require.True(t, discoveredDB5.IsAWSHosted()) require.True(t, discoveredDB5.IsRDS()) err = testCtx.authServer.CreateDatabase(ctx, discoveredDB5) @@ -260,6 +264,22 @@ func TestWatcherDynamicResource(t *testing.T) { assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5}) }) + t.Run("non-AWS discovery resource - AssumeRoleARN not applied", func(t *testing.T) { + // Created a discovery service created database resource that matches + // ResourceMatchers but is not an AWS database + _, azureDB := makeAzureSQLServer(t, "discovery-azure", "group") + setLabels(azureDB, map[string]string{"group": "b"}) + azureDB.SetOrigin(types.OriginCloud) + require.False(t, azureDB.IsAWSHosted()) + require.True(t, azureDB.GetAWS().IsEmpty()) + require.True(t, azureDB.IsAzure()) + err = testCtx.authServer.CreateDatabase(ctx, azureDB) + require.NoError(t, err) + + db6 = azureDB.Copy() + assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6}) + }) + t.Run("discovery resource - fail check", func(t *testing.T) { // Created a discovery service created database resource that fails the // fakeDiscoveryResourceChecker. @@ -268,18 +288,16 @@ func TestWatcherDynamicResource(t *testing.T) { require.NoError(t, testCtx.authServer.CreateDatabase(ctx, dbFailCheck)) // dbFailCheck should not be proxied. - assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5}) + assertReconciledResource(t, reconcileCh, types.Databases{db0, db2, db4, db5, db6}) }) } -func setDiscoveryGroupLabel(r types.ResourceWithLabels, discoveryGroup string) { +func setLabels(r types.ResourceWithLabels, newLabels map[string]string) { staticLabels := r.GetStaticLabels() if staticLabels == nil { staticLabels = make(map[string]string) } - if discoveryGroup != "" { - staticLabels[types.TeleportInternalDiscoveryGroupName] = discoveryGroup - } + maps.Copy(staticLabels, newLabels) r.SetStaticLabels(staticLabels) } @@ -292,13 +310,14 @@ func TestWatcherCloudFetchers(t *testing.T) { redshiftServerlessDatabase, err := discovery.NewDatabaseFromRedshiftServerlessWorkgroup(redshiftServerlessWorkgroup, nil) require.NoError(t, err) redshiftServerlessDatabase.SetStatusAWS(redshiftServerlessDatabase.GetAWS()) - setDiscoveryGroupLabel(redshiftServerlessDatabase, "") redshiftServerlessDatabase.SetOrigin(types.OriginCloud) discovery.ApplyAWSDatabaseNameSuffix(redshiftServerlessDatabase, types.AWSMatcherRedshiftServerless) + require.Empty(t, redshiftServerlessDatabase.GetAWS().AssumeRoleARN) + require.Empty(t, redshiftServerlessDatabase.GetAWS().ExternalID) // Test an Azure fetcher. azSQLServer, azSQLServerDatabase := makeAzureSQLServer(t, "discovery-azure", "group") - setDiscoveryGroupLabel(azSQLServerDatabase, "") azSQLServerDatabase.SetOrigin(types.OriginCloud) + require.False(t, azSQLServerDatabase.IsAWSHosted()) ctx := context.Background() testCtx := setupTestContext(ctx, t) @@ -308,6 +327,13 @@ func TestWatcherCloudFetchers(t *testing.T) { OnReconcile: func(d types.Databases) { reconcileCh <- d }, + ResourceMatchers: []services.ResourceMatcher{{ + Labels: types.Labels{types.Wildcard: []string{types.Wildcard}}, + AWS: services.ResourceMatcherAWS{ + AssumeRoleARN: "role-arn", + ExternalID: "external-id", + }, + }}, CloudClients: &clients.TestCloudClients{ RDS: &mocks.RDSMockUnauth{}, // Access denied error should not affect other fetchers. RedshiftServerless: &mocks.RedshiftServerlessMock{ @@ -341,7 +367,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t select { case d := <-ch: sort.Sort(d) - require.Equal(t, len(d), len(databases)) + require.Equal(t, len(databases), len(d)) require.Empty(t, cmp.Diff(databases, d, cmpopts.IgnoreFields(types.Metadata{}, "Revision"), cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"),