Skip to content

Commit

Permalink
Merge pull request #1494 from hashicorp/child-set-process-group
Browse files Browse the repository at this point in the history
Add setpgid for all called commands
  • Loading branch information
eikenb authored Jul 28, 2021
2 parents c39bb42 + 7d17648 commit 518b957
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 51 deletions.
64 changes: 43 additions & 21 deletions child/child.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ type Child struct {
stopLock sync.RWMutex
stopCh chan struct{}
stopped bool

// whether to set process group id or not (default on)
setpgid bool
}

// NewInput is input to the NewChild function.
Expand Down Expand Up @@ -135,6 +138,7 @@ func New(i *NewInput) (*Child, error) {
killTimeout: i.KillTimeout,
splay: i.Splay,
stopCh: make(chan struct{}, 1),
setpgid: true,
}

return child, nil
Expand Down Expand Up @@ -264,6 +268,7 @@ func (c *Child) start() error {
cmd.Stdout = c.stdout
cmd.Stderr = c.stderr
cmd.Env = c.env
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: c.setpgid}
if err := cmd.Start(); err != nil {
return err
}
Expand Down Expand Up @@ -353,7 +358,18 @@ func (c *Child) signal(s os.Signal) error {
if !c.running() {
return nil
}
return c.cmd.Process.Signal(s)

sig, ok := s.(syscall.Signal)
if !ok {
return fmt.Errorf("bad signal: %s", s)
}
pid := c.cmd.Process.Pid
if c.setpgid {
// kill takes negative pid to indicate that you want to use gpid
pid = -(pid)
}
// kill is syscall's only signal API
return syscall.Kill(pid, sig)
}

func (c *Child) reload() error {
Expand Down Expand Up @@ -381,32 +397,38 @@ func (c *Child) kill(immediately bool) {
}
}

exited := false
process := c.cmd.Process
var exited bool
defer func() {
if !exited {
c.cmd.Process.Kill()
}
c.cmd = nil
}()

if c.killSignal != nil {
if err := process.Signal(c.killSignal); err == nil {
// Wait a few seconds for it to exit
killCh := make(chan struct{}, 1)
go func() {
defer close(killCh)
process.Wait()
}()
if c.killSignal == nil {
return
}

select {
case <-c.stopCh:
case <-killCh:
exited = true
case <-time.After(c.killTimeout):
}
if err := c.signal(c.killSignal); err != nil {
log.Printf("[ERR] (child) Kill failed: %s", err)
if err == syscall.ESRCH {
exited = true // ESRCH == no such process, ie. already exited
}
return
}

if !exited {
process.Kill()
}
killCh := make(chan struct{}, 1)
go func() {
defer close(killCh)
c.cmd.Process.Wait()
}()

c.cmd = nil
select {
case <-c.stopCh:
case <-killCh:
exited = true
case <-time.After(c.killTimeout):
}
}

func (c *Child) running() bool {
Expand Down
94 changes: 64 additions & 30 deletions child/child_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/hashicorp/go-gatedio"
)

const fileWaitSleepDelay = 500 * time.Millisecond
const fileWaitSleepDelay = 50 * time.Millisecond

func testChild(t *testing.T) *Child {
c, err := New(&NewInput{
Expand Down Expand Up @@ -207,11 +207,11 @@ func TestSignal(t *testing.T) {
t.Parallel()

c := testChild(t)
c.command = "bash"
c.args = []string{"-c", "trap 'echo one; exit' SIGUSR1; while true; do sleep 0.2; done"}
c.command = "sh"
c.args = []string{"-c", "trap 'echo one; exit' USR1; while true; do sleep 0.2; done"}

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out
c.stdout = out

if err := c.Start(); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -248,12 +248,12 @@ func TestReload_signal(t *testing.T) {
t.Parallel()

c := testChild(t)
c.command = "bash"
c.args = []string{"-c", "trap 'echo one; exit' SIGUSR1; while true; do sleep 0.2; done"}
c.command = "sh"
c.args = []string{"-c", "trap 'echo one; exit' USR1; while true; do sleep 0.2; done"}
c.reloadSignal = syscall.SIGUSR1

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out
c.stdout = out

if err := c.Start(); err != nil {
t.Fatal(err)
Expand All @@ -280,14 +280,11 @@ func TestReload_noSignal(t *testing.T) {
t.Parallel()

c := testChild(t)
c.command = "bash"
c.command = "sh"
c.args = []string{"-c", "while true; do sleep 0.2; done"}
c.killTimeout = 10 * time.Millisecond
c.reloadSignal = nil

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out

if err := c.Start(); err != nil {
t.Fatal(err)
}
Expand All @@ -309,9 +306,6 @@ func TestReload_noSignal(t *testing.T) {
// Get the new pid
npid := c.cmd.Process.Pid

// Stop the child now
c.Stop()

if opid == npid {
t.Error("expected new process to restart")
}
Expand All @@ -331,12 +325,12 @@ func TestKill_signal(t *testing.T) {
t.Parallel()

c := testChild(t)
c.command = "bash"
c.args = []string{"-c", "trap 'echo one; exit' SIGUSR1; while true; do sleep 0.2; done"}
c.command = "sh"
c.args = []string{"-c", "trap 'echo one; exit' USR1; while true; do sleep 0.2; done"}
c.killSignal = syscall.SIGUSR1

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out
c.stdout = out

if err := c.Start(); err != nil {
t.Fatal(err)
Expand All @@ -361,14 +355,11 @@ func TestKill_noSignal(t *testing.T) {
t.Parallel()

c := testChild(t)
c.command = "bash"
c.command = "sh"
c.args = []string{"-c", "while true; do sleep 0.2; done"}
c.killTimeout = 20 * time.Millisecond
c.killSignal = nil

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out

if err := c.Start(); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -398,14 +389,14 @@ func TestKill_noProcess(t *testing.T) {
func TestStop_noWaitForSplay(t *testing.T) {
t.Parallel()
c := testChild(t)
c.command = "bash"
c.args = []string{"-c", "trap 'echo one; exit' SIGUSR1; while true; do sleep 0.2; done"}
c.command = "sh"
c.args = []string{"-c", "trap 'echo one; exit' USR1; while true; do sleep 0.2; done"}
c.splay = 100 * time.Second
c.reloadSignal = nil
c.killSignal = syscall.SIGUSR1

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out
c.stdout = out

if err := c.Start(); err != nil {
t.Fatal(err)
Expand All @@ -423,23 +414,20 @@ func TestStop_noWaitForSplay(t *testing.T) {
t.Errorf("expected %q to be %q", out.String(), expected)
}

if killEndTime.Sub(killStartTime) > 500*time.Millisecond {
if killEndTime.Sub(killStartTime) > fileWaitSleepDelay {
t.Error("expected not to wait for splay")
}
}

func TestStop_childAlreadyDead(t *testing.T) {
t.Parallel()
c := testChild(t)
c.command = "bash"
c.command = "sh"
c.args = []string{"-c", "exit 1"}
c.splay = 100 * time.Second
c.reloadSignal = nil
c.killSignal = syscall.SIGTERM

out := gatedio.NewByteBuffer()
c.stdout, c.stderr = out, out

if err := c.Start(); err != nil {
t.Fatal(err)
}
Expand All @@ -451,7 +439,53 @@ func TestStop_childAlreadyDead(t *testing.T) {
c.Stop()
killEndTime := time.Now()

if killEndTime.Sub(killStartTime) > 500*time.Millisecond {
if killEndTime.Sub(killStartTime) > fileWaitSleepDelay {
t.Error("expected not to wait for splay")
}
}

func TestSetpgid(t *testing.T) {
t.Run("true", func(t *testing.T) {
c := testChild(t)
c.command = "sh"
c.args = []string{"-c", "while true; do sleep 0.2; done"}
// default, but to be explicit for the test
c.setpgid = true

if err := c.Start(); err != nil {
t.Fatal(err)
}
defer c.Stop()

// when setpgid is true, the pid and gpid should be the same
gpid, err := syscall.Getpgid(c.Pid())
if err != nil {
t.Fatal("Getpgid error:", err)
}

if c.Pid() != gpid {
t.Fatal("pid and gpid should match")
}
})
t.Run("false", func(t *testing.T) {
c := testChild(t)
c.command = "sh"
c.args = []string{"-c", "while true; do sleep 0.2; done"}
c.setpgid = false

if err := c.Start(); err != nil {
t.Fatal(err)
}
defer c.Stop()

// when setpgid is true, the pid and gpid should be the same
gpid, err := syscall.Getpgid(c.Pid())
if err != nil {
t.Fatal("Getpgid error:", err)
}

if c.Pid() == gpid {
t.Fatal("pid and gpid should NOT match")
}
})
}

0 comments on commit 518b957

Please sign in to comment.