From 0eb16647b7f42298acc8cfbca4922b05005fd3b5 Mon Sep 17 00:00:00 2001 From: Kishor Joshi Date: Tue, 18 Oct 2022 10:07:26 -0700 Subject: [PATCH 1/4] refactor backend SG provider --- controllers/ingress/group_controller.go | 19 +- main.go | 3 +- pkg/deploy/elbv2/listener_manager.go | 3 +- pkg/ingress/class_utils.go | 14 + pkg/ingress/model_build_load_balancer.go | 59 +--- pkg/ingress/model_builder.go | 16 +- pkg/ingress/model_builder_test.go | 10 +- pkg/k8s/meta_utils.go | 12 + pkg/networking/backend_sg_provider.go | 131 ++++++-- pkg/networking/backend_sg_provider_mocks.go | 17 +- pkg/networking/backend_sg_provider_test.go | 309 +++++++++++++++++- pkg/networking/security_group_resolver.go | 104 ++++++ .../security_group_resolver_mocks.go | 50 +++ .../security_group_resolver_test.go | 275 ++++++++++++++++ scripts/gen_mocks.sh | 3 +- 15 files changed, 901 insertions(+), 124 deletions(-) create mode 100644 pkg/ingress/class_utils.go create mode 100644 pkg/k8s/meta_utils.go create mode 100644 pkg/networking/security_group_resolver.go create mode 100644 pkg/networking/security_group_resolver_mocks.go create mode 100644 pkg/networking/security_group_resolver_test.go diff --git a/controllers/ingress/group_controller.go b/controllers/ingress/group_controller.go index 00c036a48..35c5fcf49 100644 --- a/controllers/ingress/group_controller.go +++ b/controllers/ingress/group_controller.go @@ -45,7 +45,8 @@ const ( func NewGroupReconciler(cloud aws.Cloud, k8sClient client.Client, eventRecorder record.EventRecorder, finalizerManager k8s.FinalizerManager, networkingSGManager networkingpkg.SecurityGroupManager, networkingSGReconciler networkingpkg.SecurityGroupReconciler, subnetsResolver networkingpkg.SubnetsResolver, - controllerConfig config.ControllerConfig, backendSGProvider networkingpkg.BackendSGProvider, logger logr.Logger) *groupReconciler { + controllerConfig config.ControllerConfig, backendSGProvider networkingpkg.BackendSGProvider, + sgResolver networkingpkg.SecurityGroupResolver, logger logr.Logger) *groupReconciler { annotationParser := annotations.NewSuffixAnnotationParser(annotations.AnnotationPrefixIngress) authConfigBuilder := ingress.NewDefaultAuthConfigBuilder(annotationParser) @@ -58,7 +59,7 @@ func NewGroupReconciler(cloud aws.Cloud, k8sClient client.Client, eventRecorder annotationParser, subnetsResolver, authConfigBuilder, enhancedBackendBuilder, trackingProvider, elbv2TaggingManager, controllerConfig.FeatureGates, cloud.VpcID(), controllerConfig.ClusterName, controllerConfig.DefaultTags, controllerConfig.ExternalManagedTags, - controllerConfig.DefaultSSLPolicy, controllerConfig.DefaultTargetType, backendSGProvider, + controllerConfig.DefaultSSLPolicy, controllerConfig.DefaultTargetType, backendSGProvider, sgResolver, controllerConfig.EnableBackendSecurityGroup, controllerConfig.DisableRestrictedSGRules, controllerConfig.FeatureGates.Enabled(config.EnableIPTargetType), logger) stackMarshaller := deploy.NewDefaultStackMarshaller() stackDeployer := deploy.NewDefaultStackDeployer(cloud, k8sClient, networkingSGManager, networkingSGReconciler, @@ -144,12 +145,6 @@ func (r *groupReconciler) reconcile(ctx context.Context, req ctrl.Request) error } } - if len(ingGroup.Members) == 0 { - if err := r.backendSGProvider.Release(ctx); err != nil { - return err - } - } - if len(ingGroup.InactiveMembers) > 0 { if err := r.groupFinalizerManager.RemoveGroupFinalizer(ctx, ingGroupID, ingGroup.InactiveMembers); err != nil { r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedRemoveFinalizer, fmt.Sprintf("Failed remove finalizer due to %v", err)) @@ -162,7 +157,7 @@ func (r *groupReconciler) reconcile(ctx context.Context, req ctrl.Request) error } func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingress.Group) (core.Stack, *elbv2model.LoadBalancer, error) { - stack, lb, secrets, err := r.modelBuilder.Build(ctx, ingGroup) + stack, lb, secrets, backendSGRequired, err := r.modelBuilder.Build(ctx, ingGroup) if err != nil { r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err)) return nil, nil, err @@ -180,7 +175,11 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr } r.logger.Info("successfully deployed model", "ingressGroup", ingGroup.ID) r.secretsManager.MonitorSecrets(ingGroup.ID.String(), secrets) - return stack, lb, err + if err := r.backendSGProvider.Release(ctx, k8s.ToSliceOfMetaObject(ingress.ExtractIngresses(ingGroup.Members)), + k8s.ToSliceOfMetaObject(ingGroup.InactiveMembers), backendSGRequired); err != nil { + return nil, nil, err + } + return stack, lb, nil } func (r *groupReconciler) recordIngressGroupEvent(_ context.Context, ingGroup ingress.Group, eventType string, reason string, message string) { diff --git a/main.go b/main.go index 5042dffb1..ce68f0e24 100644 --- a/main.go +++ b/main.go @@ -111,9 +111,10 @@ func main() { mgr.GetEventRecorderFor("targetGroupBinding"), ctrl.Log) backendSGProvider := networking.NewBackendSGProvider(controllerCFG.ClusterName, controllerCFG.BackendSecurityGroup, cloud.VpcID(), cloud.EC2(), mgr.GetClient(), controllerCFG.DefaultTags, ctrl.Log.WithName("backend-sg-provider")) + sgResolver := networking.NewDefaultSecurityGroupResolver(cloud.EC2(), cloud.VpcID()) ingGroupReconciler := ingress.NewGroupReconciler(cloud, mgr.GetClient(), mgr.GetEventRecorderFor("ingress"), finalizerManager, sgManager, sgReconciler, subnetResolver, - controllerCFG, backendSGProvider, ctrl.Log.WithName("controllers").WithName("ingress")) + controllerCFG, backendSGProvider, sgResolver, ctrl.Log.WithName("controllers").WithName("ingress")) svcReconciler := service.NewServiceReconciler(cloud, mgr.GetClient(), mgr.GetEventRecorderFor("service"), finalizerManager, sgManager, sgReconciler, subnetResolver, vpcInfoProvider, controllerCFG, ctrl.Log.WithName("controllers").WithName("service")) diff --git a/pkg/deploy/elbv2/listener_manager.go b/pkg/deploy/elbv2/listener_manager.go index a3920e126..d756c6316 100644 --- a/pkg/deploy/elbv2/listener_manager.go +++ b/pkg/deploy/elbv2/listener_manager.go @@ -2,6 +2,8 @@ package elbv2 import ( "context" + "time" + awssdk "github.com/aws/aws-sdk-go/aws" elbv2sdk "github.com/aws/aws-sdk-go/service/elbv2" "github.com/go-logr/logr" @@ -15,7 +17,6 @@ import ( elbv2equality "sigs.k8s.io/aws-load-balancer-controller/pkg/equality/elbv2" elbv2model "sigs.k8s.io/aws-load-balancer-controller/pkg/model/elbv2" "sigs.k8s.io/aws-load-balancer-controller/pkg/runtime" - "time" ) // ListenerManager is responsible for create/update/delete Listener resources. diff --git a/pkg/ingress/class_utils.go b/pkg/ingress/class_utils.go new file mode 100644 index 000000000..9ec9f62f3 --- /dev/null +++ b/pkg/ingress/class_utils.go @@ -0,0 +1,14 @@ +package ingress + +import ( + networking "k8s.io/api/networking/v1" +) + +// ExtractIngresses returns the list of *networking.Ingress contained in the list of classifiedIngresses +func ExtractIngresses(classifiedIngresses []ClassifiedIngress) []*networking.Ingress { + result := make([]*networking.Ingress, len(classifiedIngresses)) + for _, v := range classifiedIngresses { + result = append(result, v.Ing) + } + return result +} diff --git a/pkg/ingress/model_build_load_balancer.go b/pkg/ingress/model_build_load_balancer.go index 4695c4b16..0660450da 100644 --- a/pkg/ingress/model_build_load_balancer.go +++ b/pkg/ingress/model_build_load_balancer.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "fmt" "regexp" - "strings" awssdk "github.com/aws/aws-sdk-go/aws" ec2sdk "github.com/aws/aws-sdk-go/service/ec2" @@ -284,11 +283,12 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { t.backendSGIDToken = managedSG.GroupID() } else { - backendSGID, err := t.backendSGProvider.Get(ctx) + backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(ExtractIngresses(t.ingGroup.Members))) if err != nil { return nil, err } t.backendSGIDToken = core.LiteralStringToken((backendSGID)) + t.backendSGAllocated = true lbSGTokens = append(lbSGTokens, t.backendSGIDToken) } t.logger.Info("Auto Create SG", "LB SGs", lbSGTokens, "backend SG", t.backendSGIDToken) @@ -297,7 +297,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if err != nil { return nil, err } - frontendSGIDs, err := t.resolveSecurityGroupIDsViaNameOrIDSlice(ctx, sgNameOrIDsViaAnnotation) + frontendSGIDs, err := t.sgResolver.ResolveViaNameOrID(ctx, sgNameOrIDsViaAnnotation) if err != nil { return nil, err } @@ -309,11 +309,12 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { return nil, errors.New("backendSG feature is required to manage worker node SG rules when frontendSG manually specified") } - backendSGID, err := t.backendSGProvider.Get(ctx) + backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(ExtractIngresses(t.ingGroup.Members))) if err != nil { return nil, err } t.backendSGIDToken = core.LiteralStringToken(backendSGID) + t.backendSGAllocated = true lbSGTokens = append(lbSGTokens, t.backendSGIDToken) } t.logger.Info("SG configured via annotation", "LB SGs", lbSGTokens, "backend SG", t.backendSGIDToken) @@ -390,56 +391,6 @@ func (t *defaultModelBuildTask) buildLoadBalancerTags(_ context.Context) (map[st return algorithm.MergeStringMap(t.defaultTags, ingGroupTags), nil } -func (t *defaultModelBuildTask) resolveSecurityGroupIDsViaNameOrIDSlice(ctx context.Context, sgNameOrIDs []string) ([]string, error) { - var sgIDs []string - var sgNames []string - for _, nameOrID := range sgNameOrIDs { - if strings.HasPrefix(nameOrID, "sg-") { - sgIDs = append(sgIDs, nameOrID) - } else { - sgNames = append(sgNames, nameOrID) - } - } - var resolvedSGs []*ec2sdk.SecurityGroup - if len(sgIDs) > 0 { - req := &ec2sdk.DescribeSecurityGroupsInput{ - GroupIds: awssdk.StringSlice(sgIDs), - } - sgs, err := t.ec2Client.DescribeSecurityGroupsAsList(ctx, req) - if err != nil { - return nil, err - } - resolvedSGs = append(resolvedSGs, sgs...) - } - if len(sgNames) > 0 { - req := &ec2sdk.DescribeSecurityGroupsInput{ - Filters: []*ec2sdk.Filter{ - { - Name: awssdk.String("tag:Name"), - Values: awssdk.StringSlice(sgNames), - }, - { - Name: awssdk.String("vpc-id"), - Values: awssdk.StringSlice([]string{t.vpcID}), - }, - }, - } - sgs, err := t.ec2Client.DescribeSecurityGroupsAsList(ctx, req) - if err != nil { - return nil, err - } - resolvedSGs = append(resolvedSGs, sgs...) - } - resolvedSGIDs := make([]string, 0, len(resolvedSGs)) - for _, sg := range resolvedSGs { - resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) - } - if len(resolvedSGIDs) != len(sgNameOrIDs) { - return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs) - } - return resolvedSGIDs, nil -} - func buildLoadBalancerSubnetMappingsWithSubnets(subnets []*ec2sdk.Subnet) []elbv2model.SubnetMapping { subnetMappings := make([]elbv2model.SubnetMapping, 0, len(subnets)) for _, subnet := range subnets { diff --git a/pkg/ingress/model_builder.go b/pkg/ingress/model_builder.go index 5f72b1411..bdc348dcf 100644 --- a/pkg/ingress/model_builder.go +++ b/pkg/ingress/model_builder.go @@ -32,7 +32,7 @@ const ( // ModelBuilder is responsible for build mode stack for a IngressGroup. type ModelBuilder interface { // build mode stack for a IngressGroup. - Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, error) + Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, error) } // NewDefaultModelBuilder constructs new defaultModelBuilder. @@ -42,7 +42,8 @@ func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventR authConfigBuilder AuthConfigBuilder, enhancedBackendBuilder EnhancedBackendBuilder, trackingProvider tracking.Provider, elbv2TaggingManager elbv2deploy.TaggingManager, featureGates config.FeatureGates, vpcID string, clusterName string, defaultTags map[string]string, externalManagedTags []string, defaultSSLPolicy string, defaultTargetType string, - backendSGProvider networkingpkg.BackendSGProvider, enableBackendSG bool, disableRestrictedSGRules bool, enableIPTargetType bool, logger logr.Logger) *defaultModelBuilder { + backendSGProvider networkingpkg.BackendSGProvider, sgResolver networkingpkg.SecurityGroupResolver, + enableBackendSG bool, disableRestrictedSGRules bool, enableIPTargetType bool, logger logr.Logger) *defaultModelBuilder { certDiscovery := NewACMCertDiscovery(acmClient, logger) ruleOptimizer := NewDefaultRuleOptimizer(logger) return &defaultModelBuilder{ @@ -54,6 +55,7 @@ func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventR annotationParser: annotationParser, subnetsResolver: subnetsResolver, backendSGProvider: backendSGProvider, + sgResolver: sgResolver, certDiscovery: certDiscovery, authConfigBuilder: authConfigBuilder, enhancedBackendBuilder: enhancedBackendBuilder, @@ -86,6 +88,7 @@ type defaultModelBuilder struct { annotationParser annotations.Parser subnetsResolver networkingpkg.SubnetsResolver backendSGProvider networkingpkg.BackendSGProvider + sgResolver networkingpkg.SecurityGroupResolver certDiscovery CertDiscovery authConfigBuilder AuthConfigBuilder enhancedBackendBuilder EnhancedBackendBuilder @@ -105,7 +108,7 @@ type defaultModelBuilder struct { } // build mode stack for a IngressGroup. -func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, error) { +func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, error) { stack := core.NewDefaultStack(core.StackID(ingGroup.ID)) task := &defaultModelBuildTask{ k8sClient: b.k8sClient, @@ -123,6 +126,7 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.S elbv2TaggingManager: b.elbv2TaggingManager, featureGates: b.featureGates, backendSGProvider: b.backendSGProvider, + sgResolver: b.sgResolver, logger: b.logger, enableBackendSG: b.enableBackendSG, disableRestrictedSGRules: b.disableRestrictedSGRules, @@ -153,9 +157,9 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group) (core.S backendServices: make(map[types.NamespacedName]*corev1.Service), } if err := task.run(ctx); err != nil { - return nil, nil, nil, err + return nil, nil, nil, false, err } - return task.stack, task.loadBalancer, task.secretKeys, nil + return task.stack, task.loadBalancer, task.secretKeys, task.backendSGAllocated, nil } // the default model build task @@ -168,6 +172,7 @@ type defaultModelBuildTask struct { annotationParser annotations.Parser subnetsResolver networkingpkg.SubnetsResolver backendSGProvider networkingpkg.BackendSGProvider + sgResolver networkingpkg.SecurityGroupResolver certDiscovery CertDiscovery authConfigBuilder AuthConfigBuilder enhancedBackendBuilder EnhancedBackendBuilder @@ -181,6 +186,7 @@ type defaultModelBuildTask struct { sslRedirectConfig *SSLRedirectConfig stack core.Stack backendSGIDToken core.StringToken + backendSGAllocated bool enableBackendSG bool disableRestrictedSGRules bool enableIPTargetType bool diff --git a/pkg/ingress/model_builder_test.go b/pkg/ingress/model_builder_test.go index b3f5ae9f1..d6f3a3674 100644 --- a/pkg/ingress/model_builder_test.go +++ b/pkg/ingress/model_builder_test.go @@ -2920,13 +2920,14 @@ func Test_defaultModelBuilder_Build(t *testing.T) { trackingProvider := tracking.NewDefaultProvider("ingress.k8s.aws", clusterName) stackMarshaller := deploy.NewDefaultStackMarshaller() backendSGProvider := networkingpkg.NewMockBackendSGProvider(ctrl) + sgResolver := networkingpkg.NewDefaultSecurityGroupResolver(ec2Client, vpcID) if tt.fields.enableBackendSG { if len(tt.fields.backendSecurityGroup) > 0 { - backendSGProvider.EXPECT().Get(gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes() + backendSGProvider.EXPECT().Get(gomock.Any(), gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes() } else { - backendSGProvider.EXPECT().Get(gomock.Any()).Return("sg-auto", nil).AnyTimes() + backendSGProvider.EXPECT().Get(gomock.Any(), gomock.Any()).Return("sg-auto", nil).AnyTimes() } - backendSGProvider.EXPECT().Release(gomock.Any()).Return(nil).AnyTimes() + backendSGProvider.EXPECT().Release(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() } defaultTargetType := tt.defaultTargetType if defaultTargetType == "" { @@ -2941,6 +2942,7 @@ func Test_defaultModelBuilder_Build(t *testing.T) { clusterName: clusterName, annotationParser: annotationParser, subnetsResolver: subnetsResolver, + sgResolver: sgResolver, backendSGProvider: backendSGProvider, certDiscovery: certDiscovery, authConfigBuilder: authConfigBuilder, @@ -2962,7 +2964,7 @@ func Test_defaultModelBuilder_Build(t *testing.T) { b.enableIPTargetType = *tt.enableIPTargetType } - gotStack, _, _, err := b.Build(context.Background(), tt.args.ingGroup) + gotStack, _, _, _, err := b.Build(context.Background(), tt.args.ingGroup) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) } else { diff --git a/pkg/k8s/meta_utils.go b/pkg/k8s/meta_utils.go new file mode 100644 index 000000000..028afeeaa --- /dev/null +++ b/pkg/k8s/meta_utils.go @@ -0,0 +1,12 @@ +package k8s + +import metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + +// ToSliceOfMetaObject converts the input slice s to slice of metav1.Object +func ToSliceOfMetaObject[T metav1.Object](s []T) []metav1.Object { + result := make([]metav1.Object, len(s)) + for i, v := range s { + result[i] = v + } + return result +} diff --git a/pkg/networking/backend_sg_provider.go b/pkg/networking/backend_sg_provider.go index e8dc3d683..3b8a944b9 100644 --- a/pkg/networking/backend_sg_provider.go +++ b/pkg/networking/backend_sg_provider.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "reflect" "regexp" "sort" "strings" @@ -17,7 +18,9 @@ import ( "github.com/go-logr/logr" "github.com/pkg/errors" networking "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" "sigs.k8s.io/aws-load-balancer-controller/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -40,9 +43,10 @@ const ( // BackendSGProvider is responsible for providing backend security groups type BackendSGProvider interface { // Get returns the backend security group to use - Get(ctx context.Context) (string, error) + Get(ctx context.Context, activeResources []metav1.Object) (string, error) // Release cleans up the auto-generated backend SG if necessary - Release(ctx context.Context) error + Release(ctx context.Context, activeResources []metav1.Object, + inactiveResources []metav1.Object, backendSGRequired bool) error } // NewBackendSGProvider constructs a new defaultBackendSGProvider @@ -58,6 +62,15 @@ func NewBackendSGProvider(clusterName string, backendSG string, vpcID string, logger: logger, mutex: sync.Mutex{}, + checkIngressFinalizersFunc: func(finalizers []string) bool { + for _, fin := range finalizers { + if fin == implicitGroupFinalizer || strings.HasPrefix(fin, explicitGroupFinalizerPrefix) { + return true + } + } + return false + }, + defaultDeletionPollInterval: defaultSGDeletionPollInterval, defaultDeletionTimeout: defaultSGDeletionTimeout, } @@ -76,37 +89,111 @@ type defaultBackendSGProvider struct { ec2Client services.EC2 k8sClient client.Client logger logr.Logger + // objectsMap keeps track of whether the backend SG is required for any tracked resources in the cluster. + // If any entry in the map is true, or there are resources with this controller specific finalizers which + // haven't been tracked in the map yet, controller doesn't delete the backend SG. If the controller has + // processed all supported resources and none of them require backend SG, i.e. the values are false in this map + // controller deletes the backend SG. + objectsMap sync.Map + + checkIngressFinalizersFunc func([]string) bool defaultDeletionPollInterval time.Duration defaultDeletionTimeout time.Duration } -func (p *defaultBackendSGProvider) Get(ctx context.Context) (string, error) { +func (p *defaultBackendSGProvider) Get(ctx context.Context, activeResources []metav1.Object) (string, error) { if len(p.backendSG) > 0 { return p.backendSG, nil } // Auto generate Backend Security group, and return the id - if err := p.allocateBackendSG(ctx); err != nil { + if err := p.allocateBackendSG(ctx, activeResources); err != nil { p.logger.Error(err, "Failed to auto-create backend SG") return "", err } return p.autoGeneratedSG, nil } -func (p *defaultBackendSGProvider) Release(ctx context.Context) error { +func (p *defaultBackendSGProvider) Release(ctx context.Context, activeResources []metav1.Object, + inactiveResources []metav1.Object, backendSGRequired bool) error { if len(p.backendSG) > 0 { return nil } + defer func() { + for _, res := range inactiveResources { + p.objectsMap.Delete(getObjectKey(res)) + } + }() + p.updateObjectsMap(ctx, activeResources, inactiveResources, backendSGRequired) + p.logger.V(1).Info("release backend SG", "active", activeResources, + "inactive", inactiveResources, "needed", backendSGRequired) + if len(activeResources) > 0 && backendSGRequired { + return nil + } if required, err := p.isBackendSGRequired(ctx); required || err != nil { return err } return p.releaseSG(ctx) } -func (p *defaultBackendSGProvider) allocateBackendSG(ctx context.Context) error { +func (p *defaultBackendSGProvider) updateObjectsMap(_ context.Context, activeResources []metav1.Object, + inactiveResources []metav1.Object, backendSGRequired bool) { + for _, res := range inactiveResources { + p.objectsMap.Store(getObjectKey(res), false) + } + for _, res := range activeResources { + p.objectsMap.Store(getObjectKey(res), backendSGRequired) + } +} + +func (p *defaultBackendSGProvider) isBackendSGRequired(ctx context.Context) (bool, error) { + var requiredForAny bool + p.objectsMap.Range(func(_, v interface{}) bool { + if v.(bool) { + requiredForAny = true + return false + } + return true + }) + if requiredForAny { + return true, nil + } + if required, err := p.checkIngressListForUnmapped(ctx); required || err != nil { + return required, err + } + return false, nil +} + +func (p *defaultBackendSGProvider) checkIngressListForUnmapped(ctx context.Context) (bool, error) { + ingList := &networking.IngressList{} + if err := p.k8sClient.List(ctx, ingList); err != nil { + return true, errors.Wrapf(err, "unable to list ingresses") + } + for _, ing := range ingList.Items { + if !p.checkIngressFinalizersFunc(ing.GetFinalizers()) { + continue + } + if !p.existsInObjectMap(&ing) { + return true, nil + } + } + return false, nil +} + +func (p *defaultBackendSGProvider) existsInObjectMap(obj metav1.Object) bool { + if _, exists := p.objectsMap.Load(getObjectKey(obj)); exists { + return true + } + return false +} + +func (p *defaultBackendSGProvider) allocateBackendSG(ctx context.Context, activeResources []metav1.Object) error { p.mutex.Lock() defer p.mutex.Unlock() + for _, res := range activeResources { + p.objectsMap.Store(getObjectKey(res), true) + } if len(p.autoGeneratedSG) > 0 { return nil } @@ -128,12 +215,12 @@ func (p *defaultBackendSGProvider) allocateBackendSG(ctx context.Context) error Description: awssdk.String(sgDescription), TagSpecifications: p.buildBackendSGTags(ctx), } - p.logger.Info("creating securityGroup", "name", sgName) + p.logger.V(1).Info("creating securityGroup", "name", sgName) resp, err := p.ec2Client.CreateSecurityGroupWithContext(ctx, createReq) if err != nil { return err } - p.logger.V(1).Info("created SecurityGroup", "name", sgName, "id", resp.GroupId) + p.logger.Info("created SecurityGroup", "name", sgName, "id", resp.GroupId) p.autoGeneratedSG = awssdk.StringValue(resp.GroupId) return nil } @@ -194,42 +281,20 @@ func (p *defaultBackendSGProvider) getBackendSGFromEC2(ctx context.Context, sgNa return "", nil } -func (p *defaultBackendSGProvider) isBackendSGRequired(ctx context.Context) (bool, error) { - ingList := &networking.IngressList{} - if err := p.k8sClient.List(ctx, ingList); err != nil { - p.logger.Error(err, "Unable to list ingresses") - return true, errors.Wrapf(err, "unable to list ingresses") - } - for _, ing := range ingList.Items { - if !ing.DeletionTimestamp.IsZero() { - continue - } - for _, fin := range ing.GetFinalizers() { - if fin == implicitGroupFinalizer || strings.HasPrefix(fin, explicitGroupFinalizerPrefix) { - return true, nil - } - } - } - p.logger.Info("No ingress found, backend SG can be deleted", "SG ID", p.autoGeneratedSG) - return false, nil -} - func (p *defaultBackendSGProvider) releaseSG(ctx context.Context) error { p.mutex.Lock() defer p.mutex.Unlock() - if len(p.autoGeneratedSG) == 0 { return nil } if required, err := p.isBackendSGRequired(ctx); required || err != nil { - p.logger.V(1).Info("backend SG is required, releaseSG ignore delete") + p.logger.V(1).Info("releaseSG ignore delete", "required", required, "err", err) return err } req := &ec2sdk.DeleteSecurityGroupInput{ GroupId: awssdk.String(p.autoGeneratedSG), } - p.logger.V(1).Info("deleting default backend SG", "ID", p.autoGeneratedSG) if err := runtime.RetryImmediateOnError(p.defaultDeletionPollInterval, p.defaultDeletionTimeout, isSecurityGroupDependencyViolationError, func() error { _, err := p.ec2Client.DeleteSecurityGroupWithContext(ctx, req) return err @@ -267,3 +332,7 @@ func isEC2SecurityGroupNotFoundError(err error) bool { } return false } + +func getObjectKey(obj metav1.Object) string { + return reflect.TypeOf(obj).String() + "/" + k8s.NamespacedName(obj).String() +} diff --git a/pkg/networking/backend_sg_provider_mocks.go b/pkg/networking/backend_sg_provider_mocks.go index 0efba1a68..fafae1785 100644 --- a/pkg/networking/backend_sg_provider_mocks.go +++ b/pkg/networking/backend_sg_provider_mocks.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) // MockBackendSGProvider is a mock of BackendSGProvider interface. @@ -35,30 +36,30 @@ func (m *MockBackendSGProvider) EXPECT() *MockBackendSGProviderMockRecorder { } // Get mocks base method. -func (m *MockBackendSGProvider) Get(arg0 context.Context) (string, error) { +func (m *MockBackendSGProvider) Get(arg0 context.Context, arg1 []v1.Object) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0) + ret := m.ctrl.Call(m, "Get", arg0, arg1) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockBackendSGProviderMockRecorder) Get(arg0 interface{}) *gomock.Call { +func (mr *MockBackendSGProviderMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBackendSGProvider)(nil).Get), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBackendSGProvider)(nil).Get), arg0, arg1) } // Release mocks base method. -func (m *MockBackendSGProvider) Release(arg0 context.Context) error { +func (m *MockBackendSGProvider) Release(arg0 context.Context, arg1, arg2 []v1.Object, arg3 bool) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Release", arg0) + ret := m.ctrl.Call(m, "Release", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // Release indicates an expected call of Release. -func (mr *MockBackendSGProviderMockRecorder) Release(arg0 interface{}) *gomock.Call { +func (mr *MockBackendSGProviderMockRecorder) Release(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockBackendSGProvider)(nil).Release), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockBackendSGProvider)(nil).Release), arg0, arg1, arg2, arg3) } diff --git a/pkg/networking/backend_sg_provider_test.go b/pkg/networking/backend_sg_provider_test.go index eb941fd7a..d4ac29152 100644 --- a/pkg/networking/backend_sg_provider_test.go +++ b/pkg/networking/backend_sg_provider_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" networking "k8s.io/api/networking/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" mock_client "sigs.k8s.io/aws-load-balancer-controller/mocks/controller-runtime/client" @@ -256,7 +257,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { sgProvider := NewBackendSGProvider(defaultClusterName, tt.fields.backendSG, defaultVPCID, ec2Client, k8sClient, tt.fields.defaultTags, logr.New(&log.NullLogSink{})) - got, err := sgProvider.Get(context.Background()) + got, err := sgProvider.Get(context.Background(), nil) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -275,17 +276,60 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { ingresses []*networking.Ingress err error } + type listServicesCall struct { + services []*corev1.Service + err error + } type deleteSecurityGroupWithContextCall struct { req *ec2sdk.DeleteSecurityGroupInput resp *ec2sdk.DeleteSecurityGroupOutput err error } + type mapItem struct { + key metav1.Object + value bool + } type fields struct { - autogenSG string - backendSG string - defaultTags map[string]string - listIngressCalls []listIngressCall - deleteSGCalls []deleteSecurityGroupWithContextCall + autogenSG string + backendSG string + defaultTags map[string]string + listIngressCalls []listIngressCall + deleteSGCalls []deleteSecurityGroupWithContextCall + listServicesCalls []listServicesCall + activeResources []metav1.Object + inactiveResources []metav1.Object + resourceMapItems []mapItem + backendSGRequiredForActive bool + } + ing := &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-ing", + }, + } + ing1 := &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "name", + }, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-svc", + }, + } + svc1 := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "svc-1", + }, + } + svc2 := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "svc-2", + }, } tests := []struct { name string @@ -296,7 +340,8 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { { name: "backend sg specified via flags", fields: fields{ - backendSG: "sg-first", + backendSG: "sg-first", + inactiveResources: []metav1.Object{ing}, }, }, { @@ -308,6 +353,77 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { ingresses: []*networking.Ingress{}, }, }, + listServicesCalls: []listServicesCall{ + { + services: []*corev1.Service{}, + }, + }, + deleteSGCalls: []deleteSecurityGroupWithContextCall{ + { + req: &ec2sdk.DeleteSecurityGroupInput{ + GroupId: awssdk.String("sg-autogen"), + }, + resp: &ec2sdk.DeleteSecurityGroupOutput{}, + }, + }, + inactiveResources: []metav1.Object{ing}, + }, + }, + { + name: "backend sg required true, for ingress", + fields: fields{ + autogenSG: "sg-autogen", + resourceMapItems: []mapItem{ + { + key: svc2, + value: true, + }, + }, + activeResources: []metav1.Object{ing}, + }, + }, + { + name: "backend sg required true, for service", + fields: fields{ + autogenSG: "sg-autogen", + resourceMapItems: []mapItem{ + { + key: svc2, + value: true, + }, + }, + activeResources: []metav1.Object{svc}, + }, + }, + { + name: "backend sg requirement true for active resource", + fields: fields{ + listIngressCalls: []listIngressCall{ + {}, + }, + listServicesCalls: []listServicesCall{ + {}, + }, + backendSGRequiredForActive: true, + activeResources: []metav1.Object{ing}, + }, + }, + { + name: "backend sg not required for active ingress", + fields: fields{ + autogenSG: "sg-autogen", + backendSGRequiredForActive: false, + activeResources: []metav1.Object{ing}, + listIngressCalls: []listIngressCall{ + { + ingresses: []*networking.Ingress{}, + }, + }, + listServicesCalls: []listServicesCall{ + { + services: []*corev1.Service{}, + }, + }, deleteSGCalls: []deleteSecurityGroupWithContextCall{ { req: &ec2sdk.DeleteSecurityGroupInput{ @@ -341,6 +457,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, }, }, + inactiveResources: []metav1.Object{ing}, }, }, { @@ -360,6 +477,134 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, }, }, + inactiveResources: []metav1.Object{ing}, + }, + }, + { + name: "backend sg required for svc", + fields: fields{ + autogenSG: "sg-autogen", + listIngressCalls: []listIngressCall{ + {}, + }, + listServicesCalls: []listServicesCall{ + { + services: []*corev1.Service{ + { + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "svc-1", + Finalizers: []string{"service.k8s.aws/resources"}, + }, + }, + }, + }, + }, + inactiveResources: []metav1.Object{ing}, + deleteSGCalls: []deleteSecurityGroupWithContextCall{ + { + req: &ec2sdk.DeleteSecurityGroupInput{ + GroupId: awssdk.String("sg-autogen"), + }, + resp: &ec2sdk.DeleteSecurityGroupOutput{}, + }, + }, + }, + }, + { + name: "backend sg requirement for service already known", + fields: fields{ + autogenSG: "sg-autogen", + inactiveResources: []metav1.Object{ing}, + resourceMapItems: []mapItem{ + { + key: svc2, + value: true, + }, + }, + }, + }, + { + name: "backend sg requirement for ingress already known", + fields: fields{ + autogenSG: "sg-autogen", + inactiveResources: []metav1.Object{ing}, + resourceMapItems: []mapItem{ + { + key: ing1, + value: true, + }, + { + key: svc1, + value: false, + }, + { + key: svc2, + value: false, + }, + }, + }, + }, + { + name: "backend sg requirement all known, requires delete", + fields: fields{ + autogenSG: "sg-autogen", + listIngressCalls: []listIngressCall{ + { + ingresses: []*networking.Ingress{ + { + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "name", + Finalizers: []string{"ingress.k8s.aws/resources"}, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-ing", + Finalizers: []string{"group.ingress.k8s.aws/awesome-group"}, + }, + }, + }, + }, + }, + listServicesCalls: []listServicesCall{ + { + services: []*corev1.Service{ + { + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-svc", + Finalizers: []string{"service.k8s.aws/resources"}, + }, + }, + }, + }, + }, + deleteSGCalls: []deleteSecurityGroupWithContextCall{ + { + req: &ec2sdk.DeleteSecurityGroupInput{ + GroupId: awssdk.String("sg-autogen"), + }, + resp: &ec2sdk.DeleteSecurityGroupOutput{}, + }, + }, + activeResources: []metav1.Object{svc}, + resourceMapItems: []mapItem{ + { + key: ing, + value: false, + }, + { + key: ing1, + value: false, + }, + { + key: svc, + value: false, + }, + }, }, }, { @@ -371,6 +616,11 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { ingresses: []*networking.Ingress{}, }, }, + listServicesCalls: []listServicesCall{ + { + services: []*corev1.Service{}, + }, + }, deleteSGCalls: []deleteSecurityGroupWithContextCall{ { req: &ec2sdk.DeleteSecurityGroupInput{ @@ -385,6 +635,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { resp: &ec2sdk.DeleteSecurityGroupOutput{}, }, }, + inactiveResources: []metav1.Object{ing}, }, }, { @@ -393,6 +644,8 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { autogenSG: "sg-autogen", listIngressCalls: []listIngressCall{ {}, + }, + listServicesCalls: []listServicesCall{ {}, }, deleteSGCalls: []deleteSecurityGroupWithContextCall{ @@ -403,11 +656,12 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { err: awserr.New("Something.Else", "unable to delete SG", nil), }, }, + inactiveResources: []metav1.Object{ing}, }, wantErr: errors.New("failed to delete securityGroup: Something.Else: unable to delete SG"), }, { - name: "k8s list returns error", + name: "k8s ingress list returns error", fields: fields{ autogenSG: "sg-autogen", listIngressCalls: []listIngressCall{ @@ -415,9 +669,33 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { err: errors.New("failed"), }, }, + inactiveResources: []metav1.Object{ing}, }, wantErr: errors.New("unable to list ingresses: failed"), }, + { + name: "k8s service list returns error", + fields: fields{ + autogenSG: "sg-autogen", + listIngressCalls: []listIngressCall{ + {}, + }, + listServicesCalls: []listServicesCall{ + { + err: errors.New("failed"), + }, + }, + inactiveResources: []metav1.Object{ing}, + deleteSGCalls: []deleteSecurityGroupWithContextCall{ + { + req: &ec2sdk.DeleteSecurityGroupInput{ + GroupId: awssdk.String("sg-autogen"), + }, + resp: &ec2sdk.DeleteSecurityGroupOutput{}, + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -432,6 +710,9 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { sgProvider.backendSG = "" sgProvider.autoGeneratedSG = tt.fields.autogenSG } + for _, item := range tt.fields.resourceMapItems { + sgProvider.objectsMap.Store(getObjectKey(item.key), item.value) + } var deleteCalls []*gomock.Call for _, call := range tt.fields.deleteSGCalls { deleteCalls = append(deleteCalls, ec2Client.EXPECT().DeleteSecurityGroupWithContext(context.Background(), call.req).Return(call.resp, call.err)) @@ -449,10 +730,20 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, ).AnyTimes() } + for _, call := range tt.fields.listServicesCalls { + k8sClient.EXPECT().List(gomock.Any(), &corev1.ServiceList{}, gomock.Any()).DoAndReturn( + func(ctx context.Context, svcList *corev1.ServiceList, opts ...client.ListOption) error { + for _, svc := range call.services { + svcList.Items = append(svcList.Items, *(svc.DeepCopy())) + } + return call.err + }, + ).AnyTimes() + } for _, ing := range tt.env.ingresses { assert.NoError(t, k8sClient.Create(context.Background(), ing.DeepCopy())) } - gotErr := sgProvider.Release(context.Background()) + gotErr := sgProvider.Release(context.Background(), tt.fields.activeResources, tt.fields.inactiveResources, tt.fields.backendSGRequiredForActive) if tt.wantErr != nil { assert.EqualError(t, gotErr, tt.wantErr.Error()) } else { diff --git a/pkg/networking/security_group_resolver.go b/pkg/networking/security_group_resolver.go new file mode 100644 index 000000000..402d1795f --- /dev/null +++ b/pkg/networking/security_group_resolver.go @@ -0,0 +1,104 @@ +package networking + +import ( + "context" + "strings" + + awssdk "github.com/aws/aws-sdk-go/aws" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +// SecurityGroupResolver is responsible for resolving the frontend security groups from the names or IDs +type SecurityGroupResolver interface { + // ResolveViaNameOrID resolves security groups from the security group names or the IDs + ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) +} + +// NewDefaultSecurityGroupResolver constructs new defaultSecurityGroupResolver. +func NewDefaultSecurityGroupResolver(ec2Client services.EC2, vpcID string) *defaultSecurityGroupResolver { + return &defaultSecurityGroupResolver{ + ec2Client: ec2Client, + vpcID: vpcID, + } +} + +var _ SecurityGroupResolver = &defaultSecurityGroupResolver{} + +// default implementation for SecurityGroupResolver +type defaultSecurityGroupResolver struct { + ec2Client services.EC2 + vpcID string +} + +func (r *defaultSecurityGroupResolver) ResolveViaNameOrID(ctx context.Context, sgNameOrIDs []string) ([]string, error) { + sgIDs, sgNames := r.splitIntoSgNameAndIDs(sgNameOrIDs) + var resolvedSGs []*ec2sdk.SecurityGroup + if len(sgIDs) > 0 { + sgs, err := r.resolveViaGroupID(ctx, sgIDs) + if err != nil { + return nil, err + } + resolvedSGs = append(resolvedSGs, sgs...) + } + if len(sgNames) > 0 { + sgs, err := r.resolveViaGroupName(ctx, sgNames) + if err != nil { + return nil, err + } + resolvedSGs = append(resolvedSGs, sgs...) + } + resolvedSGIDs := make([]string, 0, len(resolvedSGs)) + for _, sg := range resolvedSGs { + resolvedSGIDs = append(resolvedSGIDs, awssdk.StringValue(sg.GroupId)) + } + if len(resolvedSGIDs) != len(sgNameOrIDs) { + return nil, errors.Errorf("couldn't find all securityGroups, nameOrIDs: %v, found: %v", sgNameOrIDs, resolvedSGIDs) + } + return resolvedSGIDs, nil +} + +func (r *defaultSecurityGroupResolver) resolveViaGroupID(ctx context.Context, sgIDs []string) ([]*ec2sdk.SecurityGroup, error) { + req := &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice(sgIDs), + } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) + if err != nil { + return nil, err + } + return sgs, nil +} + +func (r *defaultSecurityGroupResolver) resolveViaGroupName(ctx context.Context, sgNames []string) ([]*ec2sdk.SecurityGroup, error) { + req := &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice(sgNames), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{r.vpcID}), + }, + }, + } + sgs, err := r.ec2Client.DescribeSecurityGroupsAsList(ctx, req) + if err != nil { + return nil, err + } + return sgs, nil +} + +func (r *defaultSecurityGroupResolver) splitIntoSgNameAndIDs(sgNameOrIDs []string) ([]string, []string) { + var sgIDs []string + var sgNames []string + for _, nameOrID := range sgNameOrIDs { + if strings.HasPrefix(nameOrID, "sg-") { + sgIDs = append(sgIDs, nameOrID) + } else { + sgNames = append(sgNames, nameOrID) + } + } + return sgIDs, sgNames +} diff --git a/pkg/networking/security_group_resolver_mocks.go b/pkg/networking/security_group_resolver_mocks.go new file mode 100644 index 000000000..d294207db --- /dev/null +++ b/pkg/networking/security_group_resolver_mocks.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sigs.k8s.io/aws-load-balancer-controller/pkg/networking (interfaces: SecurityGroupResolver) + +// Package networking is a generated GoMock package. +package networking + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockSecurityGroupResolver is a mock of SecurityGroupResolver interface. +type MockSecurityGroupResolver struct { + ctrl *gomock.Controller + recorder *MockSecurityGroupResolverMockRecorder +} + +// MockSecurityGroupResolverMockRecorder is the mock recorder for MockSecurityGroupResolver. +type MockSecurityGroupResolverMockRecorder struct { + mock *MockSecurityGroupResolver +} + +// NewMockSecurityGroupResolver creates a new mock instance. +func NewMockSecurityGroupResolver(ctrl *gomock.Controller) *MockSecurityGroupResolver { + mock := &MockSecurityGroupResolver{ctrl: ctrl} + mock.recorder = &MockSecurityGroupResolverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSecurityGroupResolver) EXPECT() *MockSecurityGroupResolverMockRecorder { + return m.recorder +} + +// ResolveViaNameOrID mocks base method. +func (m *MockSecurityGroupResolver) ResolveViaNameOrID(arg0 context.Context, arg1 []string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveViaNameOrID", arg0, arg1) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ResolveViaNameOrID indicates an expected call of ResolveViaNameOrID. +func (mr *MockSecurityGroupResolverMockRecorder) ResolveViaNameOrID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveViaNameOrID", reflect.TypeOf((*MockSecurityGroupResolver)(nil).ResolveViaNameOrID), arg0, arg1) +} diff --git a/pkg/networking/security_group_resolver_test.go b/pkg/networking/security_group_resolver_test.go new file mode 100644 index 000000000..ad155b75a --- /dev/null +++ b/pkg/networking/security_group_resolver_test.go @@ -0,0 +1,275 @@ +package networking + +import ( + "context" + "testing" + + awssdk "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + ec2sdk "github.com/aws/aws-sdk-go/service/ec2" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +func Test_defaultSecurityGroupResolver_ResolveViaNameOrID(t *testing.T) { + type describeSecurityGroupsAsListCall struct { + req *ec2sdk.DescribeSecurityGroupsInput + resp []*ec2sdk.SecurityGroup + err error + } + type args struct { + nameOrIDs []string + describeSGCalls []describeSecurityGroupsAsListCall + } + defaultVPCID := "vpc-xxyy" + tests := []struct { + name string + args args + want []string + wantErr error + }{ + { + name: "empty input", + }, + { + name: "group ids", + args: args{ + nameOrIDs: []string{ + "sg-xx1", + "sg-xx2", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-xx1", "sg-xx2"}), + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-xx1"), + }, + { + GroupId: awssdk.String("sg-xx2"), + }, + }, + }, + }, + }, + want: []string{ + "sg-xx1", + "sg-xx2", + }, + }, + { + name: "group names", + args: args{ + nameOrIDs: []string{ + "sg group one", + "sg group two", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group one", + "sg group two", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-0912f63b"), + }, + { + GroupId: awssdk.String("sg-08982de7"), + }, + }, + }, + }, + }, + want: []string{ + "sg-08982de7", + "sg-0912f63b", + }, + }, + { + name: "mixed group name and id", + args: args{ + nameOrIDs: []string{ + "sg group one", + "sg-id1", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group one", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-0912f63b"), + }, + }, + }, + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-id1"}), + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id1"), + }, + }, + }, + }, + }, + want: []string{ + "sg-0912f63b", + "sg-id1", + }, + }, + { + name: "describe by id returns error", + args: args{ + nameOrIDs: []string{ + "sg group name", + "sg-id", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-id"}), + }, + err: awserr.New("Describe.Error", "unable to describe security groups", nil), + }, + }, + }, + wantErr: errors.New("Describe.Error: unable to describe security groups"), + }, + { + name: "describe by name returns error", + args: args{ + nameOrIDs: []string{ + "sg group name", + "sg-id", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group name", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + err: awserr.New("Describe.Error", "unable to describe security groups", nil), + }, + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-id"}), + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id"), + }, + }, + }, + }, + }, + wantErr: errors.New("Describe.Error: unable to describe security groups"), + }, + { + name: "unable to resolve all security groups", + args: args{ + nameOrIDs: []string{ + "sg group one", + "sg-id1", + "sg-id404", + }, + describeSGCalls: []describeSecurityGroupsAsListCall{ + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + Filters: []*ec2sdk.Filter{ + { + Name: awssdk.String("tag:Name"), + Values: awssdk.StringSlice([]string{ + "sg group one", + }), + }, + { + Name: awssdk.String("vpc-id"), + Values: awssdk.StringSlice([]string{defaultVPCID}), + }, + }, + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-0912f63b"), + }, + }, + }, + { + req: &ec2sdk.DescribeSecurityGroupsInput{ + GroupIds: awssdk.StringSlice([]string{"sg-id1", "sg-id404"}), + }, + resp: []*ec2sdk.SecurityGroup{ + { + GroupId: awssdk.String("sg-id1"), + }, + }, + }, + }, + }, + wantErr: errors.New("couldn't find all securityGroups, nameOrIDs: [sg group one sg-id1 sg-id404], found: [sg-id1 sg-0912f63b]"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ec2Client := services.NewMockEC2(ctrl) + for _, call := range tt.args.describeSGCalls { + ec2Client.EXPECT().DescribeSecurityGroupsAsList(context.Background(), call.req).Return(call.resp, call.err) + } + r := &defaultSecurityGroupResolver{ + ec2Client: ec2Client, + vpcID: defaultVPCID, + } + got, err := r.ResolveViaNameOrID(context.Background(), tt.args.nameOrIDs) + if tt.wantErr != nil { + assert.EqualError(t, err, tt.wantErr.Error()) + } else { + assert.NoError(t, err) + assert.ElementsMatch(t, tt.want, got) + } + }) + } +} diff --git a/scripts/gen_mocks.sh b/scripts/gen_mocks.sh index c67a5e01b..00d24d39f 100755 --- a/scripts/gen_mocks.sh +++ b/scripts/gen_mocks.sh @@ -17,5 +17,6 @@ $MOCKGEN -package=networking -destination=./pkg/networking/az_info_provider_mock $MOCKGEN -package=networking -destination=./pkg/networking/node_info_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking NodeInfoProvider $MOCKGEN -package=networking -destination=./pkg/networking/vpc_info_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking VPCInfoProvider $MOCKGEN -package=networking -destination=./pkg/networking/backend_sg_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking BackendSGProvider +$MOCKGEN -package=networking -destination=./pkg/networking/security_group_resolver_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking SecurityGroupResolver $MOCKGEN -package=ingress -destination=./pkg/ingress/cert_discovery_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/ingress CertDiscovery -$MOCKGEN -package=elbv2 -destination=./pkg/deploy/elbv2/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2 TaggingManager +$MOCKGEN -package=elbv2 -destination=./pkg/deploy/elbv2/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/elbv2 TaggingManager \ No newline at end of file From 513e741aa6eba5804322cae4da55d1041c6446d1 Mon Sep 17 00:00:00 2001 From: Kishor Joshi Date: Sun, 23 Apr 2023 22:33:10 -0700 Subject: [PATCH 2/4] fix ExtractIngresses array append --- pkg/ingress/class_utils.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/ingress/class_utils.go b/pkg/ingress/class_utils.go index 9ec9f62f3..ba7688701 100644 --- a/pkg/ingress/class_utils.go +++ b/pkg/ingress/class_utils.go @@ -7,8 +7,8 @@ import ( // ExtractIngresses returns the list of *networking.Ingress contained in the list of classifiedIngresses func ExtractIngresses(classifiedIngresses []ClassifiedIngress) []*networking.Ingress { result := make([]*networking.Ingress, len(classifiedIngresses)) - for _, v := range classifiedIngresses { - result = append(result, v.Ing) + for i, v := range classifiedIngresses { + result[i] = v.Ing } return result } From 7e96bf32bace00224bff43e8ebec16446f7e2e57 Mon Sep 17 00:00:00 2001 From: Kishor Joshi Date: Sun, 23 Apr 2023 22:43:02 -0700 Subject: [PATCH 3/4] make classifiedIngress type satisfy ObjectMetaAccessor --- controllers/ingress/group_controller.go | 2 +- pkg/ingress/class.go | 5 +++++ pkg/ingress/class_utils.go | 14 -------------- pkg/ingress/model_build_load_balancer.go | 4 ++-- pkg/k8s/meta_utils.go | 17 ++++++++++++++--- 5 files changed, 22 insertions(+), 20 deletions(-) delete mode 100644 pkg/ingress/class_utils.go diff --git a/controllers/ingress/group_controller.go b/controllers/ingress/group_controller.go index 35c5fcf49..bb9495676 100644 --- a/controllers/ingress/group_controller.go +++ b/controllers/ingress/group_controller.go @@ -175,7 +175,7 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr } r.logger.Info("successfully deployed model", "ingressGroup", ingGroup.ID) r.secretsManager.MonitorSecrets(ingGroup.ID.String(), secrets) - if err := r.backendSGProvider.Release(ctx, k8s.ToSliceOfMetaObject(ingress.ExtractIngresses(ingGroup.Members)), + if err := r.backendSGProvider.Release(ctx, k8s.ToSliceOfMetaObject(ingGroup.Members), k8s.ToSliceOfMetaObject(ingGroup.InactiveMembers), backendSGRequired); err != nil { return nil, nil, err } diff --git a/pkg/ingress/class.go b/pkg/ingress/class.go index 192219461..64168e07a 100644 --- a/pkg/ingress/class.go +++ b/pkg/ingress/class.go @@ -2,6 +2,7 @@ package ingress import ( networking "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1" ) @@ -19,3 +20,7 @@ type ClassConfiguration struct { // The IngressClassParams for Ingress if any. IngClassParams *elbv2api.IngressClassParams } + +func (c ClassifiedIngress) GetObjectMeta() metav1.Object { + return c.Ing +} diff --git a/pkg/ingress/class_utils.go b/pkg/ingress/class_utils.go deleted file mode 100644 index ba7688701..000000000 --- a/pkg/ingress/class_utils.go +++ /dev/null @@ -1,14 +0,0 @@ -package ingress - -import ( - networking "k8s.io/api/networking/v1" -) - -// ExtractIngresses returns the list of *networking.Ingress contained in the list of classifiedIngresses -func ExtractIngresses(classifiedIngresses []ClassifiedIngress) []*networking.Ingress { - result := make([]*networking.Ingress, len(classifiedIngresses)) - for i, v := range classifiedIngresses { - result[i] = v.Ing - } - return result -} diff --git a/pkg/ingress/model_build_load_balancer.go b/pkg/ingress/model_build_load_balancer.go index 0660450da..94aefd44a 100644 --- a/pkg/ingress/model_build_load_balancer.go +++ b/pkg/ingress/model_build_load_balancer.go @@ -283,7 +283,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { t.backendSGIDToken = managedSG.GroupID() } else { - backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(ExtractIngresses(t.ingGroup.Members))) + backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(t.ingGroup.Members)) if err != nil { return nil, err } @@ -309,7 +309,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { return nil, errors.New("backendSG feature is required to manage worker node SG rules when frontendSG manually specified") } - backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(ExtractIngresses(t.ingGroup.Members))) + backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(t.ingGroup.Members)) if err != nil { return nil, err } diff --git a/pkg/k8s/meta_utils.go b/pkg/k8s/meta_utils.go index 028afeeaa..bdcd3b5ae 100644 --- a/pkg/k8s/meta_utils.go +++ b/pkg/k8s/meta_utils.go @@ -1,12 +1,23 @@ package k8s -import metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) // ToSliceOfMetaObject converts the input slice s to slice of metav1.Object -func ToSliceOfMetaObject[T metav1.Object](s []T) []metav1.Object { +func ToSliceOfMetaObject[T metav1.ObjectMetaAccessor](s []T) []metav1.Object { result := make([]metav1.Object, len(s)) for i, v := range s { - result[i] = v + result[i] = v.GetObjectMeta() + } + return result +} + +func ToSliceOfNamespacedNames[T metav1.ObjectMetaAccessor](s []T) []types.NamespacedName { + result := make([]types.NamespacedName, len(s)) + for i, v := range s { + result[i] = NamespacedName(v.GetObjectMeta()) } return result } From 017b94a9b1bdf49b7ebd3e738e86543135e4bbbf Mon Sep 17 00:00:00 2001 From: Kishor Joshi Date: Mon, 24 Apr 2023 10:16:44 -0700 Subject: [PATCH 4/4] refactor backend SG provider apis --- controllers/ingress/group_controller.go | 9 +- pkg/ingress/model_build_load_balancer.go | 4 +- pkg/ingress/model_builder_test.go | 6 +- pkg/k8s/meta_utils.go | 23 ----- pkg/k8s/utils.go | 9 ++ pkg/networking/backend_sg_provider.go | 60 ++++++----- pkg/networking/backend_sg_provider_mocks.go | 18 ++-- pkg/networking/backend_sg_provider_test.go | 104 +++++++++++++++----- 8 files changed, 138 insertions(+), 95 deletions(-) delete mode 100644 pkg/k8s/meta_utils.go diff --git a/controllers/ingress/group_controller.go b/controllers/ingress/group_controller.go index bb9495676..36c90e728 100644 --- a/controllers/ingress/group_controller.go +++ b/controllers/ingress/group_controller.go @@ -9,6 +9,7 @@ import ( corev1 "k8s.io/api/core/v1" networking "k8s.io/api/networking/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/record" elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1" @@ -175,8 +176,12 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr } r.logger.Info("successfully deployed model", "ingressGroup", ingGroup.ID) r.secretsManager.MonitorSecrets(ingGroup.ID.String(), secrets) - if err := r.backendSGProvider.Release(ctx, k8s.ToSliceOfMetaObject(ingGroup.Members), - k8s.ToSliceOfMetaObject(ingGroup.InactiveMembers), backendSGRequired); err != nil { + var inactiveResources []types.NamespacedName + inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(ingGroup.InactiveMembers)...) + if !backendSGRequired { + inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(ingGroup.Members)...) + } + if err := r.backendSGProvider.Release(ctx, networkingpkg.ResourceTypeIngress, inactiveResources); err != nil { return nil, nil, err } return stack, lb, nil diff --git a/pkg/ingress/model_build_load_balancer.go b/pkg/ingress/model_build_load_balancer.go index 94aefd44a..fc592572d 100644 --- a/pkg/ingress/model_build_load_balancer.go +++ b/pkg/ingress/model_build_load_balancer.go @@ -283,7 +283,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { t.backendSGIDToken = managedSG.GroupID() } else { - backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(t.ingGroup.Members)) + backendSGID, err := t.backendSGProvider.Get(ctx, networking.ResourceTypeIngress, k8s.ToSliceOfNamespacedNames(t.ingGroup.Members)) if err != nil { return nil, err } @@ -309,7 +309,7 @@ func (t *defaultModelBuildTask) buildLoadBalancerSecurityGroups(ctx context.Cont if !t.enableBackendSG { return nil, errors.New("backendSG feature is required to manage worker node SG rules when frontendSG manually specified") } - backendSGID, err := t.backendSGProvider.Get(ctx, k8s.ToSliceOfMetaObject(t.ingGroup.Members)) + backendSGID, err := t.backendSGProvider.Get(ctx, networking.ResourceTypeIngress, k8s.ToSliceOfNamespacedNames(t.ingGroup.Members)) if err != nil { return nil, err } diff --git a/pkg/ingress/model_builder_test.go b/pkg/ingress/model_builder_test.go index d6f3a3674..e8400eef5 100644 --- a/pkg/ingress/model_builder_test.go +++ b/pkg/ingress/model_builder_test.go @@ -2923,11 +2923,11 @@ func Test_defaultModelBuilder_Build(t *testing.T) { sgResolver := networkingpkg.NewDefaultSecurityGroupResolver(ec2Client, vpcID) if tt.fields.enableBackendSG { if len(tt.fields.backendSecurityGroup) > 0 { - backendSGProvider.EXPECT().Get(gomock.Any(), gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes() + backendSGProvider.EXPECT().Get(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return(tt.fields.backendSecurityGroup, nil).AnyTimes() } else { - backendSGProvider.EXPECT().Get(gomock.Any(), gomock.Any()).Return("sg-auto", nil).AnyTimes() + backendSGProvider.EXPECT().Get(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return("sg-auto", nil).AnyTimes() } - backendSGProvider.EXPECT().Release(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + backendSGProvider.EXPECT().Release(gomock.Any(), networkingpkg.ResourceType(networkingpkg.ResourceTypeIngress), gomock.Any()).Return(nil).AnyTimes() } defaultTargetType := tt.defaultTargetType if defaultTargetType == "" { diff --git a/pkg/k8s/meta_utils.go b/pkg/k8s/meta_utils.go deleted file mode 100644 index bdcd3b5ae..000000000 --- a/pkg/k8s/meta_utils.go +++ /dev/null @@ -1,23 +0,0 @@ -package k8s - -import ( - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" -) - -// ToSliceOfMetaObject converts the input slice s to slice of metav1.Object -func ToSliceOfMetaObject[T metav1.ObjectMetaAccessor](s []T) []metav1.Object { - result := make([]metav1.Object, len(s)) - for i, v := range s { - result[i] = v.GetObjectMeta() - } - return result -} - -func ToSliceOfNamespacedNames[T metav1.ObjectMetaAccessor](s []T) []types.NamespacedName { - result := make([]types.NamespacedName, len(s)) - for i, v := range s { - result[i] = NamespacedName(v.GetObjectMeta()) - } - return result -} diff --git a/pkg/k8s/utils.go b/pkg/k8s/utils.go index 75d144e65..1c04b765f 100644 --- a/pkg/k8s/utils.go +++ b/pkg/k8s/utils.go @@ -12,3 +12,12 @@ func NamespacedName(obj metav1.Object) types.NamespacedName { Name: obj.GetName(), } } + +// ToSliceOfNamespacedNames gets the slice of types.NamespacedName from the input slice s +func ToSliceOfNamespacedNames[T metav1.ObjectMetaAccessor](s []T) []types.NamespacedName { + result := make([]types.NamespacedName, len(s)) + for i, v := range s { + result[i] = NamespacedName(v.GetObjectMeta()) + } + return result +} diff --git a/pkg/networking/backend_sg_provider.go b/pkg/networking/backend_sg_provider.go index 3b8a944b9..8ec000f6b 100644 --- a/pkg/networking/backend_sg_provider.go +++ b/pkg/networking/backend_sg_provider.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/hex" "fmt" - "reflect" "regexp" "sort" "strings" @@ -18,7 +17,7 @@ import ( "github.com/go-logr/logr" "github.com/pkg/errors" networking "k8s.io/api/networking/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" "sigs.k8s.io/aws-load-balancer-controller/pkg/runtime" @@ -40,13 +39,19 @@ const ( sgDescription = "[k8s] Shared Backend SecurityGroup for LoadBalancer" ) +type ResourceType string + +const ( + ResourceTypeIngress = "ingress" + ResourceTypeService = "service" +) + // BackendSGProvider is responsible for providing backend security groups type BackendSGProvider interface { // Get returns the backend security group to use - Get(ctx context.Context, activeResources []metav1.Object) (string, error) + Get(ctx context.Context, resourceType ResourceType, activeResources []types.NamespacedName) (string, error) // Release cleans up the auto-generated backend SG if necessary - Release(ctx context.Context, activeResources []metav1.Object, - inactiveResources []metav1.Object, backendSGRequired bool) error + Release(ctx context.Context, resourceType ResourceType, inactiveResources []types.NamespacedName) error } // NewBackendSGProvider constructs a new defaultBackendSGProvider @@ -102,47 +107,40 @@ type defaultBackendSGProvider struct { defaultDeletionTimeout time.Duration } -func (p *defaultBackendSGProvider) Get(ctx context.Context, activeResources []metav1.Object) (string, error) { +func (p *defaultBackendSGProvider) Get(ctx context.Context, resourceType ResourceType, activeResources []types.NamespacedName) (string, error) { if len(p.backendSG) > 0 { return p.backendSG, nil } // Auto generate Backend Security group, and return the id - if err := p.allocateBackendSG(ctx, activeResources); err != nil { + if err := p.allocateBackendSG(ctx, resourceType, activeResources); err != nil { p.logger.Error(err, "Failed to auto-create backend SG") return "", err } return p.autoGeneratedSG, nil } -func (p *defaultBackendSGProvider) Release(ctx context.Context, activeResources []metav1.Object, - inactiveResources []metav1.Object, backendSGRequired bool) error { +func (p *defaultBackendSGProvider) Release(ctx context.Context, resourceType ResourceType, + inactiveResources []types.NamespacedName) error { if len(p.backendSG) > 0 { return nil } defer func() { for _, res := range inactiveResources { - p.objectsMap.Delete(getObjectKey(res)) + p.objectsMap.Delete(getObjectKey(resourceType, res)) } }() - p.updateObjectsMap(ctx, activeResources, inactiveResources, backendSGRequired) - p.logger.V(1).Info("release backend SG", "active", activeResources, - "inactive", inactiveResources, "needed", backendSGRequired) - if len(activeResources) > 0 && backendSGRequired { - return nil - } + p.updateObjectsMap(ctx, resourceType, inactiveResources, false) + p.logger.V(1).Info("release backend SG", "inactive", inactiveResources) if required, err := p.isBackendSGRequired(ctx); required || err != nil { return err } return p.releaseSG(ctx) } -func (p *defaultBackendSGProvider) updateObjectsMap(_ context.Context, activeResources []metav1.Object, - inactiveResources []metav1.Object, backendSGRequired bool) { - for _, res := range inactiveResources { - p.objectsMap.Store(getObjectKey(res), false) - } - for _, res := range activeResources { - p.objectsMap.Store(getObjectKey(res), backendSGRequired) +func (p *defaultBackendSGProvider) updateObjectsMap(_ context.Context, resourceType ResourceType, + resources []types.NamespacedName, backendSGRequired bool) { + for _, res := range resources { + p.objectsMap.Store(getObjectKey(resourceType, res), backendSGRequired) } } @@ -173,27 +171,25 @@ func (p *defaultBackendSGProvider) checkIngressListForUnmapped(ctx context.Conte if !p.checkIngressFinalizersFunc(ing.GetFinalizers()) { continue } - if !p.existsInObjectMap(&ing) { + if !p.existsInObjectMap(ResourceTypeIngress, k8s.NamespacedName(&ing)) { return true, nil } } return false, nil } -func (p *defaultBackendSGProvider) existsInObjectMap(obj metav1.Object) bool { - if _, exists := p.objectsMap.Load(getObjectKey(obj)); exists { +func (p *defaultBackendSGProvider) existsInObjectMap(resourceType ResourceType, resource types.NamespacedName) bool { + if _, exists := p.objectsMap.Load(getObjectKey(resourceType, resource)); exists { return true } return false } -func (p *defaultBackendSGProvider) allocateBackendSG(ctx context.Context, activeResources []metav1.Object) error { +func (p *defaultBackendSGProvider) allocateBackendSG(ctx context.Context, resourceType ResourceType, activeResources []types.NamespacedName) error { p.mutex.Lock() defer p.mutex.Unlock() - for _, res := range activeResources { - p.objectsMap.Store(getObjectKey(res), true) - } + p.updateObjectsMap(ctx, resourceType, activeResources, true) if len(p.autoGeneratedSG) > 0 { return nil } @@ -333,6 +329,6 @@ func isEC2SecurityGroupNotFoundError(err error) bool { return false } -func getObjectKey(obj metav1.Object) string { - return reflect.TypeOf(obj).String() + "/" + k8s.NamespacedName(obj).String() +func getObjectKey(resourceType ResourceType, resource types.NamespacedName) string { + return string(resourceType) + "/" + resource.String() } diff --git a/pkg/networking/backend_sg_provider_mocks.go b/pkg/networking/backend_sg_provider_mocks.go index fafae1785..adbd09c52 100644 --- a/pkg/networking/backend_sg_provider_mocks.go +++ b/pkg/networking/backend_sg_provider_mocks.go @@ -9,7 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + types "k8s.io/apimachinery/pkg/types" ) // MockBackendSGProvider is a mock of BackendSGProvider interface. @@ -36,30 +36,30 @@ func (m *MockBackendSGProvider) EXPECT() *MockBackendSGProviderMockRecorder { } // Get mocks base method. -func (m *MockBackendSGProvider) Get(arg0 context.Context, arg1 []v1.Object) (string, error) { +func (m *MockBackendSGProvider) Get(arg0 context.Context, arg1 ResourceType, arg2 []types.NamespacedName) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret := m.ctrl.Call(m, "Get", arg0, arg1, arg2) ret0, _ := ret[0].(string) ret1, _ := ret[1].(error) return ret0, ret1 } // Get indicates an expected call of Get. -func (mr *MockBackendSGProviderMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockBackendSGProviderMockRecorder) Get(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBackendSGProvider)(nil).Get), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBackendSGProvider)(nil).Get), arg0, arg1, arg2) } // Release mocks base method. -func (m *MockBackendSGProvider) Release(arg0 context.Context, arg1, arg2 []v1.Object, arg3 bool) error { +func (m *MockBackendSGProvider) Release(arg0 context.Context, arg1 ResourceType, arg2 []types.NamespacedName) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Release", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "Release", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // Release indicates an expected call of Release. -func (mr *MockBackendSGProviderMockRecorder) Release(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockBackendSGProviderMockRecorder) Release(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockBackendSGProvider)(nil).Release), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockBackendSGProvider)(nil).Release), arg0, arg1, arg2) } diff --git a/pkg/networking/backend_sg_provider_test.go b/pkg/networking/backend_sg_provider_test.go index d4ac29152..e0fb725b6 100644 --- a/pkg/networking/backend_sg_provider_test.go +++ b/pkg/networking/backend_sg_provider_test.go @@ -2,6 +2,9 @@ package networking import ( "context" + "k8s.io/apimachinery/pkg/types" + "reflect" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" "testing" "github.com/go-logr/logr" @@ -39,6 +42,8 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { } type fields struct { backendSG string + ingResources []*networking.Ingress + svcResource *corev1.Service defaultTags map[string]string describeSGCalls []describeSecurityGroupsAsListCall createSGCalls []createSecurityGroupWithContexCall @@ -57,6 +62,24 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { Values: awssdk.StringSlice([]string{"backend-sg"}), }, } + ing := &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-ing", + }, + } + ing1 := &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "ns", + Name: "name", + }, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "awesome-ns", + Name: "awesome-svc", + }, + } tests := []struct { name string want string @@ -66,7 +89,8 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { { name: "backend sg enabled", fields: fields{ - backendSG: "sg-xxx", + backendSG: "sg-xxx", + ingResources: []*networking.Ingress{ing}, }, want: "sg-xxx", }, @@ -85,6 +109,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { }, }, }, + ingResources: []*networking.Ingress{ing, ing1}, }, want: "sg-autogen", }, @@ -126,6 +151,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { }, }, }, + ingResources: []*networking.Ingress{ing, ing1}, }, want: "sg-newauto", }, @@ -184,6 +210,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { "KubernetesCluster": defaultClusterName, "defaultTag": "specified", }, + svcResource: svc, }, want: "sg-newauto", }, @@ -198,6 +225,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { err: awserr.New("Some.Other.Error", "describe security group as list error", nil), }, }, + ingResources: []*networking.Ingress{ing}, }, wantErr: errors.New("Some.Other.Error: describe security group as list error"), }, @@ -237,6 +265,7 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { err: awserr.New("Create.Error", "unable to create security group", nil), }, }, + ingResources: []*networking.Ingress{ing1}, }, wantErr: errors.New("Create.Error: unable to create security group"), }, @@ -257,7 +286,14 @@ func Test_defaultBackendSGProvider_Get(t *testing.T) { sgProvider := NewBackendSGProvider(defaultClusterName, tt.fields.backendSG, defaultVPCID, ec2Client, k8sClient, tt.fields.defaultTags, logr.New(&log.NullLogSink{})) - got, err := sgProvider.Get(context.Background(), nil) + resourceType := ResourceTypeIngress + var activeResources []types.NamespacedName + if len(tt.fields.ingResources) > 0 { + activeResources = k8s.ToSliceOfNamespacedNames(tt.fields.ingResources) + } else { + activeResources = k8s.ToSliceOfNamespacedNames([]*corev1.Service{tt.fields.svcResource}) + } + got, err := sgProvider.Get(context.Background(), ResourceType(resourceType), activeResources) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -296,8 +332,9 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { listIngressCalls []listIngressCall deleteSGCalls []deleteSecurityGroupWithContextCall listServicesCalls []listServicesCall - activeResources []metav1.Object - inactiveResources []metav1.Object + activeIngresses []*networking.Ingress + inactiveIngresses []*networking.Ingress + svcResource *corev1.Service resourceMapItems []mapItem backendSGRequiredForActive bool } @@ -341,7 +378,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { name: "backend sg specified via flags", fields: fields{ backendSG: "sg-first", - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, }, { @@ -366,7 +403,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { resp: &ec2sdk.DeleteSecurityGroupOutput{}, }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, }, { @@ -379,7 +416,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { value: true, }, }, - activeResources: []metav1.Object{ing}, + activeIngresses: []*networking.Ingress{ing}, }, }, { @@ -392,7 +429,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { value: true, }, }, - activeResources: []metav1.Object{svc}, + svcResource: svc, }, }, { @@ -404,16 +441,20 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { listServicesCalls: []listServicesCall{ {}, }, + resourceMapItems: []mapItem{ + { + key: ing, + value: true, + }, + }, backendSGRequiredForActive: true, - activeResources: []metav1.Object{ing}, }, }, { name: "backend sg not required for active ingress", fields: fields{ - autogenSG: "sg-autogen", - backendSGRequiredForActive: false, - activeResources: []metav1.Object{ing}, + autogenSG: "sg-autogen", + activeIngresses: []*networking.Ingress{ing}, listIngressCalls: []listIngressCall{ { ingresses: []*networking.Ingress{}, @@ -457,7 +498,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, }, { @@ -477,7 +518,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, }, { @@ -500,7 +541,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { }, }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, deleteSGCalls: []deleteSecurityGroupWithContextCall{ { req: &ec2sdk.DeleteSecurityGroupInput{ @@ -515,7 +556,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { name: "backend sg requirement for service already known", fields: fields{ autogenSG: "sg-autogen", - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, resourceMapItems: []mapItem{ { key: svc2, @@ -528,7 +569,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { name: "backend sg requirement for ingress already known", fields: fields{ autogenSG: "sg-autogen", - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, resourceMapItems: []mapItem{ { key: ing1, @@ -590,7 +631,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { resp: &ec2sdk.DeleteSecurityGroupOutput{}, }, }, - activeResources: []metav1.Object{svc}, + svcResource: svc, resourceMapItems: []mapItem{ { key: ing, @@ -635,7 +676,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { resp: &ec2sdk.DeleteSecurityGroupOutput{}, }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, }, { @@ -656,7 +697,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { err: awserr.New("Something.Else", "unable to delete SG", nil), }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, wantErr: errors.New("failed to delete securityGroup: Something.Else: unable to delete SG"), }, @@ -669,7 +710,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { err: errors.New("failed"), }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, }, wantErr: errors.New("unable to list ingresses: failed"), }, @@ -685,7 +726,7 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { err: errors.New("failed"), }, }, - inactiveResources: []metav1.Object{ing}, + inactiveIngresses: []*networking.Ingress{ing}, deleteSGCalls: []deleteSecurityGroupWithContextCall{ { req: &ec2sdk.DeleteSecurityGroupInput{ @@ -711,7 +752,11 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { sgProvider.autoGeneratedSG = tt.fields.autogenSG } for _, item := range tt.fields.resourceMapItems { - sgProvider.objectsMap.Store(getObjectKey(item.key), item.value) + var resourceType ResourceType = ResourceTypeIngress + if reflect.TypeOf(item).String() == "service" { + resourceType = ResourceTypeService + } + sgProvider.objectsMap.Store(getObjectKey(resourceType, k8s.NamespacedName(item.key)), item.value) } var deleteCalls []*gomock.Call for _, call := range tt.fields.deleteSGCalls { @@ -743,7 +788,18 @@ func Test_defaultBackendSGProvider_Release(t *testing.T) { for _, ing := range tt.env.ingresses { assert.NoError(t, k8sClient.Create(context.Background(), ing.DeepCopy())) } - gotErr := sgProvider.Release(context.Background(), tt.fields.activeResources, tt.fields.inactiveResources, tt.fields.backendSGRequiredForActive) + var inactiveResources []types.NamespacedName + var resourceType ResourceType = ResourceTypeIngress + if tt.fields.svcResource != nil { + resourceType = ResourceTypeService + inactiveResources = append(inactiveResources, k8s.NamespacedName(tt.fields.svcResource)) + } else { + inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(tt.fields.inactiveIngresses)...) + if !tt.fields.backendSGRequiredForActive { + inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(tt.fields.activeIngresses)...) + } + } + gotErr := sgProvider.Release(context.Background(), resourceType, k8s.ToSliceOfNamespacedNames(tt.fields.inactiveIngresses)) if tt.wantErr != nil { assert.EqualError(t, gotErr, tt.wantErr.Error()) } else {