Skip to content

Commit

Permalink
refactor: add contextual stdio for better testing
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmdm committed May 19, 2024
1 parent 9c6c178 commit 91c8391
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 6 deletions.
2 changes: 1 addition & 1 deletion cmd/yoke/cmd_blackbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,6 @@ func Blackbox(ctx context.Context, params BlackboxParams) error {
return err
}

_, err = fmt.Fprint(os.Stdout, text.DiffColorized(a, b, params.Context))
_, err = fmt.Fprint(internal.Stdout(ctx), text.DiffColorized(a, b, params.Context))
return err
}
2 changes: 1 addition & 1 deletion cmd/yoke/cmd_takeoff.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func TakeOff(ctx context.Context, params TakeoffParams) error {
return err
}

_, err = fmt.Fprint(os.Stdout, text.DiffColorized(a, b, params.Context))
_, err = fmt.Fprint(internal.Stdout(ctx), text.DiffColorized(a, b, params.Context))
return err
}

Expand Down
6 changes: 5 additions & 1 deletion cmd/yoke/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/clientcmd"

"github.com/davidmdm/yoke/internal"
"github.com/davidmdm/yoke/internal/home"
"github.com/davidmdm/yoke/internal/k8s"
)
Expand Down Expand Up @@ -229,7 +230,10 @@ func TestTurbulenceFix(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "corrupt", configmap.Data["key"])

require.NoError(t, Turbulence(context.Background(), TurbulenceParams{GlobalSettings: settings, Release: "foo", Fix: true}))
var stderr bytes.Buffer
ctx := internal.WithStderr(context.Background(), &stderr)
require.NoError(t, Turbulence(ctx, TurbulenceParams{GlobalSettings: settings, Release: "foo", Fix: true}))
require.Equal(t, "fixed drift for: default.core.v1.configmap.test\n", stderr.String())

configmap, err = client.CoreV1().ConfigMaps("default").Get(context.Background(), "test", metav1.GetOptions{})
require.NoError(t, err)
Expand Down
49 changes: 49 additions & 0 deletions internal/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package internal

import (
"context"
"io"
"os"
)

type (
stdoutKey struct{}
stderrKey struct{}
stdinKey struct{}
)

func WithStdout(ctx context.Context, w io.Writer) context.Context {
return context.WithValue(ctx, stdoutKey{}, w)
}

func Stdout(ctx context.Context) io.Writer {
w, ok := ctx.Value(stdoutKey{}).(io.Writer)
if !ok {
return os.Stdout
}
return w
}

func WithStderr(ctx context.Context, w io.Writer) context.Context {
return context.WithValue(ctx, stderrKey{}, w)
}

func Stderr(ctx context.Context) io.Writer {
w, ok := ctx.Value(stderrKey{}).(io.Writer)
if !ok {
return os.Stderr
}
return w
}

func WithStdin(ctx context.Context, r io.Reader) context.Context {
return context.WithValue(ctx, stdinKey{}, r)
}

func Stdin(ctx context.Context) io.Reader {
r, ok := ctx.Value(stdinKey{}).(io.Reader)
if !ok {
return os.Stdin
}
return r
}
5 changes: 2 additions & 3 deletions pkg/yoke/yoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package yoke
import (
"context"
"fmt"
"os"
"reflect"

"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
Expand Down Expand Up @@ -214,7 +213,7 @@ func (client Client) Turbulence(ctx context.Context, params TurbulenceParams) er
if err := client.k8s.ApplyResource(ctx, desired, forceConflicts); err != nil {
errs = append(errs, fmt.Errorf("%s: %w", name, err))
}
fmt.Fprintf(os.Stderr, "fixed drift for: %s\n", name)
fmt.Fprintf(internal.Stderr(ctx), "fixed drift for: %s\n", name)
}

return xerr.MultiErrOrderedFrom("failed to apply desired state to drift", errs...)
Expand Down Expand Up @@ -243,7 +242,7 @@ func (client Client) Turbulence(ctx context.Context, params TurbulenceParams) er
return internal.Warning("no turbulence detected")
}

_, err = fmt.Fprint(os.Stdout, diff)
_, err = fmt.Fprint(internal.Stdout(ctx), diff)
return err
}

Expand Down

0 comments on commit 91c8391

Please sign in to comment.