diff --git a/ec2cluster/volume/ec2volume_test.go b/ec2cluster/volume/ec2volume_test.go index c403ba5..2aecf6f 100644 --- a/ec2cluster/volume/ec2volume_test.go +++ b/ec2cluster/volume/ec2volume_test.go @@ -5,8 +5,6 @@ package volume import ( - "fmt" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ec2" @@ -15,28 +13,30 @@ import ( type mockEC2Client struct { ec2iface.EC2API - descVolsFn func() (*ec2.DescribeVolumesOutput, error) - descVolsModsFn func() (*ec2.DescribeVolumesModificationsOutput, error) + descVolsFn func(volIds []string) (*ec2.DescribeVolumesOutput, error) + descVolsModsFn func(volIds []string) (*ec2.DescribeVolumesModificationsOutput, error) modVolsFn func(volId string) (*ec2.ModifyVolumeOutput, error) + + nModVols int } -func (e *mockEC2Client) DescribeVolumesWithContext(ctx aws.Context, input *ec2.DescribeVolumesInput, _ ...request.Option) (*ec2.DescribeVolumesOutput, error) { - if len(input.VolumeIds) == 0 { - return nil, fmt.Errorf("must specify at least one volume id") +func (e *mockEC2Client) DescribeVolumesWithContext(_ aws.Context, input *ec2.DescribeVolumesInput, _ ...request.Option) (*ec2.DescribeVolumesOutput, error) { + ids := make([]string, 0, len(input.VolumeIds)) + for _, id := range input.VolumeIds { + ids = append(ids, aws.StringValue(id)) } - return e.descVolsFn() + return e.descVolsFn(ids) } -func (e *mockEC2Client) DescribeVolumesModificationsWithContext(ctx aws.Context, input *ec2.DescribeVolumesModificationsInput, _ ...request.Option) (*ec2.DescribeVolumesModificationsOutput, error) { - if len(input.VolumeIds) == 0 { - return nil, fmt.Errorf("must specify at least one volume id") +func (e *mockEC2Client) DescribeVolumesModificationsWithContext(_ aws.Context, input *ec2.DescribeVolumesModificationsInput, _ ...request.Option) (*ec2.DescribeVolumesModificationsOutput, error) { + ids := make([]string, 0, len(input.VolumeIds)) + for _, id := range input.VolumeIds { + ids = append(ids, aws.StringValue(id)) } - return e.descVolsModsFn() + return e.descVolsModsFn(ids) } -func (e *mockEC2Client) ModifyVolumeWithContext(ctx aws.Context, input *ec2.ModifyVolumeInput, _ ...request.Option) (*ec2.ModifyVolumeOutput, error) { - if input.VolumeId == nil { - return nil, fmt.Errorf("volume id cannot be empty") - } - return e.modVolsFn(*input.VolumeId) +func (e *mockEC2Client) ModifyVolumeWithContext(_ aws.Context, input *ec2.ModifyVolumeInput, _ ...request.Option) (*ec2.ModifyVolumeOutput, error) { + e.nModVols++ + return e.modVolsFn(aws.StringValue(input.VolumeId)) } diff --git a/ec2cluster/volume/volume.go b/ec2cluster/volume/volume.go index 970ec29..e96ba7c 100644 --- a/ec2cluster/volume/volume.go +++ b/ec2cluster/volume/volume.go @@ -238,9 +238,9 @@ func (v *ebsLvmVolume) ResizeEBS(ctx context.Context, newSize data.Size) error { var ( state stateT stateStr string - modifiedIds []string modifiedStates map[string]string ) + modifiedIds := make(map[string]struct{}) retriesByState := make(map[stateT]int) for state < stateDone { err = nil @@ -258,13 +258,12 @@ func (v *ebsLvmVolume) ResizeEBS(ctx context.Context, newSize data.Size) error { cancel() return nil }) - modifiedIds = nil var failed []string for idx, err := range errs { if err != nil { failed = append(failed, idsToModify[idx]) } else { - modifiedIds = append(modifiedIds, idsToModify[idx]) + modifiedIds[idsToModify[idx]] = struct{}{} } } idsToModify = failed @@ -272,8 +271,12 @@ func (v *ebsLvmVolume) ResizeEBS(ctx context.Context, newSize data.Size) error { err = fmt.Errorf("failed to modify (%s): %v", strings.Join(failed, ", "), errs) } case stateGetModificationStatus: - stateStr = fmt.Sprintf("get modification status (%s)", strings.Join(modifiedIds, ", ")) - modifiedStates, err = v.getModificationStateById(ctx, modifiedIds) + ids := make([]string, 0, len(modifiedIds)) + for id, _ := range modifiedIds { + ids = append(ids, id) + } + stateStr = fmt.Sprintf("get modification status (%s)", strings.Join(ids, ", ")) + modifiedStates, err = v.getModificationStateById(ctx, ids) case stateCheckModificationStatus: stateStr = fmt.Sprintf("checking modification status (%v)", modifiedStates) var reasons []string diff --git a/ec2cluster/volume/volume_test.go b/ec2cluster/volume/volume_test.go index 943906c..9459d6f 100644 --- a/ec2cluster/volume/volume_test.go +++ b/ec2cluster/volume/volume_test.go @@ -7,6 +7,8 @@ package volume import ( "context" "fmt" + "reflect" + "sort" "sync" "testing" "time" @@ -20,16 +22,23 @@ import ( var errTest = fmt.Errorf("test error") -func descVolsFn(m map[string]int64, err error) func() (*ec2.DescribeVolumesOutput, error) { +func descVolsFn(m map[string]int64, err error) func([]string) (*ec2.DescribeVolumesOutput, error) { var out *ec2.DescribeVolumesOutput + outIds := make([]string, 0, len(m)) if err == nil { var vols []*ec2.Volume for k, v := range m { vols = append(vols, &ec2.Volume{VolumeId: aws.String(k), Size: aws.Int64(v)}) + outIds = append(outIds, k) } out = &ec2.DescribeVolumesOutput{Volumes: vols} } - return func() (*ec2.DescribeVolumesOutput, error) { + return func(volIds []string) (*ec2.DescribeVolumesOutput, error) { + sort.Strings(volIds) + sort.Strings(outIds) + if !reflect.DeepEqual(volIds, outIds) { + return nil, fmt.Errorf("DescribeVolumes called with %v, want %v", volIds, outIds) + } return out, err } } @@ -37,7 +46,7 @@ func descVolsFn(m map[string]int64, err error) func() (*ec2.DescribeVolumesOutpu func TestEBSSize(t *testing.T) { for _, tt := range []struct { vols []string - fn func() (*ec2.DescribeVolumesOutput, error) + fn func([]string) (*ec2.DescribeVolumesOutput, error) wsize data.Size werr bool }{ @@ -59,9 +68,9 @@ func TestEBSSize(t *testing.T) { } } -func volModsFn(outs []*ec2.DescribeVolumesModificationsOutput, errs []error) func() (*ec2.DescribeVolumesModificationsOutput, error) { +func volModsFn(outs []*ec2.DescribeVolumesModificationsOutput, errs []error) func([]string) (*ec2.DescribeVolumesModificationsOutput, error) { var idx int - return func() (*ec2.DescribeVolumesModificationsOutput, error) { + return func(volIds []string) (*ec2.DescribeVolumesModificationsOutput, error) { var ( out *ec2.DescribeVolumesModificationsOutput err error @@ -72,6 +81,16 @@ func volModsFn(outs []*ec2.DescribeVolumesModificationsOutput, errs []error) fun } else { out = outs[len(outs)-1] } + outIds := make([]string, 0, len(out.VolumesModifications)) + for _, mod := range out.VolumesModifications { + outIds = append(outIds, aws.StringValue(mod.VolumeId)) + } + sort.Strings(volIds) + sort.Strings(outIds) + if !reflect.DeepEqual(volIds, outIds) { + out = nil + return nil, fmt.Errorf("DescribeVolumeModifications called with %v, want %v", volIds, outIds) + } } if errs != nil { if idx < len(errs) { @@ -119,14 +138,16 @@ func modVolsFn(m map[string][]error) func(string) (*ec2.ModifyVolumeOutput, erro } func TestResizeEBS(t *testing.T) { + retries := 5 for _, tt := range []struct { name string vols []string newSz data.Size - descVol func() (*ec2.DescribeVolumesOutput, error) - descVolMod func() (*ec2.DescribeVolumesModificationsOutput, error) + descVol func([]string) (*ec2.DescribeVolumesOutput, error) + descVolMod func([]string) (*ec2.DescribeVolumesModificationsOutput, error) modVol func(volId string) (*ec2.ModifyVolumeOutput, error) werr bool + nModVols int }{ { name: "fail to describe volumes", @@ -180,7 +201,8 @@ func TestResizeEBS(t *testing.T) { "volB": {nil}, }, ), - werr: true, + werr: true, + nModVols: 2 + retries, }, { name: "volume modification requests succeed but volume fails to modify", @@ -200,7 +222,8 @@ func TestResizeEBS(t *testing.T) { "volB": {nil}, }, ), - werr: true, + werr: true, + nModVols: 2 + retries, }, { name: "all volumes resize successfully", @@ -220,7 +243,8 @@ func TestResizeEBS(t *testing.T) { "volB": {nil}, }, ), - werr: false, + werr: false, + nModVols: 2, }, { name: "all volumes resize successfully after checking modification status 3 times", @@ -242,10 +266,34 @@ func TestResizeEBS(t *testing.T) { "volB": {nil}, }, ), - werr: false, + werr: false, + nModVols: 2, }, { - name: "one volume doesn't need to be resized, other volume resizes successfully", + name: "both resizes succeed on the second try, but failures are detected at different points in time", + vols: []string{"volA", "volB"}, + newSz: 50 * data.GiB, + descVol: descVolsFn(map[string]int64{"volA": 5, "volB": 5}, nil), + descVolMod: volModsFn( + []*ec2.DescribeVolumesModificationsOutput{ + volModsOut(map[string]string{"volA": "completed", "volB": "completed"}), + volModsOut(map[string]string{"volA": "modifying", "volB": "failed"}), + volModsOut(map[string]string{"volA": "failed", "volB": "modifying"}), + volModsOut(map[string]string{"volA": "optimizing", "volB": "optimizing"}), + }, + nil, + ), + modVol: modVolsFn( + map[string][]error{ + "volA": {nil}, + "volB": {nil}, + }, + ), + werr: false, + nModVols: 4, + }, + { + name: "volA doesn't need to be resized, volB resizes successfully", vols: []string{"volA", "volB"}, newSz: 50 * data.GiB, descVol: descVolsFn(map[string]int64{"volA": 5, "volB": 25}, nil), @@ -261,16 +309,21 @@ func TestResizeEBS(t *testing.T) { "volA": {nil}, }, ), - werr: false, + werr: false, + nModVols: 1, }, } { + c := &mockEC2Client{descVolsFn: tt.descVol, descVolsModsFn: tt.descVolMod, modVolsFn: tt.modVol} v := &ebsLvmVolume{ebsVolIds: tt.vols, log: log.Std, ebsVolType: ec2.VolumeTypeGp3, - ec2: &mockEC2Client{descVolsFn: tt.descVol, descVolsModsFn: tt.descVolMod, modVolsFn: tt.modVol}, - retrier: retry.MaxRetries(retry.Backoff(10*time.Millisecond, 20*time.Millisecond, 1.5), 5), + ec2: c, + retrier: retry.MaxRetries(retry.Backoff(10*time.Millisecond, 20*time.Millisecond, 1.5), retries), } err := v.ResizeEBS(context.Background(), tt.newSz) if gotE := err != nil; gotE != tt.werr { - t.Errorf("got error: %v, want error: %t", err, tt.werr) + t.Errorf("%s: error: got: %v, want: %t", tt.name, err, tt.werr) + } + if got, want := c.nModVols, tt.nModVols; got != want { + t.Errorf("%s: nModVols: got %d, want: %d", tt.name, got, want) } } }