diff --git a/libcontainer/cgroups/utils_test.go b/libcontainer/cgroups/utils_test.go index 4a5d95eb691..1645ae7b0ec 100644 --- a/libcontainer/cgroups/utils_test.go +++ b/libcontainer/cgroups/utils_test.go @@ -181,8 +181,11 @@ const cgroup2Mountinfo = `18 64 0:18 / /sys rw,nosuid,nodev,noexec,relatime shar func TestGetCgroupMounts(t *testing.T) { type testData struct { - mountInfo string - root string + mountInfo string + root string + // all is the total number of records expected with all=true, + // or 0 for no extra records expected (most cases). + all int subsystems map[string]bool } testTable := []testData{ @@ -223,6 +226,7 @@ func TestGetCgroupMounts(t *testing.T) { { mountInfo: bedrockMountinfo, root: "/", + all: 50, subsystems: map[string]bool{ "name=systemd": false, "cpuset": false, @@ -274,6 +278,29 @@ func TestGetCgroupMounts(t *testing.T) { t.Fatalf("subsystem %s not found in Subsystems field %v", ss, m.Subsystems) } } + // Test the all=true case. + + // Reset the test input. + mi = bytes.NewBufferString(td.mountInfo) + for k := range td.subsystems { + td.subsystems[k] = false + } + cgMountsAll, err := getCgroupMountsHelper(td.subsystems, mi, true) + if err != nil { + t.Fatal(err) + } + if td.all == 0 { + // Results with and without "all" should be the same. + if len(cgMounts) != len(cgMountsAll) || !reflect.DeepEqual(cgMounts, cgMountsAll) { + t.Errorf("expected same results, got (all=false) %v, (all=true) %v", cgMounts, cgMountsAll) + } + } else { + // Make sure we got all records. + if len(cgMountsAll) != td.all { + t.Errorf("expected %d records, got %d (%+v)", td.all, len(cgMountsAll), cgMountsAll) + } + } + } } diff --git a/libcontainer/cgroups/v1_utils.go b/libcontainer/cgroups/v1_utils.go index 8b9275fb926..f610ed8c475 100644 --- a/libcontainer/cgroups/v1_utils.go +++ b/libcontainer/cgroups/v1_utils.go @@ -173,7 +173,7 @@ func getCgroupMountsHelper(ss map[string]bool, mi io.Reader, all bool) ([]Mount, res := make([]Mount, 0, len(ss)) scanner := bufio.NewScanner(mi) numFound := 0 - for scanner.Scan() && numFound < len(ss) { + for scanner.Scan() && (all || numFound < len(ss)) { txt := scanner.Text() sepIdx := strings.Index(txt, " - ") if sepIdx == -1 {