diff --git a/client/v1/client.go b/client/v1/client.go index 2ad9a45b3..da13ff7d6 100644 --- a/client/v1/client.go +++ b/client/v1/client.go @@ -3,6 +3,7 @@ package client import ( "bytes" + "context" "crypto/tls" "encoding/json" "fmt" @@ -12,6 +13,7 @@ import ( "net/url" "path" "strconv" + "sync" "time" "github.com/influxdata/influxdb/influxql" @@ -660,7 +662,7 @@ func (c *Client) Do(req *http.Request, result interface{}, codes ...int) (*http. return resp, nil } -func (c *Client) Logs(w io.Writer, q map[string]string) error { +func (c *Client) Logs(ctx context.Context, w io.Writer, q map[string]string) error { u := c.BaseURL() u.Path = logsPath @@ -674,6 +676,7 @@ func (c *Client) Logs(w io.Writer, q map[string]string) error { if err != nil { return err } + req = req.WithContext(ctx) err = c.prepRequest(req) if err != nil { return err @@ -688,7 +691,15 @@ func (c *Client) Logs(w io.Writer, q map[string]string) error { return fmt.Errorf("bad status code %v", resp.StatusCode) } - _, err = io.Copy(w, resp.Body) + var wg sync.WaitGroup + wg.Add(1) + go func() { + _, err = io.Copy(w, resp.Body) + wg.Done() + }() + + <-ctx.Done() + wg.Wait() return err } diff --git a/cmd/kapacitor/main.go b/cmd/kapacitor/main.go index c451b3a11..086a470e9 100644 --- a/cmd/kapacitor/main.go +++ b/cmd/kapacitor/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "encoding/json" "flag" "fmt" @@ -10,10 +11,12 @@ import ( "log" "net/http" "os" + "os/signal" "path" "sort" "strconv" "strings" + "syscall" "time" humanize "github.com/dustin/go-humanize" @@ -2260,19 +2263,43 @@ func doBackup(args []string) error { } func watchUsage() { - var u = `Usage: kapacitor watch + var u = `Usage: kapacitor watch [ ...] Watch logs associated with a task. + + Examples: + + $ kapacitor logs mytask + $ kapacitor logs mytask node=log5 ` fmt.Fprintln(os.Stderr, u) } func doWatch(args []string) error { - if len(args) != 1 { + m := map[string]string{} + if len(args) < 1 { return errors.New("must provide task ID.") } - err := cli.Logs(os.Stdout, map[string]string{"task": args[0]}) - if err != nil { + m["task"] = args[0] + for _, s := range args[1:] { + pair := strings.Split(s, "=") + if len(pair) != 2 { + return fmt.Errorf("bad keyvalue pair: '%v'", s) + } + m[pair[0]] = pair[1] + } + + ctx, cancel := context.WithCancel(context.Background()) + cancelled := false + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT) + go func() { + <-sigs + cancel() + cancelled = true + }() + + if err := cli.Logs(ctx, os.Stdout, m); err != nil && !cancelled { return errors.Wrap(err, "failed writing logs") } return nil @@ -2297,8 +2324,17 @@ func doLogs(args []string) error { } m[pair[0]] = pair[1] } - err := cli.Logs(os.Stdout, m) - if err != nil { + ctx, cancel := context.WithCancel(context.Background()) + cancelled := false + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT) + go func() { + <-sigs + cancel() + cancelled = true + }() + + if err := cli.Logs(ctx, os.Stdout, m); err != nil && !cancelled { return errors.Wrap(err, "failed writing logs") } return nil