From df06df73217739dc80df2e020b4aae68e025b77d Mon Sep 17 00:00:00 2001 From: Filip Petkovski Date: Wed, 15 May 2024 19:41:04 +0200 Subject: [PATCH] Propagate request ID through gRPC context (#7356) * Propagate request ID through gRPC context The request ID only gets propagated through HTTP calls and is not available in gRPC servers. This commit adds intereceptors to grpc servers and clients to make sure request ID propagation happens. Signed-off-by: Filip Petkovski * Add license Signed-off-by: Filip Petkovski --------- Signed-off-by: Filip Petkovski --- pkg/extgrpc/client.go | 3 ++ pkg/server/grpc/grpc.go | 2 + pkg/server/grpc/request_id.go | 67 ++++++++++++++++++++++++ pkg/server/http/middleware/request_id.go | 6 +-- 4 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 pkg/server/grpc/request_id.go diff --git a/pkg/extgrpc/client.go b/pkg/extgrpc/client.go index 15ccd03eb6..87c8060cd6 100644 --- a/pkg/extgrpc/client.go +++ b/pkg/extgrpc/client.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + grpcserver "github.com/thanos-io/thanos/pkg/server/grpc" "github.com/thanos-io/thanos/pkg/tls" "github.com/thanos-io/thanos/pkg/tracing" ) @@ -58,12 +59,14 @@ func StoreClientGRPCOpts(logger log.Logger, reg prometheus.Registerer, tracer op grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)), grpc.WithUnaryInterceptor( grpc_middleware.ChainUnaryClient( + grpcserver.NewUnaryClientRequestIDInterceptor(), grpcMets.UnaryClientInterceptor(), tracing.UnaryClientInterceptor(tracer), ), ), grpc.WithStreamInterceptor( grpc_middleware.ChainStreamClient( + grpcserver.NewStreamClientRequestIDInterceptor(), grpcMets.StreamClientInterceptor(), tracing.StreamClientInterceptor(tracer), ), diff --git a/pkg/server/grpc/grpc.go b/pkg/server/grpc/grpc.go index ebdfc678b1..762465cc71 100644 --- a/pkg/server/grpc/grpc.go +++ b/pkg/server/grpc/grpc.go @@ -78,6 +78,7 @@ func New(logger log.Logger, reg prometheus.Registerer, tracer opentracing.Tracer grpc.MaxSendMsgSize(math.MaxInt32), grpc.MaxRecvMsgSize(math.MaxInt32), grpc_middleware.WithUnaryServerChain( + NewUnaryServerRequestIDInterceptor(), grpc_recovery.UnaryServerInterceptor(grpc_recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), met.UnaryServerInterceptor(), tags.UnaryServerInterceptor(tagsOpts...), @@ -85,6 +86,7 @@ func New(logger log.Logger, reg prometheus.Registerer, tracer opentracing.Tracer grpc_logging.UnaryServerInterceptor(kit.InterceptorLogger(logger), logOpts...), ), grpc_middleware.WithStreamServerChain( + NewStreamServerRequestIDInterceptor(), grpc_recovery.StreamServerInterceptor(grpc_recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), met.StreamServerInterceptor(), tags.StreamServerInterceptor(tagsOpts...), diff --git a/pkg/server/grpc/request_id.go b/pkg/server/grpc/request_id.go new file mode 100644 index 0000000000..cca8175516 --- /dev/null +++ b/pkg/server/grpc/request_id.go @@ -0,0 +1,67 @@ +// Copyright (c) The Thanos Authors. +// Licensed under the Apache License 2.0. + +package grpc + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/thanos-io/thanos/pkg/server/http/middleware" +) + +const requestIDKey = "request-id" + +func NewUnaryClientRequestIDInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + reqID, ok := middleware.RequestIDFromContext(ctx) + if ok { + ctx = metadata.AppendToOutgoingContext(ctx, requestIDKey, reqID) + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func NewUnaryServerRequestIDInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if vals := metadata.ValueFromIncomingContext(ctx, requestIDKey); len(vals) == 1 { + ctx = middleware.NewContextWithRequestID(ctx, vals[0]) + } + return handler(ctx, req) + } +} + +func NewStreamClientRequestIDInterceptor() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + reqID, ok := middleware.RequestIDFromContext(ctx) + if ok { + ctx = metadata.AppendToOutgoingContext(ctx, requestIDKey, reqID) + } + return streamer(ctx, desc, cc, method, opts...) + } +} + +func NewStreamServerRequestIDInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if vals := metadata.ValueFromIncomingContext(ss.Context(), requestIDKey); len(vals) == 1 { + ctx := middleware.NewContextWithRequestID(ss.Context(), vals[0]) + return handler(srv, newStreamWithContext(ctx, ss)) + } + return handler(srv, ss) + } +} + +type streamWithContext struct { + grpc.ServerStream + ctx context.Context +} + +func newStreamWithContext(ctx context.Context, serverStream grpc.ServerStream) *streamWithContext { + return &streamWithContext{ServerStream: serverStream, ctx: ctx} +} + +func (s streamWithContext) Context() context.Context { + return s.ctx +} diff --git a/pkg/server/http/middleware/request_id.go b/pkg/server/http/middleware/request_id.go index fddeabf932..3a3da07a83 100644 --- a/pkg/server/http/middleware/request_id.go +++ b/pkg/server/http/middleware/request_id.go @@ -16,8 +16,8 @@ type ctxKey int const reqIDKey = ctxKey(0) -// newContextWithRequestID creates a context with a request id. -func newContextWithRequestID(ctx context.Context, rid string) context.Context { +// NewContextWithRequestID creates a context with a request id. +func NewContextWithRequestID(ctx context.Context, rid string) context.Context { return context.WithValue(ctx, reqIDKey, rid) } @@ -36,7 +36,7 @@ func RequestID(h http.Handler) http.HandlerFunc { reqID = ulid.MustNew(ulid.Timestamp(time.Now()), entropy).String() r.Header.Set("X-Request-ID", reqID) } - ctx := newContextWithRequestID(r.Context(), reqID) + ctx := NewContextWithRequestID(r.Context(), reqID) h.ServeHTTP(w, r.WithContext(ctx)) } }