Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change callback signatures to pass down context where applicable #247

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion client/clientimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,24 @@ func TestAgentIdentification(t *testing.T) {
ulid.Timestamp(time.Now()), ulid.Monotonic(rand.New(rand.NewSource(0)), 0),
)
var rcvAgentInstanceUid atomic.Value
var sentInvalidId atomic.Bool
srv.OnMessage = func(msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
rcvAgentInstanceUid.Store(msg.InstanceUid)
if sentInvalidId.Load() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codecov was complaining that this was untested, so i added something that forced an error for an empty string and then resolves it. This will trigger the error checking logic.

return &protobufs.ServerToAgent{
InstanceUid: msg.InstanceUid,
AgentIdentification: &protobufs.AgentIdentification{
// If we sent the invalid one first, send a valid one now
NewInstanceUid: newInstanceUid.String(),
},
}
}
sentInvalidId.Store(true)
return &protobufs.ServerToAgent{
InstanceUid: msg.InstanceUid,
AgentIdentification: &protobufs.AgentIdentification{
NewInstanceUid: newInstanceUid.String(),
// Start by sending an invalid id forcing an error.
NewInstanceUid: "",
},
}
}
Expand Down Expand Up @@ -660,6 +672,21 @@ func TestAgentIdentification(t *testing.T) {
// Send a dummy message
_ = client.SetAgentDescription(createAgentDescr())

// Verify that the old instance id was not overridden
eventually(
t,
func() bool {
instanceUid, ok := rcvAgentInstanceUid.Load().(string)
if !ok {
return false
}
return instanceUid == oldInstanceUid
},
)

// Send a dummy message again to get the _new_ id
_ = client.SetAgentDescription(createAgentDescr())

// When it was sent, the new instance uid should have been used, which should
// have been observed by the Server
eventually(
Expand Down
18 changes: 10 additions & 8 deletions client/internal/receivedprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@
}

if msg.AgentIdentification != nil {
err := r.rcvAgentIdentification(msg.AgentIdentification)
if err == nil {
err := r.rcvAgentIdentification(ctx, msg.AgentIdentification)
if err != nil {
r.logger.Errorf(ctx, "Failed to set agent ID: %v", err)
} else {
msgData.AgentIdentification = msg.AgentIdentification
}
}
Expand All @@ -146,7 +148,7 @@

err := msg.GetErrorResponse()
if err != nil {
r.processErrorResponse(err)
r.processErrorResponse(ctx, err)

Check warning on line 151 in client/internal/receivedprocessor.go

View check run for this annotation

Codecov / codecov/patch

client/internal/receivedprocessor.go#L151

Added line #L151 was not covered by tests
}
}

Expand Down Expand Up @@ -203,21 +205,21 @@
}
}

func (r *receivedProcessor) processErrorResponse(body *protobufs.ServerErrorResponse) {
func (r *receivedProcessor) processErrorResponse(ctx context.Context, body *protobufs.ServerErrorResponse) {

Check warning on line 208 in client/internal/receivedprocessor.go

View check run for this annotation

Codecov / codecov/patch

client/internal/receivedprocessor.go#L208

Added line #L208 was not covered by tests
// TODO: implement this.
r.logger.Errorf(context.Background(), "received an error from server: %s", body.ErrorMessage)
r.logger.Errorf(ctx, "received an error from server: %s", body.ErrorMessage)

Check warning on line 210 in client/internal/receivedprocessor.go

View check run for this annotation

Codecov / codecov/patch

client/internal/receivedprocessor.go#L210

Added line #L210 was not covered by tests
}

func (r *receivedProcessor) rcvAgentIdentification(agentId *protobufs.AgentIdentification) error {
func (r *receivedProcessor) rcvAgentIdentification(ctx context.Context, agentId *protobufs.AgentIdentification) error {
if agentId.NewInstanceUid == "" {
err := errors.New("empty instance uid is not allowed")
r.logger.Debugf(context.Background(), err.Error())
r.logger.Debugf(ctx, err.Error())
return err
}

err := r.sender.SetInstanceUid(agentId.NewInstanceUid)
if err != nil {
r.logger.Errorf(context.Background(), "Error while setting instance uid: %v", err)
r.logger.Errorf(ctx, "Error while setting instance uid: %v", err)

Check warning on line 222 in client/internal/receivedprocessor.go

View check run for this annotation

Codecov / codecov/patch

client/internal/receivedprocessor.go#L222

Added line #L222 was not covered by tests
return err
}

Expand Down
16 changes: 8 additions & 8 deletions internal/examples/agent/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ func (agent *Agent) connect() error {
return nil
}

func (agent *Agent) disconnect() {
agent.logger.Debugf(context.Background(), "Disconnecting from server...")
agent.opampClient.Stop(context.Background())
func (agent *Agent) disconnect(ctx context.Context) {
agent.logger.Debugf(ctx, "Disconnecting from server...")
agent.opampClient.Stop(ctx)
}

func (agent *Agent) createAgentIdentity() {
Expand Down Expand Up @@ -209,8 +209,8 @@ func (agent *Agent) createAgentIdentity() {
}
}

func (agent *Agent) updateAgentIdentity(instanceId ulid.ULID) {
agent.logger.Debugf(context.Background(), "Agent identify is being changed from id=%v to id=%v",
func (agent *Agent) updateAgentIdentity(ctx context.Context, instanceId ulid.ULID) {
agent.logger.Debugf(ctx, "Agent identify is being changed from id=%v to id=%v",
agent.instanceId.String(),
instanceId.String())
agent.instanceId = instanceId
Expand Down Expand Up @@ -463,13 +463,13 @@ func (agent *Agent) onMessage(ctx context.Context, msg *types.MessageData) {
if err != nil {
agent.logger.Errorf(ctx, err.Error())
}
agent.updateAgentIdentity(newInstanceId)
agent.updateAgentIdentity(ctx, newInstanceId)
}

if configChanged {
err := agent.opampClient.UpdateEffectiveConfig(ctx)
if err != nil {
agent.logger.Errorf(context.Background(), err.Error())
agent.logger.Errorf(ctx, err.Error())
}
}

Expand All @@ -486,7 +486,7 @@ func (agent *Agent) onMessage(ctx context.Context, msg *types.MessageData) {
func (agent *Agent) tryChangeOpAMPCert(ctx context.Context, cert *tls.Certificate) {
agent.logger.Debugf(ctx, "Reconnecting to verify offered client certificate.\n")

agent.disconnect()
agent.disconnect(ctx)

agent.opampClientCert = cert
if err := agent.connect(); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/examples/server/opampsrv/opampsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (srv *Server) onDisconnect(conn types.Connection) {
srv.agents.RemoveConnection(conn)
}

func (srv *Server) onMessage(conn types.Connection, msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
func (srv *Server) onMessage(ctx context.Context, conn types.Connection, msg *protobufs.AgentToServer) *protobufs.ServerToAgent {
instanceId := data.InstanceId(msg.InstanceUid)

agent := srv.agents.FindOrCreateAgent(instanceId, conn)
Expand Down
14 changes: 7 additions & 7 deletions internal/examples/supervisor/supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ service:
s.agentConfigOwnMetricsSection.Store(cfg)

// Need to recalculate the Agent config so that the metric config is included in it.
configChanged, err := s.recalcEffectiveConfig()
configChanged, err := s.recalcEffectiveConfig(ctx)
if err != nil {
return
}
Expand All @@ -327,7 +327,7 @@ service:
// composeEffectiveConfig composes the effective config from multiple sources:
// 1) the remote config from OpAMP Server, 2) the own metrics config section,
// 3) the local override config that is hard-coded in the Supervisor.
func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig) (configChanged bool, err error) {
func (s *Supervisor) composeEffectiveConfig(ctx context.Context, config *protobufs.AgentRemoteConfig) (configChanged bool, err error) {
var k = koanf.New(".")

// Begin with empty config. We will merge received configs on top of it.
Expand Down Expand Up @@ -387,7 +387,7 @@ func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig)
newEffectiveConfig := string(effectiveConfigBytes)
configChanged = false
if s.effectiveConfig.Load().(string) != newEffectiveConfig {
s.logger.Debugf(context.Background(), "Effective config changed.")
s.logger.Debugf(ctx, "Effective config changed.")
s.effectiveConfig.Store(newEffectiveConfig)
configChanged = true
}
Expand All @@ -397,11 +397,11 @@ func (s *Supervisor) composeEffectiveConfig(config *protobufs.AgentRemoteConfig)

// Recalculate the Agent's effective config and if the config changes signal to the
// background goroutine that the config needs to be applied to the Agent.
func (s *Supervisor) recalcEffectiveConfig() (configChanged bool, err error) {
func (s *Supervisor) recalcEffectiveConfig(ctx context.Context) (configChanged bool, err error) {

configChanged, err = s.composeEffectiveConfig(s.remoteConfig)
configChanged, err = s.composeEffectiveConfig(ctx, s.remoteConfig)
if err != nil {
s.logger.Errorf(context.Background(), "Error composing effective config. Ignoring received config: %v", err)
s.logger.Errorf(ctx, "Error composing effective config. Ignoring received config: %v", err)
return configChanged, err
}

Expand Down Expand Up @@ -553,7 +553,7 @@ func (s *Supervisor) onMessage(ctx context.Context, msg *types.MessageData) {
s.logger.Debugf(ctx, "Received remote config from server, hash=%x.", s.remoteConfig.ConfigHash)

var err error
configChanged, err = s.recalcEffectiveConfig()
configChanged, err = s.recalcEffectiveConfig(ctx)
if err != nil {
s.opampClient.SetRemoteConfigStatus(&protobufs.RemoteConfigStatus{
LastRemoteConfigHash: msg.RemoteConfig.ConfigHash,
Expand Down
13 changes: 7 additions & 6 deletions server/callbacks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"net/http"

"github.com/open-telemetry/opamp-go/protobufs"
Expand All @@ -27,25 +28,25 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes
// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows
// to override only the methods that are needed.
type ConnectionCallbacksStruct struct {
OnConnectedFunc func(conn types.Connection)
OnMessageFunc func(conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
}

var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil)

// OnConnected implements ConnectionCallbacks.OnConnected.
func (c ConnectionCallbacksStruct) OnConnected(conn types.Connection) {
func (c ConnectionCallbacksStruct) OnConnected(ctx context.Context, conn types.Connection) {
if c.OnConnectedFunc != nil {
c.OnConnectedFunc(conn)
c.OnConnectedFunc(ctx, conn)
}
}

// OnMessage implements ConnectionCallbacks.OnMessage.
// If OnMessageFunc is nil then it will send an empty response to the agent
func (c ConnectionCallbacksStruct) OnMessage(conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
func (c ConnectionCallbacksStruct) OnMessage(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent {
if c.OnMessageFunc != nil {
return c.OnMessageFunc(conn, message)
return c.OnMessageFunc(ctx, conn, message)
} else {
// We will send an empty response since there is no user-defined callback to handle it.
return &protobufs.ServerToAgent{
Expand Down
49 changes: 25 additions & 24 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,16 @@
// No, it is a WebSocket. Upgrade it.
conn, err := s.wsUpgrader.Upgrade(w, req, nil)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot upgrade HTTP connection to WebSocket: %v", err)
s.logger.Errorf(req.Context(), "Cannot upgrade HTTP connection to WebSocket: %v", err)

Check warning on line 182 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L182

Added line #L182 was not covered by tests
return
}

// Return from this func to reduce memory usage.
// Handle the connection on a separate goroutine.
go s.handleWSConnection(conn, connectionCallbacks)
go s.handleWSConnection(req.Context(), conn, connectionCallbacks)
}

func (s *server) handleWSConnection(wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) {
func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Conn, connectionCallbacks serverTypes.ConnectionCallbacks) {
agentConn := wsConnection{wsConn: wsConn, connMutex: &sync.Mutex{}}

defer func() {
Expand All @@ -206,43 +206,44 @@
}()

if connectionCallbacks != nil {
connectionCallbacks.OnConnected(agentConn)
connectionCallbacks.OnConnected(reqCtx, agentConn)
}

// Loop until fail to read from the WebSocket connection.
for {
msgContext := context.Background()
// Block until the next message can be read.
mt, bytes, err := wsConn.ReadMessage()
mt, msgBytes, err := wsConn.ReadMessage()
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(context.Background(), "Cannot read a message from WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)

Check warning on line 219 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L219

Added line #L219 was not covered by tests
break
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(context.Background(), "Agent disconnected: %v", err)
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
break
}
if mt != websocket.BinaryMessage {
s.logger.Errorf(context.Background(), "Received unexpected message type from WebSocket: %v", mt)
s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt)

Check warning on line 227 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L227

Added line #L227 was not covered by tests
continue
}

// Decode WebSocket message as a Protobuf message.
var request protobufs.AgentToServer
err = internal.DecodeWSMessage(bytes, &request)
err = internal.DecodeWSMessage(msgBytes, &request)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot decode message from WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)

Check warning on line 235 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L235

Added line #L235 was not covered by tests
continue
}

if connectionCallbacks != nil {
response := connectionCallbacks.OnMessage(agentConn, &request)
response := connectionCallbacks.OnMessage(msgContext, agentConn, &request)
if response.InstanceUid == "" {
response.InstanceUid = request.InstanceUid
}
err = agentConn.Send(context.Background(), response)
err = agentConn.Send(msgContext, response)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot send message to WebSocket: %v", err)
s.logger.Errorf(msgContext, "Cannot send message to WebSocket: %v", err)

Check warning on line 246 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L246

Added line #L246 was not covered by tests
}
}
}
Expand Down Expand Up @@ -286,18 +287,18 @@
}

func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter, connectionCallbacks serverTypes.ConnectionCallbacks) {
bytes, err := s.readReqBody(req)
bodyBytes, err := s.readReqBody(req)
if err != nil {
s.logger.Debugf(context.Background(), "Cannot read HTTP body: %v", err)
s.logger.Debugf(req.Context(), "Cannot read HTTP body: %v", err)

Check warning on line 292 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L292

Added line #L292 was not covered by tests
w.WriteHeader(http.StatusBadRequest)
return
}

// Decode the message as a Protobuf message.
var request protobufs.AgentToServer
err = proto.Unmarshal(bytes, &request)
err = proto.Unmarshal(bodyBytes, &request)
if err != nil {
s.logger.Debugf(context.Background(), "Cannot decode message from HTTP Body: %v", err)
s.logger.Debugf(req.Context(), "Cannot decode message from HTTP Body: %v", err)

Check warning on line 301 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L301

Added line #L301 was not covered by tests
w.WriteHeader(http.StatusBadRequest)
return
}
Expand All @@ -311,7 +312,7 @@
return
}

connectionCallbacks.OnConnected(agentConn)
connectionCallbacks.OnConnected(req.Context(), agentConn)

defer func() {
// Indicate via the callback that the OpAMP Connection is closed. From OpAMP
Expand All @@ -321,15 +322,15 @@
connectionCallbacks.OnConnectionClose(agentConn)
}()

response := connectionCallbacks.OnMessage(agentConn, &request)
response := connectionCallbacks.OnMessage(req.Context(), agentConn, &request)

// Set the InstanceUid if it is not set by the callback.
if response.InstanceUid == "" {
response.InstanceUid = request.InstanceUid
}

// Marshal the response.
bytes, err = proto.Marshal(response)
bodyBytes, err = proto.Marshal(response)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
Expand All @@ -338,17 +339,17 @@
// Send the response.
w.Header().Set(headerContentType, contentTypeProtobuf)
if req.Header.Get(headerAcceptEncoding) == contentEncodingGzip {
bytes, err = compressGzip(bytes)
bodyBytes, err = compressGzip(bodyBytes)
if err != nil {
s.logger.Errorf(context.Background(), "Cannot compress response: %v", err)
s.logger.Errorf(req.Context(), "Cannot compress response: %v", err)

Check warning on line 344 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L344

Added line #L344 was not covered by tests
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set(headerContentEncoding, contentEncodingGzip)
}
_, err = w.Write(bytes)
_, err = w.Write(bodyBytes)

if err != nil {
s.logger.Debugf(context.Background(), "Cannot send HTTP response: %v", err)
s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err)

Check warning on line 353 in server/serverimpl.go

View check run for this annotation

Codecov / codecov/patch

server/serverimpl.go#L353

Added line #L353 was not covered by tests
}
}
Loading
Loading