diff --git a/cmd/controller/main.go b/cmd/controller/main.go index 06f0366d0227..e41c110ca1e0 100644 --- a/cmd/controller/main.go +++ b/cmd/controller/main.go @@ -41,7 +41,13 @@ func main() { EventRecorder: operator.EventRecorder, StartAsync: operator.Elected(), }) - awsCloudProvider := cloudprovider.New(awsCtx) + awsCloudProvider := cloudprovider.New( + awsCtx, + awsCtx.InstanceTypesProvider, + awsCtx.InstanceProvider, + awsCtx.KubeClient, + awsCtx.AMIProvider, + ) lo.Must0(operator.AddHealthzCheck("cloud-provider", awsCloudProvider.LivenessProbe)) cloudProvider := metrics.Decorate(awsCloudProvider) diff --git a/hack/docs/instancetypes_gen_docs.go b/hack/docs/instancetypes_gen_docs.go index 826ee621297a..9b9d4b671172 100644 --- a/hack/docs/instancetypes_gen_docs.go +++ b/hack/docs/instancetypes_gen_docs.go @@ -233,9 +233,10 @@ func (f kubeDnsTransport) RoundTrip(request *http.Request) (*http.Response, erro } func NewAWSCloudProviderForCodeGen(ctx context.Context) *awscloudprovider.CloudProvider { - return awscloudprovider.New(awscontext.NewOrDie(cloudprovider.Context{ + context := awscontext.NewOrDie(cloudprovider.Context{ Context: ctx, RESTConfig: &rest.Config{}, KubernetesInterface: lo.Must(kubernetes.NewForConfigAndClient(&rest.Config{}, &http.Client{Transport: &kubeDnsTransport{}})), - })) + }) + return awscloudprovider.New(context, context.InstanceTypesProvider, context.InstanceProvider, context.KubeClient, context.AMIProvider) } diff --git a/pkg/cloudprovider/cloudprovider.go b/pkg/cloudprovider/cloudprovider.go index 2c3335193c8e..c01b553a094b 100644 --- a/pkg/cloudprovider/cloudprovider.go +++ b/pkg/cloudprovider/cloudprovider.go @@ -42,8 +42,8 @@ import ( "knative.dev/pkg/logging" "sigs.k8s.io/controller-runtime/pkg/client" - awscontext "github.com/aws/karpenter/pkg/context" "github.com/aws/karpenter/pkg/providers/amifamily" + "github.com/aws/karpenter/pkg/providers/instance" "github.com/aws/karpenter/pkg/providers/instancetype" coreapis "github.com/aws/karpenter-core/pkg/apis" @@ -51,11 +51,6 @@ import ( "github.com/aws/karpenter-core/pkg/cloudprovider" ) -const ( - // MaxInstanceTypes defines the number of instance type options to pass to CreateFleet - MaxInstanceTypes = 60 -) - func init() { v1alpha5.NormalizedLabels = lo.Assign(v1alpha5.NormalizedLabels, map[string]string{"topology.ebs.csi.aws.com/zone": v1.LabelTopologyZone}) coreapis.Settings = append(coreapis.Settings, apis.Settings...) @@ -65,25 +60,18 @@ var _ cloudprovider.CloudProvider = (*CloudProvider)(nil) type CloudProvider struct { instanceTypeProvider *instancetype.Provider - instanceProvider *InstanceProvider + instanceProvider *instance.Provider kubeClient client.Client amiProvider *amifamily.Provider } -func New(ctx awscontext.Context) *CloudProvider { +func New(ctx context.Context, instanceTypeProvider *instancetype.Provider, + instanceProvider *instance.Provider, kubeClient client.Client, amiProvider *amifamily.Provider) *CloudProvider { return &CloudProvider{ - kubeClient: ctx.KubeClient, - instanceTypeProvider: ctx.InstanceTypesProvider, - amiProvider: ctx.AMIProvider, - instanceProvider: NewInstanceProvider( - ctx, - aws.StringValue(ctx.Session.Config.Region), - ctx.EC2API, - ctx.UnavailableOfferingsCache, - ctx.InstanceTypesProvider, - ctx.SubnetProvider, - ctx.LaunchTemplateProvider, - ), + instanceTypeProvider: instanceTypeProvider, + instanceProvider: instanceProvider, + kubeClient: kubeClient, + amiProvider: amiProvider, } } @@ -315,7 +303,7 @@ func (c *CloudProvider) resolveProvisionerFromInstance(ctx context.Context, inst return provisioner, nil } -func (c *CloudProvider) instanceToMachine(ctx context.Context, instance *ec2.Instance, instanceType *cloudprovider.InstanceType) *v1alpha5.Machine { +func (c *CloudProvider) instanceToMachine(ctx context.Context, ec2instance *ec2.Instance, instanceType *cloudprovider.InstanceType) *v1alpha5.Machine { machine := &v1alpha5.Machine{} labels := map[string]string{} @@ -328,22 +316,22 @@ func (c *CloudProvider) instanceToMachine(ctx context.Context, instance *ec2.Ins machine.Status.Capacity = functional.FilterMap(instanceType.Capacity, func(_ v1.ResourceName, v resource.Quantity) bool { return !resources.IsZero(v) }) machine.Status.Allocatable = functional.FilterMap(instanceType.Allocatable(), func(_ v1.ResourceName, v resource.Quantity) bool { return !resources.IsZero(v) }) } - labels[v1alpha1.LabelInstanceAMIID] = aws.StringValue(instance.ImageId) - labels[v1.LabelTopologyZone] = aws.StringValue(instance.Placement.AvailabilityZone) - labels[v1alpha5.LabelCapacityType] = getCapacityType(instance) - if tag, ok := lo.Find(instance.Tags, func(t *ec2.Tag) bool { return aws.StringValue(t.Key) == v1alpha5.ProvisionerNameLabelKey }); ok { + labels[v1alpha1.LabelInstanceAMIID] = aws.StringValue(ec2instance.ImageId) + labels[v1.LabelTopologyZone] = aws.StringValue(ec2instance.Placement.AvailabilityZone) + labels[v1alpha5.LabelCapacityType] = instance.GetCapacityType(ec2instance) + if tag, ok := lo.Find(ec2instance.Tags, func(t *ec2.Tag) bool { return aws.StringValue(t.Key) == v1alpha5.ProvisionerNameLabelKey }); ok { labels[v1alpha5.ProvisionerNameLabelKey] = aws.StringValue(tag.Value) } - if tag, ok := lo.Find(instance.Tags, func(t *ec2.Tag) bool { return aws.StringValue(t.Key) == v1alpha5.ManagedByLabelKey }); ok { + if tag, ok := lo.Find(ec2instance.Tags, func(t *ec2.Tag) bool { return aws.StringValue(t.Key) == v1alpha5.ManagedByLabelKey }); ok { labels[v1alpha5.ManagedByLabelKey] = aws.StringValue(tag.Value) } machine.Name = lo.Ternary( settings.FromContext(ctx).NodeNameConvention == settings.ResourceName, - aws.StringValue(instance.InstanceId), - strings.ToLower(aws.StringValue(instance.PrivateDnsName)), + aws.StringValue(ec2instance.InstanceId), + strings.ToLower(aws.StringValue(ec2instance.PrivateDnsName)), ) machine.Labels = labels - machine.CreationTimestamp = metav1.Time{Time: aws.TimeValue(instance.LaunchTime)} - machine.Status.ProviderID = fmt.Sprintf("aws:///%s/%s", aws.StringValue(instance.Placement.AvailabilityZone), aws.StringValue(instance.InstanceId)) + machine.CreationTimestamp = metav1.Time{Time: aws.TimeValue(ec2instance.LaunchTime)} + machine.Status.ProviderID = fmt.Sprintf("aws:///%s/%s", aws.StringValue(ec2instance.Placement.AvailabilityZone), aws.StringValue(ec2instance.InstanceId)) return machine } diff --git a/pkg/cloudprovider/suite_test.go b/pkg/cloudprovider/suite_test.go index 271b6faf3f94..e0180651c273 100644 --- a/pkg/cloudprovider/suite_test.go +++ b/pkg/cloudprovider/suite_test.go @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cloudprovider +package cloudprovider_test import ( "context" @@ -22,20 +22,18 @@ import ( "time" "github.com/Pallinder/go-randomdata" - "github.com/patrickmn/go-cache" "github.com/samber/lo" "k8s.io/client-go/tools/record" - "github.com/aws/karpenter-core/pkg/events" - . "github.com/aws/karpenter-core/pkg/test/expectations" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/aws/aws-sdk-go/aws" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime/schema" clock "k8s.io/utils/clock/testing" - "knative.dev/pkg/ptr" + + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ssm" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -44,67 +42,42 @@ import ( "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" - "github.com/aws/karpenter/pkg/providers/amifamily" - "github.com/aws/karpenter/pkg/providers/instancetype" - "github.com/aws/karpenter/pkg/providers/launchtemplate" - "github.com/aws/karpenter/pkg/providers/pricing" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" - "github.com/aws/karpenter/pkg/test" - - "github.com/aws/karpenter-core/pkg/cloudprovider" - machineutil "github.com/aws/karpenter-core/pkg/utils/machine" - - "github.com/aws/karpenter-core/pkg/operator/controller" - - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ssm" - + "github.com/aws/karpenter/pkg/cloudprovider" "github.com/aws/karpenter/pkg/fake" + "github.com/aws/karpenter/pkg/test" coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" + corecloudproivder "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/controllers/provisioning" "github.com/aws/karpenter-core/pkg/controllers/state" + "github.com/aws/karpenter-core/pkg/events" + "github.com/aws/karpenter-core/pkg/operator/controller" "github.com/aws/karpenter-core/pkg/operator/injection" "github.com/aws/karpenter-core/pkg/operator/options" "github.com/aws/karpenter-core/pkg/operator/scheme" coretest "github.com/aws/karpenter-core/pkg/test" + . "github.com/aws/karpenter-core/pkg/test/expectations" + machineutil "github.com/aws/karpenter-core/pkg/utils/machine" ) var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var launchTemplateCache *cache.Cache -var ssmCache *cache.Cache -var ec2Cache *cache.Cache -var kubernetesVersionCache *cache.Cache -var unavailableOfferingsCache *awscache.UnavailableOfferings -var instanceTypeCache *cache.Cache -var instanceTypesProvider *instancetype.Provider -var launchTemplateProvider *launchtemplate.Provider -var amiProvider *amifamily.Provider -var fakeEC2API *fake.EC2API -var fakeSSMAPI *fake.SSMAPI -var fakeEKSAPI *fake.EKSAPI -var fakePricingAPI *fake.PricingAPI +var awsEnv *test.Environment var prov *provisioning.Provisioner var provisioningController controller.Controller -var cloudProvider *CloudProvider var cluster *state.Cluster var fakeClock *clock.FakeClock var provisioner *v1alpha5.Provisioner var nodeTemplate *v1alpha1.AWSNodeTemplate -var pricingProvider *pricing.Provider -var subnetProvider *subnet.Provider -var securityGroupProvider *securitygroup.Provider +var cloudProvider *cloudprovider.CloudProvider func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) RegisterFailHandler(Fail) - RunSpecs(t, "CloudProvider/AWS") + RunSpecs(t, "cloudProvider/AWS") } var _ = BeforeSuite(func() { @@ -112,47 +85,10 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) + awsEnv = test.NewEnvironment(ctx, env) - launchTemplateCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - unavailableOfferingsCache = awscache.NewUnavailableOfferings() - ssmCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - ec2Cache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - kubernetesVersionCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - instanceTypeCache = cache.New(awscache.InstanceTypesAndZonesTTL, awscache.DefaultCleanupInterval) - fakeEC2API = &fake.EC2API{} - fakeSSMAPI = &fake.SSMAPI{} - fakeEKSAPI = &fake.EKSAPI{} - fakePricingAPI = &fake.PricingAPI{} - pricingProvider = pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) - amiProvider = amifamily.NewProvider(env.Client, env.KubernetesInterface, fakeSSMAPI, fakeEC2API, ssmCache, ec2Cache, kubernetesVersionCache) - subnetProvider = subnet.NewProvider(fakeEC2API) - instanceTypesProvider = instancetype.NewProvider( - "", - instanceTypeCache, - fakeEC2API, - subnetProvider, - unavailableOfferingsCache, - pricingProvider, - ) - securityGroupProvider = securitygroup.NewProvider(fakeEC2API) - launchTemplateProvider = launchtemplate.NewProvider( - ctx, - launchTemplateCache, - fakeEC2API, - amifamily.New(env.Client, amiProvider), - securityGroupProvider, - ptr.String("ca-bundle"), - make(chan struct{}), - net.ParseIP("10.0.100.10"), - "https://test-cluster", - ) - cloudProvider = &CloudProvider{ - instanceTypeProvider: instanceTypesProvider, - amiProvider: amiProvider, - instanceProvider: NewInstanceProvider(ctx, "", fakeEC2API, unavailableOfferingsCache, instanceTypesProvider, subnetProvider, launchTemplateProvider), - kubeClient: env.Client, - } fakeClock = clock.NewFakeClock(time.Now()) + cloudProvider = cloudprovider.New(ctx, awsEnv.InstanceTypesProvider, awsEnv.InstanceProvider, env.Client, awsEnv.AMIProvider) cluster = state.NewCluster(fakeClock, env.Client, cloudProvider) prov = provisioning.NewProvisioner(ctx, env.Client, env.KubernetesInterface.CoreV1(), events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster) provisioningController = provisioning.NewController(env.Client, prov, events.NewRecorder(&record.FakeRecorder{})) @@ -197,37 +133,17 @@ var _ = BeforeEach(func() { }) cluster.Reset() - fakeEC2API.Reset() - fakeSSMAPI.Reset() - fakeEKSAPI.Reset() - fakePricingAPI.Reset() - launchTemplateCache.Flush() - unavailableOfferingsCache.Flush() - ssmCache.Flush() - ec2Cache.Flush() - kubernetesVersionCache.Flush() - instanceTypeCache.Flush() - subnetProvider.Reset() - securityGroupProvider.Reset() - launchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") - launchTemplateProvider.ClusterEndpoint = "https://test-cluster" + awsEnv.Reset() - // Reset the pricing provider, so we don't cross-pollinate pricing data - instanceTypesProvider = instancetype.NewProvider( - "", - instanceTypeCache, - fakeEC2API, - subnetProvider, - unavailableOfferingsCache, - pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})), - ) + awsEnv.LaunchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") + awsEnv.LaunchTemplateProvider.ClusterEndpoint = "https://test-cluster" }) var _ = AfterEach(func() { ExpectCleanedUp(ctx, env.Client) }) -var _ = Describe("Allocation", func() { +var _ = Describe("CloudProvider", func() { Context("Defaulting", func() { // Intent here is that if updates occur on the provisioningController, the Provisioner doesn't need to be recreated It("should not set the InstanceProfile with the default if none provided in Provisioner", func() { @@ -263,8 +179,8 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(aws.StringValue(createFleetInput.Context)).To(Equal("context-1234")) }) It("should default to no EC2 Context", func() { @@ -273,21 +189,21 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.Context).To(BeNil()) }) }) Context("Node Drift", func() { var validAMI string - var selectedInstanceType *cloudprovider.InstanceType + var selectedInstanceType *corecloudproivder.InstanceType var instance *ec2.Instance BeforeEach(func() { validAMI = fake.ImageID() - fakeSSMAPI.GetParameterOutput = &ssm.GetParameterOutput{ + awsEnv.SSMAPI.GetParameterOutput = &ssm.GetParameterOutput{ Parameter: &ssm.Parameter{Value: aws.String(validAMI)}, } - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{ Images: []*ec2.Image{{ImageId: aws.String(validAMI)}}, }) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) @@ -306,7 +222,7 @@ var _ = Describe("Allocation", func() { }, InstanceId: aws.String(fake.InstanceID()), } - fakeEC2API.DescribeInstancesBehavior.Output.Set(&ec2.DescribeInstancesOutput{ + awsEnv.EC2API.DescribeInstancesBehavior.Output.Set(&ec2.DescribeInstancesOutput{ Reservations: []*ec2.Reservation{{Instances: []*ec2.Instance{instance}}}, }) }) @@ -430,11 +346,11 @@ var _ = Describe("Allocation", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - firstLt := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + firstLt := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() launchTemplate := createFleetInput.LaunchTemplateConfigs[0].LaunchTemplateSpecification Expect(createFleetInput.LaunchTemplateConfigs).To(HaveLen(1)) @@ -459,8 +375,8 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(aws.StringValueSlice(input.LaunchTemplateData.SecurityGroupIds)).To(ConsistOf( "sg-test1", )) @@ -481,7 +397,7 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("subnet-test1")) }) It("should use the instance profile on the Provisioner when specified", func() { @@ -501,8 +417,8 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.IamInstanceProfile.Name).To(Equal("overridden-profile")) }) }) @@ -515,8 +431,8 @@ var _ = Describe("Allocation", func() { coretest.PodOptions{NodeSelector: map[string]string{v1.LabelArchStable: v1alpha5.ArchitectureAmd64}}) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - input := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(input.LaunchTemplateConfigs).To(HaveLen(1)) foundNonGPULT := false @@ -535,7 +451,7 @@ var _ = Describe("Allocation", func() { Expect(foundNonGPULT).To(BeTrue()) }) It("should launch instances into subnet with the most available IP addresses", func() { - fakeEC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(100), @@ -545,11 +461,11 @@ var _ = Describe("Allocation", func() { pod := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-2")) }) It("should launch instances into subnet with the most available IP addresses in-between cache refreshes", func() { - fakeEC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(11), @@ -562,28 +478,28 @@ var _ = Describe("Allocation", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod1, pod2) ExpectScheduled(ctx, env.Client, pod1) ExpectScheduled(ctx, env.Client, pod2) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-2")) // Provision for another pod that should now use the other subnet since we've consumed some from the first launch. pod3 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectProvisioned(ctx, env.Client, cluster, prov, pod3) ExpectScheduled(ctx, env.Client, pod3) - createFleetInput = fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput = awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1")) }) It("should update in-flight IPs when a CreateFleet error occurs", func() { - fakeEC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, }}) pod1 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1.LabelTopologyZone: "test-zone-1a"}}) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate, pod1) - fakeEC2API.CreateFleetBehavior.Error.Set(fmt.Errorf("CreateFleet synthetic error")) + awsEnv.EC2API.CreateFleetBehavior.Error.Set(fmt.Errorf("CreateFleet synthetic error")) bindings := ExpectProvisioned(ctx, env.Client, cluster, prov, pod1) Expect(len(bindings)).To(Equal(0)) }) It("should launch instances into subnets that are excluded by another provisioner", func() { - fakeEC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ + awsEnv.EC2API.DescribeSubnetsOutput.Set(&ec2.DescribeSubnetsOutput{Subnets: []*ec2.Subnet{ {SubnetId: aws.String("test-subnet-1"), AvailabilityZone: aws.String("test-zone-1a"), AvailableIpAddressCount: aws.Int64(10), Tags: []*ec2.Tag{{Key: aws.String("Name"), Value: aws.String("test-subnet-1")}}}, {SubnetId: aws.String("test-subnet-2"), AvailabilityZone: aws.String("test-zone-1b"), AvailableIpAddressCount: aws.Int64(100), @@ -594,7 +510,7 @@ var _ = Describe("Allocation", func() { podSubnet1 := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, podSubnet1) ExpectScheduled(ctx, env.Client, podSubnet1) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-1")) provisioner = test.Provisioner(coretest.ProvisionerOptions{Provider: &v1alpha1.AWS{ @@ -605,7 +521,7 @@ var _ = Describe("Allocation", func() { podSubnet2 := coretest.UnschedulablePod(coretest.PodOptions{NodeSelector: map[string]string{v1alpha5.ProvisionerNameLabelKey: provisioner.Name}}) ExpectProvisioned(ctx, env.Client, cluster, prov, podSubnet2) ExpectScheduled(ctx, env.Client, podSubnet2) - createFleetInput = fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput = awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(fake.SubnetsFromFleetRequest(createFleetInput)).To(ConsistOf("test-subnet-2")) }) }) diff --git a/pkg/context/context.go b/pkg/context/context.go index 430c2ab1a5fd..cad118905089 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -45,6 +45,7 @@ import ( "github.com/aws/karpenter/pkg/apis/settings" awscache "github.com/aws/karpenter/pkg/cache" "github.com/aws/karpenter/pkg/providers/amifamily" + "github.com/aws/karpenter/pkg/providers/instance" "github.com/aws/karpenter/pkg/providers/instancetype" "github.com/aws/karpenter/pkg/providers/launchtemplate" "github.com/aws/karpenter/pkg/providers/pricing" @@ -69,6 +70,7 @@ type Context struct { LaunchTemplateProvider *launchtemplate.Provider PricingProvider *pricing.Provider InstanceTypesProvider *instancetype.Provider + InstanceProvider *instance.Provider } func NewOrDie(ctx cloudprovider.Context) Context { @@ -106,8 +108,8 @@ func NewOrDie(ctx cloudprovider.Context) Context { } unavailableOfferingsCache := awscache.NewUnavailableOfferings() - subnetProvider := subnet.NewProvider(ec2api) - securityGroupProvider := securitygroup.NewProvider(ec2api) + subnetProvider := subnet.NewProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) + securityGroupProvider := securitygroup.NewProvider(ec2api, cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval)) pricingProvider := pricing.NewProvider( ctx, pricing.NewAPI(sess, *sess.Config.Region), @@ -137,6 +139,15 @@ func NewOrDie(ctx cloudprovider.Context) Context { unavailableOfferingsCache, pricingProvider, ) + instanceProvider := instance.NewProvider( + ctx, + aws.StringValue(sess.Config.Region), + ec2api, + unavailableOfferingsCache, + instanceTypeProvider, + subnetProvider, + launchTemplateProvider, + ) return Context{ Context: ctx, @@ -150,6 +161,7 @@ func NewOrDie(ctx cloudprovider.Context) Context { LaunchTemplateProvider: launchTemplateProvider, PricingProvider: pricingProvider, InstanceTypesProvider: instanceTypeProvider, + InstanceProvider: instanceProvider, } } diff --git a/pkg/controllers/machine/garbagecollect/suite_test.go b/pkg/controllers/machine/garbagecollect/suite_test.go index 0d54824e35a5..975819ca1ef6 100644 --- a/pkg/controllers/machine/garbagecollect/suite_test.go +++ b/pkg/controllers/machine/garbagecollect/suite_test.go @@ -22,47 +22,39 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/awstesting/mock" "github.com/aws/aws-sdk-go/service/ec2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/patrickmn/go-cache" "github.com/samber/lo" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/tools/record" - clock "k8s.io/utils/clock/testing" . "knative.dev/pkg/logging/testing" "sigs.k8s.io/controller-runtime/pkg/client" coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" - "github.com/aws/karpenter-core/pkg/events" "github.com/aws/karpenter-core/pkg/operator/controller" "github.com/aws/karpenter-core/pkg/operator/scheme" coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" + "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" "github.com/aws/karpenter/pkg/cloudprovider" - awscontext "github.com/aws/karpenter/pkg/context" "github.com/aws/karpenter/pkg/controllers/machine/garbagecollect" "github.com/aws/karpenter/pkg/controllers/machine/link" "github.com/aws/karpenter/pkg/fake" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" "github.com/aws/karpenter/pkg/test" ) var ctx context.Context +var awsEnv *test.Environment var env *coretest.Environment -var unavailableOfferingsCache *awscache.UnavailableOfferings -var ec2API *fake.EC2API -var cloudProvider *cloudprovider.CloudProvider var garbageCollectController controller.Controller var linkedMachineCache *cache.Cache +var cloudProvider *cloudprovider.CloudProvider func TestAPIs(t *testing.T) { ctx = TestContextWithLogger(t) @@ -74,24 +66,9 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...)) - unavailableOfferingsCache = awscache.NewUnavailableOfferings() - ec2API = &fake.EC2API{} - cloudProvider = cloudprovider.New(awscontext.Context{ - Context: corecloudprovider.Context{ - Context: ctx, - RESTConfig: env.Config, - KubernetesInterface: env.KubernetesInterface, - KubeClient: env.Client, - EventRecorder: events.NewRecorder(&record.FakeRecorder{}), - Clock: &clock.FakeClock{}, - StartAsync: nil, - }, - SubnetProvider: subnet.NewProvider(ec2API), - SecurityGroupProvider: securitygroup.NewProvider(ec2API), - Session: mock.Session, - UnavailableOfferingsCache: unavailableOfferingsCache, - EC2API: ec2API, - }) + awsEnv = test.NewEnvironment(ctx, env) + + cloudProvider = cloudprovider.New(ctx, awsEnv.InstanceTypesProvider, awsEnv.InstanceProvider, env.Client, awsEnv.AMIProvider) linkedMachineCache = cache.New(time.Minute*10, time.Second*10) linkController := &link.Controller{ Cache: linkedMachineCache, @@ -103,12 +80,15 @@ var _ = AfterSuite(func() { Expect(env.Stop()).To(Succeed(), "Failed to stop environment") }) +var _ = BeforeEach(func() { + awsEnv.Reset() +}) + var _ = Describe("MachineGarbageCollect", func() { var instance *ec2.Instance var providerID string BeforeEach(func() { - ec2API.Reset() instanceID := fake.InstanceID() providerID = fmt.Sprintf("aws:///test-zone-1a/%s", instanceID) nodeTemplate := test.AWSNodeTemplate(v1alpha1.AWSNodeTemplateSpec{}) @@ -153,7 +133,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should delete an instance if there is no machine owner", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) ExpectReconcileSucceeded(ctx, garbageCollectController, client.ObjectKey{}) _, err := cloudProvider.Get(ctx, providerID) @@ -163,7 +143,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should delete an instance along with the node if there is no machine owner (to quicken scheduling)", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) node := coretest.Node(coretest.NodeOptions{ ProviderID: providerID, @@ -182,7 +162,7 @@ var _ = Describe("MachineGarbageCollect", func() { var ids []string for i := 0; i < 500; i++ { instanceID := fake.InstanceID() - ec2API.Instances.Store( + awsEnv.EC2API.Instances.Store( instanceID, &ec2.Instance{ State: &ec2.InstanceState{ @@ -236,7 +216,7 @@ var _ = Describe("MachineGarbageCollect", func() { var machines []*v1alpha5.Machine for i := 0; i < 500; i++ { instanceID := fake.InstanceID() - ec2API.Instances.Store( + awsEnv.EC2API.Instances.Store( instanceID, &ec2.Instance{ State: &ec2.InstanceState{ @@ -297,7 +277,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should not delete an instance if it is within the machine resolution window (1m)", func() { // Launch time just happened instance.LaunchTime = aws.Time(time.Now()) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) ExpectReconcileSucceeded(ctx, garbageCollectController, client.ObjectKey{}) _, err := cloudProvider.Get(ctx, providerID) @@ -311,7 +291,7 @@ var _ = Describe("MachineGarbageCollect", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) ExpectReconcileSucceeded(ctx, garbageCollectController, client.ObjectKey{}) _, err := cloudProvider.Get(ctx, providerID) @@ -320,7 +300,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should not delete the instance or node if it already has a machine that matches it", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) machine := coretest.Machine(v1alpha5.Machine{ Status: v1alpha5.MachineStatus{ @@ -340,7 +320,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should not delete an instance if it is linked", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) // Create a machine that is actively linking machine := coretest.Machine(v1alpha5.Machine{ @@ -359,7 +339,7 @@ var _ = Describe("MachineGarbageCollect", func() { It("should not delete an instance if it is recently linked but the machine doesn't exist", func() { // Launch time was 10m ago instance.LaunchTime = aws.Time(time.Now().Add(-time.Minute * 10)) - ec2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) + awsEnv.EC2API.Instances.Store(aws.StringValue(instance.InstanceId), instance) // Add a provider id to the recently linked cache linkedMachineCache.SetDefault(providerID, nil) diff --git a/pkg/controllers/machine/link/suite_test.go b/pkg/controllers/machine/link/suite_test.go index 0298af48293c..2ece7472cdb4 100644 --- a/pkg/controllers/machine/link/suite_test.go +++ b/pkg/controllers/machine/link/suite_test.go @@ -21,53 +21,38 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/awstesting/mock" "github.com/aws/aws-sdk-go/service/ec2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/patrickmn/go-cache" "github.com/samber/lo" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/runtime" - "k8s.io/client-go/tools/record" - clock "k8s.io/utils/clock/testing" . "knative.dev/pkg/logging/testing" "sigs.k8s.io/controller-runtime/pkg/client" coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" - corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" - "github.com/aws/karpenter-core/pkg/events" "github.com/aws/karpenter-core/pkg/operator/controller" "github.com/aws/karpenter-core/pkg/operator/scheme" coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" "github.com/aws/karpenter-core/pkg/utils/sets" + "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" "github.com/aws/karpenter/pkg/cloudprovider" - awscontext "github.com/aws/karpenter/pkg/context" "github.com/aws/karpenter/pkg/controllers/machine/link" "github.com/aws/karpenter/pkg/fake" - "github.com/aws/karpenter/pkg/providers/instancetype" - "github.com/aws/karpenter/pkg/providers/pricing" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" "github.com/aws/karpenter/pkg/test" "github.com/aws/karpenter/pkg/utils" ) var ctx context.Context +var awsEnv *test.Environment var env *coretest.Environment -var unavailableOfferingsCache *awscache.UnavailableOfferings -var ec2API *fake.EC2API -var cloudProvider *cloudprovider.CloudProvider -var subnetProvider *subnet.Provider var linkController controller.Controller -var pricingProvider *pricing.Provider -var instanceTypesProvider *instancetype.Provider +var cloudProvider *cloudprovider.CloudProvider func TestAPIs(t *testing.T) { ctx = TestContextWithLogger(t) @@ -79,36 +64,19 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...)) - unavailableOfferingsCache = awscache.NewUnavailableOfferings() - ec2API = &fake.EC2API{} - subnetProvider = subnet.NewProvider(ec2API) - pricingProvider = pricing.NewProvider(ctx, &fake.PricingAPI{}, ec2API, "", make(chan struct{})) - instanceTypesProvider = instancetype.NewProvider("", cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval), ec2API, subnetProvider, unavailableOfferingsCache, pricingProvider) - cloudProvider = cloudprovider.New(awscontext.Context{ - Context: corecloudprovider.Context{ - Context: ctx, - RESTConfig: env.Config, - KubernetesInterface: env.KubernetesInterface, - KubeClient: env.Client, - EventRecorder: events.NewRecorder(&record.FakeRecorder{}), - Clock: &clock.FakeClock{}, - StartAsync: nil, - }, - SubnetProvider: subnet.NewProvider(ec2API), - SecurityGroupProvider: securitygroup.NewProvider(ec2API), - Session: mock.Session, - UnavailableOfferingsCache: unavailableOfferingsCache, - EC2API: ec2API, - PricingProvider: pricingProvider, - InstanceTypesProvider: instanceTypesProvider, - }) + awsEnv = test.NewEnvironment(ctx, env) + + cloudProvider = cloudprovider.New(ctx, awsEnv.InstanceTypesProvider, awsEnv.InstanceProvider, env.Client, awsEnv.AMIProvider) linkController = link.NewController(env.Client, cloudProvider) }) - var _ = AfterSuite(func() { Expect(env.Stop()).To(Succeed(), "Failed to stop environment") }) +var _ = BeforeEach(func() { + awsEnv.Reset() +}) + var _ = Describe("MachineLink", func() { var instanceID string var providerID string @@ -116,7 +84,6 @@ var _ = Describe("MachineLink", func() { var nodeTemplate *v1alpha1.AWSNodeTemplate BeforeEach(func() { - ec2API.Reset() instanceID = fake.InstanceID() providerID = fmt.Sprintf("aws:///test-zone-1a/%s", instanceID) nodeTemplate = test.AWSNodeTemplate(v1alpha1.AWSNodeTemplateSpec{}) @@ -129,7 +96,7 @@ var _ = Describe("MachineLink", func() { }) // Store the instance as existing at DescribeInstances - ec2API.Instances.Store( + awsEnv.EC2API.Instances.Store( instanceID, &ec2.Instance{ State: &ec2.InstanceState{ @@ -175,7 +142,7 @@ var _ = Describe("MachineLink", func() { }, } ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - ExpectInstanceExists(ec2API, instanceID) + ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectReconcileSucceeded(ctx, linkController, client.ObjectKey{}) machineList := &v1alpha5.MachineList{} @@ -191,7 +158,7 @@ var _ = Describe("MachineLink", func() { // Expect machine has linking annotation to get machine details Expect(machine.Annotations).To(HaveKeyWithValue(v1alpha5.MachineLinkedAnnotationKey, providerID)) - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectManagedByTagExists(instance) }) It("should link and instance with expected requirements and labels", func() { @@ -213,7 +180,7 @@ var _ = Describe("MachineLink", func() { }, } ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - ExpectInstanceExists(ec2API, instanceID) + ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectReconcileSucceeded(ctx, linkController, client.ObjectKey{}) machineList := &v1alpha5.MachineList{} @@ -242,7 +209,7 @@ var _ = Describe("MachineLink", func() { // Expect machine has linking annotation to get machine details Expect(machine.Annotations).To(HaveKeyWithValue(v1alpha5.MachineLinkedAnnotationKey, providerID)) - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectManagedByTagExists(instance) }) It("should link an instance with expected kubelet from provisioner kubelet configuration", func() { @@ -266,18 +233,18 @@ var _ = Describe("MachineLink", func() { // Expect machine has linking annotation to get machine details Expect(machine.Annotations).To(HaveKeyWithValue(v1alpha5.MachineLinkedAnnotationKey, providerID)) - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectManagedByTagExists(instance) }) It("should link many instances to many machines", func() { - ec2API.Reset() // Reset so we don't store the extra instance + awsEnv.EC2API.Reset() // Reset so we don't store the extra instance ExpectApplied(ctx, env.Client, provisioner) // Generate 500 instances that have different instanceIDs var ids []string for i := 0; i < 500; i++ { instanceID = fake.InstanceID() - ec2API.EC2Behavior.Instances.Store( + awsEnv.EC2API.EC2Behavior.Instances.Store( instanceID, &ec2.Instance{ State: &ec2.InstanceState{ @@ -318,7 +285,7 @@ var _ = Describe("MachineLink", func() { Expect(machineInstanceIDs).To(HaveLen(len(ids))) for _, id := range ids { Expect(machineInstanceIDs.Has(id)).To(BeTrue()) - instance := ExpectInstanceExists(ec2API, id) + instance := ExpectInstanceExists(awsEnv.EC2API, id) ExpectManagedByTagExists(instance) } }) @@ -343,7 +310,7 @@ var _ = Describe("MachineLink", func() { // Expect machine has linking annotation to get machine details Expect(machine.Annotations).To(HaveKeyWithValue(v1alpha5.MachineLinkedAnnotationKey, providerID)) - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectManagedByTagExists(instance) }) It("should link an instance without node template existence", func() { @@ -360,17 +327,17 @@ var _ = Describe("MachineLink", func() { // Expect machine has linking annotation to get machine details Expect(machine.Annotations).To(HaveKeyWithValue(v1alpha5.MachineLinkedAnnotationKey, providerID)) - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) ExpectManagedByTagExists(instance) }) }) Context("Failed", func() { It("should not link an instance without a provisioner tag", func() { - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) instance.Tags = lo.Reject(instance.Tags, func(t *ec2.Tag, _ int) bool { return aws.StringValue(t.Key) == v1alpha5.ProvisionerNameLabelKey }) - ec2API.Instances.Store(instanceID, instance) + awsEnv.EC2API.Instances.Store(instanceID, instance) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) ExpectReconcileSucceeded(ctx, linkController, client.ObjectKey{}) @@ -407,9 +374,9 @@ var _ = Describe("MachineLink", func() { }) It("should not link an instance that is terminated", func() { // Update the state of the existing instance - instance := ExpectInstanceExists(ec2API, instanceID) + instance := ExpectInstanceExists(awsEnv.EC2API, instanceID) instance.State.Name = aws.String(ec2.InstanceStateNameTerminated) - ec2API.Instances.Store(instanceID, instance) + awsEnv.EC2API.Instances.Store(instanceID, instance) ExpectReconcileSucceeded(ctx, linkController, client.ObjectKey{}) machineList := &v1alpha5.MachineList{} diff --git a/pkg/controllers/nodetemplate/suite_test.go b/pkg/controllers/nodetemplate/suite_test.go index 717832d3ea62..f3f20bffd9dc 100644 --- a/pkg/controllers/nodetemplate/suite_test.go +++ b/pkg/controllers/nodetemplate/suite_test.go @@ -19,38 +19,34 @@ import ( "sort" "testing" + "github.com/aws/aws-sdk-go/service/ec2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/samber/lo" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" . "knative.dev/pkg/logging/testing" _ "knative.dev/pkg/system/testing" "sigs.k8s.io/controller-runtime/pkg/client" - "github.com/aws/aws-sdk-go/service/ec2" - + coresettings "github.com/aws/karpenter-core/pkg/apis/settings" + corecontroller "github.com/aws/karpenter-core/pkg/operator/controller" "github.com/aws/karpenter-core/pkg/operator/injection" "github.com/aws/karpenter-core/pkg/operator/options" "github.com/aws/karpenter-core/pkg/operator/scheme" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - corecontroller "github.com/aws/karpenter-core/pkg/operator/controller" coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" + "github.com/aws/karpenter/pkg/apis" + "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" "github.com/aws/karpenter/pkg/controllers/nodetemplate" - "github.com/aws/karpenter/pkg/fake" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" + "github.com/aws/karpenter/pkg/test" ) var ctx context.Context var env *coretest.Environment -var fakeEC2API *fake.EC2API +var awsEnv *test.Environment var opts options.Options -var subnetProvider *subnet.Provider -var securityGroupProvider *securitygroup.Provider var nodeTemplate *v1alpha1.AWSNodeTemplate var controller corecontroller.Controller @@ -62,11 +58,11 @@ func TestAPIs(t *testing.T) { var _ = BeforeSuite(func() { env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...)) + ctx = coresettings.ToContext(ctx, coretest.Settings()) + ctx = settings.ToContext(ctx, test.Settings()) + awsEnv = test.NewEnvironment(ctx, env) - fakeEC2API = &fake.EC2API{} - subnetProvider = subnet.NewProvider(fakeEC2API) - securityGroupProvider = securitygroup.NewProvider(fakeEC2API) - controller = nodetemplate.NewController(env.Client, subnetProvider, securityGroupProvider) + controller = nodetemplate.NewController(env.Client, awsEnv.SubnetProvider, awsEnv.SecurityGroupProvider) }) var _ = AfterSuite(func() { @@ -88,7 +84,7 @@ var _ = BeforeEach(func() { }, } - fakeEC2API.Reset() + awsEnv.Reset() }) var _ = AfterEach(func() { @@ -101,7 +97,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) subnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) string { return *ec2subnet.SubnetId }) @@ -116,7 +112,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) sort.Slice(subnet, func(i, j int) bool { return int(*subnet[i].AvailableIpAddressCount) > int(*subnet[j].AvailableIpAddressCount) }) @@ -133,7 +129,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) sort.Slice(subnet, func(i, j int) bool { return int(*subnet[i].AvailableIpAddressCount) > int(*subnet[j].AvailableIpAddressCount) }) @@ -150,7 +146,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) correctSubnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) v1alpha1.SubnetStatus { return v1alpha1.SubnetStatus{ ID: *ec2subnet.SubnetId, @@ -164,7 +160,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) subnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) string { return *ec2subnet.SubnetId }) @@ -179,7 +175,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ = subnetProvider.List(ctx, nodeTemplate) + subnet, _ = awsEnv.SubnetProvider.List(ctx, nodeTemplate) sort.Slice(subnet, func(i, j int) bool { return int(*subnet[i].AvailableIpAddressCount) > int(*subnet[j].AvailableIpAddressCount) }) @@ -195,7 +191,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) subnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) string { return *ec2subnet.SubnetId }) @@ -210,7 +206,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ = subnetProvider.List(ctx, nodeTemplate) + subnet, _ = awsEnv.SubnetProvider.List(ctx, nodeTemplate) correctSubnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) v1alpha1.SubnetStatus { return v1alpha1.SubnetStatus{ ID: *ec2subnet.SubnetId, @@ -231,7 +227,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - subnet, _ := subnetProvider.List(ctx, nodeTemplate) + subnet, _ := awsEnv.SubnetProvider.List(ctx, nodeTemplate) subnetIDs := lo.Map(subnet, func(ec2subnet *ec2.Subnet, _ int) string { return *ec2subnet.SubnetId }) @@ -256,7 +252,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) securityGroupsIDInStatus := lo.Map(nodeTemplate.Status.SecurityGroups, func(securitygroup v1alpha1.SecurityGroupStatus, _ int) string { return securitygroup.ID }) @@ -266,7 +262,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) securityGroupsIDInStatus := lo.Map(nodeTemplate.Status.SecurityGroups, func(securitygroup v1alpha1.SecurityGroupStatus, _ int) string { return securitygroup.ID }) @@ -277,7 +273,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) correctSecurityGroupsIDs := lo.Map(securityGroupsIDs, func(securitygroup string, _ int) v1alpha1.SecurityGroupStatus { return v1alpha1.SecurityGroupStatus{ ID: securitygroup, @@ -290,7 +286,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) correctSecurityGroupsIDs := lo.Map(securityGroupsIDs, func(securitygroup string, _ int) v1alpha1.SecurityGroupStatus { return v1alpha1.SecurityGroupStatus{ ID: securitygroup, @@ -302,7 +298,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) securityGroupsIDInStatus := lo.Map(nodeTemplate.Status.SecurityGroups, func(securitygroup v1alpha1.SecurityGroupStatus, _ int) string { return securitygroup.ID }) @@ -312,7 +308,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ = securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ = awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) correctSecurityGroupsIDs := lo.Map(securityGroupsIDs, func(securitygroup string, _ int) v1alpha1.SecurityGroupStatus { return v1alpha1.SecurityGroupStatus{ ID: securitygroup, @@ -324,7 +320,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) securityGroupsIDInStatus := lo.Map(nodeTemplate.Status.SecurityGroups, func(securitygroup v1alpha1.SecurityGroupStatus, _ int) string { return securitygroup.ID }) @@ -334,7 +330,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ = securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ = awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) correctSecurityGroupsIDs := lo.Map(securityGroupsIDs, func(securitygroup string, _ int) v1alpha1.SecurityGroupStatus { return v1alpha1.SecurityGroupStatus{ ID: securitygroup, @@ -353,7 +349,7 @@ var _ = Describe("AWSNodeTemplateController", func() { ExpectApplied(ctx, env.Client, nodeTemplate) ExpectReconcileSucceeded(ctx, controller, client.ObjectKeyFromObject(nodeTemplate)) nodeTemplate = ExpectExists(ctx, env.Client, nodeTemplate) - securityGroupsIDs, _ := securityGroupProvider.List(ctx, nodeTemplate) + securityGroupsIDs, _ := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) securityGroupsIDInStatus := lo.Map(nodeTemplate.Status.SecurityGroups, func(securitygroup v1alpha1.SecurityGroupStatus, _ int) string { return securitygroup.ID }) diff --git a/pkg/cloudprovider/instance.go b/pkg/providers/instance/instance.go similarity index 90% rename from pkg/cloudprovider/instance.go rename to pkg/providers/instance/instance.go index cb0d82a8ea00..98a40b7722fe 100644 --- a/pkg/cloudprovider/instance.go +++ b/pkg/providers/instance/instance.go @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package cloudprovider +package instance import ( "context" @@ -46,11 +46,13 @@ import ( "github.com/aws/karpenter-core/pkg/utils/resources" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" - "github.com/aws/karpenter-core/pkg/cloudprovider" + cloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/scheduling" ) var ( + // MaxInstanceTypes defines the number of instance type options to pass to CreateFleet + MaxInstanceTypes = 60 instanceTypeFlexibilityThreshold = 5 // falling back to on-demand without flexibility risks insufficient capacity errors instanceStateFilter = &ec2.Filter{ @@ -59,7 +61,7 @@ var ( } ) -type InstanceProvider struct { +type Provider struct { region string ec2api ec2iface.EC2API unavailableOfferings *cache.UnavailableOfferings @@ -69,8 +71,9 @@ type InstanceProvider struct { ec2Batcher *batcher.EC2API } -func NewInstanceProvider(ctx context.Context, region string, ec2api ec2iface.EC2API, unavailableOfferings *cache.UnavailableOfferings, instanceTypeProvider *instancetype.Provider, subnetProvider *subnet.Provider, launchTemplateProvider *launchtemplate.Provider) *InstanceProvider { - return &InstanceProvider{ +func NewProvider(ctx context.Context, region string, ec2api ec2iface.EC2API, unavailableOfferings *cache.UnavailableOfferings, + instanceTypeProvider *instancetype.Provider, subnetProvider *subnet.Provider, launchTemplateProvider *launchtemplate.Provider) *Provider { + return &Provider{ region: region, ec2api: ec2api, unavailableOfferings: unavailableOfferings, @@ -81,7 +84,7 @@ func NewInstanceProvider(ctx context.Context, region string, ec2api ec2iface.EC2 } } -func (p *InstanceProvider) Create(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) (*ec2.Instance, error) { +func (p *Provider) Create(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) (*ec2.Instance, error) { instanceTypes = p.filterInstanceTypes(machine, instanceTypes) instanceTypes = orderInstanceTypesByPrice(instanceTypes, scheduling.NewNodeSelectorRequirements(machine.Spec.Requirements...)) if len(instanceTypes) > MaxInstanceTypes { @@ -112,12 +115,12 @@ func (p *InstanceProvider) Create(ctx context.Context, nodeTemplate *v1alpha1.AW "hostname", aws.StringValue(instance.PrivateDnsName), "instance-type", aws.StringValue(instance.InstanceType), "zone", aws.StringValue(instance.Placement.AvailabilityZone), - "capacity-type", getCapacityType(instance)).Infof("launched new instance") + "capacity-type", GetCapacityType(instance)).Infof("launched new instance") return instance, nil } -func (p *InstanceProvider) Link(ctx context.Context, id string) error { +func (p *Provider) Link(ctx context.Context, id string) error { _, err := p.ec2api.CreateTagsWithContext(ctx, &ec2.CreateTagsInput{ Resources: aws.StringSlice([]string{id}), Tags: []*ec2.Tag{ @@ -136,7 +139,7 @@ func (p *InstanceProvider) Link(ctx context.Context, id string) error { return nil } -func (p *InstanceProvider) Get(ctx context.Context, id string) (*ec2.Instance, error) { +func (p *Provider) Get(ctx context.Context, id string) (*ec2.Instance, error) { out, err := p.ec2Batcher.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ InstanceIds: aws.StringSlice([]string{id}), Filters: []*ec2.Filter{instanceStateFilter}, @@ -160,7 +163,7 @@ func (p *InstanceProvider) Get(ctx context.Context, id string) (*ec2.Instance, e return instances[0], nil } -func (p *InstanceProvider) List(ctx context.Context) ([]*ec2.Instance, error) { +func (p *Provider) List(ctx context.Context) ([]*ec2.Instance, error) { // Use the machine name data to determine which instances match this machine out, err := p.ec2api.DescribeInstancesWithContext(ctx, &ec2.DescribeInstancesInput{ Filters: []*ec2.Filter{ @@ -182,7 +185,7 @@ func (p *InstanceProvider) List(ctx context.Context) ([]*ec2.Instance, error) { return instances, cloudprovider.IgnoreMachineNotFoundError(err) } -func (p *InstanceProvider) Delete(ctx context.Context, id string) error { +func (p *Provider) Delete(ctx context.Context, id string) error { if _, err := p.ec2Batcher.TerminateInstances(ctx, &ec2.TerminateInstancesInput{ InstanceIds: []*string{aws.String(id)}, }); err != nil { @@ -200,7 +203,7 @@ func (p *InstanceProvider) Delete(ctx context.Context, id string) error { return nil } -func (p *InstanceProvider) launchInstance(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) (*string, error) { +func (p *Provider) launchInstance(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) (*string, error) { capacityType := p.getCapacityType(machine, instanceTypes) zonalSubnets, err := p.subnetProvider.ZonalSubnetsForLaunch(ctx, nodeTemplate, instanceTypes, capacityType) if err != nil { @@ -260,7 +263,7 @@ func (p *InstanceProvider) launchInstance(ctx context.Context, nodeTemplate *v1a return createFleetOutput.Instances[0].InstanceIds[0], nil } -func (p *InstanceProvider) checkODFallback(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType, launchTemplateConfigs []*ec2.FleetLaunchTemplateConfigRequest) error { +func (p *Provider) checkODFallback(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType, launchTemplateConfigs []*ec2.FleetLaunchTemplateConfigRequest) error { // only evaluate for on-demand fallback if the capacity type for the request is OD and both OD and spot are allowed in requirements if p.getCapacityType(machine, instanceTypes) != v1alpha5.CapacityTypeOnDemand || !scheduling.NewNodeSelectorRequirements(machine.Spec.Requirements...).Get(v1alpha5.LabelCapacityType).Has(v1alpha5.CapacityTypeSpot) { return nil @@ -282,7 +285,7 @@ func (p *InstanceProvider) checkODFallback(machine *v1alpha5.Machine, instanceTy return nil } -func (p *InstanceProvider) getLaunchTemplateConfigs(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, +func (p *Provider) getLaunchTemplateConfigs(ctx context.Context, nodeTemplate *v1alpha1.AWSNodeTemplate, machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]*ec2.Subnet, capacityType string) ([]*ec2.FleetLaunchTemplateConfigRequest, error) { var launchTemplateConfigs []*ec2.FleetLaunchTemplateConfigRequest launchTemplates, err := p.launchTemplateProvider.EnsureAll(ctx, nodeTemplate, machine, instanceTypes, map[string]string{v1alpha5.LabelCapacityType: capacityType}) @@ -309,7 +312,7 @@ func (p *InstanceProvider) getLaunchTemplateConfigs(ctx context.Context, nodeTem // getOverrides creates and returns launch template overrides for the cross product of InstanceTypes and subnets (with subnets being constrained by // zones and the offerings in InstanceTypes) -func (p *InstanceProvider) getOverrides(instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]*ec2.Subnet, zones *scheduling.Requirement, capacityType string) []*ec2.FleetLaunchTemplateOverridesRequest { +func (p *Provider) getOverrides(instanceTypes []*cloudprovider.InstanceType, zonalSubnets map[string]*ec2.Subnet, zones *scheduling.Requirement, capacityType string) []*ec2.FleetLaunchTemplateOverridesRequest { // Unwrap all the offerings to a flat slice that includes a pointer // to the parent instance type name type offeringWithParentName struct { @@ -352,7 +355,7 @@ func (p *InstanceProvider) getOverrides(instanceTypes []*cloudprovider.InstanceT // Update receives a machine and updates the EC2 instance with tags linking it to the machine // Deprecated: This function can be removed when v1alpha6/v1beta1 migration has completed. -func (p *InstanceProvider) Update(ctx context.Context, machine *v1alpha5.Machine) (*ec2.Instance, error) { +func (p *Provider) Update(ctx context.Context, machine *v1alpha5.Machine) (*ec2.Instance, error) { _, err := p.ec2api.CreateTagsWithContext(ctx, &ec2.CreateTagsInput{ Resources: aws.StringSlice([]string{lo.Must(utils.ParseInstanceID(machine.Status.ProviderID))}), Tags: []*ec2.Tag{ @@ -394,7 +397,7 @@ func (p *InstanceProvider) Update(ctx context.Context, machine *v1alpha5.Machine return instance, nil } -func (p *InstanceProvider) updateUnavailableOfferingsCache(ctx context.Context, errors []*ec2.CreateFleetError, capacityType string) { +func (p *Provider) updateUnavailableOfferingsCache(ctx context.Context, errors []*ec2.CreateFleetError, capacityType string) { for _, err := range errors { if awserrors.IsUnfulfillableCapacity(err) { p.unavailableOfferings.MarkUnavailableForFleetErr(ctx, err, capacityType) @@ -405,7 +408,7 @@ func (p *InstanceProvider) updateUnavailableOfferingsCache(ctx context.Context, // getCapacityType selects spot if both constraints are flexible and there is an // available offering. The AWS Cloud Provider defaults to [ on-demand ], so spot // must be explicitly included in capacity type requirements. -func (p *InstanceProvider) getCapacityType(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) string { +func (p *Provider) getCapacityType(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) string { requirements := scheduling.NewNodeSelectorRequirements(machine. Spec.Requirements...) if requirements.Get(v1alpha5.LabelCapacityType).Has(v1alpha5.CapacityTypeSpot) { @@ -441,7 +444,7 @@ func orderInstanceTypesByPrice(instanceTypes []*cloudprovider.InstanceType, requ // filterInstanceTypes is used to provide filtering on the list of potential instance types to further limit it to those // that make the most sense given our specific AWS cloudprovider. -func (p *InstanceProvider) filterInstanceTypes(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) []*cloudprovider.InstanceType { +func (p *Provider) filterInstanceTypes(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) []*cloudprovider.InstanceType { instanceTypes = filterExoticInstanceTypes(instanceTypes) // If we could potentially launch either a spot or on-demand node, we want to filter out the spot instance types that // are more expensive than the cheapest on-demand type. @@ -453,7 +456,7 @@ func (p *InstanceProvider) filterInstanceTypes(machine *v1alpha5.Machine, instan // isMixedCapacityLaunch returns true if provisioners and available offerings could potentially allow either a spot or // and on-demand node to launch -func (p *InstanceProvider) isMixedCapacityLaunch(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) bool { +func (p *Provider) isMixedCapacityLaunch(machine *v1alpha5.Machine, instanceTypes []*cloudprovider.InstanceType) bool { requirements := scheduling.NewNodeSelectorRequirements(machine.Spec.Requirements...) // requirements must allow both if !requirements.Get(v1alpha5.LabelCapacityType).Has(v1alpha5.CapacityTypeSpot) || @@ -558,7 +561,7 @@ func combineFleetErrors(errors []*ec2.CreateFleetError) (errs error) { return fmt.Errorf("with fleet error(s), %w", errs) } -func getCapacityType(instance *ec2.Instance) string { +func GetCapacityType(instance *ec2.Instance) string { if instance.SpotInstanceRequestId != nil { return v1alpha5.CapacityTypeSpot } diff --git a/pkg/providers/instancetype/suite_test.go b/pkg/providers/instancetype/suite_test.go index e30e43c9a643..500804c2d073 100644 --- a/pkg/providers/instancetype/suite_test.go +++ b/pkg/providers/instancetype/suite_test.go @@ -25,11 +25,9 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/awstesting/mock" "github.com/aws/aws-sdk-go/service/ec2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/patrickmn/go-cache" "github.com/samber/lo" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -41,8 +39,9 @@ import ( . "knative.dev/pkg/logging/testing" "knative.dev/pkg/ptr" + coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" - corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" + corecloudproivder "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/controllers/provisioning" "github.com/aws/karpenter-core/pkg/controllers/state" "github.com/aws/karpenter-core/pkg/events" @@ -51,24 +50,18 @@ import ( "github.com/aws/karpenter-core/pkg/operator/options" "github.com/aws/karpenter-core/pkg/operator/scheme" "github.com/aws/karpenter-core/pkg/scheduling" + coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" "github.com/aws/karpenter-core/pkg/utils/resources" - coresettings "github.com/aws/karpenter-core/pkg/apis/settings" - coretest "github.com/aws/karpenter-core/pkg/test" "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" "github.com/aws/karpenter/pkg/cloudprovider" - awscontext "github.com/aws/karpenter/pkg/context" "github.com/aws/karpenter/pkg/fake" - "github.com/aws/karpenter/pkg/providers/amifamily" + "github.com/aws/karpenter/pkg/providers/instance" "github.com/aws/karpenter/pkg/providers/instancetype" - "github.com/aws/karpenter/pkg/providers/launchtemplate" "github.com/aws/karpenter/pkg/providers/pricing" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" "github.com/aws/karpenter/pkg/test" ) @@ -76,29 +69,14 @@ var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var ssmCache *cache.Cache -var ec2Cache *cache.Cache -var launchTemplateCache *cache.Cache -var instanceTypeCache *cache.Cache -var kubernetesVersionCache *cache.Cache -var fakeEC2API *fake.EC2API -var fakeSSMAPI *fake.SSMAPI +var awsEnv *test.Environment var fakeClock *clock.FakeClock -var fakePricingAPI *fake.PricingAPI -var amiProvider *amifamily.Provider -var amiResolver *amifamily.Resolver -var cloudProvider *cloudprovider.CloudProvider -var unavailableOfferingsCache *awscache.UnavailableOfferings var prov *provisioning.Provisioner var provisioner *v1alpha5.Provisioner -var launchTemplateProvider *launchtemplate.Provider var nodeTemplate *v1alpha1.AWSNodeTemplate var cluster *state.Cluster -var pricingProvider *pricing.Provider -var subnetProvider *subnet.Provider -var instanceTypeProvider *instancetype.Provider -var securityGroupProvider *securitygroup.Provider var provisioningController controller.Controller +var cloudProvider *cloudprovider.CloudProvider func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -111,57 +89,10 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) + awsEnv = test.NewEnvironment(ctx, env) - fakeEC2API = &fake.EC2API{} - fakeSSMAPI = &fake.SSMAPI{} - ssmCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - ec2Cache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - launchTemplateCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - instanceTypeCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - kubernetesVersionCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - fakeClock = clock.NewFakeClock(time.Now()) - unavailableOfferingsCache = awscache.NewUnavailableOfferings() - fakePricingAPI = &fake.PricingAPI{} - pricingProvider = pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) - subnetProvider = subnet.NewProvider(fakeEC2API) - securityGroupProvider = securitygroup.NewProvider(fakeEC2API) - amiProvider = amifamily.NewProvider(env.Client, env.KubernetesInterface, fakeSSMAPI, fakeEC2API, ssmCache, ec2Cache, kubernetesVersionCache) - amiResolver = amifamily.New(env.Client, amiProvider) - instanceTypeProvider = instancetype.NewProvider("", instanceTypeCache, fakeEC2API, subnetProvider, unavailableOfferingsCache, pricingProvider) - - launchTemplateProvider = launchtemplate.NewProvider( - ctx, - launchTemplateCache, - fakeEC2API, - amiResolver, - securityGroupProvider, - ptr.String("ca-bundle"), - make(chan struct{}), - net.ParseIP("10.0.100.10"), - "https://test-cluster", - ) - - cloudProvider = cloudprovider.New(awscontext.Context{ - Context: corecloudprovider.Context{ - Context: ctx, - RESTConfig: env.Config, - KubernetesInterface: env.KubernetesInterface, - KubeClient: env.Client, - EventRecorder: events.NewRecorder(&record.FakeRecorder{}), - Clock: &clock.FakeClock{}, - StartAsync: nil, - }, - SubnetProvider: subnet.NewProvider(fakeEC2API), - SecurityGroupProvider: securityGroupProvider, - Session: mock.Session, - UnavailableOfferingsCache: unavailableOfferingsCache, - EC2API: fakeEC2API, - PricingProvider: pricingProvider, - AMIProvider: amiProvider, - AMIResolver: amiResolver, - LaunchTemplateProvider: launchTemplateProvider, - InstanceTypesProvider: instanceTypeProvider, - }) + fakeClock = &clock.FakeClock{} + cloudProvider = cloudprovider.New(ctx, awsEnv.InstanceTypesProvider, awsEnv.InstanceProvider, env.Client, awsEnv.AMIProvider) cluster = state.NewCluster(fakeClock, env.Client, cloudProvider) prov = provisioning.NewProvisioner(ctx, env.Client, env.KubernetesInterface.CoreV1(), events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster) provisioningController = provisioning.NewController(env.Client, prov, events.NewRecorder(&record.FakeRecorder{})) @@ -206,29 +137,10 @@ var _ = BeforeEach(func() { }) cluster.Reset() - fakeEC2API.Reset() - fakeSSMAPI.Reset() - fakePricingAPI.Reset() - launchTemplateCache.Flush() - unavailableOfferingsCache.Flush() - ssmCache.Flush() - ec2Cache.Flush() - kubernetesVersionCache.Flush() - instanceTypeCache.Flush() - subnetProvider.Reset() - securityGroupProvider.Reset() - launchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") - launchTemplateProvider.ClusterEndpoint = "https://test-cluster" + awsEnv.Reset() - // Reset the pricing provider, so we don't cross-pollinate pricing data - instanceTypeProvider = instancetype.NewProvider( - "", - instanceTypeCache, - fakeEC2API, - subnetProvider, - unavailableOfferingsCache, - pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})), - ) + awsEnv.LaunchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") + awsEnv.LaunchTemplateProvider.ClusterEndpoint = "https://test-cluster" }) var _ = AfterEach(func() { @@ -295,10 +207,10 @@ var _ = Describe("Instance Types", func() { }) It("should order the instance types by price and only consider the cheapest ones", func() { instances := makeFakeInstances() - fakeEC2API.DescribeInstanceTypesOutput.Set(&ec2.DescribeInstanceTypesOutput{ + awsEnv.EC2API.DescribeInstanceTypesOutput.Set(&ec2.DescribeInstanceTypesOutput{ InstanceTypes: makeFakeInstances(), }) - fakeEC2API.DescribeInstanceTypeOfferingsOutput.Set(&ec2.DescribeInstanceTypeOfferingsOutput{ + awsEnv.EC2API.DescribeInstanceTypeOfferingsOutput.Set(&ec2.DescribeInstanceTypeOfferingsOutput{ InstanceTypeOfferings: makeFakeInstanceOfferings(instances), }) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) @@ -324,24 +236,24 @@ var _ = Describe("Instance Types", func() { return iPrice < jPrice }) // Expect that the launch template overrides gives the 60 cheapest instance types - expected := sets.NewString(lo.Map(its[:cloudprovider.MaxInstanceTypes], func(i *corecloudprovider.InstanceType, _ int) string { + expected := sets.NewString(lo.Map(its[:instance.MaxInstanceTypes], func(i *corecloudproivder.InstanceType, _ int) string { return i.Name })...) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - call := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + call := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(call.LaunchTemplateConfigs).To(HaveLen(1)) - Expect(call.LaunchTemplateConfigs[0].Overrides).To(HaveLen(cloudprovider.MaxInstanceTypes)) + Expect(call.LaunchTemplateConfigs[0].Overrides).To(HaveLen(instance.MaxInstanceTypes)) for _, override := range call.LaunchTemplateConfigs[0].Overrides { Expect(expected.Has(aws.StringValue(override.InstanceType))).To(BeTrue(), fmt.Sprintf("expected %s to exist in set", aws.StringValue(override.InstanceType))) } }) It("should order the instance types by price and only consider the spot types that are cheaper than the cheapest on-demand", func() { instances := makeFakeInstances() - fakeEC2API.DescribeInstanceTypesOutput.Set(&ec2.DescribeInstanceTypesOutput{ + awsEnv.EC2API.DescribeInstanceTypesOutput.Set(&ec2.DescribeInstanceTypesOutput{ InstanceTypes: makeFakeInstances(), }) - fakeEC2API.DescribeInstanceTypeOfferingsOutput.Set(&ec2.DescribeInstanceTypeOfferingsOutput{ + awsEnv.EC2API.DescribeInstanceTypeOfferingsOutput.Set(&ec2.DescribeInstanceTypeOfferingsOutput{ InstanceTypeOfferings: makeFakeInstanceOfferings(instances), }) @@ -356,8 +268,8 @@ var _ = Describe("Instance Types", func() { }, } ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(generateSpotPricing(cloudProvider, provisioner)) - Expect(pricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(generateSpotPricing(cloudProvider, provisioner)) + Expect(awsEnv.PricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) pod := coretest.UnschedulablePod(coretest.PodOptions{ ResourceRequirements: v1.ResourceRequirements{ @@ -382,14 +294,14 @@ var _ = Describe("Instance Types", func() { return iPrice < jPrice }) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - call := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + call := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(call.LaunchTemplateConfigs).To(HaveLen(1)) // find the cheapest OD price that works cheapestODPrice := math.MaxFloat64 for _, override := range call.LaunchTemplateConfigs[0].Overrides { - odPrice, ok := pricingProvider.OnDemandPrice(*override.InstanceType) + odPrice, ok := awsEnv.PricingProvider.OnDemandPrice(*override.InstanceType) Expect(ok).To(BeTrue()) if odPrice < cheapestODPrice { cheapestODPrice = odPrice @@ -397,7 +309,7 @@ var _ = Describe("Instance Types", func() { } // and our spot prices should be cheaper than the OD price for _, override := range call.LaunchTemplateConfigs[0].Overrides { - spotPrice, ok := pricingProvider.SpotPrice(*override.InstanceType, *override.AvailabilityZone) + spotPrice, ok := awsEnv.PricingProvider.SpotPrice(*override.InstanceType, *override.AvailabilityZone) Expect(ok).To(BeTrue()) Expect(spotPrice).To(BeNumerically("<", cheapestODPrice)) } @@ -413,8 +325,8 @@ var _ = Describe("Instance Types", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - call := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + call := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() for _, ltc := range call.LaunchTemplateConfigs { for _, ovr := range ltc.Overrides { Expect(strings.HasSuffix(aws.StringValue(ovr.InstanceType), "metal")).To(BeFalse()) @@ -432,8 +344,8 @@ var _ = Describe("Instance Types", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - call := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + call := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() for _, ltc := range call.LaunchTemplateConfigs { for _, ovr := range ltc.Overrides { Expect(strings.HasPrefix(aws.StringValue(ovr.InstanceType), "g")).To(BeFalse()) @@ -591,7 +503,7 @@ var _ = Describe("Instance Types", func() { ctx = settings.ToContext(ctx, test.Settings(test.SettingOptions{ EnableENILimitedPodDensity: lo.ToPtr(false), })) - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) for _, info := range instanceInfo { it := instancetype.NewInstanceType(ctx, info, provisioner.Spec.KubeletConfiguration, "", nodeTemplate, nil) @@ -599,7 +511,7 @@ var _ = Describe("Instance Types", func() { } }) It("should not set pods to 110 if using ENI-based pod density", func() { - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) for _, info := range instanceInfo { it := instancetype.NewInstanceType(ctx, info, provisioner.Spec.KubeletConfiguration, "", nodeTemplate, nil) @@ -615,7 +527,7 @@ var _ = Describe("Instance Types", func() { })) var ok bool - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) info, ok = lo.Find(instanceInfo, func(i *ec2.InstanceTypeInfo) bool { return aws.StringValue(i.InstanceType) == "m5.xlarge" @@ -877,7 +789,7 @@ var _ = Describe("Instance Types", func() { }) }) It("should set max-pods to user-defined value if specified", func() { - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{MaxPods: ptr.Int32(10)}}) for _, info := range instanceInfo { @@ -890,7 +802,7 @@ var _ = Describe("Instance Types", func() { EnablePodENI: lo.ToPtr(false), })) - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{MaxPods: ptr.Int32(10)}}) for _, info := range instanceInfo { @@ -899,7 +811,7 @@ var _ = Describe("Instance Types", func() { } }) It("should override pods-per-core value", func() { - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{PodsPerCore: ptr.Int32(1)}}) for _, info := range instanceInfo { @@ -908,7 +820,7 @@ var _ = Describe("Instance Types", func() { } }) It("should take the minimum of pods-per-core and max-pods", func() { - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{PodsPerCore: ptr.Int32(4), MaxPods: ptr.Int32(20)}}) for _, info := range instanceInfo { @@ -917,7 +829,7 @@ var _ = Describe("Instance Types", func() { } }) It("should ignore pods-per-core when using Bottlerocket AMI", func() { - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) nodeTemplate.Spec.AMIFamily = &v1alpha1.AMIFamilyBottlerocket provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{PodsPerCore: ptr.Int32(1)}}) @@ -932,7 +844,7 @@ var _ = Describe("Instance Types", func() { EnableENILimitedPodDensity: lo.ToPtr(false), })) - instanceInfo, err := instanceTypeProvider.GetInstanceTypes(ctx) + instanceInfo, err := awsEnv.InstanceTypesProvider.GetInstanceTypes(ctx) Expect(err).To(BeNil()) provisioner = test.Provisioner(coretest.ProvisionerOptions{Kubelet: &v1alpha5.KubeletConfiguration{PodsPerCore: ptr.Int32(0)}}) for _, info := range instanceInfo { @@ -943,7 +855,7 @@ var _ = Describe("Instance Types", func() { }) Context("Insufficient Capacity Error Cache", func() { It("should launch instances of different type on second reconciliation attempt with Insufficient Capacity Error Cache fallback", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "inf1.6xlarge", Zone: "test-zone-1a"}}) + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "inf1.6xlarge", Zone: "test-zone-1a"}}) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) pods := []*v1.Pod{ coretest.UnschedulablePod(coretest.PodOptions{ @@ -976,7 +888,7 @@ var _ = Describe("Instance Types", func() { Expect(nodeNames.Len()).To(Equal(2)) }) It("should launch instances in a different zone on second reconciliation attempt with Insufficient Capacity Error Cache fallback", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "p3.8xlarge", Zone: "test-zone-1a"}}) + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "p3.8xlarge", Zone: "test-zone-1a"}}) pod := coretest.UnschedulablePod(coretest.PodOptions{ NodeSelector: map[string]string{v1.LabelInstanceTypeStable: "p3.8xlarge"}, ResourceRequirements: v1.ResourceRequirements{ @@ -1003,7 +915,7 @@ var _ = Describe("Instance Types", func() { HaveKeyWithValue(v1.LabelTopologyZone, "test-zone-1b"))) }) It("should launch smaller instances than optimal if larger instance launch results in Insufficient Capacity Error", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{ + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{ {CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "m5.xlarge", Zone: "test-zone-1a"}, }) provisioner.Spec.Requirements = append(provisioner.Spec.Requirements, v1.NodeSelectorRequirement{ @@ -1035,7 +947,7 @@ var _ = Describe("Instance Types", func() { } }) It("should launch instances on later reconciliation attempt with Insufficient Capacity Error Cache expiry", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "inf1.6xlarge", Zone: "test-zone-1a"}}) + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "inf1.6xlarge", Zone: "test-zone-1a"}}) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) pod := coretest.UnschedulablePod(coretest.PodOptions{ NodeSelector: map[string]string{v1.LabelInstanceTypeStable: "inf1.6xlarge"}, @@ -1047,14 +959,14 @@ var _ = Describe("Instance Types", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectNotScheduled(ctx, env.Client, pod) // capacity shortage is over - expire the item from the cache and try again - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{}) - unavailableOfferingsCache.Delete("inf1.6xlarge", "test-zone-1a", v1alpha5.CapacityTypeOnDemand) + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{}) + awsEnv.UnavailableOfferingsCache.Delete("inf1.6xlarge", "test-zone-1a", v1alpha5.CapacityTypeOnDemand) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) node := ExpectScheduled(ctx, env.Client, pod) Expect(node.Labels).To(HaveKeyWithValue(v1.LabelInstanceTypeStable, "inf1.6xlarge")) }) It("should launch instances in a different zone on second reconciliation attempt with Insufficient Capacity Error Cache fallback (Habana)", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "dl1.24xlarge", Zone: "test-zone-1a"}}) + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{{CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "dl1.24xlarge", Zone: "test-zone-1a"}}) pod := coretest.UnschedulablePod(coretest.PodOptions{ NodeSelector: map[string]string{v1.LabelInstanceTypeStable: "dl1.24xlarge"}, ResourceRequirements: v1.ResourceRequirements{ @@ -1081,9 +993,9 @@ var _ = Describe("Instance Types", func() { HaveKeyWithValue(v1.LabelTopologyZone, "test-zone-1b"))) }) It("should launch on-demand capacity if flexible to both spot and on-demand, but spot is unavailable", func() { - Expect(fakeEC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{}, func(dito *ec2.DescribeInstanceTypesOutput, b bool) bool { + Expect(awsEnv.EC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{}, func(dito *ec2.DescribeInstanceTypesOutput, b bool) bool { for _, it := range dito.InstanceTypes { - fakeEC2API.InsufficientCapacityPools.Add(fake.CapacityPool{CapacityType: v1alpha5.CapacityTypeSpot, InstanceType: aws.StringValue(it.InstanceType), Zone: "test-zone-1a"}) + awsEnv.EC2API.InsufficientCapacityPools.Add(fake.CapacityPool{CapacityType: v1alpha5.CapacityTypeSpot, InstanceType: aws.StringValue(it.InstanceType), Zone: "test-zone-1a"}) } return true })).To(Succeed()) @@ -1103,7 +1015,7 @@ var _ = Describe("Instance Types", func() { Expect(node.Labels).To(HaveKeyWithValue(v1alpha5.LabelCapacityType, v1alpha5.CapacityTypeOnDemand)) }) It("should return all instance types, even though with no offerings due to Insufficient Capacity Error", func() { - fakeEC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{ + awsEnv.EC2API.InsufficientCapacityPools.Set([]fake.CapacityPool{ {CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "m5.xlarge", Zone: "test-zone-1a"}, {CapacityType: v1alpha5.CapacityTypeOnDemand, InstanceType: "m5.xlarge", Zone: "test-zone-1b"}, {CapacityType: v1alpha5.CapacityTypeSpot, InstanceType: "m5.xlarge", Zone: "test-zone-1a"}, @@ -1137,7 +1049,7 @@ var _ = Describe("Instance Types", func() { } } - instanceTypeCache.Flush() + awsEnv.InstanceTypeCache.Flush() instanceTypes, err := cloudProvider.GetInstanceTypes(ctx, provisioner) Expect(err).To(BeNil()) instanceTypeNames := sets.NewString() @@ -1170,7 +1082,7 @@ var _ = Describe("Instance Types", func() { }) It("should fail to launch capacity when there is no zonal availability for spot", func() { now := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -1180,8 +1092,8 @@ var _ = Describe("Instance Types", func() { }, }, }) - Expect(pricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) - Eventually(func() bool { return pricingProvider.SpotLastUpdated().After(now) }).Should(BeTrue()) + Expect(awsEnv.PricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) + Eventually(func() bool { return awsEnv.PricingProvider.SpotLastUpdated().After(now) }).Should(BeTrue()) provisioner.Spec.Requirements = []v1.NodeSelectorRequirement{ {Key: v1alpha5.LabelCapacityType, Operator: v1.NodeSelectorOpIn, Values: []string{v1alpha5.CapacityTypeSpot}}, @@ -1197,7 +1109,7 @@ var _ = Describe("Instance Types", func() { }) It("should succeed to launch spot instance when zonal availability exists", func() { now := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -1207,8 +1119,8 @@ var _ = Describe("Instance Types", func() { }, }, }) - Expect(pricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) - Eventually(func() bool { return pricingProvider.SpotLastUpdated().After(now) }).Should(BeTrue()) + Expect(awsEnv.PricingProvider.UpdateSpotPricing(ctx)).To(Succeed()) + Eventually(func() bool { return awsEnv.PricingProvider.SpotLastUpdated().After(now) }).Should(BeTrue()) // not restricting to the zone so we can get any zone provisioner.Spec.Requirements = []v1.NodeSelectorRequirement{ @@ -1229,8 +1141,8 @@ var _ = Describe("Instance Types", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.MetadataOptions.HttpEndpoint).To(Equal(ec2.LaunchTemplateInstanceMetadataEndpointStateEnabled)) Expect(*input.LaunchTemplateData.MetadataOptions.HttpProtocolIpv6).To(Equal(ec2.LaunchTemplateInstanceMetadataProtocolIpv6Disabled)) Expect(*input.LaunchTemplateData.MetadataOptions.HttpPutResponseHopLimit).To(Equal(int64(2))) @@ -1247,8 +1159,8 @@ var _ = Describe("Instance Types", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.MetadataOptions.HttpEndpoint).To(Equal(ec2.LaunchTemplateInstanceMetadataEndpointStateDisabled)) Expect(*input.LaunchTemplateData.MetadataOptions.HttpProtocolIpv6).To(Equal(ec2.LaunchTemplateInstanceMetadataProtocolIpv6Enabled)) Expect(*input.LaunchTemplateData.MetadataOptions.HttpPutResponseHopLimit).To(Equal(int64(1))) @@ -1262,7 +1174,7 @@ var _ = Describe("Instance Types", func() { func generateSpotPricing(cp *cloudprovider.CloudProvider, prov *v1alpha5.Provisioner) *ec2.DescribeSpotPriceHistoryOutput { rsp := &ec2.DescribeSpotPriceHistoryOutput{} instanceTypes, err := cp.GetInstanceTypes(ctx, prov) - instanceTypeCache.Flush() + awsEnv.InstanceTypeCache.Flush() Expect(err).To(Succeed()) t := fakeClock.Now() diff --git a/pkg/providers/launchtemplate/launchtemplate_test.go b/pkg/providers/launchtemplate/suite_test.go similarity index 83% rename from pkg/providers/launchtemplate/launchtemplate_test.go rename to pkg/providers/launchtemplate/suite_test.go index 07f9087991e5..e31c2cbfc285 100644 --- a/pkg/providers/launchtemplate/launchtemplate_test.go +++ b/pkg/providers/launchtemplate/suite_test.go @@ -29,11 +29,9 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/awstesting/mock" "github.com/aws/aws-sdk-go/service/ec2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" - "github.com/patrickmn/go-cache" "github.com/samber/lo" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -43,28 +41,18 @@ import ( "k8s.io/client-go/tools/record" clock "k8s.io/utils/clock/testing" . "knative.dev/pkg/logging/testing" - "knative.dev/pkg/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" "github.com/aws/karpenter/pkg/cloudprovider" - awscontext "github.com/aws/karpenter/pkg/context" - "github.com/aws/karpenter/pkg/fake" - "github.com/aws/karpenter/pkg/providers/amifamily" "github.com/aws/karpenter/pkg/providers/amifamily/bootstrap" "github.com/aws/karpenter/pkg/providers/instancetype" - "github.com/aws/karpenter/pkg/providers/launchtemplate" - "github.com/aws/karpenter/pkg/providers/pricing" - "github.com/aws/karpenter/pkg/providers/securitygroup" - "github.com/aws/karpenter/pkg/providers/subnet" "github.com/aws/karpenter/pkg/test" coresettings "github.com/aws/karpenter-core/pkg/apis/settings" "github.com/aws/karpenter-core/pkg/apis/v1alpha5" - corecloudprovider "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/controllers/provisioning" "github.com/aws/karpenter-core/pkg/controllers/state" "github.com/aws/karpenter-core/pkg/events" @@ -79,28 +67,13 @@ var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var ssmCache *cache.Cache -var ec2Cache *cache.Cache -var launchTemplateCache *cache.Cache -var instanceTypeCache *cache.Cache -var kubernetesVersionCache *cache.Cache -var fakeEC2API *fake.EC2API -var fakeSSMAPI *fake.SSMAPI +var awsEnv *test.Environment var fakeClock *clock.FakeClock -var fakePricingAPI *fake.PricingAPI -var amiProvider *amifamily.Provider -var amiResolver *amifamily.Resolver -var cloudProvider *cloudprovider.CloudProvider -var unavailableOfferingsCache *awscache.UnavailableOfferings var prov *provisioning.Provisioner var provisioner *v1alpha5.Provisioner -var launchTemplateProvider *launchtemplate.Provider var nodeTemplate *v1alpha1.AWSNodeTemplate var cluster *state.Cluster -var pricingProvider *pricing.Provider -var subnetProvider *subnet.Provider -var instanceTypesProvider *instancetype.Provider -var securityGroupProvider *securitygroup.Provider +var cloudProvider *cloudprovider.CloudProvider func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -113,60 +86,12 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) + awsEnv = test.NewEnvironment(ctx, env) - fakeEC2API = &fake.EC2API{} - fakeSSMAPI = &fake.SSMAPI{} - ssmCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - ec2Cache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - launchTemplateCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - instanceTypeCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - kubernetesVersionCache = cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) - fakeClock = clock.NewFakeClock(time.Now()) - unavailableOfferingsCache = awscache.NewUnavailableOfferings() - fakePricingAPI = &fake.PricingAPI{} - pricingProvider = pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) - subnetProvider = subnet.NewProvider(fakeEC2API) - securityGroupProvider = securitygroup.NewProvider(fakeEC2API) - amiProvider = amifamily.NewProvider(env.Client, env.KubernetesInterface, fakeSSMAPI, fakeEC2API, ssmCache, ec2Cache, kubernetesVersionCache) - amiResolver = amifamily.New(env.Client, amiProvider) - instanceTypesProvider = instancetype.NewProvider("", instanceTypeCache, fakeEC2API, subnetProvider, unavailableOfferingsCache, pricingProvider) - - launchTemplateProvider = launchtemplate.NewProvider( - ctx, - launchTemplateCache, - fakeEC2API, - amiResolver, - securityGroupProvider, - ptr.String("ca-bundle"), - make(chan struct{}), - net.ParseIP("10.0.100.10"), - "https://test-cluster", - ) - - cloudProvider = cloudprovider.New(awscontext.Context{ - Context: corecloudprovider.Context{ - Context: ctx, - RESTConfig: env.Config, - KubernetesInterface: env.KubernetesInterface, - KubeClient: env.Client, - EventRecorder: events.NewRecorder(&record.FakeRecorder{}), - Clock: &clock.FakeClock{}, - StartAsync: nil, - }, - SubnetProvider: subnet.NewProvider(fakeEC2API), - SecurityGroupProvider: securityGroupProvider, - Session: mock.Session, - UnavailableOfferingsCache: unavailableOfferingsCache, - EC2API: fakeEC2API, - PricingProvider: pricingProvider, - AMIProvider: amiProvider, - AMIResolver: amiResolver, - LaunchTemplateProvider: launchTemplateProvider, - InstanceTypesProvider: instanceTypesProvider, - }) + fakeClock = &clock.FakeClock{} + cloudProvider = cloudprovider.New(ctx, awsEnv.InstanceTypesProvider, awsEnv.InstanceProvider, env.Client, awsEnv.AMIProvider) cluster = state.NewCluster(fakeClock, env.Client, cloudProvider) prov = provisioning.NewProvisioner(ctx, env.Client, env.KubernetesInterface.CoreV1(), events.NewRecorder(&record.FakeRecorder{}), cloudProvider, cluster) - }) var _ = AfterSuite(func() { @@ -206,29 +131,11 @@ var _ = BeforeEach(func() { Name: nodeTemplate.Name, }, }) - cluster.Reset() - fakeEC2API.Reset() - fakeSSMAPI.Reset() - launchTemplateCache.Flush() - unavailableOfferingsCache.Flush() - ssmCache.Flush() - ec2Cache.Flush() - instanceTypeCache.Flush() - kubernetesVersionCache.Flush() - securityGroupProvider.Reset() - launchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") - launchTemplateProvider.ClusterEndpoint = "https://test-cluster" - - // Reset the pricing provider, so we don't cross-pollinate pricing data - instanceTypesProvider = instancetype.NewProvider( - "", - instanceTypeCache, - fakeEC2API, - subnetProvider, - unavailableOfferingsCache, - pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})), - ) + awsEnv.Reset() + + awsEnv.LaunchTemplateProvider.KubeDNSIP = net.ParseIP("10.0.100.10") + awsEnv.LaunchTemplateProvider.ClusterEndpoint = "https://test-cluster" }) var _ = AfterEach(func() { @@ -243,13 +150,13 @@ var _ = Describe("LaunchTemplates", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - firstLt := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + firstLt := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() launchTemplate := createFleetInput.LaunchTemplateConfigs[0].LaunchTemplateSpecification Expect(createFleetInput.LaunchTemplateConfigs).To(HaveLen(1)) @@ -278,9 +185,9 @@ var _ = Describe("LaunchTemplates", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() // Expect these values to be correctly ordered by price overrides := createFleetInput.LaunchTemplateConfigs[0].Overrides @@ -289,7 +196,7 @@ var _ = Describe("LaunchTemplates", func() { }) lastPrice := -math.MaxFloat64 for _, override := range overrides { - offeringPrice, ok := pricingProvider.SpotPrice(*override.InstanceType, *override.AvailabilityZone) + offeringPrice, ok := awsEnv.PricingProvider.SpotPrice(*override.InstanceType, *override.AvailabilityZone) Expect(ok).To(BeTrue()) Expect(offeringPrice).To(BeNumerically(">=", lastPrice)) lastPrice = offeringPrice @@ -303,8 +210,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - input := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(input.LaunchTemplateConfigs).To(HaveLen(1)) launchTemplate := input.LaunchTemplateConfigs[0].LaunchTemplateSpecification Expect(*launchTemplate.LaunchTemplateName).To(Equal("test-launch-template")) @@ -348,8 +255,8 @@ var _ = Describe("LaunchTemplates", func() { }) ExpectProvisioned(ctx, env.Client, cluster, prov, pod1) ExpectScheduled(ctx, env.Client, pod1) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - name1 := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop().LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + name1 := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop().LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName pod2 := coretest.UnschedulablePod(coretest.PodOptions{ Tolerations: []v1.Toleration{t2, t3, t1}, @@ -358,8 +265,8 @@ var _ = Describe("LaunchTemplates", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod2) ExpectScheduled(ctx, env.Client, pod2) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - name2 := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop().LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + name2 := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop().LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName Expect(name1).To(Equal(name2)) }) It("should recover from an out-of-sync launch template cache", func() { @@ -368,19 +275,19 @@ var _ = Describe("LaunchTemplates", func() { ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - firstLt := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + firstLt := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() ltName := aws.StringValue(firstLt.LaunchTemplateName) - lt, ok := launchTemplateCache.Get(ltName) + lt, ok := awsEnv.LaunchTemplateCache.Get(ltName) Expect(ok).To(Equal(true)) // Remove expiration from cached LT - launchTemplateCache.Set(ltName, lt, -1) + awsEnv.LaunchTemplateCache.Set(ltName, lt, -1) - fakeEC2API.CreateFleetBehavior.Error.Set(awserr.New("InvalidLaunchTemplateName.NotFoundException", "", errors.New(""))) + awsEnv.EC2API.CreateFleetBehavior.Error.Set(awserr.New("InvalidLaunchTemplateName.NotFoundException", "", errors.New(""))) pod = coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) // should call fleet twice. Once will fail on invalid LT and the next will succeed - fleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + fleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(aws.StringValue(fleetInput.LaunchTemplateConfigs[0].LaunchTemplateSpecification.LaunchTemplateName)).To(Equal(ltName)) ExpectScheduled(ctx, env.Client, pod) }) @@ -396,7 +303,7 @@ var _ = Describe("LaunchTemplates", func() { Expect(node.Labels).To(HaveKey(v1.LabelInstanceTypeStable)) }) It("should apply provider labels to the node", func() { - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -427,8 +334,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.TagSpecifications).To(HaveLen(3)) tags := map[string]string{ @@ -454,8 +361,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.TagSpecifications).To(HaveLen(3)) // tags should be included in instance, volume, and fleet tag specification @@ -478,8 +385,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.TagSpecifications).To(HaveLen(3)) // tags should be included in instance, volume, and fleet tag specification @@ -509,8 +416,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.TagSpecifications).To(HaveLen(3)) // tags should be included in instance, volume, and fleet tag specification @@ -539,8 +446,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) - createFleetInput := fakeEC2API.CreateFleetBehavior.CalledWithInput.Pop() + Expect(awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Len()).To(Equal(1)) + createFleetInput := awsEnv.EC2API.CreateFleetBehavior.CalledWithInput.Pop() Expect(createFleetInput.TagSpecifications).To(HaveLen(3)) // tags should be included in instance, volume, and fleet tag specification @@ -564,8 +471,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(len(input.LaunchTemplateData.BlockDeviceMappings)).To(Equal(1)) Expect(*input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeSize).To(Equal(int64(20))) Expect(*input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeType).To(Equal("gp3")) @@ -601,8 +508,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(input.LaunchTemplateData.BlockDeviceMappings[0].Ebs).To(Equal(&ec2.LaunchTemplateEbsBlockDeviceRequest{ VolumeSize: aws.Int64(187), VolumeType: aws.String("io2"), @@ -650,8 +557,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() // Both of these values are rounded up when converting to Gibibytes Expect(aws.Int64Value(input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeSize)).To(BeNumerically("==", 4)) @@ -663,8 +570,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(len(input.LaunchTemplateData.BlockDeviceMappings)).To(Equal(2)) // Bottlerocket control volume Expect(*input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeSize).To(Equal(int64(4))) @@ -681,8 +588,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(len(input.LaunchTemplateData.BlockDeviceMappings)).To(Equal(0)) }) It("should use custom block device mapping for custom AMIFamilies", func() { @@ -704,8 +611,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(len(input.LaunchTemplateData.BlockDeviceMappings)).To(Equal(1)) Expect(*input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeSize).To(Equal(int64(40))) Expect(*input.LaunchTemplateData.BlockDeviceMappings[0].Ebs.VolumeType).To(Equal("io2")) @@ -864,7 +771,7 @@ var _ = Describe("LaunchTemplates", func() { BeforeEach(func() { var ok bool var instanceInfo []*ec2.InstanceTypeInfo - err := fakeEC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{ + err := awsEnv.EC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{ Filters: []*ec2.Filter{ { Name: aws.String("supported-virtualization-type"), @@ -914,7 +821,7 @@ var _ = Describe("LaunchTemplates", func() { BeforeEach(func() { var ok bool var instanceInfo []*ec2.InstanceTypeInfo - err := fakeEC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{ + err := awsEnv.EC2API.DescribeInstanceTypesPagesWithContext(ctx, &ec2.DescribeInstanceTypesInput{ Filters: []*ec2.Filter{ { Name: aws.String("supported-virtualization-type"), @@ -965,8 +872,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).NotTo(ContainSubstring("--use-max-pods false")) @@ -980,8 +887,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--use-max-pods false")) @@ -993,8 +900,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--use-max-pods false")) @@ -1012,8 +919,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1038,8 +945,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1064,8 +971,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1090,8 +997,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1116,8 +1023,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1138,8 +1045,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) @@ -1153,8 +1060,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring(fmt.Sprintf("--pods-per-core=%d", 2))) @@ -1168,8 +1075,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring(fmt.Sprintf("--pods-per-core=%d", 2))) @@ -1180,8 +1087,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--container-runtime containerd")) @@ -1192,8 +1099,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--container-runtime dockerd")) @@ -1214,8 +1121,8 @@ var _ = Describe("LaunchTemplates", func() { }) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--container-runtime containerd")) @@ -1236,20 +1143,20 @@ var _ = Describe("LaunchTemplates", func() { }) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--container-runtime containerd")) }) It("should specify --dns-cluster-ip and --ip-family when running in an ipv6 cluster", func() { - launchTemplateProvider.KubeDNSIP = net.ParseIP("fd4b:121b:812b::a") + awsEnv.LaunchTemplateProvider.KubeDNSIP = net.ParseIP("fd4b:121b:812b::a") ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--dns-cluster-ip 'fd4b:121b:812b::a'")) @@ -1264,8 +1171,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--image-gc-high-threshold=50")) @@ -1278,8 +1185,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--image-gc-low-threshold=50")) @@ -1303,8 +1210,8 @@ var _ = Describe("LaunchTemplates", func() { }) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) content, err = os.ReadFile("testdata/br_userdata_merged.golden") @@ -1328,8 +1235,8 @@ var _ = Describe("LaunchTemplates", func() { }) ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) content, err := os.ReadFile("testdata/br_userdata_unmerged.golden") @@ -1380,8 +1287,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) config := &bootstrap.BottlerocketConfig{} @@ -1410,8 +1317,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) config := &bootstrap.BottlerocketConfig{} @@ -1440,8 +1347,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) config := &bootstrap.BottlerocketConfig{} @@ -1465,8 +1372,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) config := &bootstrap.BottlerocketConfig{} @@ -1490,8 +1397,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) content, err = os.ReadFile("testdata/al2_userdata_merged.golden") @@ -1513,8 +1420,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) content, err = os.ReadFile("testdata/al2_userdata_merged.golden") @@ -1533,8 +1440,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) content, err := os.ReadFile("testdata/al2_userdata_unmerged.golden") @@ -1546,7 +1453,7 @@ var _ = Describe("LaunchTemplates", func() { Context("Custom AMI Selector", func() { It("should use ami selector specified in AWSNodeTemplate", func() { nodeTemplate.Spec.AMISelector = map[string]string{"karpenter.sh/discovery": "my-cluster"} - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -1558,15 +1465,15 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect("ami-123").To(Equal(*input.LaunchTemplateData.ImageId)) }) It("should copy over userData untouched when AMIFamily is Custom", func() { nodeTemplate.Spec.UserData = aws.String("special user data") nodeTemplate.Spec.AMISelector = map[string]string{"karpenter.sh/discovery": "my-cluster"} nodeTemplate.Spec.AMIFamily = &v1alpha1.AMIFamilyCustom - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -1578,15 +1485,15 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect("special user data").To(Equal(string(userData))) }) It("should correctly use ami selector with specific IDs in AWSNodeTemplate", func() { nodeTemplate.Spec.AMISelector = map[string]string{"aws-ids": "ami-123,ami-456"} - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -1606,8 +1513,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(2)) - actualFilter := fakeEC2API.CalledWithDescribeImagesInput.Pop().Filters + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(2)) + actualFilter := awsEnv.EC2API.CalledWithDescribeImagesInput.Pop().Filters expectedFilter := []*ec2.Filter{ { Name: aws.String("image-id"), @@ -1617,7 +1524,7 @@ var _ = Describe("LaunchTemplates", func() { Expect(actualFilter).To(Equal(expectedFilter)) }) It("should create multiple launch templates when multiple amis are discovered with non-equivalent requirements", func() { - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -1638,16 +1545,16 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(2)) + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(2)) expectedImageIds := sets.NewString("ami-123", "ami-456") actualImageIds := sets.NewString( - *fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop().LaunchTemplateData.ImageId, - *fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop().LaunchTemplateData.ImageId, + *awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop().LaunchTemplateData.ImageId, + *awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop().LaunchTemplateData.ImageId, ) Expect(expectedImageIds.Equal(actualImageIds)).To(BeTrue()) }) It("should create a launch template with the newest compatible AMI when multiple amis are discovered", func() { - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ { ImageId: aws.String("ami-123"), Architecture: aws.String("x86_64"), @@ -1681,13 +1588,13 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect("ami-456").To(Equal(*input.LaunchTemplateData.ImageId)) }) It("should fail if no amis match selector.", func() { - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{}}) + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{}}) nodeTemplate.Spec.AMISelector = map[string]string{"karpenter.sh/discovery": "my-cluster"} ExpectApplied(ctx, env.Client, nodeTemplate) newProvisioner := test.Provisioner(coretest.ProvisionerOptions{ProviderRef: &v1alpha5.ProviderRef{Name: nodeTemplate.Name}}) @@ -1695,10 +1602,10 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectNotScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(0)) + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(0)) }) It("should fail if no instanceType matches ami requirements.", func() { - fakeEC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ + awsEnv.EC2API.DescribeImagesOutput.Set(&ec2.DescribeImagesOutput{Images: []*ec2.Image{ {ImageId: aws.String("ami-123"), Architecture: aws.String("newnew"), CreationDate: aws.String("2022-01-01T12:00:00Z")}, }}) nodeTemplate.Spec.AMISelector = map[string]string{"karpenter.sh/discovery": "my-cluster"} @@ -1708,7 +1615,7 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectNotScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(0)) + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(0)) }) It("should choose amis from SSM if no selector specified in AWSNodeTemplate", func() { ExpectApplied(ctx, env.Client, nodeTemplate) @@ -1717,7 +1624,7 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.ImageId).To(ContainSubstring("test-ami")) }) }) @@ -1728,8 +1635,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() userData, err := base64.StdEncoding.DecodeString(*input.LaunchTemplateData.UserData) Expect(err).To(BeNil()) Expect(string(userData)).To(ContainSubstring("--dns-cluster-ip '10.0.10.100'")) @@ -1741,8 +1648,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.IamInstanceProfile.Name).To(Equal("test-instance-profile")) }) It("should use the instance profile on the Provisioner when specified", func() { @@ -1751,8 +1658,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(*input.LaunchTemplateData.IamInstanceProfile.Name).To(Equal("overridden-profile")) }) }) @@ -1764,8 +1671,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(aws.BoolValue(input.LaunchTemplateData.Monitoring.Enabled)).To(BeFalse()) }) It("should pass detailed monitoring setting to the launch template at creation", func() { @@ -1775,8 +1682,8 @@ var _ = Describe("LaunchTemplates", func() { pod := coretest.UnschedulablePod() ExpectProvisioned(ctx, env.Client, cluster, prov, pod) ExpectScheduled(ctx, env.Client, pod) - Expect(fakeEC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) - input := fakeEC2API.CalledWithCreateLaunchTemplateInput.Pop() + Expect(awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Len()).To(Equal(1)) + input := awsEnv.EC2API.CalledWithCreateLaunchTemplateInput.Pop() Expect(aws.BoolValue(input.LaunchTemplateData.Monitoring.Enabled)).To(BeTrue()) }) }) diff --git a/pkg/providers/pricing/suite_test.go b/pkg/providers/pricing/suite_test.go index 6ce6a219e5a9..60511eaf2d58 100644 --- a/pkg/providers/pricing/suite_test.go +++ b/pkg/providers/pricing/suite_test.go @@ -46,9 +46,7 @@ var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var fakePricingAPI *fake.PricingAPI -var fakeEC2API *fake.EC2API -var pricingProvider *pricing.Provider +var awsEnv *test.Environment func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -61,10 +59,7 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) - - fakeEC2API = &fake.EC2API{} - fakePricingAPI = &fake.PricingAPI{} - pricingProvider = pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + awsEnv = test.NewEnvironment(ctx, env) }) var _ = AfterSuite(func() { @@ -77,8 +72,7 @@ var _ = BeforeEach(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) - fakeEC2API.Reset() - fakePricingAPI.Reset() + awsEnv.Reset() }) var _ = AfterEach(func() { @@ -86,20 +80,16 @@ var _ = AfterEach(func() { }) var _ = Describe("Pricing", func() { - BeforeEach(func() { - fakeEC2API.Reset() - fakePricingAPI.Reset() - }) It("should return static on-demand data if pricing API fails", func() { - fakePricingAPI.NextError.Set(fmt.Errorf("failed")) - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + awsEnv.PricingAPI.NextError.Set(fmt.Errorf("failed")) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) price, ok := p.OnDemandPrice("c5.large") Expect(ok).To(BeTrue()) Expect(price).To(BeNumerically(">", 0)) }) It("should return static spot data if EC2 describeSpotPriceHistory API fails", func() { - fakePricingAPI.NextError.Set(fmt.Errorf("failed")) - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + awsEnv.PricingAPI.NextError.Set(fmt.Errorf("failed")) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) price, ok := p.SpotPrice("c5.large", "test-zone-1a") Expect(ok).To(BeTrue()) Expect(price).To(BeNumerically(">", 0)) @@ -107,14 +97,14 @@ var _ = Describe("Pricing", func() { It("should update on-demand pricing with response from the pricing API", func() { // modify our API before creating the pricing provider as it performs an initial update on creation. The pricing // API provides on-demand prices, the ec2 API provides spot prices - fakePricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ + awsEnv.PricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ PriceList: []aws.JSONValue{ fake.NewOnDemandPrice("c98.large", 1.20), fake.NewOnDemandPrice("c99.large", 1.23), }, }) updateStart := time.Now() - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) Eventually(func() bool { return p.OnDemandLastUpdated().After(updateStart) }).Should(BeTrue()) price, ok := p.OnDemandPrice("c98.large") @@ -127,7 +117,7 @@ var _ = Describe("Pricing", func() { }) It("should update spot pricing with response from the pricing API", func() { now := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -155,14 +145,14 @@ var _ = Describe("Pricing", func() { }, }, }) - fakePricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ + awsEnv.PricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ PriceList: []aws.JSONValue{ fake.NewOnDemandPrice("c98.large", 1.20), fake.NewOnDemandPrice("c99.large", 1.23), }, }) updateStart := time.Now() - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) Eventually(func() bool { return p.SpotLastUpdated().After(updateStart) }).Should(BeTrue()) price, ok := p.SpotPrice("c98.large", "test-zone-1b") @@ -175,7 +165,7 @@ var _ = Describe("Pricing", func() { }) It("should update zonal pricing with data from the spot pricing API", func() { now := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -191,14 +181,14 @@ var _ = Describe("Pricing", func() { }, }, }) - fakePricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ + awsEnv.PricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ PriceList: []aws.JSONValue{ fake.NewOnDemandPrice("c98.large", 1.20), fake.NewOnDemandPrice("c99.large", 1.23), }, }) updateStart := time.Now() - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) Eventually(func() bool { return p.SpotLastUpdated().After(updateStart) }).Should(BeTrue()) price, ok := p.SpotPrice("c98.large", "test-zone-1a") @@ -210,7 +200,7 @@ var _ = Describe("Pricing", func() { }) It("should respond with false if price doesn't exist in zone", func() { now := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -220,14 +210,14 @@ var _ = Describe("Pricing", func() { }, }, }) - fakePricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ + awsEnv.PricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ PriceList: []aws.JSONValue{ fake.NewOnDemandPrice("c98.large", 1.20), fake.NewOnDemandPrice("c99.large", 1.23), }, }) updateStart := time.Now() - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) Eventually(func() bool { return p.SpotLastUpdated().After(updateStart) }).Should(BeTrue()) _, ok := p.SpotPrice("c99.large", "test-zone-1b") @@ -239,7 +229,7 @@ var _ = Describe("Pricing", func() { // If it doesn't, they have a product description of Linux/UNIX. To work in both cases, we // need to search for both values. updateStart := time.Now() - fakeEC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ + awsEnv.EC2API.DescribeSpotPriceHistoryOutput.Set(&ec2.DescribeSpotPriceHistoryOutput{ SpotPriceHistory: []*ec2.SpotPrice{ { AvailabilityZone: aws.String("test-zone-1a"), @@ -249,15 +239,15 @@ var _ = Describe("Pricing", func() { }, }, }) - fakePricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ + awsEnv.PricingAPI.GetProductsOutput.Set(&awspricing.GetProductsOutput{ PriceList: []aws.JSONValue{ fake.NewOnDemandPrice("c98.large", 1.20), fake.NewOnDemandPrice("c99.large", 1.23), }, }) - p := pricing.NewProvider(ctx, fakePricingAPI, fakeEC2API, "", make(chan struct{})) + p := pricing.NewProvider(ctx, awsEnv.PricingAPI, awsEnv.EC2API, "", make(chan struct{})) Eventually(func() bool { return p.SpotLastUpdated().After(updateStart) }, 5*time.Second).Should(BeTrue()) - inp := fakeEC2API.DescribeSpotPriceHistoryInput.Clone() + inp := awsEnv.EC2API.DescribeSpotPriceHistoryInput.Clone() Expect(lo.Map(inp.ProductDescriptions, func(x *string, _ int) string { return *x })). To(ContainElements("Linux/UNIX", "Linux/UNIX (Amazon VPC)")) }) diff --git a/pkg/providers/securitygroup/securitygroup.go b/pkg/providers/securitygroup/securitygroup.go index 0ea8eb47a27a..178f5e71a24e 100644 --- a/pkg/providers/securitygroup/securitygroup.go +++ b/pkg/providers/securitygroup/securitygroup.go @@ -30,7 +30,6 @@ import ( "github.com/aws/karpenter-core/pkg/utils/functional" "github.com/aws/karpenter-core/pkg/utils/pretty" "github.com/aws/karpenter/pkg/apis/v1alpha1" - awscache "github.com/aws/karpenter/pkg/cache" ) type Provider struct { @@ -42,12 +41,12 @@ type Provider struct { const TTL = 5 * time.Minute -func NewProvider(ec2api ec2iface.EC2API) *Provider { +func NewProvider(ec2api ec2iface.EC2API, cache *cache.Cache) *Provider { return &Provider{ ec2api: ec2api, cm: pretty.NewChangeMonitor(), // TODO: Remove cache for v1beta1, utilize resolved security groups from the AWSNodeTemplate.status - cache: cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval), + cache: cache, } } @@ -122,7 +121,3 @@ func (p *Provider) securityGroupIds(securityGroups []*ec2.SecurityGroup) []strin } return names } - -func (p *Provider) Reset() { - p.cache.Flush() -} diff --git a/pkg/providers/securitygroup/suite_test.go b/pkg/providers/securitygroup/suite_test.go index 405c3f0ae88a..f6183bdf88c5 100644 --- a/pkg/providers/securitygroup/suite_test.go +++ b/pkg/providers/securitygroup/suite_test.go @@ -32,7 +32,6 @@ import ( "github.com/aws/karpenter/pkg/apis" "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" - "github.com/aws/karpenter/pkg/providers/securitygroup" "github.com/aws/karpenter/pkg/test" coresettings "github.com/aws/karpenter-core/pkg/apis/settings" @@ -42,17 +41,15 @@ import ( "github.com/aws/karpenter-core/pkg/operator/scheme" coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" - "github.com/aws/karpenter/pkg/fake" ) var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var fakeEC2API *fake.EC2API +var awsEnv *test.Environment var provisioner *corev1alpha5.Provisioner var nodeTemplate *v1alpha1.AWSNodeTemplate -var securityGroupProvider *securitygroup.Provider func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -65,8 +62,7 @@ var _ = BeforeSuite(func() { ctx = coresettings.ToContext(ctx, coretest.Settings()) ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) - fakeEC2API = &fake.EC2API{} - securityGroupProvider = securitygroup.NewProvider(fakeEC2API) + awsEnv = test.NewEnvironment(ctx, env) }) var _ = AfterSuite(func() { @@ -107,8 +103,7 @@ var _ = BeforeEach(func() { }, }) - fakeEC2API.Reset() - securityGroupProvider.Reset() + awsEnv.Reset() }) var _ = AfterEach(func() { @@ -118,7 +113,7 @@ var _ = AfterEach(func() { var _ = Describe("Security Group Provider", func() { It("should default to the clusters security groups", func() { ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(3)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -128,12 +123,12 @@ var _ = Describe("Security Group Provider", func() { )) }) It("should discover security groups by tag", func() { - fakeEC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ + awsEnv.EC2API.DescribeSecurityGroupsOutput.Set(&ec2.DescribeSecurityGroupsOutput{SecurityGroups: []*ec2.SecurityGroup{ {GroupId: aws.String("test-sg-1"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-1")}}}, {GroupId: aws.String("test-sg-2"), Tags: []*ec2.Tag{{Key: aws.String("kubernetes.io/cluster/test-cluster"), Value: aws.String("test-sg-2")}}}, }}) ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(2)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -144,7 +139,7 @@ var _ = Describe("Security Group Provider", func() { It("should discover security groups by multiple tag values", func() { nodeTemplate.Spec.SecurityGroupSelector = map[string]string{"Name": "test-security-group-1,test-security-group-2"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(2)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -155,7 +150,7 @@ var _ = Describe("Security Group Provider", func() { It("should discover security groups by ID", func() { nodeTemplate.Spec.SecurityGroupSelector = map[string]string{"aws-ids": "sg-test1"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(1)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -165,7 +160,7 @@ var _ = Describe("Security Group Provider", func() { It("should discover security groups by IDs", func() { nodeTemplate.Spec.SecurityGroupSelector = map[string]string{"aws-ids": "sg-test1,sg-test2"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(2)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -176,7 +171,7 @@ var _ = Describe("Security Group Provider", func() { It("should discover security groups by IDs and tags", func() { nodeTemplate.Spec.SecurityGroupSelector = map[string]string{"aws-ids": "sg-test1,sg-test2", "foo": "bar"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(2)) Expect(resolvedSecurityGroups).To(ConsistOf( @@ -187,7 +182,7 @@ var _ = Describe("Security Group Provider", func() { It("should discover security groups by IDs intersected with tags", func() { nodeTemplate.Spec.SecurityGroupSelector = map[string]string{"aws-ids": "sg-test2", "foo": "bar"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSecurityGroups, err := securityGroupProvider.List(ctx, nodeTemplate) + resolvedSecurityGroups, err := awsEnv.SecurityGroupProvider.List(ctx, nodeTemplate) Expect(err).To(BeNil()) Expect(len(resolvedSecurityGroups)).To(Equal(1)) Expect(resolvedSecurityGroups).To(ConsistOf( diff --git a/pkg/providers/subnet/subnet.go b/pkg/providers/subnet/subnet.go index f72f238e5a04..553f8f11be30 100644 --- a/pkg/providers/subnet/subnet.go +++ b/pkg/providers/subnet/subnet.go @@ -34,7 +34,6 @@ import ( "github.com/aws/karpenter-core/pkg/cloudprovider" "github.com/aws/karpenter-core/pkg/utils/functional" "github.com/aws/karpenter-core/pkg/utils/pretty" - awscache "github.com/aws/karpenter/pkg/cache" ) type Provider struct { @@ -45,13 +44,13 @@ type Provider struct { inflightIPs map[string]int64 } -func NewProvider(ec2api ec2iface.EC2API) *Provider { +func NewProvider(ec2api ec2iface.EC2API, cache *cache.Cache) *Provider { return &Provider{ ec2api: ec2api, cm: pretty.NewChangeMonitor(), // TODO: Remove cache for v1beta1, utilize resolved subnet from the AWSNodeTemplate.status // Subnets are sorted on AvailableIpAddressCount, descending order - cache: cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval), + cache: cache, // inflightIPs is used to track IPs from known launched instances inflightIPs: map[string]int64{}, } @@ -242,8 +241,3 @@ func Pretty(subnets []*ec2.Subnet) []string { } return names } - -func (p *Provider) Reset() { - p.cache.Flush() - p.inflightIPs = map[string]int64{} -} diff --git a/pkg/providers/subnet/suite_test.go b/pkg/providers/subnet/suite_test.go index 034b115bddbd..20203ca2b2e6 100644 --- a/pkg/providers/subnet/suite_test.go +++ b/pkg/providers/subnet/suite_test.go @@ -29,29 +29,27 @@ import ( . "knative.dev/pkg/logging/testing" "github.com/aws/karpenter/pkg/apis" - awssettings "github.com/aws/karpenter/pkg/apis/settings" + "github.com/aws/karpenter/pkg/apis/settings" "github.com/aws/karpenter/pkg/apis/v1alpha1" "github.com/aws/karpenter/pkg/providers/subnet" "github.com/aws/karpenter/pkg/test" - "github.com/aws/karpenter-core/pkg/apis/settings" + coresettings "github.com/aws/karpenter-core/pkg/apis/settings" corev1alpha5 "github.com/aws/karpenter-core/pkg/apis/v1alpha5" "github.com/aws/karpenter-core/pkg/operator/injection" "github.com/aws/karpenter-core/pkg/operator/options" "github.com/aws/karpenter-core/pkg/operator/scheme" coretest "github.com/aws/karpenter-core/pkg/test" . "github.com/aws/karpenter-core/pkg/test/expectations" - "github.com/aws/karpenter/pkg/fake" ) var ctx context.Context var stop context.CancelFunc var opts options.Options var env *coretest.Environment -var fakeEC2API *fake.EC2API +var awsEnv *test.Environment var provisioner *corev1alpha5.Provisioner var nodeTemplate *v1alpha1.AWSNodeTemplate -var subnetProvider *subnet.Provider func TestAWS(t *testing.T) { ctx = TestContextWithLogger(t) @@ -61,10 +59,10 @@ func TestAWS(t *testing.T) { var _ = BeforeSuite(func() { env = coretest.NewEnvironment(scheme.Scheme, coretest.WithCRDs(apis.CRDs...)) + ctx = coresettings.ToContext(ctx, coretest.Settings()) + ctx = settings.ToContext(ctx, test.Settings()) ctx, stop = context.WithCancel(ctx) - - fakeEC2API = &fake.EC2API{} - subnetProvider = subnet.NewProvider(fakeEC2API) + awsEnv = test.NewEnvironment(ctx, env) }) var _ = AfterSuite(func() { @@ -74,8 +72,8 @@ var _ = AfterSuite(func() { var _ = BeforeEach(func() { ctx = injection.WithOptions(ctx, opts) - ctx = settings.ToContext(ctx, coretest.Settings()) - ctx = awssettings.ToContext(ctx, test.Settings()) + ctx = coresettings.ToContext(ctx, coretest.Settings()) + ctx = settings.ToContext(ctx, test.Settings()) nodeTemplate = &v1alpha1.AWSNodeTemplate{ ObjectMeta: metav1.ObjectMeta{ Name: coretest.RandomName(), @@ -105,8 +103,7 @@ var _ = BeforeEach(func() { }, }) - fakeEC2API.Reset() - subnetProvider.Reset() + awsEnv.Reset() }) var _ = AfterEach(func() { @@ -117,7 +114,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnet by ID", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"aws-ids": "subnet-test1"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(1)) @@ -128,7 +125,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnets by IDs", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"aws-ids": "subnet-test1,subnet-test2"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(2)) @@ -140,7 +137,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnets by IDs and tags", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"aws-ids": "subnet-test1,subnet-test2", "foo": "bar"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(2)) @@ -152,7 +149,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnets by a single tag", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"Name": "test-subnet-1"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(1)) @@ -163,7 +160,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnets by multiple tag values", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"Name": "test-subnet-1,test-subnet-2"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(2)) @@ -175,7 +172,7 @@ var _ = Describe("Subnet Provider", func() { It("should discover subnets by IDs intersected with tags", func() { nodeTemplate.Spec.SubnetSelector = map[string]string{"aws-ids": "subnet-test2", "foo": "bar"} ExpectApplied(ctx, env.Client, provisioner, nodeTemplate) - resolvedSubnetProvider, err := subnetProvider.List(ctx, nodeTemplate) + resolvedSubnetProvider, err := awsEnv.SubnetProvider.List(ctx, nodeTemplate) resolvedSubnet := subnet.Pretty(resolvedSubnetProvider) Expect(err).To(BeNil()) Expect(len(resolvedSubnet)).To(Equal(1)) diff --git a/pkg/test/environment.go b/pkg/test/environment.go new file mode 100644 index 000000000000..36bcc528f675 --- /dev/null +++ b/pkg/test/environment.go @@ -0,0 +1,148 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package test + +import ( + "context" + "net" + + "knative.dev/pkg/ptr" + + "github.com/patrickmn/go-cache" + + awscache "github.com/aws/karpenter/pkg/cache" + "github.com/aws/karpenter/pkg/fake" + "github.com/aws/karpenter/pkg/providers/amifamily" + "github.com/aws/karpenter/pkg/providers/instance" + "github.com/aws/karpenter/pkg/providers/instancetype" + "github.com/aws/karpenter/pkg/providers/launchtemplate" + "github.com/aws/karpenter/pkg/providers/pricing" + "github.com/aws/karpenter/pkg/providers/securitygroup" + "github.com/aws/karpenter/pkg/providers/subnet" + + coretest "github.com/aws/karpenter-core/pkg/test" +) + +type Environment struct { + // API + EC2API *fake.EC2API + SSMAPI *fake.SSMAPI + PricingAPI *fake.PricingAPI + + // Cache + SSMCache *cache.Cache + EC2Cache *cache.Cache + KubernetesVersionCache *cache.Cache + InstanceTypeCache *cache.Cache + UnavailableOfferingsCache *awscache.UnavailableOfferings + LaunchTemplateCache *cache.Cache + SubnetCache *cache.Cache + SecurityGroupCache *cache.Cache + + // Providers + InstanceTypesProvider *instancetype.Provider + InstanceProvider *instance.Provider + SubnetProvider *subnet.Provider + SecurityGroupProvider *securitygroup.Provider + PricingProvider *pricing.Provider + AMIProvider *amifamily.Provider + AMIResolver *amifamily.Resolver + LaunchTemplateProvider *launchtemplate.Provider +} + +func NewEnvironment(ctx context.Context, env *coretest.Environment) *Environment { + // API + ec2api := &fake.EC2API{} + ssmapi := &fake.SSMAPI{} + + // cache + ssmCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + ec2Cache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + kubernetesVersionCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + instanceTypeCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + unavailableOfferingsCache := awscache.NewUnavailableOfferings() + launchTemplateCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + subnetCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + securityGroupCache := cache.New(awscache.DefaultTTL, awscache.DefaultCleanupInterval) + fakePricingAPI := &fake.PricingAPI{} + + // Providers + pricingProvider := pricing.NewProvider(ctx, fakePricingAPI, ec2api, "", make(chan struct{})) + subnetProvider := subnet.NewProvider(ec2api, subnetCache) + securityGroupProvider := securitygroup.NewProvider(ec2api, securityGroupCache) + amiProvider := amifamily.NewProvider(env.Client, env.KubernetesInterface, ssmapi, ec2api, ssmCache, ec2Cache, kubernetesVersionCache) + amiResolver := amifamily.New(env.Client, amiProvider) + instanceTypesProvider := instancetype.NewProvider("", instanceTypeCache, ec2api, subnetProvider, unavailableOfferingsCache, pricingProvider) + launchTemplateProvider := + launchtemplate.NewProvider( + ctx, + launchTemplateCache, + ec2api, + amiResolver, + securityGroupProvider, + ptr.String("ca-bundle"), + make(chan struct{}), + net.ParseIP("10.0.100.10"), + "https://test-cluster", + ) + instanceProvider := + instance.NewProvider(ctx, + "", + ec2api, + unavailableOfferingsCache, + instanceTypesProvider, + subnetProvider, + launchTemplateProvider, + ) + + return &Environment{ + EC2API: ec2api, + SSMAPI: ssmapi, + PricingAPI: fakePricingAPI, + + SSMCache: ssmCache, + EC2Cache: ec2Cache, + KubernetesVersionCache: kubernetesVersionCache, + InstanceTypeCache: instanceTypeCache, + LaunchTemplateCache: launchTemplateCache, + SubnetCache: subnetCache, + SecurityGroupCache: securityGroupCache, + UnavailableOfferingsCache: unavailableOfferingsCache, + + InstanceTypesProvider: instanceTypesProvider, + InstanceProvider: instanceProvider, + SubnetProvider: subnetProvider, + SecurityGroupProvider: securityGroupProvider, + PricingProvider: pricingProvider, + AMIProvider: amiProvider, + AMIResolver: amiResolver, + LaunchTemplateProvider: launchTemplateProvider, + } +} + +func (env *Environment) Reset() { + env.EC2API.Reset() + env.SSMAPI.Reset() + env.PricingAPI.Reset() + + env.SSMCache.Flush() + env.EC2Cache.Flush() + env.KubernetesVersionCache.Flush() + env.InstanceTypeCache.Flush() + env.UnavailableOfferingsCache.Flush() + env.LaunchTemplateCache.Flush() + env.SubnetCache.Flush() + env.SecurityGroupCache.Flush() +} diff --git a/test/go.mod b/test/go.mod index 694b6e8d2aba..798a6bbff642 100644 --- a/test/go.mod +++ b/test/go.mod @@ -72,6 +72,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.14.0 // indirect github.com/prometheus/client_model v0.3.0 // indirect diff --git a/test/go.sum b/test/go.sum index e7c9dc1ce390..bae2acb95d8c 100644 --- a/test/go.sum +++ b/test/go.sum @@ -284,6 +284,8 @@ github.com/onsi/gomega v1.27.1 h1:rfztXRbg6nv/5f+Raen9RcGoSecHIFgBBLQK3Wdj754= github.com/onsi/gomega v1.27.1/go.mod h1:aHX5xOykVYzWOV4WqQy0sy8BQptgukenXpCXfadcIAw= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= +github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -328,11 +330,16 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=