diff --git a/client/client.go b/client/client.go index deac2507a996..dac1ec4ccd28 100644 --- a/client/client.go +++ b/client/client.go @@ -11,7 +11,6 @@ import ( contentapi "github.com/containerd/containerd/api/services/content/v1" "github.com/containerd/containerd/defaults" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" controlapi "github.com/moby/buildkit/api/services/control" "github.com/moby/buildkit/client/connhelper" "github.com/moby/buildkit/session" @@ -54,6 +53,7 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error var tracerProvider trace.TracerProvider var tracerDelegate TracerDelegate var sessionDialer func(context.Context, string, map[string][]string) (net.Conn, error) + var customDialOptions []grpc.DialOption for _, o := range opts { if _, ok := o.(*withFailFast); ok { @@ -82,6 +82,9 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error if sd, ok := o.(*withSessionDialer); ok { sessionDialer = sd.dialer } + if opt, ok := o.(grpc.DialOption); ok { + customDialOptions = append(customDialOptions, opt) + } } if !customTracer { @@ -131,17 +134,9 @@ func New(ctx context.Context, address string, opts ...ClientOpt) (*Client, error unary = append(unary, grpcerrors.UnaryClientInterceptor) stream = append(stream, grpcerrors.StreamClientInterceptor) - if len(unary) == 1 { - gopts = append(gopts, grpc.WithUnaryInterceptor(unary[0])) - } else if len(unary) > 1 { - gopts = append(gopts, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(unary...))) - } - - if len(stream) == 1 { - gopts = append(gopts, grpc.WithStreamInterceptor(stream[0])) - } else if len(stream) > 1 { - gopts = append(gopts, grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(stream...))) - } + gopts = append(gopts, grpc.WithChainUnaryInterceptor(unary...)) + gopts = append(gopts, grpc.WithChainStreamInterceptor(stream...)) + gopts = append(gopts, customDialOptions...) conn, err := grpc.DialContext(ctx, address, gopts...) if err != nil { diff --git a/client/client_test.go b/client/client_test.go index 4dc968d94d37..6a9637f36eb9 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -65,6 +65,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh/agent" "golang.org/x/sync/errgroup" + "google.golang.org/grpc" ) func init() { @@ -195,6 +196,7 @@ func TestIntegration(t *testing.T) { testMountStubsTimestamp, testSourcePolicy, testLLBMountPerformance, + testClientCustomGRPCOpts, ) } @@ -9020,3 +9022,22 @@ func testLLBMountPerformance(t *testing.T, sb integration.Sandbox) { _, err = c.Solve(timeoutCtx, def, SolveOpt{}, nil) require.NoError(t, err) } + +func testClientCustomGRPCOpts(t *testing.T, sb integration.Sandbox) { + var interceptedMethods []string + intercept := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + interceptedMethods = append(interceptedMethods, method) + return invoker(ctx, method, req, reply, cc, opts...) + } + c, err := New(sb.Context(), sb.Address(), grpc.WithChainUnaryInterceptor(intercept)) + require.NoError(t, err) + defer c.Close() + + st := llb.Image("busybox:latest") + def, err := st.Marshal(sb.Context()) + require.NoError(t, err) + _, err = c.Solve(sb.Context(), def, SolveOpt{}, nil) + require.NoError(t, err) + + require.Contains(t, interceptedMethods, "/moby.buildkit.v1.Control/Solve") +}