Skip to content

Commit

Permalink
ensure that every middleware stores data in an attached context separ…
Browse files Browse the repository at this point in the history
…ate from the fasthttp ctx
  • Loading branch information
felixgehrmann committed Oct 25, 2024
1 parent 585e171 commit 6ba1aae
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 52 deletions.
8 changes: 4 additions & 4 deletions examples/http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ func main() {
// setup main server
setupMainServer(tracer)

resp, err := callMainServer(context.Background(), tracer.CorrelationIDKey())
resp, err := callMainServer(context.Background(), string(tracer.CorrelationIDKey()))
if err != nil {
panic(err)
}

defer resp.Body.Close()

if resp.Header.Get(tracer.CorrelationIDKey()) != correlationValue {
if resp.Header.Get(string(tracer.CorrelationIDKey())) != correlationValue {
panic("X-Correlation-ID header is not set in response")
}

if resp.Header.Get(tracer.RequestIDKey()) == "" {
if resp.Header.Get(string(tracer.RequestIDKey())) == "" {
panic("X-Request-ID header is not set in response")
}
}
Expand All @@ -59,7 +59,7 @@ func setupMainServer(tracer *tracygo.TracyGo) {
restyClient.OnBeforeRequest(restytracygo.CheckTracingIDs(tracer))

mux.HandleFunc("/", func(_ http.ResponseWriter, r *http.Request) {
if r.Header.Get(tracer.CorrelationIDKey()) != correlationValue {
if r.Header.Get(string(tracer.CorrelationIDKey())) != correlationValue {
panic("X-Correlation-ID header is not set in context")
}

Expand Down
33 changes: 23 additions & 10 deletions middleware/atreugo/middlewares.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package atreugo

import (
"context"

"github.com/Clarilab/tracygo/v2"
"github.com/google/uuid"
"github.com/savsgio/atreugo/v11"
)

// CheckTracingIDs is a middleware for atreugo that checks if a correlationID and requestID have been set
// and creates a new one if they have not been set yet.
func CheckTracingIDs(t *tracygo.TracyGo) func(ctx *atreugo.RequestCtx) error {
return func(ctx *atreugo.RequestCtx) error {
correlationID := string(ctx.Request.Header.Peek(t.CorrelationIDKey()))
requestID := string(ctx.Request.Header.Peek(t.RequestIDKey()))
func CheckTracingIDs(tracy *tracygo.TracyGo) func(request *atreugo.RequestCtx) error {
return func(request *atreugo.RequestCtx) error {
correlationID := string(request.Request.Header.Peek(string(tracy.CorrelationIDKey())))
requestID := string(request.Request.Header.Peek(string(tracy.RequestIDKey())))

if correlationID == "" {
correlationID = uuid.NewString()
Expand All @@ -21,13 +23,24 @@ func CheckTracingIDs(t *tracygo.TracyGo) func(ctx *atreugo.RequestCtx) error {
requestID = uuid.NewString()
}

// set userValues for resty middleware
ctx.SetUserValue(t.CorrelationIDKey(), correlationID)
ctx.SetUserValue(t.RequestIDKey(), requestID)
// Set values to attachedContext. While the request fulfills the context interface, it is not recommended for performance and opens up some pitfalls.
aCtx := request.AttachedContext()
if aCtx == nil {
aCtx = context.Background()
}

aCtx = context.WithValue(aCtx, tracy.CorrelationIDKey(), correlationID)
aCtx = context.WithValue(aCtx, tracy.RequestIDKey(), requestID)

request.AttachContext(aCtx)

// set userValues for resty middleware (legacy)
request.SetUserValue(tracy.CorrelationIDKey(), correlationID)
request.SetUserValue(tracy.RequestIDKey(), requestID)

ctx.Response.Header.Set(t.CorrelationIDKey(), correlationID)
ctx.Response.Header.Set(t.RequestIDKey(), requestID)
request.Response.Header.Set(string(tracy.CorrelationIDKey()), correlationID)
request.Response.Header.Set(string(tracy.RequestIDKey()), requestID)

return ctx.Next()
return request.Next()
}
}
34 changes: 27 additions & 7 deletions middleware/echo/middlewares.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package echo

import (
"context"

"github.com/Clarilab/tracygo/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
)

const tracyContextKey = "userCtx"

// CheckTracingIDs is a middleware for fiber that checks if a correlationID and requestID have been set
// and creates a new one if they have not been set yet.
func CheckTracingIDs(tracer *tracygo.TracyGo) echo.MiddlewareFunc {
func CheckTracingIDs(t *tracygo.TracyGo) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
correlationID := c.Request().Header.Get(tracer.CorrelationIDKey())
requestID := c.Request().Header.Get(tracer.RequestIDKey())
correlationID := c.Request().Header.Get(string(t.CorrelationIDKey()))
requestID := c.Request().Header.Get(string(t.RequestIDKey()))

if correlationID == "" {
correlationID = uuid.NewString()
Expand All @@ -22,13 +26,29 @@ func CheckTracingIDs(tracer *tracygo.TracyGo) echo.MiddlewareFunc {
requestID = uuid.NewString()
}

c.Set(tracer.CorrelationIDKey(), correlationID)
c.Set(tracer.RequestIDKey(), requestID)
tCtx := context.WithValue(context.Background(), t.CorrelationIDKey(), correlationID)
tCtx = context.WithValue(tCtx, t.RequestIDKey(), requestID)
c.Set(tracyContextKey, tCtx)

c.Response().Header().Set(tracer.CorrelationIDKey(), correlationID)
c.Response().Header().Set(tracer.RequestIDKey(), requestID)
c.Set(string(t.CorrelationIDKey()), correlationID)
c.Set(string(t.RequestIDKey()), requestID)

c.Response().Header().Set(string(t.CorrelationIDKey()), correlationID)
c.Response().Header().Set(string(t.RequestIDKey()), requestID)

return next(c)
}
}
}

// GetUserContext is a helper function to extract a context set by tracygo from a echo.Context
// This mirrors context attachment functionality of other libs
func GetUserContext(c echo.Context) context.Context {
if val := c.Get(tracyContextKey); val != nil {
if ctx, ok := val.(context.Context); ok {
return ctx
}
}

return nil
}
15 changes: 11 additions & 4 deletions middleware/fiber/middlewares.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fiber

import (
"context"

"github.com/Clarilab/tracygo/v2"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
Expand All @@ -10,8 +12,8 @@ import (
// and creates a new one if they have not been set yet.
func CheckTracingIDs(t *tracygo.TracyGo) func(ctx *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
correlationID := string(ctx.Request().Header.Peek(t.CorrelationIDKey()))
requestID := string(ctx.Request().Header.Peek(t.RequestIDKey()))
correlationID := string(ctx.Request().Header.Peek(string(t.CorrelationIDKey())))
requestID := string(ctx.Request().Header.Peek(string(t.RequestIDKey())))

if correlationID == "" {
correlationID = uuid.NewString()
Expand All @@ -21,12 +23,17 @@ func CheckTracingIDs(t *tracygo.TracyGo) func(ctx *fiber.Ctx) error {
requestID = uuid.NewString()
}

// Set values to UserContext
userCtx := context.WithValue(ctx.UserContext(), t.CorrelationIDKey(), correlationID)
userCtx = context.WithValue(userCtx, t.RequestIDKey(), requestID)
ctx.SetUserContext(userCtx)

// set userValues for resty middleware
ctx.Context().SetUserValue(t.CorrelationIDKey(), correlationID)
ctx.Context().SetUserValue(t.RequestIDKey(), requestID)

ctx.Response().Header.Set(t.CorrelationIDKey(), correlationID)
ctx.Response().Header.Set(t.RequestIDKey(), requestID)
ctx.Response().Header.Set(string(t.CorrelationIDKey()), correlationID)
ctx.Response().Header.Set(string(t.RequestIDKey()), requestID)

return ctx.Next()
}
Expand Down
16 changes: 8 additions & 8 deletions middleware/grpc/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,24 @@ func CheckTracingIDs(t *tracygo.TracyGo) func(ctx context.Context, req any, _ *g
var correlationID string
var requestID string

if values := md[strings.ToLower(t.CorrelationIDKey())]; len(values) == 1 {
if values := md[strings.ToLower(string(t.CorrelationIDKey()))]; len(values) == 1 {
correlationID = values[0]
}

if values := md[strings.ToLower(t.RequestIDKey())]; len(values) == 1 {
if values := md[strings.ToLower(string(t.RequestIDKey()))]; len(values) == 1 {
requestID = values[0]
}

if correlationID == "" {
correlationID = uuid.NewString()

md.Append(t.CorrelationIDKey(), correlationID)
md.Append(string(t.CorrelationIDKey()), correlationID)
}

if requestID == "" {
requestID = uuid.NewString()

md.Append(t.RequestIDKey(), requestID)
md.Append(string(t.RequestIDKey()), requestID)
}

if err := grpc.SetTrailer(ctx, md); err != nil {
Expand All @@ -48,12 +48,12 @@ func CheckTracingIDs(t *tracygo.TracyGo) func(ctx context.Context, req any, _ *g

ctx = metadata.AppendToOutgoingContext(
ctx,
t.CorrelationIDKey(), correlationID,
t.RequestIDKey(), requestID,
string(t.CorrelationIDKey()), correlationID,
string(t.RequestIDKey()), requestID,
)

ctx = context.WithValue(ctx, t.CorrelationIDKey(), correlationID) //nolint:staticcheck // intended use
ctx = context.WithValue(ctx, t.RequestIDKey(), requestID) //nolint:staticcheck // intended use
ctx = context.WithValue(ctx, t.CorrelationIDKey(), correlationID)
ctx = context.WithValue(ctx, t.RequestIDKey(), requestID)

return handler(ctx, req)
}
Expand Down
12 changes: 6 additions & 6 deletions middleware/http/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
func CheckTracingIDs(t *tracygo.TracyGo) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
correlationID := r.Header.Get(t.CorrelationIDKey())
requestID := r.Header.Get(t.RequestIDKey())
correlationID := r.Header.Get(string(t.CorrelationIDKey()))
requestID := r.Header.Get(string(t.RequestIDKey()))

if correlationID == "" {
correlationID = uuid.NewString()
Expand All @@ -24,11 +24,11 @@ func CheckTracingIDs(t *tracygo.TracyGo) func(next http.Handler) http.Handler {
requestID = uuid.NewString()
}

ctx := context.WithValue(r.Context(), t.RequestIDKey(), requestID) //nolint:staticcheck // intended use
ctx = context.WithValue(ctx, t.CorrelationIDKey(), correlationID) //nolint:staticcheck // intended use
ctx := context.WithValue(r.Context(), t.RequestIDKey(), requestID)
ctx = context.WithValue(ctx, t.CorrelationIDKey(), correlationID)

w.Header().Set(t.RequestIDKey(), requestID)
w.Header().Set(t.CorrelationIDKey(), correlationID)
w.Header().Set(string(t.RequestIDKey()), requestID)
w.Header().Set(string(t.CorrelationIDKey()), correlationID)

next.ServeHTTP(w, r.WithContext(ctx))
}
Expand Down
6 changes: 3 additions & 3 deletions middleware/resty/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ import (
// If they are set, they should be put into the request headers.
func CheckTracingIDs(t *tracygo.TracyGo) func(client *resty.Client, request *resty.Request) error {
return func(_ *resty.Client, request *resty.Request) error {
request.Header.Set(t.RequestIDKey(), uuid.NewString())
request.Header.Set(string(t.RequestIDKey()), uuid.NewString())

correlationID, ok := request.Context().Value(t.CorrelationIDKey()).(string)
if ok && correlationID != "" {
request.Header.Set(t.CorrelationIDKey(), correlationID)
request.Header.Set(string(t.CorrelationIDKey()), correlationID)

return nil
}

request.Header.Set(t.CorrelationIDKey(), uuid.NewString())
request.Header.Set(string(t.CorrelationIDKey()), uuid.NewString())

return nil
}
Expand Down
4 changes: 2 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ package tracygo
type Option func(tracy *TracyGo)

// WithCorrelationID returns a function that sets the key for the correlation id header.
func WithCorrelationID(id string) Option {
func WithCorrelationID(id ContextKey) Option {
return func(tracy *TracyGo) {
tracy.correlationID = id
}
}

// WithRequestID returns a function that sets the key for the request id header.
func WithRequestID(id string) Option {
func WithRequestID(id ContextKey) Option {
return func(tracy *TracyGo) {
tracy.requestID = id
}
Expand Down
18 changes: 10 additions & 8 deletions tracy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,24 @@ import (
"github.com/google/uuid"
)

type ContextKey string

const (
correlationID = "X-Correlation-ID"
requestID = "X-Request-ID"
keyCorrelationID ContextKey = "X-Correlation-ID"
keyRequestID ContextKey = "X-Request-ID"
)

// TracyGo is a struct for the tracy object.
type TracyGo struct {
correlationID string
requestID string
correlationID ContextKey
requestID ContextKey
}

// New creates a new TracyGo object and uses the options on it.
func New(options ...Option) *TracyGo {
tracy := &TracyGo{
correlationID: correlationID,
requestID: requestID,
correlationID: keyCorrelationID,
requestID: keyRequestID,
}

for _, option := range options {
Expand All @@ -34,12 +36,12 @@ func New(options ...Option) *TracyGo {
}

// CorrelationIDKey returns the underlying correlation id key.
func (t *TracyGo) CorrelationIDKey() string {
func (t *TracyGo) CorrelationIDKey() ContextKey {
return t.correlationID
}

// RequestIDKey returns the underlying request id key.
func (t *TracyGo) RequestIDKey() string {
func (t *TracyGo) RequestIDKey() ContextKey {
return t.requestID
}

Expand Down

0 comments on commit 6ba1aae

Please sign in to comment.