Skip to content

Commit

Permalink
gRPC: fix request context propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed Jun 16, 2021
1 parent cceeab7 commit ddc3e3c
Show file tree
Hide file tree
Showing 32 changed files with 1,100 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.quarkus.grpc.runtime.MutinyStub;
import io.quarkus.grpc.runtime.supports.Channels;
import io.quarkus.grpc.runtime.supports.GrpcClientConfigProvider;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;

Expand All @@ -28,7 +27,6 @@ public class GrpcDotNames {
public static final DotName CHANNEL = DotName.createSimple(Channel.class.getName());
public static final DotName GRPC_CLIENT = DotName.createSimple(GrpcClient.class.getName());
public static final DotName GRPC_SERVICE = DotName.createSimple(GrpcService.class.getName());
public static final DotName GRPC_ENABLE_REQUEST_CONTEXT = DotName.createSimple(GrpcEnableRequestContext.class.getName());

public static final DotName BLOCKING = DotName.createSimple(Blocking.class.getName());
public static final DotName NON_BLOCKING = DotName.createSimple(NonBlocking.class.getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import io.quarkus.arc.processor.AnnotationsTransformer;
import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.Transformation;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.IsNormal;
import io.quarkus.deployment.annotations.BuildProducer;
Expand All @@ -60,8 +59,6 @@
import io.quarkus.grpc.runtime.config.GrpcServerBuildTimeConfig;
import io.quarkus.grpc.runtime.health.GrpcHealthEndpoint;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextCdiInterceptor;
import io.quarkus.kubernetes.spi.KubernetesPortBuildItem;
import io.quarkus.netty.deployment.MinNettyAllocatorMaxOrderBuildItem;
import io.quarkus.runtime.LaunchMode;
Expand Down Expand Up @@ -240,14 +237,11 @@ public boolean appliesTo(Kind kind) {
@Override
public void transform(TransformationContext context) {
ClassInfo clazz = context.getTarget().asClass();
if (userDefinedServices.contains(clazz.name())) {
// Add @GrpcEnableRequestContext to activate the request context during each call
Transformation transform = context.transform().add(GrpcDotNames.GRPC_ENABLE_REQUEST_CONTEXT);
if (!customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
transform.add(BuiltinScope.SINGLETON.getName());
}
transform.done();
if (userDefinedServices.contains(clazz.name()) && !customScopes.isScopeDeclaredOn(clazz)) {
// Add @Singleton to make it a bean
context.transform()
.add(BuiltinScope.SINGLETON.getName())
.done();
}
}
});
Expand Down Expand Up @@ -303,8 +297,6 @@ void registerBeans(BuildProducer<AdditionalBeanBuildItem> beans,
List<BindableServiceBuildItem> bindables, BuildProducer<FeatureBuildItem> features) {
// @GrpcService is a CDI qualifier
beans.produce(new AdditionalBeanBuildItem(GrpcService.class));
beans.produce(new AdditionalBeanBuildItem(GrpcRequestContextCdiInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(GrpcEnableRequestContext.class));

if (!bindables.isEmpty() || LaunchMode.current() == LaunchMode.DEVELOPMENT) {
beans.produce(AdditionalBeanBuildItem.unremovableOf(GrpcContainer.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
import io.quarkus.grpc.runtime.devmode.GrpcServerReloader;
import io.quarkus.grpc.runtime.health.GrpcHealthStorage;
import io.quarkus.grpc.runtime.reflection.ReflectionService;
import io.quarkus.grpc.runtime.supports.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.CompressionInterceptor;
import io.quarkus.grpc.runtime.supports.blocking.BlockingServerInterceptor;
import io.quarkus.grpc.runtime.supports.context.GrpcRequestContextGrpcInterceptor;
import io.quarkus.runtime.LaunchMode;
import io.quarkus.runtime.RuntimeValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import grpc.health.v1.HealthOuterClass.HealthCheckResponse.ServingStatus;
import grpc.health.v1.MutinyHealthGrpc;
import io.quarkus.grpc.GrpcService;
import io.quarkus.grpc.runtime.supports.context.GrpcEnableRequestContext;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor;

// Note that we need to add the scope and interceptor binding explicitly because this class is not part of the index
// Note that we need to add the scope explicitly because this class is not part of the index
@Singleton
@GrpcEnableRequestContext
@GrpcService
public class GrpcHealthEndpoint extends MutinyHealthGrpc.HealthImplBase {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public static Channel createChannel(String name) throws SSLException {
GrpcClientConfiguration config = configProvider.getConfiguration(name);

if (config == null && LaunchMode.current() == LaunchMode.TEST) {
LOGGER.infof(
"gRPC client %s created without configuration. We are assuming that it's created to test your gRPC services.",
name);
config = testConfig(configProvider.getServerConfiguration());
}

Expand Down Expand Up @@ -164,7 +167,6 @@ public static Channel createChannel(String name) throws SSLException {
}

private static GrpcClientConfiguration testConfig(GrpcServerConfiguration serverConfiguration) {
LOGGER.info("gRPC client created without configuration. We are assuming that it's created to test your gRPC services.");
GrpcClientConfiguration config = new GrpcClientConfiguration();
config.port = serverConfiguration.testPort;
config.host = serverConfiguration.host;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.function.Consumer;

import io.grpc.Context;
import io.grpc.ServerCall;
import io.quarkus.arc.InjectableContext;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;

class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;
private final InjectableContext.ContextState state;
private final ManagedContext requestContext;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate, InjectableContext.ContextState state,
ManagedContext requestContext) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
this.state = state;
this.requestContext = requestContext;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
requestContext.activate(state);
try {
consumer.accept(delegate);
} finally {
requestContext.deactivate();
}
event.complete();
} finally {
grpcContext.detach(previous);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package io.quarkus.grpc.runtime.supports;
package io.quarkus.grpc.runtime.supports.blocking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -13,12 +12,15 @@
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.quarkus.arc.Arc;
import io.quarkus.arc.InjectableContext.ContextState;
import io.quarkus.arc.ManagedContext;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;

/**
* gRPC Server interceptor offloading the execution of the gRPC method on a wroker thread if the method is annotated
* gRPC Server interceptor offloading the execution of the gRPC method on a worker thread if the method is annotated
* with {@link io.smallrye.common.annotation.Blocking}.
*
* For non-annotated methods, the interceptor acts as a pass-through.
Expand Down Expand Up @@ -63,12 +65,19 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re

if (isBlocking) {
ReplayListener<ReqT> replay = new ReplayListener<>();

final ManagedContext requestContext = getRequestContext();
ContextState state = requestContext.getState();
vertx.executeBlocking(new Handler<Promise<Object>>() {
@Override
public void handle(Promise<Object> f) {
ServerCall.Listener<ReqT> listener = next.startCall(call, headers);
replay.setDelegate(listener);
ServerCall.Listener<ReqT> listener;
try {
requestContext.activate(state);
listener = next.startCall(call, headers);
} finally {
requestContext.deactivate();
}
replay.setDelegate(listener, requestContext);
f.complete(null);
}
}, null);
Expand All @@ -87,28 +96,40 @@ public void handle(Promise<Object> f) {
*/
private class ReplayListener<ReqT> extends ServerCall.Listener<ReqT> {
private ServerCall.Listener<ReqT> delegate;
private final List<Consumer<ServerCall.Listener<ReqT>>> incomingEvents = new LinkedList<>();
private final List<ConsumerWithState> incomingEvents = new ArrayList<>();

synchronized void setDelegate(ServerCall.Listener<ReqT> delegate) {
synchronized void setDelegate(ServerCall.Listener<ReqT> delegate,
ManagedContext requestContext) {
this.delegate = delegate;
for (Consumer<ServerCall.Listener<ReqT>> event : incomingEvents) {
event.accept(delegate);
for (ConsumerWithState event : incomingEvents) {
requestContext.activate(event.contextState);
try {
event.listenerConsumer.accept(delegate);
} finally {
requestContext.deactivate();
}
}
incomingEvents.clear();
}

private synchronized void executeOnContextOrEnqueue(Consumer<ServerCall.Listener<ReqT>> consumer) {
ContextState state = getRequestContext().getState();
if (this.delegate != null) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate);
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler<ReqT>(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
executeBlockingWithRequestContext(consumer, state);
} else {
incomingEvents.add(consumer);
incomingEvents.add(new ConsumerWithState(consumer, state));
}
}

private void executeBlockingWithRequestContext(Consumer<ServerCall.Listener<ReqT>> consumer, ContextState state) {
final Context grpcContext = Context.current();
Handler<Promise<Object>> blockingHandler = new BlockingExecutionHandler<>(consumer, grpcContext, delegate,
state, getRequestContext());
if (devMode) {
blockingHandler = new DevModeBlockingExecutionHandler(Thread.currentThread().getContextClassLoader(),
blockingHandler);
}
vertx.executeBlocking(blockingHandler, true, null);
}

@Override
Expand Down Expand Up @@ -140,52 +161,20 @@ public void onComplete() {
public void onReady() {
executeOnContextOrEnqueue(ServerCall.Listener::onReady);
}
}

private static class DevModeBlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;
private class ConsumerWithState {
Consumer<ServerCall.Listener<ReqT>> listenerConsumer;
ContextState contextState;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
public ConsumerWithState(Consumer<ServerCall.Listener<ReqT>> listenerConsumer, ContextState contextState) {
this.listenerConsumer = listenerConsumer;
this.contextState = contextState;
}
}
}

private static class BlockingExecutionHandler<ReqT> implements Handler<Promise<Object>> {
private final ServerCall.Listener<ReqT> delegate;
private final Context grpcContext;
private final Consumer<ServerCall.Listener<ReqT>> consumer;

public BlockingExecutionHandler(Consumer<ServerCall.Listener<ReqT>> consumer, Context grpcContext,
ServerCall.Listener<ReqT> delegate) {
this.consumer = consumer;
this.grpcContext = grpcContext;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
final Context previous = Context.current();
grpcContext.attach();
try {
consumer.accept(delegate);
event.complete();
} finally {
grpcContext.detach(previous);
}
}
// protected for tests
protected ManagedContext getRequestContext() {
return Arc.container().requestContext();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.quarkus.grpc.runtime.supports.blocking;

import io.vertx.core.Handler;
import io.vertx.core.Promise;

class DevModeBlockingExecutionHandler implements Handler<Promise<Object>> {

final ClassLoader tccl;
final Handler<Promise<Object>> delegate;

public DevModeBlockingExecutionHandler(ClassLoader tccl, Handler<Promise<Object>> delegate) {
this.tccl = tccl;
this.delegate = delegate;
}

@Override
public void handle(Promise<Object> event) {
ClassLoader originalTccl = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(tccl);
try {
delegate.handle(event);
} finally {
Thread.currentThread().setContextClassLoader(originalTccl);
}
}
}

This file was deleted.

Loading

0 comments on commit ddc3e3c

Please sign in to comment.