diff --git a/pkg/extgrpc/client.go b/pkg/extgrpc/client.go index 15ccd03eb6..87c8060cd6 100644 --- a/pkg/extgrpc/client.go +++ b/pkg/extgrpc/client.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + grpcserver "github.com/thanos-io/thanos/pkg/server/grpc" "github.com/thanos-io/thanos/pkg/tls" "github.com/thanos-io/thanos/pkg/tracing" ) @@ -58,12 +59,14 @@ func StoreClientGRPCOpts(logger log.Logger, reg prometheus.Registerer, tracer op grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)), grpc.WithUnaryInterceptor( grpc_middleware.ChainUnaryClient( + grpcserver.NewUnaryClientRequestIDInterceptor(), grpcMets.UnaryClientInterceptor(), tracing.UnaryClientInterceptor(tracer), ), ), grpc.WithStreamInterceptor( grpc_middleware.ChainStreamClient( + grpcserver.NewStreamClientRequestIDInterceptor(), grpcMets.StreamClientInterceptor(), tracing.StreamClientInterceptor(tracer), ), diff --git a/pkg/server/grpc/grpc.go b/pkg/server/grpc/grpc.go index ebdfc678b1..762465cc71 100644 --- a/pkg/server/grpc/grpc.go +++ b/pkg/server/grpc/grpc.go @@ -78,6 +78,7 @@ func New(logger log.Logger, reg prometheus.Registerer, tracer opentracing.Tracer grpc.MaxSendMsgSize(math.MaxInt32), grpc.MaxRecvMsgSize(math.MaxInt32), grpc_middleware.WithUnaryServerChain( + NewUnaryServerRequestIDInterceptor(), grpc_recovery.UnaryServerInterceptor(grpc_recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), met.UnaryServerInterceptor(), tags.UnaryServerInterceptor(tagsOpts...), @@ -85,6 +86,7 @@ func New(logger log.Logger, reg prometheus.Registerer, tracer opentracing.Tracer grpc_logging.UnaryServerInterceptor(kit.InterceptorLogger(logger), logOpts...), ), grpc_middleware.WithStreamServerChain( + NewStreamServerRequestIDInterceptor(), grpc_recovery.StreamServerInterceptor(grpc_recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), met.StreamServerInterceptor(), tags.StreamServerInterceptor(tagsOpts...), diff --git a/pkg/server/grpc/request_id.go b/pkg/server/grpc/request_id.go new file mode 100644 index 0000000000..cca8175516 --- /dev/null +++ b/pkg/server/grpc/request_id.go @@ -0,0 +1,67 @@ +// Copyright (c) The Thanos Authors. +// Licensed under the Apache License 2.0. + +package grpc + +import ( + "context" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/thanos-io/thanos/pkg/server/http/middleware" +) + +const requestIDKey = "request-id" + +func NewUnaryClientRequestIDInterceptor() grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + reqID, ok := middleware.RequestIDFromContext(ctx) + if ok { + ctx = metadata.AppendToOutgoingContext(ctx, requestIDKey, reqID) + } + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func NewUnaryServerRequestIDInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if vals := metadata.ValueFromIncomingContext(ctx, requestIDKey); len(vals) == 1 { + ctx = middleware.NewContextWithRequestID(ctx, vals[0]) + } + return handler(ctx, req) + } +} + +func NewStreamClientRequestIDInterceptor() grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + reqID, ok := middleware.RequestIDFromContext(ctx) + if ok { + ctx = metadata.AppendToOutgoingContext(ctx, requestIDKey, reqID) + } + return streamer(ctx, desc, cc, method, opts...) + } +} + +func NewStreamServerRequestIDInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if vals := metadata.ValueFromIncomingContext(ss.Context(), requestIDKey); len(vals) == 1 { + ctx := middleware.NewContextWithRequestID(ss.Context(), vals[0]) + return handler(srv, newStreamWithContext(ctx, ss)) + } + return handler(srv, ss) + } +} + +type streamWithContext struct { + grpc.ServerStream + ctx context.Context +} + +func newStreamWithContext(ctx context.Context, serverStream grpc.ServerStream) *streamWithContext { + return &streamWithContext{ServerStream: serverStream, ctx: ctx} +} + +func (s streamWithContext) Context() context.Context { + return s.ctx +} diff --git a/pkg/server/http/middleware/request_id.go b/pkg/server/http/middleware/request_id.go index fddeabf932..3a3da07a83 100644 --- a/pkg/server/http/middleware/request_id.go +++ b/pkg/server/http/middleware/request_id.go @@ -16,8 +16,8 @@ type ctxKey int const reqIDKey = ctxKey(0) -// newContextWithRequestID creates a context with a request id. -func newContextWithRequestID(ctx context.Context, rid string) context.Context { +// NewContextWithRequestID creates a context with a request id. +func NewContextWithRequestID(ctx context.Context, rid string) context.Context { return context.WithValue(ctx, reqIDKey, rid) } @@ -36,7 +36,7 @@ func RequestID(h http.Handler) http.HandlerFunc { reqID = ulid.MustNew(ulid.Timestamp(time.Now()), entropy).String() r.Header.Set("X-Request-ID", reqID) } - ctx := newContextWithRequestID(r.Context(), reqID) + ctx := NewContextWithRequestID(r.Context(), reqID) h.ServeHTTP(w, r.WithContext(ctx)) } }