From 3dd80d8eefa589a9639d915793d0802654fa98e3 Mon Sep 17 00:00:00 2001 From: Pavlo Sumkin Date: Fri, 20 Dec 2024 17:48:16 +0100 Subject: [PATCH] GROUNDWORK-3850 prevent sending improper payload --- services/agent.go | 56 ++++++++++++++++++++++++++++++------------ services/agent_test.go | 5 ++-- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/services/agent.go b/services/agent.go index 7c8bfee1..12c3f721 100644 --- a/services/agent.go +++ b/services/agent.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/json" "expvar" + "fmt" "net/http" "os" "os/signal" @@ -541,6 +542,16 @@ func (service *AgentService) stopNats() error { } func (service *AgentService) startTransport() error { + if service.Connector.AgentID == traceOnDemandAgentID || + service.Connector.AppType == traceOnDemandAppType || + len(service.Connector.AgentID) == 0 || + len(service.Connector.AppType) == 0 { + + err := fmt.Errorf("connector is not configured: AppType/AgentID: %v/%v", + service.Connector.AppType, service.Connector.AgentID) + log.Err(err).Msg("could not start") + return err + } /* Process clients */ gwClients := make([]clients.GWClient, 0, len(config.GetConfig().GWConnections)) for i := range config.GetConfig().GWConnections { @@ -580,28 +591,41 @@ func (service *AgentService) stopTransport() error { // mixTracerContext adds `context` field if absent func (service *AgentService) mixTracerContext(payloadJSON []byte) ([]byte, bool) { - if !bytes.Contains(payloadJSON, []byte(`"context":`)) || !bytes.Contains(payloadJSON, []byte(`"traceToken":`)) { - tc, todoTracerCtx := service.MakeTracerContext(), false - ctxJSON, err := json.Marshal(tc) - if err != nil { - log.Err(err).Msg("could not mixTracerContext") - return payloadJSON, false - } - if tc.AgentID == traceOnDemandAgentID || - tc.AppType == traceOnDemandAppType { - todoTracerCtx = true - } + if bytes.Contains(payloadJSON, []byte(`"context":`)) && + bytes.Contains(payloadJSON, []byte(`"traceToken":`)) { + return payloadJSON, false + } - l := bytes.LastIndexByte(payloadJSON, byte('}')) - return bytes.Join([][]byte{ - payloadJSON[:l], []byte(`,"context":`), ctxJSON, []byte(`}`), - }, []byte(``)), todoTracerCtx + tc, todoTracerCtx := service.MakeTracerContext(), false + ctxJSON, err := json.Marshal(tc) + if err != nil { + log.Err(err).Msg("could not mixTracerContext") + return payloadJSON, false } - return payloadJSON, false + if tc.AgentID == traceOnDemandAgentID || + tc.AppType == traceOnDemandAppType { + todoTracerCtx = true + } + + l := bytes.LastIndexByte(payloadJSON, byte('}')) + return bytes.Join([][]byte{ + payloadJSON[:l], []byte(`,"context":`), ctxJSON, []byte(`}`), + }, []byte(``)), todoTracerCtx } // fixTracerContext replaces placeholders func (service *AgentService) fixTracerContext(payloadJSON []byte) []byte { + if service.Connector.AgentID == traceOnDemandAgentID || + service.Connector.AppType == traceOnDemandAppType || + len(service.Connector.AgentID) == 0 || + len(service.Connector.AppType) == 0 { + + err := fmt.Errorf("connector is not configured: AppType/AgentID: %v/%v", + service.Connector.AppType, service.Connector.AgentID) + log.Err(err).Msg("could not fixTracerContext") + return payloadJSON + } + return bytes.ReplaceAll( bytes.ReplaceAll( payloadJSON, diff --git a/services/agent_test.go b/services/agent_test.go index b63c603a..14e0a7f8 100644 --- a/services/agent_test.go +++ b/services/agent_test.go @@ -38,7 +38,7 @@ func TestAgentService(t *testing.T) { }) t.Run("NATS", func(t *testing.T) { - t.Setenv("TCG_CONNECTOR_NATSSTOREMAXBYTES", "333_222_111_000") + GetAgentService().Connector.NatsStoreMaxBytes = 333_222_111_000 assert.NoError(t, GetAgentService().StartNats()) assert.NoError(t, GetAgentService().StopNats()) assert.NoError(t, GetAgentService().StartNats()) @@ -46,7 +46,8 @@ func TestAgentService(t *testing.T) { }) t.Run("Transport", func(t *testing.T) { - t.Setenv("TCG_CONNECTOR_NATSSTOREMAXBYTES", "333_222_111_000") + GetAgentService().Connector.AgentID = "TESTAGENTID" + GetAgentService().Connector.AppType = "TESTAPPTYPE" assert.NoError(t, GetAgentService().StartNats()) assert.NoError(t, GetAgentService().StartTransport()) assert.NoError(t, GetAgentService().StopTransport())