From 33e7d161393fa360f50eb4a987dec05f51f430df Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Fri, 7 Aug 2020 15:37:27 -0400 Subject: [PATCH] CSI: fix missing ACL tokens for leader-driven RPCs (#8607) The volumewatcher and GC job in the leader can't make CSI RPCs when ACLs are enabled without the leader ACL token being passed thru. --- nomad/core_sched.go | 14 ++++++++++---- nomad/server.go | 2 +- nomad/volumewatcher/volume_watcher.go | 14 +++++++++++--- nomad/volumewatcher/volumes_watcher.go | 14 +++++++++----- nomad/volumewatcher/volumes_watcher_test.go | 8 ++++---- 5 files changed, 35 insertions(+), 17 deletions(-) diff --git a/nomad/core_sched.go b/nomad/core_sched.go index 5e47b9184cd..85659118ae1 100644 --- a/nomad/core_sched.go +++ b/nomad/core_sched.go @@ -724,9 +724,12 @@ func (c *CoreScheduler) csiVolumeClaimGC(eval *structs.Evaluation) error { req := &structs.CSIVolumeClaimRequest{ VolumeID: volID, Claim: structs.CSIVolumeClaimRelease, + WriteRequest: structs.WriteRequest{ + Namespace: ns, + Region: c.srv.Region(), + AuthToken: eval.LeaderACL, + }, } - req.Namespace = ns - req.Region = c.srv.config.Region err := c.srv.RPC("CSIVolume.Claim", req, &structs.CSIVolumeClaimResponse{}) return err } @@ -850,8 +853,11 @@ func (c *CoreScheduler) csiPluginGC(eval *structs.Evaluation) error { continue } - req := &structs.CSIPluginDeleteRequest{ID: plugin.ID} - req.Region = c.srv.Region() + req := &structs.CSIPluginDeleteRequest{ID: plugin.ID, + QueryOptions: structs.QueryOptions{ + Region: c.srv.Region(), + AuthToken: eval.LeaderACL, + }} err := c.srv.RPC("CSIPlugin.Delete", req, &structs.CSIPluginDeleteResponse{}) if err != nil { if err.Error() == "plugin in use" { diff --git a/nomad/server.go b/nomad/server.go index cf441f044e6..396e5a47c57 100644 --- a/nomad/server.go +++ b/nomad/server.go @@ -1019,7 +1019,7 @@ func (s *Server) setupDeploymentWatcher() error { // setupVolumeWatcher creates a volume watcher that sends CSI RPCs func (s *Server) setupVolumeWatcher() error { s.volumeWatcher = volumewatcher.NewVolumesWatcher( - s.logger, s.staticEndpoints.CSIVolume) + s.logger, s.staticEndpoints.CSIVolume, s.getLeaderAcl()) return nil } diff --git a/nomad/volumewatcher/volume_watcher.go b/nomad/volumewatcher/volume_watcher.go index ca413a9ff6f..9137d721e2b 100644 --- a/nomad/volumewatcher/volume_watcher.go +++ b/nomad/volumewatcher/volume_watcher.go @@ -23,6 +23,9 @@ type volumeWatcher struct { // server interface for CSI client RPCs rpc CSIVolumeRPC + // the ACL needed to send RPCs + leaderAcl string + logger log.Logger shutdownCtx context.Context // parent context ctx context.Context // own context @@ -44,6 +47,7 @@ func newVolumeWatcher(parent *Watcher, vol *structs.CSIVolume) *volumeWatcher { v: vol, state: parent.state, rpc: parent.rpc, + leaderAcl: parent.leaderAcl, logger: parent.logger.With("volume_id", vol.ID, "namespace", vol.Namespace), shutdownCtx: parent.ctx, } @@ -228,9 +232,13 @@ func (vw *volumeWatcher) collectPastClaims(vol *structs.CSIVolume) *structs.CSIV func (vw *volumeWatcher) unpublish(vol *structs.CSIVolume, claim *structs.CSIVolumeClaim) error { req := &structs.CSIVolumeUnpublishRequest{ - VolumeID: vol.ID, - Claim: claim, - WriteRequest: structs.WriteRequest{Namespace: vol.Namespace}, + VolumeID: vol.ID, + Claim: claim, + WriteRequest: structs.WriteRequest{ + Namespace: vol.Namespace, + Region: vw.state.Config().Region, + AuthToken: vw.leaderAcl, + }, } err := vw.rpc.Unpublish(req, &structs.CSIVolumeUnpublishResponse{}) if err != nil { diff --git a/nomad/volumewatcher/volumes_watcher.go b/nomad/volumewatcher/volumes_watcher.go index 2f82725b08d..061df061398 100644 --- a/nomad/volumewatcher/volumes_watcher.go +++ b/nomad/volumewatcher/volumes_watcher.go @@ -21,6 +21,9 @@ type Watcher struct { // the volumes watcher for RPC rpc CSIVolumeRPC + // the ACL needed to send RPCs + leaderAcl string + // state is the state that is watched for state changes. state *state.StateStore @@ -36,7 +39,7 @@ type Watcher struct { // NewVolumesWatcher returns a volumes watcher that is used to watch // volumes and trigger the scheduler as needed. -func NewVolumesWatcher(logger log.Logger, rpc CSIVolumeRPC) *Watcher { +func NewVolumesWatcher(logger log.Logger, rpc CSIVolumeRPC, leaderAcl string) *Watcher { // the leader step-down calls SetEnabled(false) which is what // cancels this context, rather than passing in its own shutdown @@ -44,10 +47,11 @@ func NewVolumesWatcher(logger log.Logger, rpc CSIVolumeRPC) *Watcher { ctx, exitFn := context.WithCancel(context.Background()) return &Watcher{ - rpc: rpc, - logger: logger.Named("volumes_watcher"), - ctx: ctx, - exitFn: exitFn, + rpc: rpc, + logger: logger.Named("volumes_watcher"), + ctx: ctx, + exitFn: exitFn, + leaderAcl: leaderAcl, } } diff --git a/nomad/volumewatcher/volumes_watcher_test.go b/nomad/volumewatcher/volumes_watcher_test.go index 180fee962da..1ee13ca073c 100644 --- a/nomad/volumewatcher/volumes_watcher_test.go +++ b/nomad/volumewatcher/volumes_watcher_test.go @@ -22,7 +22,7 @@ func TestVolumeWatch_EnableDisable(t *testing.T) { srv.state = state.TestStateStore(t) index := uint64(100) - watcher := NewVolumesWatcher(testlog.HCLogger(t), srv) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") watcher.SetEnabled(true, srv.State()) plugin := mock.CSIPlugin() @@ -57,7 +57,7 @@ func TestVolumeWatch_Checkpoint(t *testing.T) { srv.state = state.TestStateStore(t) index := uint64(100) - watcher := NewVolumesWatcher(testlog.HCLogger(t), srv) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") plugin := mock.CSIPlugin() node := testNode(plugin, srv.State()) @@ -98,7 +98,7 @@ func TestVolumeWatch_StartStop(t *testing.T) { srv := &MockStatefulRPCServer{} srv.state = state.TestStateStore(t) index := uint64(100) - watcher := NewVolumesWatcher(testlog.HCLogger(t), srv) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") watcher.SetEnabled(true, srv.State()) require.Equal(0, len(watcher.watchers)) @@ -190,7 +190,7 @@ func TestVolumeWatch_RegisterDeregister(t *testing.T) { index := uint64(100) - watcher := NewVolumesWatcher(testlog.HCLogger(t), srv) + watcher := NewVolumesWatcher(testlog.HCLogger(t), srv, "") watcher.SetEnabled(true, srv.State()) require.Equal(0, len(watcher.watchers))