diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go index 7d730ead7aa..b95daa85769 100644 --- a/client/allocrunner/taskrunner/plugin_supervisor_hook.go +++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go @@ -335,10 +335,10 @@ func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, socket } client, err := csi.NewClient(socketPath, h.logger.Named("csi_client").With("plugin.name", h.task.CSIPluginConfig.ID, "plugin.type", h.task.CSIPluginConfig.Type)) - defer client.Close() if err != nil { return false, fmt.Errorf("failed to create csi client: %v", err) } + defer client.Close() healthy, err := client.PluginProbe(ctx) if err != nil { diff --git a/plugins/csi/client.go b/plugins/csi/client.go index ddde111dff9..930dffe5bc2 100644 --- a/plugins/csi/client.go +++ b/plugins/csi/client.go @@ -114,8 +114,12 @@ func NewClient(addr string, logger hclog.Logger) (CSIPlugin, error) { } func newGrpcConn(addr string, logger hclog.Logger) (*grpc.ClientConn, error) { - conn, err := grpc.Dial( + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) + defer cancel() + conn, err := grpc.DialContext( + ctx, addr, + grpc.WithBlock(), grpc.WithInsecure(), grpc.WithUnaryInterceptor(logging.UnaryClientInterceptor(logger)), grpc.WithStreamInterceptor(logging.StreamClientInterceptor(logger)),