From 3c19ae61c7a2c78b5148e97b01d7a48f23b83d8f Mon Sep 17 00:00:00 2001 From: Nicolas De Loof Date: Thu, 7 Dec 2023 14:25:55 +0100 Subject: [PATCH] introduce general-purpose service dependency graph traversal functions Signed-off-by: Nicolas De Loof --- dotenv/godotenv_test.go | 2 +- graph/graph.go | 144 +++++++++++++ graph/graph_test.go | 453 +++++++++++++++++++++++++++++++++++++++ graph/traversal.go | 234 ++++++++++++++++++++ types/config_test.go | 4 +- types/project.go | 87 ++++---- types/project_test.go | 16 +- utils/collectionutils.go | 10 +- utils/set.go | 95 ++++++++ utils/set_test.go | 43 ++++ utils/stringutils.go | 10 - 11 files changed, 1038 insertions(+), 60 deletions(-) create mode 100644 graph/graph.go create mode 100644 graph/graph_test.go create mode 100644 graph/traversal.go create mode 100644 utils/set.go create mode 100644 utils/set_test.go diff --git a/dotenv/godotenv_test.go b/dotenv/godotenv_test.go index 1b60f7102..fa9118e4f 100644 --- a/dotenv/godotenv_test.go +++ b/dotenv/godotenv_test.go @@ -19,7 +19,7 @@ func parseAndCompare(t *testing.T, rawEnvLine string, expectedKey string, expect assert.NilError(t, err) actualValue, ok := env[expectedKey] if !ok { - t.Errorf("Key %q was not found in env: %v", expectedKey, env) + t.Errorf("key %q was not found in env: %v", expectedKey, env) } else if actualValue != expectedValue { t.Errorf("Expected '%v' to parse as '%v' => '%v', got '%v' => '%v' instead", rawEnvLine, expectedKey, expectedValue, expectedKey, actualValue) } diff --git a/graph/graph.go b/graph/graph.go new file mode 100644 index 000000000..37d7a9144 --- /dev/null +++ b/graph/graph.go @@ -0,0 +1,144 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package graph + +import ( + "fmt" + "strings" + + "github.com/compose-spec/compose-go/v2/types" + "github.com/compose-spec/compose-go/v2/utils" + "golang.org/x/exp/slices" +) + +// graph represents project as service dependencies +type graph struct { + vertices map[string]*vertex +} + +// vertex represents a service in the dependencies structure +type vertex struct { + key string + service *types.ServiceConfig + children map[string]*vertex + parents map[string]*vertex +} + +// newGraph creates a service graph from project +func newGraph(project *types.Project) (*graph, error) { + g := &graph{ + vertices: map[string]*vertex{}, + } + + for name, s := range project.Services { + g.addVertex(name, s) + } + + for name, s := range project.Services { + src := g.vertices[name] + for dep, condition := range s.DependsOn { + dest, ok := g.vertices[dep] + if !ok { + if condition.Required { + if ds, exists := project.DisabledServices[dep]; exists { + return nil, fmt.Errorf("service %q is required by %q but is disabled. Can be enabled by profiles %s", dep, name, ds.Profiles) + } + return nil, fmt.Errorf("service %q depends on unknown service %q", name, dep) + } + delete(s.DependsOn, name) + project.Services[name] = s + continue + } + src.children[dep] = dest + dest.parents[name] = src + } + } + + err := g.checkCycle() + return g, err +} + +func (g *graph) addVertex(name string, service types.ServiceConfig) { + g.vertices[name] = &vertex{ + key: name, + service: &service, + parents: map[string]*vertex{}, + children: map[string]*vertex{}, + } +} + +func (g *graph) addEdge(src, dest string) { + g.vertices[src].children[dest] = g.vertices[dest] + g.vertices[dest].parents[src] = g.vertices[src] +} + +func (g *graph) roots() []*vertex { + var res []*vertex + for _, v := range g.vertices { + if len(v.parents) == 0 { + res = append(res, v) + } + } + return res +} + +func (g *graph) leaves() []*vertex { + var res []*vertex + for _, v := range g.vertices { + if len(v.children) == 0 { + res = append(res, v) + } + } + + return res +} + +func (g *graph) checkCycle() error { + names := utils.MapKeys(g.vertices) + for _, name := range names { + err := searchCycle([]string{name}, g.vertices[name]) + if err != nil { + return err + } + } + return nil +} + +func searchCycle(path []string, v *vertex) error { + names := utils.MapKeys(v.children) + for _, name := range names { + if i := slices.Index(path, name); i > 0 { + return fmt.Errorf("dependency cycle detected: %s", strings.Join(path[i:], " -> ")) + } + ch := v.children[name] + err := searchCycle(append(path, name), ch) + if err != nil { + return err + } + } + return nil +} + +// descendents return all descendents for a vertex, might contain duplicates +func (v *vertex) descendents() []string { + var vx []string + for _, n := range v.children { + vx = append(vx, n.key) + vx = append(vx, n.descendents()...) + } + return vx +} diff --git a/graph/graph_test.go b/graph/graph_test.go new file mode 100644 index 000000000..d5f29843b --- /dev/null +++ b/graph/graph_test.go @@ -0,0 +1,453 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package graph + +import ( + "context" + "fmt" + "sort" + "sync" + "testing" + + "github.com/compose-spec/compose-go/v2/types" + "github.com/compose-spec/compose-go/v2/utils" + "github.com/stretchr/testify/require" + "gotest.tools/v3/assert" +) + +func TestTraversalWithMultipleParents(t *testing.T) { + dependent := types.ServiceConfig{ + Name: "dependent", + DependsOn: make(types.DependsOnConfig), + } + + project := types.Project{ + Services: types.Services{"dependent": dependent}, + } + + for i := 1; i <= 100; i++ { + name := fmt.Sprintf("svc_%d", i) + dependent.DependsOn[name] = types.ServiceDependency{} + + svc := types.ServiceConfig{Name: name} + project.Services[name] = svc + } + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + svc := make(chan string, 10) + seen := make(map[string]int) + done := make(chan struct{}) + go func() { + for service := range svc { + seen[service]++ + } + done <- struct{}{} + }() + + err := InDependencyOrder(ctx, &project, func(ctx context.Context, name string, _ types.ServiceConfig) error { + svc <- name + return nil + }) + require.NoError(t, err, "Error during iteration") + close(svc) + <-done + + assert.Equal(t, len(seen), 101) + for svc, count := range seen { + assert.Equal(t, 1, count, "service: %s", svc) + } +} + +func TestInDependencyUpCommandOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + var order []string + result, err := CollectInDependencyOrder(ctx, exampleProject(), + func(ctx context.Context, name string, _ types.ServiceConfig) (string, error) { + order = append(order, name) + return name, nil + }, WithMaxConcurrency(10)) + require.NoError(t, err, "Error during iteration") + require.Equal(t, []string{"test3", "test2", "test1"}, order) + assert.DeepEqual(t, result, map[string]string{ + "test1": "test1", + "test2": "test2", + "test3": "test3", + }) +} + +func TestInDependencyReverseDownCommandOrder(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + var order []string + fn := func(ctx context.Context, name string, _ types.ServiceConfig) error { + order = append(order, name) + return nil + } + err := InDependencyOrder(ctx, exampleProject(), fn, InReverseOrder) + require.NoError(t, err, "Error during iteration") + require.Equal(t, []string{"test1", "test2", "test3"}, order) +} + +func TestBuildGraph(t *testing.T) { + testCases := []struct { + desc string + services types.Services + disabled types.Services + expectedVertices map[string]*vertex + expectedError string + }{ + { + desc: "builds graph with single service", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{}, + }, + }, + expectedVertices: map[string]*vertex{ + "test": { + key: "test", + service: &types.ServiceConfig{Name: "test"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{}, + }, + }, + }, + { + desc: "builds graph with two separate services", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{}, + }, + "another": { + Name: "another", + DependsOn: types.DependsOnConfig{}, + }, + }, + expectedVertices: map[string]*vertex{ + "test": { + key: "test", + service: &types.ServiceConfig{Name: "test"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{}, + }, + "another": { + key: "another", + service: &types.ServiceConfig{Name: "another"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{}, + }, + }, + }, + { + desc: "builds graph with a service and a dependency", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{ + "another": types.ServiceDependency{}, + }, + }, + "another": { + Name: "another", + DependsOn: types.DependsOnConfig{}, + }, + }, + expectedVertices: map[string]*vertex{ + "test": { + key: "test", + service: &types.ServiceConfig{Name: "test"}, + children: map[string]*vertex{ + "another": {}, + }, + parents: map[string]*vertex{}, + }, + "another": { + key: "another", + service: &types.ServiceConfig{Name: "another"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{ + "test": {}, + }, + }, + }, + }, + { + desc: "builds graph with a service and optional (missing) dependency", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{ + "another": types.ServiceDependency{ + Required: false, + }, + }, + }, + }, + expectedVertices: map[string]*vertex{ + "test": { + key: "test", + service: &types.ServiceConfig{Name: "test"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{}, + }, + }, + }, + { + desc: "builds graph with a service and required (missing) dependency", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{ + "another": types.ServiceDependency{ + Required: true, + }, + }, + }, + }, + expectedError: `service "test" depends on unknown service "another"`, + }, + { + desc: "builds graph with a service and disabled dependency", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{ + "another": types.ServiceDependency{ + Required: true, + }, + }, + }, + }, + disabled: types.Services{ + "another": { + Name: "another", + Profiles: []string{"test"}, + DependsOn: types.DependsOnConfig{}, + }, + }, + expectedError: `service "another" is required by "test" but is disabled. Can be enabled by profiles [test]`, + }, + { + desc: "builds graph with multiple dependency levels", + services: types.Services{ + "test": { + Name: "test", + DependsOn: types.DependsOnConfig{ + "another": types.ServiceDependency{}, + }, + }, + "another": { + Name: "another", + DependsOn: types.DependsOnConfig{ + "another_dep": types.ServiceDependency{}, + }, + }, + "another_dep": { + Name: "another_dep", + DependsOn: types.DependsOnConfig{}, + }, + }, + expectedVertices: map[string]*vertex{ + "test": { + key: "test", + service: &types.ServiceConfig{Name: "test"}, + children: map[string]*vertex{ + "another": {}, + }, + parents: map[string]*vertex{}, + }, + "another": { + key: "another", + service: &types.ServiceConfig{Name: "another"}, + children: map[string]*vertex{ + "another_dep": {}, + }, + parents: map[string]*vertex{ + "test": {}, + }, + }, + "another_dep": { + key: "another_dep", + service: &types.ServiceConfig{Name: "another_dep"}, + children: map[string]*vertex{}, + parents: map[string]*vertex{ + "another": {}, + }, + }, + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + project := types.Project{ + Services: tC.services, + DisabledServices: tC.disabled, + } + + graph, err := newGraph(&project) + if tC.expectedError != "" { + assert.Error(t, err, tC.expectedError) + return + } + + assert.NilError(t, err, fmt.Sprintf("failed to build graph for: %s", tC.desc)) + for k, vertex := range graph.vertices { + expected, ok := tC.expectedVertices[k] + assert.Equal(t, true, ok) + assertVertexEqual(t, *expected, *vertex) + } + }) + } +} + +func Test_detectCycle(t *testing.T) { + graph := exampleGraph() + graph.addEdge("B", "D") + err := graph.checkCycle() + assert.Error(t, err, "dependency cycle detected: D -> C -> B") +} + +func TestWith_RootNodesAndUp(t *testing.T) { + graph := exampleGraph() + + tests := []struct { + name string + nodes []string + want []string + }{ + { + name: "whole graph", + nodes: []string{"A", "B"}, + want: []string{"A", "B", "C", "D", "E", "F", "G"}, + }, + { + name: "only leaves", + nodes: []string{"F", "G"}, + want: []string{"F", "G"}, + }, + { + name: "simple dependent", + nodes: []string{"D"}, + want: []string{"D", "F"}, + }, + { + name: "diamond dependents", + nodes: []string{"B"}, + want: []string{"B", "C", "D", "E", "F"}, + }, + { + name: "partial graph", + nodes: []string{"A"}, + want: []string{"A", "C", "D", "F", "G"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mx := sync.Mutex{} + expected := utils.Set[string]{} + expected.AddAll("C", "G", "D", "F") + var visited []string + + gt := newTraversal(func(ctx context.Context, name string, service types.ServiceConfig) (any, error) { + mx.Lock() + defer mx.Unlock() + visited = append(visited, name) + return nil, nil + }) + WithRootNodesAndDown(tt.nodes)(gt.Options) + err := walk(context.TODO(), graph, gt) + assert.NilError(t, err) + sort.Strings(visited) + assert.DeepEqual(t, tt.want, visited) + }) + } +} + +func assertVertexEqual(t *testing.T, a, b vertex) { + assert.Equal(t, a.key, b.key) + assert.Equal(t, a.service.Name, b.service.Name) + for c := range a.children { + _, ok := b.children[c] + assert.Check(t, ok, "expected children missing %s", c) + } + for p := range a.parents { + _, ok := b.parents[p] + assert.Check(t, ok, "expected parent missing %s", p) + } +} + +func exampleGraph() *graph { + graph := &graph{ + vertices: map[string]*vertex{}, + } + + /** graph topology: + A B + / \ / \ + G C E + \ / + D + | + F + */ + + graph.addVertex("A", types.ServiceConfig{Name: "A"}) + graph.addVertex("B", types.ServiceConfig{Name: "B"}) + graph.addVertex("C", types.ServiceConfig{Name: "C"}) + graph.addVertex("D", types.ServiceConfig{Name: "D"}) + graph.addVertex("E", types.ServiceConfig{Name: "E"}) + graph.addVertex("F", types.ServiceConfig{Name: "F"}) + graph.addVertex("G", types.ServiceConfig{Name: "G"}) + + graph.addEdge("C", "A") + graph.addEdge("C", "B") + graph.addEdge("E", "B") + graph.addEdge("D", "C") + graph.addEdge("D", "E") + graph.addEdge("F", "D") + graph.addEdge("G", "A") + return graph +} + +func exampleProject() *types.Project { + return &types.Project{ + Services: types.Services{ + "test1": { + Name: "test1", + DependsOn: map[string]types.ServiceDependency{ + "test2": {}, + }, + }, + "test2": { + Name: "test2", + DependsOn: map[string]types.ServiceDependency{ + "test3": {}, + }, + }, + "test3": { + Name: "test3", + }, + }, + } +} diff --git a/graph/traversal.go b/graph/traversal.go new file mode 100644 index 000000000..a10f224cf --- /dev/null +++ b/graph/traversal.go @@ -0,0 +1,234 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package graph + +import ( + "context" + "sync" + + "github.com/compose-spec/compose-go/v2/types" + "golang.org/x/exp/slices" + "golang.org/x/sync/errgroup" +) + +// CollectorFn executes on each graph vertex based on visit order and return associated value +type CollectorFn[T any] func(context.Context, string, types.ServiceConfig) (T, error) + +// VisitorFn executes on each graph nodes based on visit order +type VisitorFn func(context.Context, string, types.ServiceConfig) error + +// InDependencyOrder walk the service graph an invoke VisitorFn in respect to dependency order +func InDependencyOrder(ctx context.Context, project *types.Project, fn VisitorFn, options ...func(*Options)) error { + _, err := CollectInDependencyOrder[any](ctx, project, func(ctx context.Context, s string, config types.ServiceConfig) (any, error) { + return nil, fn(ctx, s, config) + }, options...) + return err +} + +// CollectInDependencyOrder walk the service graph an invoke CollectorFn in respect to dependency order, then return result for each call +func CollectInDependencyOrder[T any](ctx context.Context, project *types.Project, fn CollectorFn[T], options ...func(*Options)) (map[string]T, error) { + graph, err := newGraph(project) + if err != nil { + return nil, err + } + t := newTraversal(fn) + for _, option := range options { + option(t.Options) + } + err = walk(ctx, graph, t) + return t.results, err +} + +type traversal[T any] struct { + *Options + visitor CollectorFn[T] + + mu sync.Mutex + status map[string]int + results map[string]T +} + +type Options struct { + // inverse reverse the traversal direction + inverse bool + // maxConcurrency limit the concurrent execution of visitorFn while walking the graph + maxConcurrency int + // after marks a set of node as starting points walking the graph + after []string +} + +const ( + vertexEntered = iota + vertexVisited +) + +func newTraversal[T any](fn CollectorFn[T]) *traversal[T] { + return &traversal[T]{ + Options: &Options{}, + status: map[string]int{}, + results: map[string]T{}, + visitor: fn, + } +} + +// WithMaxConcurrency configure traversal to limit concurrency walking graph nodes +func WithMaxConcurrency(max int) func(*Options) { + return func(o *Options) { + o.maxConcurrency = max + } +} + +// InReverseOrder configure traversal to walk the graph in reverse dependency order +func InReverseOrder(o *Options) { + o.inverse = true +} + +// WithRootNodesAndDown creates a graphTraversal to start from selected nodes +func WithRootNodesAndDown(nodes []string) func(*Options) { + return func(o *Options) { + o.after = nodes + } +} + +func walk[T any](ctx context.Context, g *graph, t *traversal[T]) error { + expect := len(g.vertices) + if expect == 0 { + return nil + } + // nodeCh need to allow n=expect writers while reader goroutine could have returned after ctx.Done + nodeCh := make(chan *vertex, expect) + defer close(nodeCh) + + eg, ctx := errgroup.WithContext(ctx) + if t.maxConcurrency > 0 { + eg.SetLimit(t.maxConcurrency + 1) + } + + eg.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case node := <-nodeCh: + expect-- + if expect == 0 { + return nil + } + + for _, adj := range t.adjacentNodes(node) { + t.visit(ctx, eg, adj, nodeCh) + } + } + } + }) + + // select nodes to start walking the graph based on traversal.direction + for _, node := range t.extremityNodes(g) { + t.visit(ctx, eg, node, nodeCh) + } + + return eg.Wait() +} + +func (t *traversal[T]) visit(ctx context.Context, eg *errgroup.Group, node *vertex, nodeCh chan *vertex) { + if !t.ready(node) { + // don't visit this service yet as dependencies haven't been visited + return + } + if !t.enter(node) { + // another worker already acquired this node + return + } + eg.Go(func() error { + var ( + err error + result T + ) + if !t.skip(node) { + result, err = t.visitor(ctx, node.key, *node.service) + } + t.done(node, result) + nodeCh <- node + return err + }) +} + +func (t *traversal[T]) extremityNodes(g *graph) []*vertex { + if t.inverse { + return g.roots() + } + return g.leaves() +} + +func (t *traversal[T]) adjacentNodes(v *vertex) map[string]*vertex { + if t.inverse { + return v.children + } + return v.parents +} + +func (t *traversal[T]) ready(v *vertex) bool { + t.mu.Lock() + defer t.mu.Unlock() + + depends := v.children + if t.inverse { + depends = v.parents + } + for name := range depends { + if t.status[name] != vertexVisited { + return false + } + } + return true +} + +func (t *traversal[T]) enter(v *vertex) bool { + t.mu.Lock() + defer t.mu.Unlock() + + if _, ok := t.status[v.key]; ok { + return false + } + t.status[v.key] = vertexEntered + return true +} + +func (t *traversal[T]) done(v *vertex, result T) { + t.mu.Lock() + defer t.mu.Unlock() + t.status[v.key] = vertexVisited + t.results[v.key] = result +} + +func (t *traversal[T]) skip(node *vertex) bool { + if len(t.after) == 0 { + return false + } + if slices.Contains(t.after, node.key) { + return false + } + + // is none of our starting node is a descendent, skip visit + ancestors := node.descendents() + for _, name := range t.after { + if slices.Contains(ancestors, name) { + return false + } + } + return true +} diff --git a/types/config_test.go b/types/config_test.go index ef574c013..e0730b5f8 100644 --- a/types/config_test.go +++ b/types/config_test.go @@ -49,8 +49,8 @@ func Test_WithServices(t *testing.T) { }, } order := []string{} - fn := func(service ServiceConfig) error { - order = append(order, service.Name) + fn := func(name string, _ ServiceConfig) error { + order = append(order, name) return nil } diff --git a/types/project.go b/types/project.go index a3c4e7f7a..e2c161090 100644 --- a/types/project.go +++ b/types/project.go @@ -183,7 +183,7 @@ func (p *Project) AllServices() Services { return all } -type ServiceFunc func(service ServiceConfig) error +type ServiceFunc func(name string, service ServiceConfig) error // WithServices run ServiceFunc on each service and dependencies according to DependencyPolicy func (p *Project) WithServices(names []string, fn ServiceFunc, options ...DependencyOption) error { @@ -194,6 +194,16 @@ func (p *Project) WithServices(names []string, fn ServiceFunc, options ...Depend return p.withServices(names, fn, map[string]bool{}, options, map[string]ServiceDependency{}) } +type withServicesOptions struct { + dependencyPolicy int +} + +const ( + includeDependencies = iota + includeDependents + ignoreDependencies +) + func (p *Project) withServices(names []string, fn ServiceFunc, seen map[string]bool, options []DependencyOption, dependencies map[string]ServiceDependency) error { services, servicesNotFound := p.getServicesByNames(names...) if len(servicesNotFound) > 0 { @@ -203,23 +213,26 @@ func (p *Project) withServices(names []string, fn ServiceFunc, seen map[string]b } } } - for _, service := range services { - if seen[service.Name] { + opts := withServicesOptions{ + dependencyPolicy: includeDependencies, + } + for _, option := range options { + option(&opts) + } + + for name, service := range services { + if seen[name] { continue } - seen[service.Name] = true + seen[name] = true var dependencies map[string]ServiceDependency - for _, policy := range options { - switch policy { - case IncludeDependents: - dependencies = utils.MapsAppend(dependencies, p.dependentsForService(service)) - case IncludeDependencies: - dependencies = utils.MapsAppend(dependencies, service.DependsOn) - case IgnoreDependencies: - // Noop - default: - return fmt.Errorf("unsupported dependency policy %d", policy) - } + switch opts.dependencyPolicy { + case includeDependents: + dependencies = utils.MapsAppend(dependencies, p.dependentsForService(service)) + case includeDependencies: + dependencies = utils.MapsAppend(dependencies, service.DependsOn) + case ignoreDependencies: + // Noop } if len(dependencies) > 0 { err := p.withServices(utils.MapKeys(dependencies), fn, seen, options, dependencies) @@ -227,7 +240,7 @@ func (p *Project) withServices(names []string, fn ServiceFunc, seen map[string]b return err } } - if err := fn(service); err != nil { + if err := fn(name, service); err != nil { return err } } @@ -380,13 +393,19 @@ func (p *Project) WithoutUnnecessaryResources() { p.Configs = configs } -type DependencyOption int +type DependencyOption func(options *withServicesOptions) -const ( - IncludeDependencies = iota - IncludeDependents - IgnoreDependencies -) +func IncludeDependencies(options *withServicesOptions) { + options.dependencyPolicy = includeDependencies +} + +func IncludeDependents(options *withServicesOptions) { + options.dependencyPolicy = includeDependents +} + +func IgnoreDependencies(options *withServicesOptions) { + options.dependencyPolicy = ignoreDependencies +} // ForServices restrict the project model to selected services and dependencies func (p *Project) ForServices(names []string, options ...DependencyOption) error { @@ -395,9 +414,9 @@ func (p *Project) ForServices(names []string, options ...DependencyOption) error return nil } - set := map[string]struct{}{} - err := p.WithServices(names, func(service ServiceConfig) error { - set[service.Name] = struct{}{} + set := utils.NewSet[string]() + err := p.WithServices(names, func(name string, service ServiceConfig) error { + set.Add(name) return nil }, options...) if err != nil { @@ -407,19 +426,15 @@ func (p *Project) ForServices(names []string, options ...DependencyOption) error // Disable all services which are not explicit target or dependencies enabled := Services{} for name, s := range p.Services { - if _, ok := set[s.Name]; ok { - for _, option := range options { - if option == IgnoreDependencies { - // remove all dependencies but those implied by explicitly selected services - dependencies := s.DependsOn - for d := range dependencies { - if _, ok := set[d]; !ok { - delete(dependencies, d) - } - } - s.DependsOn = dependencies + if _, ok := set[name]; ok { + // remove all dependencies but those implied by explicitly selected services + dependencies := s.DependsOn + for d := range dependencies { + if _, ok := set[d]; !ok { + delete(dependencies, d) } } + s.DependsOn = dependencies enabled[name] = s } else { p.DisableService(s) diff --git a/types/project_test.go b/types/project_test.go index 2856b8079..5aed8da39 100644 --- a/types/project_test.go +++ b/types/project_test.go @@ -206,16 +206,16 @@ func Test_ResolveImages(t *testing.T) { func TestWithServices(t *testing.T) { p := makeProject() var seen []string - err := p.WithServices([]string{"service_3"}, func(service ServiceConfig) error { - seen = append(seen, service.Name) + err := p.WithServices([]string{"service_3"}, func(name string, _ ServiceConfig) error { + seen = append(seen, name) return nil }, IncludeDependencies) assert.NilError(t, err) assert.DeepEqual(t, seen, []string{"service_1", "service_2", "service_3"}) seen = []string{} - err = p.WithServices([]string{"service_1"}, func(service ServiceConfig) error { - seen = append(seen, service.Name) + err = p.WithServices([]string{"service_1"}, func(name string, _ ServiceConfig) error { + seen = append(seen, name) return nil }, IncludeDependents) assert.NilError(t, err) @@ -223,16 +223,16 @@ func TestWithServices(t *testing.T) { assert.Check(t, utils.ArrayContains(seen, []string{"service_3", "service_4", "service_2", "service_1"})) seen = []string{} - err = p.WithServices([]string{"service_1"}, func(service ServiceConfig) error { - seen = append(seen, service.Name) + err = p.WithServices([]string{"service_1"}, func(name string, _ ServiceConfig) error { + seen = append(seen, name) return nil }, IgnoreDependencies) assert.NilError(t, err) assert.DeepEqual(t, seen, []string{"service_1"}) seen = []string{} - err = p.WithServices([]string{"service_4"}, func(service ServiceConfig) error { - seen = append(seen, service.Name) + err = p.WithServices([]string{"service_4"}, func(name string, _ ServiceConfig) error { + seen = append(seen, name) return nil }, IncludeDependencies) assert.NilError(t, err) diff --git a/utils/collectionutils.go b/utils/collectionutils.go index 343692250..bd44c5844 100644 --- a/utils/collectionutils.go +++ b/utils/collectionutils.go @@ -16,13 +16,17 @@ package utils -import "golang.org/x/exp/slices" +import ( + "golang.org/x/exp/constraints" + "golang.org/x/exp/slices" +) -func MapKeys[T comparable, U any](theMap map[T]U) []T { - var result []T +func MapKeys[T constraints.Ordered, U any](theMap map[T]U) []T { + result := make([]T, 0, len(theMap)) for key := range theMap { result = append(result, key) } + slices.Sort(result) return result } diff --git a/utils/set.go b/utils/set.go new file mode 100644 index 000000000..bbbeaa966 --- /dev/null +++ b/utils/set.go @@ -0,0 +1,95 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +type Set[T comparable] map[T]struct{} + +func NewSet[T comparable](v ...T) Set[T] { + if len(v) == 0 { + return make(Set[T]) + } + + out := make(Set[T], len(v)) + for i := range v { + out.Add(v[i]) + } + return out +} + +func (s Set[T]) Has(v T) bool { + _, ok := s[v] + return ok +} + +func (s Set[T]) Add(v T) { + s[v] = struct{}{} +} + +func (s Set[T]) AddAll(v ...T) { + for _, e := range v { + s[e] = struct{}{} + } +} + +func (s Set[T]) Remove(v T) bool { + _, ok := s[v] + if ok { + delete(s, v) + } + return ok +} + +func (s Set[T]) Clear() { + for v := range s { + delete(s, v) + } +} + +func (s Set[T]) Elements() []T { + elements := make([]T, 0, len(s)) + for v := range s { + elements = append(elements, v) + } + return elements +} + +func (s Set[T]) RemoveAll(elements ...T) { + for _, e := range elements { + s.Remove(e) + } +} + +func (s Set[T]) Diff(other Set[T]) Set[T] { + out := make(Set[T]) + for k := range s { + if _, ok := other[k]; !ok { + out[k] = struct{}{} + } + } + return out +} + +func (s Set[T]) Union(other Set[T]) Set[T] { + out := make(Set[T]) + for k := range s { + out[k] = struct{}{} + } + for k := range other { + out[k] = struct{}{} + } + return out +} diff --git a/utils/set_test.go b/utils/set_test.go new file mode 100644 index 000000000..279ede917 --- /dev/null +++ b/utils/set_test.go @@ -0,0 +1,43 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package utils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSet_Has(t *testing.T) { + x := NewSet[string]("value") + require.True(t, x.Has("value")) + require.False(t, x.Has("VALUE")) +} + +func TestSet_Diff(t *testing.T) { + a := NewSet[int](1, 2) + b := NewSet[int](2, 3) + require.ElementsMatch(t, []int{1}, a.Diff(b).Elements()) + require.ElementsMatch(t, []int{3}, b.Diff(a).Elements()) +} + +func TestSet_Union(t *testing.T) { + a := NewSet[int](1, 2) + b := NewSet[int](2, 3) + require.ElementsMatch(t, []int{1, 2, 3}, a.Union(b).Elements()) + require.ElementsMatch(t, []int{1, 2, 3}, b.Union(a).Elements()) +} diff --git a/utils/stringutils.go b/utils/stringutils.go index 182ddf830..dfabf6c97 100644 --- a/utils/stringutils.go +++ b/utils/stringutils.go @@ -22,16 +22,6 @@ import ( "strings" ) -// StringContains check if an array contains a specific value -func StringContains(array []string, needle string) bool { - for _, val := range array { - if val == needle { - return true - } - } - return false -} - // StringToBool converts a string to a boolean ignoring errors func StringToBool(s string) bool { b, _ := strconv.ParseBool(strings.ToLower(strings.TrimSpace(s)))