Skip to content

Commit

Permalink
gRPC Dev UI - support streaming calls
Browse files Browse the repository at this point in the history
  • Loading branch information
michalszynkiewicz committed May 25, 2021
1 parent b54f1ca commit 3666102
Show file tree
Hide file tree
Showing 13 changed files with 680 additions and 220 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.quarkus.dev.testing;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

public class GrpcWebSocketProxy {

private static final AtomicInteger connectionIdSeq = new AtomicInteger();

private static volatile WebSocketListener webSocketListener;

private static final Map<Integer, Consumer<Runnable>> webSocketConnections = new ConcurrentHashMap<>();

public static Integer addWebSocket(Consumer<String> responseConsumer,
Consumer<Runnable> closeHandler) {
if (webSocketListener != null) {
int id = connectionIdSeq.getAndIncrement();
webSocketListener.onOpen(id, responseConsumer);

webSocketConnections.put(id, closeHandler);
return id;
}
return null;
}

public static void closeAll() {
CountDownLatch latch = new CountDownLatch(webSocketConnections.size());
for (Map.Entry<Integer, Consumer<Runnable>> connection : webSocketConnections.entrySet()) {
connection.getValue().accept(latch::countDown);
webSocketListener.onClose(connection.getKey());
}
try {
if (!latch.await(5, TimeUnit.SECONDS)) {
System.err.println("Failed to close all the websockets in 5 seconds");
}
} catch (InterruptedException e) {
System.err.println("Interrupted while waiting for websockets to be closed");
}
}

public static void closeWebSocket(int id) {
webSocketListener.onClose(id);
}

public static void setWebSocketListener(WebSocketListener listener) {
webSocketListener = listener;
}

public static void addMessage(Integer socketId, String message) {
webSocketListener.newMessage(socketId, message);
}

public interface WebSocketListener {
void onOpen(int id, Consumer<String> responseConsumer);

void newMessage(int id, String content);

void onClose(int id);
}
}
6 changes: 5 additions & 1 deletion extensions/grpc/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@
<artifactId>quarkus-grpc-codegen</artifactId>
</dependency>

<!-- Test dependencies -->
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx-http-deployment</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-deployment</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.quarkus.grpc.deployment.devmode;

import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand All @@ -17,42 +15,36 @@
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.logging.Logger;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;

import io.grpc.Channel;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.PrototypeMarshaller;
import io.grpc.ServiceDescriptor;
import io.grpc.netty.NettyChannelBuilder;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.runtime.BeanLookupSupplier;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.ServiceStartBuildItem;
import io.quarkus.dev.console.DevConsoleManager;
import io.quarkus.devconsole.spi.DevConsoleRouteBuildItem;
import io.quarkus.dev.testing.GrpcWebSocketProxy;
import io.quarkus.devconsole.spi.DevConsoleRuntimeTemplateInfoBuildItem;
import io.quarkus.grpc.deployment.GrpcDotNames;
import io.quarkus.grpc.protoc.plugin.MutinyGrpcGenerator;
import io.quarkus.grpc.runtime.devmode.GrpcDevConsoleRecorder;
import io.quarkus.grpc.runtime.devmode.GrpcServices;
import io.vertx.core.Handler;
import io.vertx.ext.web.RoutingContext;
import io.quarkus.vertx.http.deployment.NonApplicationRootPathBuildItem;
import io.quarkus.vertx.http.deployment.RouteBuildItem;

public class GrpcDevConsoleProcessor {

private static final Logger LOG = Logger.getLogger(GrpcDevConsoleProcessor.class);

@BuildStep(onlyIf = IsDevelopment.class)
public void devConsoleInfo(BuildProducer<AdditionalBeanBuildItem> beans,
BuildProducer<DevConsoleRuntimeTemplateInfoBuildItem> infos) {
Expand All @@ -70,7 +62,8 @@ public void collectMessagePrototypes(CombinedIndexBuildItem index,
IllegalArgumentException, InvocationTargetException, InvalidProtocolBufferException {
Map<String, String> messagePrototypes = new HashMap<>();

for (Class<?> grpcServiceClass : getGrpcServices(index.getIndex())) {
Collection<Class<?>> grpcServices = getGrpcServices(index.getIndex());
for (Class<?> grpcServiceClass : grpcServices) {

Method method = grpcServiceClass.getDeclaredMethod("getServiceDescriptor");
ServiceDescriptor serviceDescriptor = (ServiceDescriptor) method.invoke(null);
Expand All @@ -87,144 +80,17 @@ public void collectMessagePrototypes(CombinedIndexBuildItem index,
}
DevConsoleManager.setGlobal("io.quarkus.grpc.messagePrototypes", messagePrototypes);

GrpcWebSocketProxy.setWebSocketListener(
new GrpcDevConsoleWebSocketListener(grpcServices, Thread.currentThread().getContextClassLoader()));
}

@Record(value = RUNTIME_INIT)
@BuildStep
DevConsoleRouteBuildItem registerTestEndpoint(GrpcDevConsoleRecorder recorder, CombinedIndexBuildItem index)
throws ClassNotFoundException, NoSuchMethodException,
SecurityException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
// Store the server config so that it can be used in the test endpoint handler
@BuildStep(onlyIf = IsDevelopment.class)
@Record(ExecutionTime.RUNTIME_INIT)
public RouteBuildItem createWebSocketEndpoint(NonApplicationRootPathBuildItem nonApplicationRootPathBuildItem,
GrpcDevConsoleRecorder recorder) {
recorder.setServerConfiguration();
return new DevConsoleRouteBuildItem("test", "POST", new TestEndpointHandler(getGrpcServices(index.getIndex())), true);
}

static class TestEndpointHandler implements Handler<RoutingContext> {

private Map<String, Object> blockingStubs;
private Map<String, ServiceDescriptor> serviceDescriptors;
private final Collection<Class<?>> grpcServiceClasses;

TestEndpointHandler(Collection<Class<?>> grpcServiceClasses) {
this.grpcServiceClasses = grpcServiceClasses;
}

void init() throws NoSuchMethodException, SecurityException, IllegalAccessException, IllegalArgumentException,
InvocationTargetException {
if (blockingStubs == null) {
blockingStubs = new HashMap<>();
serviceDescriptors = new HashMap<>();

Map<String, Object> serverConfig = DevConsoleManager.getGlobal("io.quarkus.grpc.serverConfig");

if (Boolean.FALSE.equals(serverConfig.get("ssl"))) {
for (Class<?> grpcServiceClass : grpcServiceClasses) {

Method method = grpcServiceClass.getDeclaredMethod("getServiceDescriptor");
ServiceDescriptor serviceDescriptor = (ServiceDescriptor) method.invoke(null);
serviceDescriptors.put(serviceDescriptor.getName(), serviceDescriptor);

// TODO more config options
Channel channel = NettyChannelBuilder
.forAddress(serverConfig.get("host").toString(), (Integer) serverConfig.get("port"))
.usePlaintext()
.build();
Method blockingStubFactoryMethod;

try {
blockingStubFactoryMethod = grpcServiceClass.getDeclaredMethod("newBlockingStub", Channel.class);
} catch (NoSuchMethodException e) {
LOG.warnf("Ignoring gRPC service - newBlockingStub() method not declared on %s", grpcServiceClass);
continue;
}

Object blockingStub = blockingStubFactoryMethod.invoke(null, channel);
blockingStubs.put(serviceDescriptor.getName(), blockingStub);
}
}
}
}

@Override
public void handle(RoutingContext context) {
try {
// Lazily initialize the handler
init();
} catch (Exception e) {
throw new IllegalStateException("Unable to initialize the test endpoint handler");
}

String serviceName = context.request().getParam("serviceName");
String methodName = context.request().getParam("methodName");
String testJsonData = context.getBodyAsString();

Object blockingStub = blockingStubs.get(serviceName);

if (blockingStub == null) {
error(context, "No blocking stub found for: " + serviceName);
} else {
ServiceDescriptor serviceDescriptor = serviceDescriptors.get(serviceName);
MethodDescriptor<?, ?> methodDescriptor = null;
for (MethodDescriptor<?, ?> method : serviceDescriptor.getMethods()) {
if (method.getBareMethodName().equals(methodName)) {
methodDescriptor = method;
}
}

if (methodDescriptor == null) {
error(context, "No method descriptor found for: " + serviceName + "/" + methodName);
} else {

// We need to find the correct method declared on the blocking stub
Method stubMethod = null;
String realMethodName = decapitalize(methodDescriptor.getBareMethodName());

for (Method method : blockingStub.getClass().getDeclaredMethods()) {
if (method.getName().equals(realMethodName)) {
stubMethod = method;
}
}

if (stubMethod == null) {
error(context, realMethodName + " method not declared on the " + blockingStub.getClass());
} else {

// Identify the request class
Marshaller<?> requestMarshaller = methodDescriptor.getRequestMarshaller();
if (requestMarshaller instanceof PrototypeMarshaller) {
PrototypeMarshaller<?> protoMarshaller = (PrototypeMarshaller<?>) requestMarshaller;
Class<?> requestType = protoMarshaller.getMessagePrototype().getClass();

try {
// Create a new builder for the request message, e.g. HelloRequest.newBuilder()
Method newBuilderMethod = requestType.getDeclaredMethod("newBuilder");
Message.Builder builder = (Builder) newBuilderMethod.invoke(null);
;

// Use the test data to build the request object
JsonFormat.parser().merge(testJsonData, builder);

// Invoke the blocking stub method and format the response as JSON
Object response = stubMethod.invoke(blockingStub, builder.build());
context.response().putHeader("Content-Type", "application/json");
context.end(JsonFormat.printer().print((MessageOrBuilder) response));

} catch (Exception e) {
throw new IllegalStateException(e);
}
} else {
error(context, "Unable to identify the request type for: " + methodDescriptor);
}
}
}
}

}
}

static void error(RoutingContext rc, String message) {
LOG.warn(message);
rc.response().setStatusCode(500).end(message);
return nonApplicationRootPathBuildItem.routeBuilder().route("dev/grpc-test")
.handler(recorder.handler()).build();
}

Collection<Class<?>> getGrpcServices(IndexView index) throws ClassNotFoundException {
Expand Down Expand Up @@ -253,18 +119,4 @@ Collection<Class<?>> getGrpcServices(IndexView index) throws ClassNotFoundExcept
}
return serviceClasses;
}

static String decapitalize(String name) {
if (name == null || name.length() == 0) {
return name;
}
if (name.length() > 1 && Character.isUpperCase(name.charAt(1)) &&
Character.isUpperCase(name.charAt(0))) {
return name;
}
char chars[] = name.toCharArray();
chars[0] = Character.toLowerCase(chars[0]);
return new String(chars);
}

}
Loading

0 comments on commit 3666102

Please sign in to comment.