diff --git a/pkg/testingutil/chan.go b/pkg/testingutil/chan.go index bc24e2c6..22777aa4 100644 --- a/pkg/testingutil/chan.go +++ b/pkg/testingutil/chan.go @@ -2,11 +2,10 @@ package testingutil import ( "context" - "testing" ) // FetchAll pulls all resources from `f` over the passed channel, returning the resources as a slice -func FetchAll[T any](ctx context.Context, t testing.TB, f func(context.Context, chan<- T) error) ([]T, error) { +func FetchAll[T any](ctx context.Context, t TestingTB, f func(context.Context, chan<- T) error) ([]T, error) { t.Helper() var resources []T @@ -27,7 +26,7 @@ func FetchAll[T any](ctx context.Context, t testing.TB, f func(context.Context, } // MustFetchAll is like FetchAll, but fatals the running test if there is an error during fetching -func MustFetchAll[T any](ctx context.Context, t testing.TB, f func(context.Context, chan<- T) error) []T { +func MustFetchAll[T any](ctx context.Context, t TestingTB, f func(context.Context, chan<- T) error) []T { t.Helper() resources, err := FetchAll(ctx, t, f) diff --git a/pkg/testingutil/chan_test.go b/pkg/testingutil/chan_test.go index 11908ea9..afb0ec8c 100644 --- a/pkg/testingutil/chan_test.go +++ b/pkg/testingutil/chan_test.go @@ -71,7 +71,8 @@ func TestMustFetchAllCanceled(t *testing.T) { return util.SendAllFromSlice(ctx, out, in) } - assert.PanicsWithError(t, "error with testingutil.FetchAll: context canceled", func() { - MustFetchAll(ctx, Fake(t), fetchFunc) - }) + fake := Fake(t) + MustFetchAll(ctx, fake, fetchFunc) + assert.Contains(t, fake.Logs, "error with testingutil.FetchAll: context canceled") + assert.True(t, fake.IsFail) } diff --git a/pkg/testingutil/filter.go b/pkg/testingutil/filter.go index e21df18a..e5870395 100644 --- a/pkg/testingutil/filter.go +++ b/pkg/testingutil/filter.go @@ -21,39 +21,12 @@ type ResourceFilter struct { func (f ResourceFilter) String() string { var parts []string - if f.AccountId != "" { - parts = append(parts, fmt.Sprintf("AccountId=%s", f.AccountId)) - } - - if f.Type != "" { - parts = append(parts, fmt.Sprintf("Type=%s", f.Type)) - } - - if f.Region != "" { - parts = append(parts, fmt.Sprintf("Region=%s", f.Region)) - } - - if f.Tags != nil && len(f.Tags) == 0 { - parts = append(parts, "Tags=[]") - } else { - for _, tag := range f.Tags { - if tag.Value == "" { - parts = append(parts, fmt.Sprintf("Tags[%s]", tag.Key)) - } else { - parts = append(parts, fmt.Sprintf("Tags[%s]=%s", tag.Key, tag.Value)) - } + for _, matcher := range f.matchers() { + if !matcher.present() { + continue } - } - if len(f.RawData) > 0 { - rawParts := make([]string, 0, len(f.RawData)) - - for key, val := range f.RawData { - rawParts = append(rawParts, fmt.Sprintf("%s=%v", key, val)) - } - //sorting ensures consistent output for testing - sort.Strings(rawParts) - parts = append(parts, fmt.Sprintf("RawData={%s}", strings.Join(rawParts, ", "))) + parts = append(parts, matcher.stringer()) } fields := strings.Join(parts, ", ") @@ -61,84 +34,177 @@ func (f ResourceFilter) String() string { } func (f ResourceFilter) Matches(resource model.Resource) bool { - if f.AccountId != "" { - if resource.AccountId != f.AccountId { - return false - } - } - - if f.Region != "" { - if resource.Region != f.Region { - return false + for _, matcher := range f.matchers() { + if !matcher.present() { + continue } - } - if f.Type != "" { - if resource.Type != f.Type { + if !matcher.match(resource) { return false } } - // Treat empty slice different from nil - if f.Tags != nil { - // Treat empty slice as special "no tags" filter - if len(f.Tags) == 0 { - if len(resource.Tags) != 0 { - return false - } - } else { - tagMap := make(map[string]string) - for _, tag := range resource.Tags { - tagMap[tag.Key] = tag.Value - } - - for _, tag := range f.Tags { - val, has := tagMap[tag.Key] - if !has { - return false - } + return true +} - if tag.Value == "" { - continue - } +func (f ResourceFilter) Filter(in []model.Resource) []model.Resource { + out := make([]model.Resource, 0, len(in)) - if strings.TrimSpace(val) != strings.TrimSpace(tag.Value) { - return false - } - } + for _, resource := range in { + if f.Matches(resource) { + out = append(out, resource) } } - if len(f.RawData) > 0 { - var raw map[string]any - err := json.Unmarshal(resource.RawData, &raw) - if err != nil { - panic(fmt.Errorf("cannot pase model.Resource.RawData: %s", resource.Id)) + return out +} + +// PartialFilter returns a more detailed filtering of the resources, with filter field processed separately. +// This can aid in debugging to determine why a resource isn't matching a given filter. +func (f ResourceFilter) PartialFilter(in []model.Resource) map[string][]model.Resource { + output := make(map[string][]model.Resource) + + for _, matcher := range f.matchers() { + if !matcher.present() { + continue } - for key, val := range f.RawData { - rawVal, has := raw[key] - if !has { - return false + var resources []model.Resource + for _, res := range in { + if !matcher.match(res) { + continue } - if !reflect.DeepEqual(val, rawVal) { - return false - } + resources = append(resources, res) } + + output[matcher.name] = resources } - return true + return output } -func (f ResourceFilter) Filter(in []model.Resource) []model.Resource { - out := make([]model.Resource, 0, len(in)) +type resourceFilterMatcher struct { + name string + present func() bool + stringer func() string + match func(model.Resource) bool +} - for _, resource := range in { - if f.Matches(resource) { - out = append(out, resource) +func (f ResourceFilter) matchers() []resourceFilterMatcher { + p := func(present bool) func() bool { + return func() bool { + return present + } + } + s := func(format string, val any) func() string { + return func() string { + return fmt.Sprintf(format, val) } } + return []resourceFilterMatcher{ + { + name: "AccountId", + present: p(f.AccountId != ""), + stringer: s("AccountId=%s", f.AccountId), + match: func(r model.Resource) bool { return f.AccountId == r.AccountId }, + }, + { + name: "Type", + present: p(f.Type != ""), + stringer: s("Type=%s", f.Type), + match: func(r model.Resource) bool { return f.Type == r.Type }, + }, + { + name: "Region", + present: p(f.Region != ""), + stringer: s("Region=%s", f.Region), + match: func(r model.Resource) bool { return f.Region == r.Region }, + }, + { + name: "Tags", + present: p(f.Tags != nil), // Empty non-nil slice has special meaning + stringer: func() string { + if len(f.Tags) == 0 { + return "Tags=[]" + } - return out + var parts []string + for _, tag := range f.Tags { + if tag.Value == "" { + parts = append(parts, fmt.Sprintf("Tags[%s]", tag.Key)) + } else { + parts = append(parts, fmt.Sprintf("Tags[%s]=%s", tag.Key, tag.Value)) + } + } + + return strings.Join(parts, ", ") + }, + match: func(r model.Resource) bool { + // Treat empty slice as special "no tags" filter + if len(f.Tags) == 0 { + return len(r.Tags) == 0 + } + + tagMap := make(map[string]string) + for _, tag := range r.Tags { + tagMap[tag.Key] = tag.Value + } + + for _, tag := range f.Tags { + val, has := tagMap[tag.Key] + if !has { + return false + } + + if tag.Value == "" { + continue + } + + if strings.TrimSpace(val) != strings.TrimSpace(tag.Value) { + return false + } + } + + return true + }, + }, + { + name: "RawData", + present: p(len(f.RawData) > 0), + stringer: func() string { + rawParts := make([]string, 0, len(f.RawData)) + + for key, val := range f.RawData { + pair := fmt.Sprintf("%s=%v", key, val) + rawParts = append(rawParts, pair) + } + //sorting ensures consistent output for testing + sort.Strings(rawParts) + + data := strings.Join(rawParts, ", ") + return fmt.Sprintf("RawData={%s}", data) + }, + match: func(r model.Resource) bool { + var raw map[string]any + err := json.Unmarshal(r.RawData, &raw) + if err != nil { + panic(fmt.Errorf("cannot parse model.Resource.RawData: %s", r.Id)) + } + + for key, val := range f.RawData { + rawVal, has := raw[key] + if !has { + return false + } + + if !reflect.DeepEqual(val, rawVal) { + return false + } + } + + return true + }, + }, + } } diff --git a/pkg/testingutil/filter_test.go b/pkg/testingutil/filter_test.go index dde3c1c4..96468269 100644 --- a/pkg/testingutil/filter_test.go +++ b/pkg/testingutil/filter_test.go @@ -5,6 +5,7 @@ import ( "github.com/run-x/cloudgrep/pkg/model" "github.com/stretchr/testify/assert" + "golang.org/x/exp/maps" ) func TestResourceFilter_Matches(t *testing.T) { @@ -298,3 +299,94 @@ func TestResourceFilter_String(t *testing.T) { }) } } + +func TestResourceFilter_matchers_unqiue(t *testing.T) { + names := make(map[string]struct{}) + f := ResourceFilter{} + for _, matcher := range f.matchers() { + if _, has := names[matcher.name]; has { + t.Errorf("duplicate matcher name: %s", matcher.name) + } + + names[matcher.name] = struct{}{} + } +} + +func TestResourceFilter_PartialFilter(t *testing.T) { + tests := []struct { + name string + want map[string][]string + filter ResourceFilter + resources []model.Resource + }{ + { + name: "multiple", + want: map[string][]string{ + "AccountId": {"foo"}, + "Type": {"foo", "bar"}, + "Region": {"bar"}, + }, + filter: ResourceFilter{ + AccountId: "a", + Type: "b", + Region: "c", + }, + resources: []model.Resource{ + { + Id: "foo", + AccountId: "a", + Type: "b", + Region: "d", + }, + { + Id: "bar", + AccountId: "e", + Type: "b", + Region: "c", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := test.filter.PartialFilter(test.resources) + actualKeys := maps.Keys(actual) + expectedKeys := maps.Keys(test.want) + + assert.ElementsMatch(t, expectedKeys, actualKeys) + + for name, expectedIds := range test.want { + actualResources, has := actual[name] + if !has { + // Already errored on above ElementsMatch + continue + } + + var actualIds []string + for _, resource := range actualResources { + actualIds = append(actualIds, resource.Id) + } + + assert.ElementsMatchf(t, expectedIds, actualIds, "expected ids on matcher %s to match", name) + } + }) + } +} + +func TestResourceFilter_Match_rawDataPanic(t *testing.T) { + f := ResourceFilter{ + RawData: map[string]any{ + "foo": "bar", + }, + } + + r := model.Resource{ + Id: "spam", + RawData: []byte("{"), + } + + assert.PanicsWithError(t, "cannot parse model.Resource.RawData: spam", func() { + f.Matches(r) + }) +} diff --git a/pkg/testingutil/model.go b/pkg/testingutil/model.go index bd7e5e54..45802739 100644 --- a/pkg/testingutil/model.go +++ b/pkg/testingutil/model.go @@ -1,17 +1,21 @@ package testingutil import ( + "fmt" + "strings" "testing" "github.com/run-x/cloudgrep/pkg/model" "github.com/stretchr/testify/assert" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) const TestTag = "test" // AssertResourceCount asserts that there is a specific number of given resources with the "test" tag. // If tagValue is not an empty string, it also filters on resources that have the "test" tag with that value. -func AssertResourceCount(t testing.TB, resources []model.Resource, tagValue string, count int) { +func AssertResourceCount(t TestingTB, resources []model.Resource, tagValue string, count int) { t.Helper() if tagValue == "" { resources = ResourceFilterTagKey(resources, TestTag) @@ -101,12 +105,28 @@ func AssertEqualsTags(t *testing.T, a, b model.Tags) { } } -func AssertResourceFilteredCount(t testing.TB, resources []model.Resource, count int, filter ResourceFilter) []model.Resource { +func AssertResourceFilteredCount(t TestingTB, resources []model.Resource, count int, filter ResourceFilter) []model.Resource { t.Helper() filtered := filter.Filter(resources) - assert.Lenf(t, filtered, count, "expected %d resource(s) with filter %s", count, filter) + success := assert.Lenf(t, filtered, count, "expected %d resource(s) with filter %s", count, filter) + if !success { + partialFiltered := filter.PartialFilter(resources) + + names := maps.Keys(partialFiltered) + slices.Sort(names) + + var matches []string + for _, name := range names { + resources := partialFiltered[name] + matches = append(matches, + fmt.Sprintf("%s=%d", name, len(resources)), + ) + } + + t.Errorf("filter %s partial matches: %s", filter, strings.Join(matches, ", ")) + } return filtered } diff --git a/pkg/testingutil/model_test.go b/pkg/testingutil/model_test.go index 77c40afc..3128b454 100644 --- a/pkg/testingutil/model_test.go +++ b/pkg/testingutil/model_test.go @@ -7,6 +7,7 @@ import ( "github.com/run-x/cloudgrep/pkg/model" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAssertResourceCount_keyOnly(t *testing.T) { @@ -114,3 +115,65 @@ func TestAssertEqualTag(t *testing.T) { } AssertEqualsTag(t, &t1, &t2) } + +func TestAssertResourceFilteredCount_success(t *testing.T) { + f := ResourceFilter{ + AccountId: "foo", + } + + in := []model.Resource{ + { + AccountId: "foo", + Id: "spam", + }, + { + AccountId: "foo", + Id: "ham", + }, + { + AccountId: "bar", + Id: "a", + }, + } + + tb := &FakeTB{} + + filtered := AssertResourceFilteredCount(tb, in, 2, f) + assert.False(t, tb.IsFail) + assert.ElementsMatch(t, in[0:2], filtered) + assert.Empty(t, tb.Logs) +} + +func TestAssertResourceFilteredCount_fail(t *testing.T) { + f := ResourceFilter{ + AccountId: "foo", + Region: "us", + } + + in := []model.Resource{ + { + AccountId: "foo", + Region: "us", + Id: "spam", + }, + { + AccountId: "foo", + Region: "eu", + Id: "ham", + }, + { + AccountId: "bar", + Region: "us", + Id: "a", + }, + } + + tb := &FakeTB{} + + filtered := AssertResourceFilteredCount(tb, in, 2, f) + assert.True(t, tb.IsFail) + assert.ElementsMatch(t, in[0:1], filtered) + require.Len(t, tb.Logs, 2) + assert.Contains(t, tb.Logs[0], "expected 2 resource(s) with filter ResourceFilter{AccountId=foo, Region=us}") + assert.Equal(t, "filter ResourceFilter{AccountId=foo, Region=us} partial matches: AccountId=2, Region=2", tb.Logs[1]) +} diff --git a/pkg/testingutil/type.go b/pkg/testingutil/type.go index 193b3bb5..1aa125e4 100644 --- a/pkg/testingutil/type.go +++ b/pkg/testingutil/type.go @@ -4,6 +4,10 @@ import ( "fmt" "path" "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TypeStr is a convenience function to get the fully qualified type identifier for a value @@ -11,3 +15,17 @@ func TypeStr(v any) string { t := reflect.TypeOf(v) return fmt.Sprintf("%v/%v", path.Dir(t.PkgPath()), t.String()) } + +// TestingTB is an interface wrapper around testing.TB reduced to what the funcs in testingutil need (to assist with tesing this package). +// Satisfies the assert.TestingT and require.TestingT interfaces +type TestingTB interface { + Errorf(string, ...any) + Fatalf(string, ...any) + Helper() + FailNow() +} + +var _ TestingTB = &testing.T{} +var _ TestingTB = &testing.B{} +var _ assert.TestingT = TestingTB(&testing.T{}) +var _ require.TestingT = TestingTB(&testing.T{}) diff --git a/pkg/testingutil/util_test.go b/pkg/testingutil/util_test.go index c2c432fa..115cbb43 100644 --- a/pkg/testingutil/util_test.go +++ b/pkg/testingutil/util_test.go @@ -6,31 +6,32 @@ import ( ) type FakeTB struct { - testing.TB IsHelper bool IsFail bool Logs []string } +var _ TestingTB = &FakeTB{} + func Fake(t testing.TB) *FakeTB { - return &FakeTB{ - TB: t, - } + return &FakeTB{} } func (t *FakeTB) Helper() { t.IsHelper = true } -func (t *FakeTB) Fatal(args ...any) { - panic(fmt.Sprint(args...)) -} - func (t *FakeTB) Fatalf(format string, args ...any) { - panic(fmt.Errorf(format, args...)) + t.Errorf(format, args...) + t.FailNow() } func (t *FakeTB) Errorf(format string, args ...any) { t.IsFail = true t.Logs = append(t.Logs, fmt.Sprintf(format, args...)) } + +func (t *FakeTB) FailNow() { + // We have no way of "aborting", since we don't want to actually fail the test. + t.IsFail = true +}