diff --git a/service/gcs/bridge/bridge.go b/service/gcs/bridge/bridge.go index 73ad1466..23721d34 100644 --- a/service/gcs/bridge/bridge.go +++ b/service/gcs/bridge/bridge.go @@ -101,9 +101,9 @@ func (b *bridge) loop() error { if err != nil { logrus.Error(err) } - case prot.ComputeSystemTerminateProcessV1: - logrus.Info("received from HCS: ComputeSystemTerminateProcessV1") - response, err = b.terminateProcess(message) + case prot.ComputeSystemSignalProcessV1: + logrus.Info("received from HCS: ComputeSystemSignalProcessV1") + response, err = b.signalProcess(message) if err != nil { logrus.Error(err) } @@ -268,15 +268,15 @@ func (b *bridge) shutdownContainer(message []byte) (*prot.MessageResponseBase, e return response, nil } -func (b *bridge) terminateProcess(message []byte) (*prot.MessageResponseBase, error) { +func (b *bridge) signalProcess(message []byte) (*prot.MessageResponseBase, error) { response := newResponseBase() - var request prot.ContainerTerminateProcess + var request prot.ContainerSignalProcess if err := commonutils.UnmarshalJSONWithHresult(message, &request); err != nil { return response, errors.Wrapf(err, "failed to unmarshal JSON for message \"%s\"", message) } response.ActivityID = request.ActivityID - if err := b.coreint.TerminateProcess(int(request.ProcessID)); err != nil { + if err := b.coreint.SignalProcess(int(request.ProcessID), request.Options); err != nil { return response, err } diff --git a/service/gcs/bridge/bridge_test.go b/service/gcs/bridge/bridge_test.go index f11fc3ba..f975167d 100644 --- a/service/gcs/bridge/bridge_test.go +++ b/service/gcs/bridge/bridge_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "syscall" "github.com/Microsoft/opengcs/service/gcs/core/mockcore" "github.com/Microsoft/opengcs/service/gcs/oslayer" @@ -465,35 +466,39 @@ var _ = Describe("Bridge", func() { }) }) - Describe("calling terminateProcess", func() { + Describe("calling signalProcess", func() { var ( response prot.MessageResponseBase - callArgs mockcore.TerminateProcessCall + callArgs mockcore.SignalProcessCall + options prot.SignalProcessOptions ) BeforeEach(func() { - messageType = prot.ComputeSystemTerminateProcessV1 + messageType = prot.ComputeSystemSignalProcessV1 + options = prot.SignalProcessOptions{Signal: int32(syscall.SIGKILL)} }) JustBeforeEach(func() { response = prot.MessageResponseBase{} err := json.Unmarshal([]byte(responseString), &response) Expect(err).NotTo(HaveOccurred()) responseBase = &response - callArgs = coreint.LastTerminateProcess + callArgs = coreint.LastSignalProcess }) Context("the message is normal ASCII", func() { BeforeEach(func() { - message = prot.ContainerTerminateProcess{ + message = prot.ContainerSignalProcess{ MessageBase: &prot.MessageBase{ ContainerID: containerID, ActivityID: activityID, }, ProcessID: processID, + Options: options, } }) AssertNoResponseErrors() AssertActivityIDCorrect() It("should receive the correct values", func() { Expect(callArgs.Pid).To(Equal(int(processID))) + Expect(callArgs.Options).To(Equal(options)) }) }) }) diff --git a/service/gcs/core/core.go b/service/gcs/core/core.go index e44932c8..2ebf58df 100644 --- a/service/gcs/core/core.go +++ b/service/gcs/core/core.go @@ -16,7 +16,7 @@ type Core interface { CreateContainer(id string, info prot.VMHostedContainerSettings) error ExecProcess(id string, info prot.ProcessParameters, stdioSet *stdio.ConnectionSet) (pid int, err error) SignalContainer(id string, signal oslayer.Signal) error - TerminateProcess(pid int) error + SignalProcess(pid int, options prot.SignalProcessOptions) error ListProcesses(id string) ([]runtime.ContainerProcessState, error) RunExternalProcess(info prot.ProcessParameters, stdioSet *stdio.ConnectionSet) (pid int, err error) ModifySettings(id string, request prot.ResourceModificationRequestResponse) error diff --git a/service/gcs/core/gcs/gcs.go b/service/gcs/core/gcs/gcs.go index 40c7459e..d4d89cd7 100644 --- a/service/gcs/core/gcs/gcs.go +++ b/service/gcs/core/gcs/gcs.go @@ -10,7 +10,6 @@ import ( "path/filepath" "sync" "syscall" - "time" gcserr "github.com/Microsoft/opengcs/service/gcs/errors" "github.com/Microsoft/opengcs/service/gcs/oslayer" @@ -23,10 +22,6 @@ import ( "github.com/pkg/errors" ) -const ( - terminateProcessTimeout = time.Second * 10 -) - // gcsCore is an implementation of the Core interface, defining the // functionality of the GCS. type gcsCore struct { @@ -334,9 +329,8 @@ func (c *gcsCore) SignalContainer(id string, signal oslayer.Signal) error { return nil } -// TerminateProcess sends a SIGTERM signal to the given process. If it does not -// exit after a timeout, it then sends a SIGKILL. -func (c *gcsCore) TerminateProcess(pid int) error { +// SignalProcess sends the signal specified in options to the given process. +func (c *gcsCore) SignalProcess(pid int, options prot.SignalProcessOptions) error { c.processCacheMutex.Lock() c.externalProcessCacheMutex.Lock() if _, ok := c.processCache[pid]; !ok { @@ -349,31 +343,18 @@ func (c *gcsCore) TerminateProcess(pid int) error { c.processCacheMutex.Unlock() c.externalProcessCacheMutex.Unlock() - // First, send the process a SIGTERM. If it doesn't exit before the - // specified timeout, send it a SIGKILL. - exitedChannel := make(chan bool, 1) - exitHook := func(state oslayer.ProcessExitState) { - exitedChannel <- true - } - if err := c.RegisterProcessExitHook(pid, exitHook); err != nil { - return errors.Wrapf(err, "failed to register exit hook during call to TerminateProcess for process %d", pid) - } - if err := c.OS.Kill(pid, syscall.SIGTERM); err != nil { - return errors.Wrapf(err, "failed call to kill on process %d", pid) - } - select { - case <-exitedChannel: // Do nothing. - case <-time.After(terminateProcessTimeout): - // If the timeout is exceeded, kill the process with SIGKILL. - // TODO: Properly handle the race condition between the process exiting - // and Kill being called, so that the error doesn't need to be ignored. - // This can be done by waiting on processes without hanging, locking on - // the waits, and setting a flag to indicate that the process has - // exited. Then, this code can lock on the same lock, and check if the - // process has exited or not before calling Kill. - if err := c.OS.Kill(pid, syscall.SIGKILL); err != nil { - logrus.Error(err) - } + // Interpret signal value 0 as SIGKILL. + // TODO: Remove this special casing when we are not worried about breaking + // older Windows builds which don't support sending signals. + var signal syscall.Signal + if options.Signal == 0 { + signal = syscall.SIGKILL + } else { + signal = syscall.Signal(options.Signal) + } + + if err := c.OS.Kill(pid, signal); err != nil { + return errors.Wrapf(err, "failed call to kill on process %d with signal %d", pid, options.Signal) } return nil diff --git a/service/gcs/core/gcs/gcs_test.go b/service/gcs/core/gcs/gcs_test.go index da51ff45..9eb1fc12 100644 --- a/service/gcs/core/gcs/gcs_test.go +++ b/service/gcs/core/gcs/gcs_test.go @@ -2,6 +2,7 @@ package gcs import ( "fmt" + "syscall" "github.com/Microsoft/opengcs/service/gcs/oslayer" "github.com/Microsoft/opengcs/service/gcs/oslayer/mockos" @@ -845,9 +846,15 @@ var _ = Describe("GCS", func() { }) }) }) - Describe("calling TerminateProcess", func() { + Describe("calling SignalProcess", func() { + var ( + sigkillOptions prot.SignalProcessOptions + ) + BeforeEach(func() { + sigkillOptions = prot.SignalProcessOptions{Signal: int32(syscall.SIGKILL)} + }) JustBeforeEach(func() { - err = coreint.TerminateProcess(processID) + err = coreint.SignalProcess(processID, sigkillOptions) }) Context("the process has already been created", func() { BeforeEach(func() { diff --git a/service/gcs/core/mockcore/mockcore.go b/service/gcs/core/mockcore/mockcore.go index 2dc199a8..876e12db 100644 --- a/service/gcs/core/mockcore/mockcore.go +++ b/service/gcs/core/mockcore/mockcore.go @@ -28,9 +28,10 @@ type SignalContainerCall struct { Signal oslayer.Signal } -// TerminateProcessCall captures the arguments of TerminateProcess. -type TerminateProcessCall struct { - Pid int +// SignalProcessCall captures the arguments of SignalProcess. +type SignalProcessCall struct { + Pid int + Options prot.SignalProcessOptions } // ListProcessesCall captures the arguments of ListProcesses. @@ -71,7 +72,7 @@ type MockCore struct { LastCreateContainer CreateContainerCall LastExecProcess ExecProcessCall LastSignalContainer SignalContainerCall - LastTerminateProcess TerminateProcessCall + LastSignalProcess SignalProcessCall LastListProcesses ListProcessesCall LastRunExternalProcess RunExternalProcessCall LastModifySettings ModifySettingsCall @@ -104,9 +105,12 @@ func (c *MockCore) SignalContainer(id string, signal oslayer.Signal) error { return nil } -// TerminateProcess captures its arguments and returns a nil error. -func (c *MockCore) TerminateProcess(pid int) error { - c.LastTerminateProcess = TerminateProcessCall{Pid: pid} +// SignalProcess captures its arguments and returns a nil error. +func (c *MockCore) SignalProcess(pid int, options prot.SignalProcessOptions) error { + c.LastSignalProcess = SignalProcessCall{ + Pid: pid, + Options: options, + } return nil } diff --git a/service/gcs/prot/protocol.go b/service/gcs/prot/protocol.go index 591d3f32..bc08c7f0 100644 --- a/service/gcs/prot/protocol.go +++ b/service/gcs/prot/protocol.go @@ -73,7 +73,7 @@ const ( ComputeSystemShutdownForcedV1 = 0x10100401 ComputeSystemExecuteProcessV1 = 0x10100501 ComputeSystemWaitForProcessV1 = 0x10100601 - ComputeSystemTerminateProcessV1 = 0x10100701 + ComputeSystemSignalProcessV1 = 0x10100701 ComputeSystemResizeConsoleV1 = 0x10100801 ComputeSystemGetPropertiesV1 = 0x10100901 ComputeSystemModifySettingsV1 = 0x10100a01 @@ -85,7 +85,7 @@ const ( ComputeSystemResponseShutdownForcedV1 = 0x20100401 ComputeSystemResponseExecuteProcessV1 = 0x20100501 ComputeSystemResponseWaitForProcessV1 = 0x20100601 - ComputeSystemResponseTerminateProcessV1 = 0x20100701 + ComputeSystemResponseSignalProcessV1 = 0x20100701 ComputeSystemResponseResizeConsoleV1 = 0x20100801 ComputeSystemResponseGetPropertiesV1 = 0x20100901 ComputeSystemResponseModifySettingsV1 = 0x20100a01 @@ -225,11 +225,12 @@ type ContainerWaitForProcess struct { TimeoutInMs uint32 } -// ContainerTerminateProcess is the message from the HCS specifying to kill the -// given process. -type ContainerTerminateProcess struct { +// ContainerSignalProcess is the message from the HCS specifying to send a +// signal to the given process. +type ContainerSignalProcess struct { *MessageBase - ProcessID uint32 `json:"ProcessId"` + ProcessID uint32 `json:"ProcessId"` + Options SignalProcessOptions `json:",omitempty"` } // ContainerGetProperties is the message from the HCS requesting certain @@ -474,3 +475,8 @@ type ProcessParameters struct { // be specified. OCISpecification oci.Spec `json:"OciSpecification,omitempty"` } + +// SignalProcessOptions represents the options for signaling a process. +type SignalProcessOptions struct { + Signal int32 +}