Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gRPC Dev UI - support streaming calls #17442

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.quarkus.dev.testing;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

public class GrpcWebSocketProxy {

private static final AtomicInteger connectionIdSeq = new AtomicInteger();

private static volatile WebSocketListener webSocketListener;

private static final Map<Integer, Consumer<Runnable>> webSocketConnections = new ConcurrentHashMap<>();

public static Integer addWebSocket(Consumer<String> responseConsumer,
Consumer<Runnable> closeHandler) {
if (webSocketListener != null) {
int id = connectionIdSeq.getAndIncrement();
webSocketListener.onOpen(id, responseConsumer);

webSocketConnections.put(id, closeHandler);
return id;
}
return null;
}

public static void closeAll() {
CountDownLatch latch = new CountDownLatch(webSocketConnections.size());
for (Map.Entry<Integer, Consumer<Runnable>> connection : webSocketConnections.entrySet()) {
connection.getValue().accept(latch::countDown);
webSocketListener.onClose(connection.getKey());
}
try {
if (!latch.await(5, TimeUnit.SECONDS)) {
System.err.println("Failed to close all the websockets in 5 seconds");
}
} catch (InterruptedException e) {
System.err.println("Interrupted while waiting for websockets to be closed");
}
}

public static void closeWebSocket(int id) {
webSocketListener.onClose(id);
}

public static void setWebSocketListener(WebSocketListener listener) {
webSocketListener = listener;
}

public static void addMessage(Integer socketId, String message) {
webSocketListener.newMessage(socketId, message);
}

public interface WebSocketListener {
void onOpen(int id, Consumer<String> responseConsumer);

void newMessage(int id, String content);

void onClose(int id);
}
}
7 changes: 6 additions & 1 deletion extensions/grpc/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
<artifactId>quarkus-grpc-codegen</artifactId>
</dependency>

<!-- Test dependencies -->
<!-- for dev mode -->
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-vertx-http-deployment</artifactId>
</dependency>

<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-deployment</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ ServiceStartBuildItem initializeServer(GrpcServerRecorder recorder, GrpcConfigur
}
}

if (!bindables.isEmpty()) {
if (!bindables.isEmpty() || LaunchMode.current() == LaunchMode.DEVELOPMENT) {
recorder.initializeGrpcServer(vertx.getVertx(), config, shutdown, blocking, launchModeBuildItem.getLaunchMode());
return new ServiceStartBuildItem(GRPC_SERVER);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.quarkus.grpc.deployment.devmode;

import static io.quarkus.deployment.annotations.ExecutionTime.RUNTIME_INIT;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
Expand All @@ -17,44 +15,38 @@
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.logging.Logger;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Message.Builder;
import com.google.protobuf.MessageOrBuilder;
import com.google.protobuf.util.JsonFormat;

import io.grpc.Channel;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.Marshaller;
import io.grpc.MethodDescriptor.PrototypeMarshaller;
import io.grpc.ServiceDescriptor;
import io.grpc.netty.NettyChannelBuilder;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.runtime.BeanLookupSupplier;
import io.quarkus.deployment.IsDevelopment;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.Consume;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.deployment.builditem.RuntimeConfigSetupCompleteBuildItem;
import io.quarkus.deployment.builditem.ServiceStartBuildItem;
import io.quarkus.dev.console.DevConsoleManager;
import io.quarkus.devconsole.spi.DevConsoleRouteBuildItem;
import io.quarkus.dev.testing.GrpcWebSocketProxy;
import io.quarkus.devconsole.spi.DevConsoleRuntimeTemplateInfoBuildItem;
import io.quarkus.grpc.deployment.GrpcDotNames;
import io.quarkus.grpc.protoc.plugin.MutinyGrpcGenerator;
import io.quarkus.grpc.runtime.devmode.GrpcDevConsoleRecorder;
import io.quarkus.grpc.runtime.devmode.GrpcServices;
import io.vertx.core.Handler;
import io.vertx.ext.web.RoutingContext;
import io.quarkus.vertx.http.deployment.NonApplicationRootPathBuildItem;
import io.quarkus.vertx.http.deployment.RouteBuildItem;

public class GrpcDevConsoleProcessor {

private static final Logger LOG = Logger.getLogger(GrpcDevConsoleProcessor.class);

@BuildStep(onlyIf = IsDevelopment.class)
public void devConsoleInfo(BuildProducer<AdditionalBeanBuildItem> beans,
BuildProducer<DevConsoleRuntimeTemplateInfoBuildItem> infos) {
Expand All @@ -72,7 +64,8 @@ public void collectMessagePrototypes(CombinedIndexBuildItem index,
IllegalArgumentException, InvocationTargetException, InvalidProtocolBufferException {
Map<String, String> messagePrototypes = new HashMap<>();

for (Class<?> grpcServiceClass : getGrpcServices(index.getIndex())) {
Collection<Class<?>> grpcServices = getGrpcServices(index.getIndex());
for (Class<?> grpcServiceClass : grpcServices) {

Method method = grpcServiceClass.getDeclaredMethod("getServiceDescriptor");
ServiceDescriptor serviceDescriptor = (ServiceDescriptor) method.invoke(null);
Expand All @@ -89,145 +82,18 @@ public void collectMessagePrototypes(CombinedIndexBuildItem index,
}
DevConsoleManager.setGlobal("io.quarkus.grpc.messagePrototypes", messagePrototypes);

GrpcWebSocketProxy.setWebSocketListener(
new GrpcDevConsoleWebSocketListener(grpcServices, Thread.currentThread().getContextClassLoader()));
}

@Consume(RuntimeConfigSetupCompleteBuildItem.class)
@Record(value = RUNTIME_INIT)
@Record(ExecutionTime.RUNTIME_INIT)
@BuildStep(onlyIf = IsDevelopment.class)
DevConsoleRouteBuildItem registerTestEndpoint(GrpcDevConsoleRecorder recorder, CombinedIndexBuildItem index)
throws ClassNotFoundException, NoSuchMethodException,
SecurityException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
// Store the server config so that it can be used in the test endpoint handler
public RouteBuildItem createWebSocketEndpoint(NonApplicationRootPathBuildItem nonApplicationRootPathBuildItem,
GrpcDevConsoleRecorder recorder) {
recorder.setServerConfiguration();
return new DevConsoleRouteBuildItem("test", "POST", new TestEndpointHandler(getGrpcServices(index.getIndex())), true);
}

static class TestEndpointHandler implements Handler<RoutingContext> {

private Map<String, Object> blockingStubs;
private Map<String, ServiceDescriptor> serviceDescriptors;
private final Collection<Class<?>> grpcServiceClasses;

TestEndpointHandler(Collection<Class<?>> grpcServiceClasses) {
this.grpcServiceClasses = grpcServiceClasses;
}

void init() throws NoSuchMethodException, SecurityException, IllegalAccessException, IllegalArgumentException,
InvocationTargetException {
if (blockingStubs == null) {
blockingStubs = new HashMap<>();
serviceDescriptors = new HashMap<>();

Map<String, Object> serverConfig = DevConsoleManager.getGlobal("io.quarkus.grpc.serverConfig");

if (Boolean.FALSE.equals(serverConfig.get("ssl"))) {
for (Class<?> grpcServiceClass : grpcServiceClasses) {

Method method = grpcServiceClass.getDeclaredMethod("getServiceDescriptor");
ServiceDescriptor serviceDescriptor = (ServiceDescriptor) method.invoke(null);
serviceDescriptors.put(serviceDescriptor.getName(), serviceDescriptor);

// TODO more config options
Channel channel = NettyChannelBuilder
.forAddress(serverConfig.get("host").toString(), (Integer) serverConfig.get("port"))
.usePlaintext()
.build();
Method blockingStubFactoryMethod;

try {
blockingStubFactoryMethod = grpcServiceClass.getDeclaredMethod("newBlockingStub", Channel.class);
} catch (NoSuchMethodException e) {
LOG.warnf("Ignoring gRPC service - newBlockingStub() method not declared on %s", grpcServiceClass);
continue;
}

Object blockingStub = blockingStubFactoryMethod.invoke(null, channel);
blockingStubs.put(serviceDescriptor.getName(), blockingStub);
}
}
}
}

@Override
public void handle(RoutingContext context) {
try {
// Lazily initialize the handler
init();
} catch (Exception e) {
throw new IllegalStateException("Unable to initialize the test endpoint handler");
}

String serviceName = context.request().getParam("serviceName");
String methodName = context.request().getParam("methodName");
String testJsonData = context.getBodyAsString();

Object blockingStub = blockingStubs.get(serviceName);

if (blockingStub == null) {
error(context, "No blocking stub found for: " + serviceName);
} else {
ServiceDescriptor serviceDescriptor = serviceDescriptors.get(serviceName);
MethodDescriptor<?, ?> methodDescriptor = null;
for (MethodDescriptor<?, ?> method : serviceDescriptor.getMethods()) {
if (method.getBareMethodName().equals(methodName)) {
methodDescriptor = method;
}
}

if (methodDescriptor == null) {
error(context, "No method descriptor found for: " + serviceName + "/" + methodName);
} else {

// We need to find the correct method declared on the blocking stub
Method stubMethod = null;
String realMethodName = decapitalize(methodDescriptor.getBareMethodName());

for (Method method : blockingStub.getClass().getDeclaredMethods()) {
if (method.getName().equals(realMethodName)) {
stubMethod = method;
}
}

if (stubMethod == null) {
error(context, realMethodName + " method not declared on the " + blockingStub.getClass());
} else {

// Identify the request class
Marshaller<?> requestMarshaller = methodDescriptor.getRequestMarshaller();
if (requestMarshaller instanceof PrototypeMarshaller) {
PrototypeMarshaller<?> protoMarshaller = (PrototypeMarshaller<?>) requestMarshaller;
Class<?> requestType = protoMarshaller.getMessagePrototype().getClass();

try {
// Create a new builder for the request message, e.g. HelloRequest.newBuilder()
Method newBuilderMethod = requestType.getDeclaredMethod("newBuilder");
Message.Builder builder = (Builder) newBuilderMethod.invoke(null);
;

// Use the test data to build the request object
JsonFormat.parser().merge(testJsonData, builder);

// Invoke the blocking stub method and format the response as JSON
Object response = stubMethod.invoke(blockingStub, builder.build());
context.response().putHeader("Content-Type", "application/json");
context.end(JsonFormat.printer().print((MessageOrBuilder) response));

} catch (Exception e) {
throw new IllegalStateException(e);
}
} else {
error(context, "Unable to identify the request type for: " + methodDescriptor);
}
}
}
}

}
}

static void error(RoutingContext rc, String message) {
LOG.warn(message);
rc.response().setStatusCode(500).end(message);
return nonApplicationRootPathBuildItem.routeBuilder().route("dev/grpc-test")
.handler(recorder.handler()).build();
}

Collection<Class<?>> getGrpcServices(IndexView index) throws ClassNotFoundException {
Expand Down Expand Up @@ -256,18 +122,4 @@ Collection<Class<?>> getGrpcServices(IndexView index) throws ClassNotFoundExcept
}
return serviceClasses;
}

static String decapitalize(String name) {
if (name == null || name.length() == 0) {
return name;
}
if (name.length() > 1 && Character.isUpperCase(name.charAt(1)) &&
Character.isUpperCase(name.charAt(0))) {
return name;
}
char chars[] = name.toCharArray();
chars[0] = Character.toLowerCase(chars[0]);
return new String(chars);
}

}
Loading