diff --git a/README.md b/README.md index ede340699..e44fd7202 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,6 @@ the actual driver's name. args: - "--csi-address=/csi/csi.sock" - "--kubelet-registration-path=/var/lib/kubelet/plugins//csi.sock" - lifecycle: - preStop: - exec: - command: ["/bin/sh", "-c", "rm -rf /registration/ /registration/-reg.sock"] volumeMounts: - name: plugin-dir mountPath: /csi diff --git a/cmd/csi-node-driver-registrar/node_register.go b/cmd/csi-node-driver-registrar/node_register.go index 3fe2a395a..b9fa36015 100644 --- a/cmd/csi-node-driver-registrar/node_register.go +++ b/cmd/csi-node-driver-registrar/node_register.go @@ -20,7 +20,9 @@ import ( "fmt" "net" "os" + "os/signal" "runtime" + "syscall" "google.golang.org/grpc" @@ -36,7 +38,7 @@ func nodeRegister( // as gRPC server which replies to registration requests initiated by kubelet's // pluginswatcher infrastructure. Node labeling is done by kubelet's csi code. registrar := newRegistrationServer(csiDriverName, *kubeletRegistrationPath, supportedVersions) - socketPath := fmt.Sprintf("/registration/%s-reg.sock", csiDriverName) + socketPath := buildSocketPath(csiDriverName) if err := util.CleanupSocketFile(socketPath); err != nil { klog.Errorf("%+v", err) os.Exit(1) @@ -62,6 +64,7 @@ func nodeRegister( // Registers kubelet plugin watcher api. registerapi.RegisterRegistrationServer(grpcServer, registrar) + go removeRegSocket(csiDriverName) // Starts service if err := grpcServer.Serve(lis); err != nil { klog.Errorf("Registration Server stopped serving: %v", err) @@ -70,3 +73,20 @@ func nodeRegister( // If gRPC server is gracefully shutdown, exit os.Exit(0) } + +func buildSocketPath(csiDriverName string) string { + return fmt.Sprintf("/registration/%s-reg.sock", csiDriverName) +} + +func removeRegSocket(csiDriverName string) { + sigc := make(chan os.Signal, 1) + signal.Notify(sigc, syscall.SIGTERM) + <-sigc + socketPath := buildSocketPath(csiDriverName) + err := os.Remove(socketPath) + if err != nil && !os.IsNotExist(err) { + klog.Errorf("failed to remove socket: %s with error: %+v", socketPath, err) + os.Exit(1) + } + os.Exit(0) +}