diff --git a/agent/agent.go b/agent/agent.go index b72d2dd..55cf485 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -18,6 +18,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/nakabonne/gosivy/process" "github.com/nakabonne/gosivy/stats" @@ -124,11 +125,10 @@ func gracefulShutdown() { } func listen() { - sig := make([]byte, 1) for { conn, err := listener.Accept() if err != nil { - // TODO: Find better way to check for closed connection, see: https://golang.org/issues/4373. + // TODO: Use net.ErrClosed after upgrading Go1.16, see: https://golang.org/issues/4373. if !strings.Contains(err.Error(), "use of closed network connection") { fmt.Fprintf(logWriter, "gosivy: %v\n", err) } @@ -137,43 +137,52 @@ func listen() { } continue } - if _, err := conn.Read(sig); err != nil { - fmt.Fprintf(logWriter, "gosivy: %v\n", err) - continue - } - if err := handle(conn, sig); err != nil { - fmt.Fprintf(logWriter, "gosivy: %v\n", err) - continue - } - conn.Close() + fmt.Fprintf(logWriter, "gosivy: accept %v\n", conn.RemoteAddr()) + go func() { + if err := handle(conn); err != nil { + fmt.Fprintf(logWriter, "gosivy: %v\n", err) + } + }() } } -func handle(conn io.ReadWriter, msg []byte) error { - switch msg[0] { - case stats.SignalMeta: - meta, err := stats.NewMeta() - if err != nil { - return err - } - b, err := json.Marshal(meta) - if err != nil { - return err - } - _, err = conn.Write(b) - return err - case stats.SignalStats: - s, err := stats.NewStats() - if err != nil { +// handle keeps using the given connection until an issue occurred. +func handle(conn net.Conn) error { + defer conn.Close() + + for { + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + sig := make([]byte, 1) + if _, err := conn.Read(sig); err != nil { return err } - b, err := json.Marshal(s) - if err != nil { - return err + switch sig[0] { + case stats.SignalMeta: + meta, err := stats.NewMeta() + if err != nil { + return err + } + b, err := json.Marshal(meta) + if err != nil { + return err + } + if _, err := conn.Write(append(b, stats.Delimiter)); err != nil { + return err + } + case stats.SignalStats: + s, err := stats.NewStats() + if err != nil { + return err + } + b, err := json.Marshal(s) + if err != nil { + return err + } + if _, err := conn.Write(append(b, stats.Delimiter)); err != nil { + return err + } + default: + return fmt.Errorf("unknown signal received: %b", sig[0]) } - _, err = conn.Write(b) - return err - default: - return fmt.Errorf("unknown signal received: %b", msg[0]) } } diff --git a/agent/agent_test.go b/agent/agent_test.go index 9bdb072..ca89304 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1,13 +1,10 @@ package agent import ( - "bytes" "os" "testing" "github.com/stretchr/testify/assert" - - "github.com/nakabonne/gosivy/stats" ) func TestListenAndClose(t *testing.T) { @@ -19,34 +16,3 @@ func TestListenAndClose(t *testing.T) { assert.True(t, os.IsNotExist(err)) assert.Empty(t, pidFile) } - -func TestHandle(t *testing.T) { - tests := []struct { - name string - signal byte - wantErr bool - }{ - { - name: "signal meta received", - signal: stats.SignalMeta, - wantErr: false, - }, - { - name: "signal stats received", - signal: stats.SignalStats, - wantErr: false, - }, - { - name: "unknown signal received", - signal: byte(0x9), - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := new(bytes.Buffer) - err := handle(b, []byte{tt.signal}) - assert.Equal(t, tt.wantErr, err != nil) - }) - } -} diff --git a/diagnoser/diagnoser.go b/diagnoser/diagnoser.go index 665e0f9..7530c0e 100644 --- a/diagnoser/diagnoser.go +++ b/diagnoser/diagnoser.go @@ -3,9 +3,10 @@ package diagnoser import ( + "bufio" "context" "encoding/json" - "io/ioutil" + "fmt" "net" "time" @@ -57,24 +58,25 @@ func (d *diagnoser) Run() error { func (d *diagnoser) startScraping(ctx context.Context, statsCh chan<- *stats.Stats) (*stats.Meta, error) { conn, err := net.DialTCP("tcp", nil, d.addr) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to dial TCP: %w", err) } + // First up, fetch meta data of process, - buf := []byte{stats.SignalMeta} - if _, err := conn.Write(buf); err != nil { + if _, err := conn.Write([]byte{stats.SignalMeta}); err != nil { return nil, err } - res, err := ioutil.ReadAll(conn) + reader := bufio.NewReader(conn) + res, err := reader.ReadBytes(stats.Delimiter) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to read metadata: %w", err) } - conn.Close() var meta stats.Meta if err := json.Unmarshal(res, &meta); err != nil { - return nil, err + return nil, fmt.Errorf("failed to decode metadata: %w", err) } go func(ctx context.Context, ch chan<- *stats.Stats) { + defer conn.Close() tick := time.NewTicker(d.scrapeInterval) defer tick.Stop() for { @@ -82,24 +84,26 @@ func (d *diagnoser) startScraping(ctx context.Context, statsCh chan<- *stats.Sta case <-ctx.Done(): return case <-tick.C: - // TODO: Reuse connections instead of creating each time. - conn, err := net.DialTCP("tcp", nil, d.addr) - if err != nil { - logrus.Errorf("failed to create connection: %v", err) - continue + if conn == nil { + conn, err = net.DialTCP("tcp", nil, d.addr) + if err != nil { + logrus.Errorf("failed to dial: %v", err) + continue + } } - buf := []byte{stats.SignalStats} - if _, err := conn.Write(buf); err != nil { + if _, err := conn.Write([]byte{stats.SignalStats}); err != nil { logrus.Errorf("failed to write into connection: %v", err) + conn = nil continue } - res, err := ioutil.ReadAll(conn) + reader.Reset(conn) + res, err := reader.ReadBytes(stats.Delimiter) if err != nil { logrus.Errorf("failed to read the response: %v", err) + conn = nil continue } - conn.Close() var stats stats.Stats if err := json.Unmarshal(res, &stats); err != nil { diff --git a/diagnoser/diagnoser_test.go b/diagnoser/diagnoser_test.go index 39929e8..98e981d 100644 --- a/diagnoser/diagnoser_test.go +++ b/diagnoser/diagnoser_test.go @@ -40,10 +40,10 @@ func startServer() *net.TCPAddr { switch sig[0] { case stats.SignalMeta: b, _ := json.Marshal(&stats.Meta{}) - _, _ = conn.Write(b) + _, _ = conn.Write(append(b, stats.Delimiter)) case stats.SignalStats: b, _ := json.Marshal(&stats.Stats{}) - _, _ = conn.Write(b) + _, _ = conn.Write(append(b, stats.Delimiter)) } conn.Close() } diff --git a/stats/signal.go b/stats/signal.go index 745a0b7..e5099af 100644 --- a/stats/signal.go +++ b/stats/signal.go @@ -8,4 +8,7 @@ const ( // SignalStats reports Go process stats. SignalStats = byte(0x2) + + // Delimiter indicates to complete the writing. + Delimiter = '\n' )