diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 83baef225eb..4daba649ac8 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -28,13 +28,15 @@ const ( ) var ( - // When the package is loaded the driver is registered as an internal plugin - // with the plugin catalog + // PluginID is the mock driver plugin metadata registered in the plugin + // catalog. PluginID = loader.PluginID{ Name: pluginName, PluginType: base.PluginTypeDriver, } + // PluginConfig is the mock driver factory function registered in the + // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, Factory: func(l hclog.Logger) interface{} { return NewMockDriver(l) }, @@ -338,13 +340,13 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru handle := drivers.NewTaskHandle(pluginName) handle.Config = cfg if err := handle.SetDriverState(&driverState); err != nil { - d.logger.Error("failed to start task, error setting driver state", "error", err) + d.logger.Error("failed to start task, error setting driver state", "error", err, "task_name", cfg.Name) return nil, nil, fmt.Errorf("failed to set driver state: %v", err) } d.tasks.Set(cfg.ID, h) - d.logger.Debug("starting task", "name", cfg.Name) + d.logger.Debug("starting task", "task_name", cfg.Name) go h.run() return handle, net, nil @@ -380,11 +382,7 @@ func (d *Driver) StopTask(taskID string, timeout time.Duration, signal string) e return drivers.ErrTaskNotFound } - d.logger.Debug("killing task", - "task_name", h.task.Name, - "kill_after", h.killAfter, - "kill_timeout", h.killTimeout, - ) + d.logger.Debug("killing task", "task_name", h.task.Name, "kill_after", h.killAfter) select { case <-h.waitCh: @@ -392,15 +390,20 @@ func (d *Driver) StopTask(taskID string, timeout time.Duration, signal string) e case <-time.After(h.killAfter): d.logger.Debug("killing task due to kill_after", "task_name", h.task.Name) h.kill() - case <-time.After(h.killTimeout): - d.logger.Debug("killing task after kill_timeout", "task_name", h.task.Name) - h.kill() } return nil } func (d *Driver) DestroyTask(taskID string, force bool) error { - //TODO is there anything else to do here? + handle, ok := d.tasks.Get(taskID) + if !ok { + return drivers.ErrTaskNotFound + } + + if handle.IsRunning() && !force { + return fmt.Errorf("cannot destroy running task") + } + d.tasks.Delete(taskID) return nil } @@ -414,8 +417,8 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { return nil, nil } -func (d *Driver) TaskEvents(netctx.Context) (<-chan *drivers.TaskEvent, error) { - panic("not implemented") +func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) { + return d.eventer.TaskEvents(ctx) } func (d *Driver) SignalTask(taskID string, signal string) error { diff --git a/drivers/mock/handle.go b/drivers/mock/handle.go index 214ecba1d96..e20215943a8 100644 --- a/drivers/mock/handle.go +++ b/drivers/mock/handle.go @@ -3,6 +3,7 @@ package mock import ( "context" "io" + "sync" "time" hclog "github.com/hashicorp/go-hclog" @@ -16,7 +17,6 @@ type mockTaskHandle struct { runFor time.Duration killAfter time.Duration - killTimeout time.Duration waitCh chan struct{} exitCode int exitSignal int @@ -26,8 +26,12 @@ type mockTaskHandle struct { stdoutRepeat int stdoutRepeatDur time.Duration - task *drivers.TaskConfig - procState drivers.TaskState + task *drivers.TaskConfig + + // stateLock guards the procState field + stateLock sync.Mutex + procState drivers.TaskState + startedAt time.Time completedAt time.Time exitResult *drivers.ExitResult @@ -37,8 +41,25 @@ type mockTaskHandle struct { killCh <-chan struct{} } +func (h *mockTaskHandle) IsRunning() bool { + h.stateLock.Lock() + defer h.stateLock.Unlock() + return h.procState == drivers.TaskStateRunning +} + func (h *mockTaskHandle) run() { - defer close(h.waitCh) + defer func() { + h.stateLock.Lock() + h.procState = drivers.TaskStateExited + h.stateLock.Unlock() + + h.completedAt = time.Now() + close(h.waitCh) + }() + + h.stateLock.Lock() + h.procState = drivers.TaskStateRunning + h.stateLock.Unlock() errCh := make(chan error, 1) diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index fce320c8d9d..12da7560946 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -22,14 +22,24 @@ import ( "golang.org/x/net/context" ) -// When the package is loaded the driver is registered as an internal plugin -// with the plugin catalog +const ( + // pluginName is the name of the plugin + pluginName = "raw_exec" + + // fingerprintPeriod is the interval at which the driver will send fingerprint responses + fingerprintPeriod = 30 * time.Second +) + var ( + // PluginID is the rawexec plugin metadata registered in the plugin + // catalog. PluginID = loader.PluginID{ Name: pluginName, PluginType: base.PluginTypeDriver, } + // PluginConfig is the rawexec factory function registered in the + // plugin catalog. PluginConfig = &loader.InternalPluginConfig{ Config: map[string]interface{}{}, Factory: func(l hclog.Logger) interface{} { return NewRawExecDriver(l) }, @@ -47,14 +57,6 @@ func PluginLoader(opts map[string]string) (map[string]interface{}, error) { return conf, nil } -const ( - // pluginName is the name of the plugin - pluginName = "raw_exec" - - // fingerprintPeriod is the interval at which the driver will send fingerprint responses - fingerprintPeriod = 30 * time.Second -) - var ( // pluginInfo is the response returned for the PluginInfo RPC pluginInfo = &base.PluginInfoResponse{