diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java new file mode 100644 index 0000000000000..575230301ae2f --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java @@ -0,0 +1,322 @@ +package io.quarkus.websockets.next.deployment; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.Predicate; + +import org.jboss.jandex.AnnotationInstance; +import org.jboss.jandex.AnnotationValue; +import org.jboss.jandex.DotName; +import org.jboss.jandex.IndexView; +import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.MethodParameterInfo; +import org.jboss.jandex.Type; +import org.jboss.jandex.Type.Kind; + +import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; +import io.quarkus.arc.processor.Annotations; +import io.quarkus.arc.processor.DotNames; +import io.quarkus.gizmo.BytecodeCreator; +import io.quarkus.gizmo.FieldDescriptor; +import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.WebSocketException; +import io.quarkus.websockets.next.deployment.CallbackArgument.InvocationBytecodeContext; +import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext; +import io.quarkus.websockets.next.runtime.WebSocketConnectionBase; +import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; +import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; + +/** + * Represents either an endpoint callback or a global error handler. + */ +public class Callback { + + public final Target target; + public final String endpointPath; + public final AnnotationInstance annotation; + public final MethodInfo method; + public final ExecutionModel executionModel; + public final MessageType messageType; + public final List arguments; + + public Callback(Target target, AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + String endpointPath, IndexView index) { + this.target = target; + this.method = method; + this.annotation = annotation; + this.executionModel = executionModel; + if (WebSocketDotNames.ON_BINARY_MESSAGE.equals(annotation.name())) { + this.messageType = MessageType.BINARY; + } else if (WebSocketDotNames.ON_TEXT_MESSAGE.equals(annotation.name())) { + this.messageType = MessageType.TEXT; + } else if (WebSocketDotNames.ON_PONG_MESSAGE.equals(annotation.name())) { + this.messageType = MessageType.PONG; + } else { + this.messageType = MessageType.NONE; + } + this.endpointPath = endpointPath; + this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, index); + } + + public boolean isGlobal() { + return endpointPath == null; + } + + public boolean isClient() { + return target == Target.CLIENT; + } + + public boolean isServer() { + return target == Target.SERVER; + } + + public boolean isOnOpen() { + return annotation.name().equals(WebSocketDotNames.ON_OPEN); + } + + public boolean isOnClose() { + return annotation.name().equals(WebSocketDotNames.ON_CLOSE); + } + + public boolean isOnError() { + return annotation.name().equals(WebSocketDotNames.ON_ERROR); + } + + public Type returnType() { + return method.returnType(); + } + + public Type messageParamType() { + return acceptsMessage() ? method.parameterType(0) : null; + } + + public boolean isReturnTypeVoid() { + return returnType().kind() == Kind.VOID; + } + + public boolean isReturnTypeUni() { + return WebSocketDotNames.UNI.equals(returnType().name()); + } + + public boolean isReturnTypeMulti() { + return WebSocketDotNames.MULTI.equals(returnType().name()); + } + + public boolean acceptsMessage() { + return messageType != MessageType.NONE; + } + + public boolean acceptsBinaryMessage() { + return messageType == MessageType.BINARY || messageType == MessageType.PONG; + } + + public boolean acceptsMulti() { + return acceptsMessage() && method.parameterType(0).name().equals(WebSocketDotNames.MULTI); + } + + public Callback.MessageType messageType() { + return messageType; + } + + public boolean broadcast() { + AnnotationValue broadcastValue = annotation.value("broadcast"); + return broadcastValue != null && broadcastValue.asBoolean(); + } + + public DotName getInputCodec() { + return getCodec("codec"); + } + + public DotName getOutputCodec() { + DotName output = getCodec("outputCodec"); + return output != null ? output : getInputCodec(); + } + + public String asString() { + return method.declaringClass().name() + "#" + method.name() + "()"; + } + + private DotName getCodec(String valueName) { + AnnotationValue codecValue = annotation.value(valueName); + if (codecValue != null) { + return codecValue.asClass().name(); + } + return null; + } + + public enum MessageType { + NONE, + PONG, + TEXT, + BINARY + } + + public enum Target { + CLIENT, + SERVER, + UNDEFINED + } + + public ResultHandle[] generateArguments(ResultHandle endpointThis, BytecodeCreator bytecode, + TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) { + if (arguments.isEmpty()) { + return new ResultHandle[] {}; + } + ResultHandle[] resultHandles = new ResultHandle[arguments.size()]; + int idx = 0; + for (CallbackArgument argument : arguments) { + resultHandles[idx] = argument.get( + invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations, index, + endpointThis, bytecode)); + idx++; + } + return resultHandles; + } + + private List collectArguments(AnnotationInstance annotation, MethodInfo method, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, + IndexView index) { + List parameters = method.parameters(); + if (parameters.isEmpty()) { + return List.of(); + } + List arguments = new ArrayList<>(parameters.size()); + for (MethodParameterInfo parameter : parameters) { + List found = callbackArguments + .findMatching(parameterContext(annotation, parameter, transformedAnnotations, index)); + if (found.isEmpty()) { + String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found", + DotNames.simpleName(annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + asString()); + throw new WebSocketException(msg); + } else if (found.size() > 1 && (found.get(0).priotity() == found.get(1).priotity())) { + String msg = String.format( + "Unable to inject @%s callback parameter '%s' declared on %s: ambiguous injectors found: %s", + DotNames.simpleName(annotation.name()), + parameter.name() != null ? parameter.name() : "#" + parameter.position(), + asString(), + found.stream().map(p -> p.getClass().getSimpleName() + ":" + p.priotity())); + throw new WebSocketException(msg); + } + arguments.add(found.get(0)); + } + return List.copyOf(arguments); + } + + Type argumentType(Predicate filter) { + for (int i = 0; i < arguments.size(); i++) { + if (filter.test(arguments.get(i))) { + return method.parameterType(i); + } + } + return null; + } + + private ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter, + TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) { + return new ParameterContext() { + + @Override + public Target callbackTarget() { + return target; + } + + @Override + public MethodParameterInfo parameter() { + return parameter; + } + + @Override + public Set parameterAnnotations() { + return Annotations.getParameterAnnotations( + transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); + } + + @Override + public AnnotationInstance callbackAnnotation() { + return callbackAnnotation; + } + + @Override + public String endpointPath() { + return endpointPath; + } + + @Override + public IndexView index() { + return index; + } + + }; + } + + private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation, + MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, IndexView index, + ResultHandle endpointThis, BytecodeCreator bytecode) { + return new InvocationBytecodeContext() { + + @Override + public Target callbackTarget() { + return target; + } + + @Override + public AnnotationInstance callbackAnnotation() { + return callbackAnnotation; + } + + @Override + public MethodParameterInfo parameter() { + return parameter; + } + + @Override + public Set parameterAnnotations() { + return Annotations.getParameterAnnotations( + transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); + } + + @Override + public String endpointPath() { + return endpointPath; + } + + @Override + public IndexView index() { + return index; + } + + @Override + public BytecodeCreator bytecode() { + return bytecode; + } + + @Override + public ResultHandle getPayload() { + return acceptsMessage() || callbackAnnotation.name().equals(WebSocketDotNames.ON_ERROR) + ? bytecode.getMethodParam(0) + : null; + } + + @Override + public ResultHandle getDecodedMessage(Type parameterType) { + return acceptsMessage() + ? WebSocketProcessor.decodeMessage(endpointThis, bytecode, acceptsBinaryMessage(), + parameterType, + getPayload(), Callback.this) + : null; + } + + @Override + public ResultHandle getConnection() { + return bytecode.readInstanceField( + FieldDescriptor.of(WebSocketEndpointBase.class, "connection", WebSocketConnectionBase.class), + endpointThis); + } + }; + } + +} diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java index 752cea1ed4f25..3a3899b80a8a3 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/CallbackArgument.java @@ -13,7 +13,8 @@ import io.quarkus.websockets.next.OnError; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocketConnection; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; +import io.quarkus.websockets.next.deployment.Callback.Target; /** * Provides arguments for method parameters of a callback method declared on a WebSocket endpoint. @@ -24,7 +25,7 @@ interface CallbackArgument { * * @param context * @return {@code true} if this provider matches the given parameter context, {@code false} otherwise - * @throws WebSocketServerException If an invalid parameter is detected + * @throws WebSocketException If an invalid parameter is detected */ boolean matches(ParameterContext context); @@ -49,6 +50,12 @@ default int priotity() { interface ParameterContext { + /** + * + * @return the callback target + */ + Target callbackTarget(); + /** * * @return the endpoint path or {@code null} for global error handlers diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java index d6d16fc468430..607ae698c2aee 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/ConnectionCallbackArgument.java @@ -1,12 +1,30 @@ package io.quarkus.websockets.next.deployment; +import org.jboss.jandex.DotName; + import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.WebSocketException; +import io.quarkus.websockets.next.deployment.Callback.Target; class ConnectionCallbackArgument implements CallbackArgument { @Override public boolean matches(ParameterContext context) { - return context.parameter().type().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTION); + DotName paramTypeName = context.parameter().type().name(); + if (context.callbackTarget() == Target.SERVER) { + if (WebSocketDotNames.WEB_SOCKET_CONNECTION.equals(paramTypeName)) { + return true; + } else if (WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION.equals(paramTypeName)) { + throw new WebSocketException("@WebSocket callback method may not accept WebSocketClientConnection"); + } + } else if (context.callbackTarget() == Target.CLIENT) { + if (WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION.equals(paramTypeName)) { + return true; + } else if (WebSocketDotNames.WEB_SOCKET_CONNECTION.equals(paramTypeName)) { + throw new WebSocketException("@WebSocketClient callback method may not accept WebSocketConnection"); + } + } + return false; } @Override diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java index 0c65f2320a52f..ec69bab43d6e7 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GeneratedEndpointBuildItem.java @@ -3,7 +3,7 @@ import io.quarkus.builder.item.MultiBuildItem; /** - * A generated representation of a {@link io.quarkus.websockets.next.runtime.WebSocketEndpoint}. + * A generated representation of a WebSocket endpoint. */ public final class GeneratedEndpointBuildItem extends MultiBuildItem { @@ -11,12 +11,39 @@ public final class GeneratedEndpointBuildItem extends MultiBuildItem { public final String endpointClassName; public final String generatedClassName; public final String path; + public final boolean isClient; - GeneratedEndpointBuildItem(String endpointId, String endpointClassName, String generatedClassName, String path) { + GeneratedEndpointBuildItem(String endpointId, String endpointClassName, String generatedClassName, String path, + boolean isClient) { this.endpointId = endpointId; this.endpointClassName = endpointClassName; this.generatedClassName = generatedClassName; this.path = path; + this.isClient = isClient; + } + + public boolean isServer() { + return !isClient; + } + + public boolean isClient() { + return isClient; + } + + public String getEndpointId() { + return endpointId; + } + + public String getEndpointClassName() { + return endpointClassName; + } + + public String getGeneratedClassName() { + return generatedClassName; + } + + public String getPath() { + return path; } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GlobalErrorHandlersBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GlobalErrorHandlersBuildItem.java index 723610606ae1d..c2b91420f5369 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GlobalErrorHandlersBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/GlobalErrorHandlersBuildItem.java @@ -3,7 +3,8 @@ import java.util.List; import io.quarkus.builder.item.SimpleBuildItem; -import io.quarkus.websockets.next.deployment.WebSocketServerProcessor.GlobalErrorHandler; +import io.quarkus.websockets.next.deployment.Callback.Target; +import io.quarkus.websockets.next.deployment.WebSocketProcessor.GlobalErrorHandler; final class GlobalErrorHandlersBuildItem extends SimpleBuildItem { @@ -13,4 +14,11 @@ final class GlobalErrorHandlersBuildItem extends SimpleBuildItem { this.handlers = handlers; } + List forServer() { + return handlers.stream().filter(h -> h.callback().isServer() || h.callback().target == Target.UNDEFINED).toList(); + } + + List forClient() { + return handlers.stream().filter(h -> h.callback().isClient() || h.callback().target == Target.UNDEFINED).toList(); + } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java index 0e252ae9e26da..14cfedf39a1e3 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/HandshakeRequestCallbackArgument.java @@ -2,7 +2,8 @@ import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.gizmo.ResultHandle; -import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.HandshakeRequest; +import io.quarkus.websockets.next.runtime.WebSocketConnectionBase; class HandshakeRequestCallbackArgument implements CallbackArgument { @@ -13,9 +14,9 @@ public boolean matches(ParameterContext context) { @Override public ResultHandle get(InvocationBytecodeContext context) { - ResultHandle connection = context.getConnection(); - return context.bytecode().invokeInterfaceMethod(MethodDescriptor.ofMethod(WebSocketConnection.class, "handshakeRequest", - WebSocketConnection.HandshakeRequest.class), connection); + return context.bytecode() + .invokeVirtualMethod(MethodDescriptor.ofMethod(WebSocketConnectionBase.class, "handshakeRequest", + HandshakeRequest.class), context.getConnection()); } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java index 9fb550fb9eeab..317daabc01c20 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/PathParamCallbackArgument.java @@ -10,8 +10,8 @@ import io.quarkus.arc.processor.Annotations; import io.quarkus.gizmo.MethodDescriptor; import io.quarkus.gizmo.ResultHandle; -import io.quarkus.websockets.next.WebSocketConnection; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; +import io.quarkus.websockets.next.runtime.WebSocketConnectionBase; class PathParamCallbackArgument implements CallbackArgument { @@ -20,20 +20,20 @@ public boolean matches(ParameterContext context) { String name = getParamName(context); if (name != null) { if (!context.parameter().type().name().equals(WebSocketDotNames.STRING)) { - throw new WebSocketServerException("Method parameter annotated with @PathParam must be java.lang.String: " - + WebSocketServerProcessor.callbackToString(context.parameter().method())); + throw new WebSocketException("Method parameter annotated with @PathParam must be java.lang.String: " + + WebSocketProcessor.methodToString(context.parameter().method())); } if (context.endpointPath() == null) { - throw new WebSocketServerException("Global error handlers may not accept @PathParam parameters: " - + WebSocketServerProcessor.callbackToString(context.parameter().method())); + throw new WebSocketException("Global error handlers may not accept @PathParam parameters: " + + WebSocketProcessor.methodToString(context.parameter().method())); } List pathParams = getPathParamNames(context.endpointPath()); if (!pathParams.contains(name)) { - throw new WebSocketServerException( + throw new WebSocketException( String.format( "@PathParam name [%s] must be used in the endpoint path [%s]: %s", name, context.endpointPath(), - WebSocketServerProcessor.callbackToString(context.parameter().method()))); + WebSocketProcessor.methodToString(context.parameter().method()))); } return true; } @@ -42,10 +42,10 @@ public boolean matches(ParameterContext context) { @Override public ResultHandle get(InvocationBytecodeContext context) { - ResultHandle connection = context.getConnection(); String paramName = getParamName(context); - return context.bytecode().invokeInterfaceMethod( - MethodDescriptor.ofMethod(WebSocketConnection.class, "pathParam", String.class, String.class), connection, + return context.bytecode().invokeVirtualMethod( + MethodDescriptor.ofMethod(WebSocketConnectionBase.class, "pathParam", String.class, String.class), + context.getConnection(), context.bytecode().load(paramName)); } @@ -61,7 +61,7 @@ private String getParamName(ParameterContext context) { name = context.parameter().name(); } if (name == null) { - throw new WebSocketServerException(String.format( + throw new WebSocketException(String.format( "Unable to extract the path parameter name - method parameter names not recorded for %s: compile the class with -parameters", context.parameter().method().declaringClass().name())); } @@ -72,7 +72,7 @@ private String getParamName(ParameterContext context) { static List getPathParamNames(String path) { List names = new ArrayList<>(); - Matcher m = WebSocketServerProcessor.TRANSLATED_PATH_PARAM_PATTERN.matcher(path); + Matcher m = WebSocketProcessor.TRANSLATED_PATH_PARAM_PATTERN.matcher(path); while (m.find()) { names.add(m.group().substring(1)); } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java index 98dfd77f50cad..311852514b82e 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketDotNames.java @@ -4,6 +4,7 @@ import org.jboss.jandex.DotName; +import io.quarkus.websockets.next.HandshakeRequest; import io.quarkus.websockets.next.OnBinaryMessage; import io.quarkus.websockets.next.OnClose; import io.quarkus.websockets.next.OnError; @@ -12,7 +13,10 @@ import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.PathParam; import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.WebSocketConnector; import io.smallrye.common.annotation.Blocking; import io.smallrye.common.annotation.NonBlocking; import io.smallrye.common.annotation.RunOnVirtualThread; @@ -25,7 +29,10 @@ final class WebSocketDotNames { static final DotName WEB_SOCKET = DotName.createSimple(WebSocket.class); + static final DotName WEB_SOCKET_CLIENT = DotName.createSimple(WebSocketClient.class); static final DotName WEB_SOCKET_CONNECTION = DotName.createSimple(WebSocketConnection.class); + static final DotName WEB_SOCKET_CLIENT_CONNECTION = DotName.createSimple(WebSocketClientConnection.class); + static final DotName WEB_SOCKET_CONNECTOR = DotName.createSimple(WebSocketConnector.class); static final DotName ON_OPEN = DotName.createSimple(OnOpen.class); static final DotName ON_TEXT_MESSAGE = DotName.createSimple(OnTextMessage.class); static final DotName ON_BINARY_MESSAGE = DotName.createSimple(OnBinaryMessage.class); @@ -43,7 +50,7 @@ final class WebSocketDotNames { static final DotName JSON_ARRAY = DotName.createSimple(JsonArray.class); static final DotName VOID = DotName.createSimple(Void.class); static final DotName PATH_PARAM = DotName.createSimple(PathParam.class); - static final DotName HANDSHAKE_REQUEST = DotName.createSimple(WebSocketConnection.HandshakeRequest.class); + static final DotName HANDSHAKE_REQUEST = DotName.createSimple(HandshakeRequest.class); static final DotName THROWABLE = DotName.createSimple(Throwable.class); static final List CALLBACK_ANNOTATIONS = List.of(ON_OPEN, ON_CLOSE, ON_BINARY_MESSAGE, ON_TEXT_MESSAGE, diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java index d5ff556891e9e..e143eae0831ec 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketEndpointBuildItem.java @@ -1,44 +1,28 @@ package io.quarkus.websockets.next.deployment; -import java.util.ArrayList; import java.util.List; -import java.util.Set; -import java.util.function.Predicate; -import org.jboss.jandex.AnnotationInstance; -import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.DotName; -import org.jboss.jandex.IndexView; -import org.jboss.jandex.MethodInfo; -import org.jboss.jandex.MethodParameterInfo; -import org.jboss.jandex.Type; -import org.jboss.jandex.Type.Kind; -import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; -import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.BeanInfo; -import io.quarkus.arc.processor.DotNames; import io.quarkus.builder.item.MultiBuildItem; -import io.quarkus.gizmo.BytecodeCreator; -import io.quarkus.gizmo.FieldDescriptor; -import io.quarkus.gizmo.ResultHandle; +import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketConnection; -import io.quarkus.websockets.next.WebSocketServerException; -import io.quarkus.websockets.next.deployment.CallbackArgument.InvocationBytecodeContext; -import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext; -import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; -import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; +import io.quarkus.websockets.next.WebSocketClient; /** - * This build item represents a WebSocket endpoint class. + * This build item represents a WebSocket endpoint class, i.e. class annotated with {@link WebSocket} or + * {@link WebSocketClient}. */ public final class WebSocketEndpointBuildItem extends MultiBuildItem { + public final boolean isClient; public final BeanInfo bean; + // The path is using Vertx syntax for path params, i.e. /foo/:bar public final String path; - public final String endpointId; - public final WebSocket.ExecutionMode executionMode; + // @WebSocket#endpointId() or @WebSocketClient#clientId() + public final String id; + public final InboundProcessingMode inboundProcessingMode; public final Callback onOpen; public final Callback onTextMessage; public final Callback onBinaryMessage; @@ -46,14 +30,15 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem { public final Callback onClose; public final List onErrors; - WebSocketEndpointBuildItem(BeanInfo bean, String path, String endpointId, WebSocket.ExecutionMode executionMode, - Callback onOpen, - Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose, + WebSocketEndpointBuildItem(boolean isClient, BeanInfo bean, String path, String id, + InboundProcessingMode inboundProcessingMode, + Callback onOpen, Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose, List onErrors) { + this.isClient = isClient; this.bean = bean; this.path = path; - this.endpointId = endpointId; - this.executionMode = executionMode; + this.id = id; + this.inboundProcessingMode = inboundProcessingMode; this.onOpen = onOpen; this.onTextMessage = onTextMessage; this.onBinaryMessage = onBinaryMessage; @@ -62,266 +47,16 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem { this.onErrors = onErrors; } - public static class Callback { - - public final String endpointPath; - public final AnnotationInstance annotation; - public final MethodInfo method; - public final ExecutionModel executionModel; - public final MessageType messageType; - public final List arguments; - - public Callback(AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, IndexView index) { - this.method = method; - this.annotation = annotation; - this.executionModel = executionModel; - if (WebSocketDotNames.ON_BINARY_MESSAGE.equals(annotation.name())) { - this.messageType = MessageType.BINARY; - } else if (WebSocketDotNames.ON_TEXT_MESSAGE.equals(annotation.name())) { - this.messageType = MessageType.TEXT; - } else if (WebSocketDotNames.ON_PONG_MESSAGE.equals(annotation.name())) { - this.messageType = MessageType.PONG; - } else { - this.messageType = MessageType.NONE; - } - this.endpointPath = endpointPath; - this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, index); - } - - public boolean isGlobal() { - return endpointPath == null; - } - - public boolean isOnOpen() { - return annotation.name().equals(WebSocketDotNames.ON_OPEN); - } - - public boolean isOnClose() { - return annotation.name().equals(WebSocketDotNames.ON_CLOSE); - } - - public boolean isOnError() { - return annotation.name().equals(WebSocketDotNames.ON_ERROR); - } - - public Type returnType() { - return method.returnType(); - } - - public Type messageParamType() { - return acceptsMessage() ? method.parameterType(0) : null; - } - - public boolean isReturnTypeVoid() { - return returnType().kind() == Kind.VOID; - } - - public boolean isReturnTypeUni() { - return WebSocketDotNames.UNI.equals(returnType().name()); - } - - public boolean isReturnTypeMulti() { - return WebSocketDotNames.MULTI.equals(returnType().name()); - } - - public boolean acceptsMessage() { - return messageType != MessageType.NONE; - } - - public boolean acceptsBinaryMessage() { - return messageType == MessageType.BINARY || messageType == MessageType.PONG; - } - - public boolean acceptsMulti() { - return acceptsMessage() && method.parameterType(0).name().equals(WebSocketDotNames.MULTI); - } - - public MessageType messageType() { - return messageType; - } - - public boolean broadcast() { - AnnotationValue broadcastValue = annotation.value("broadcast"); - return broadcastValue != null && broadcastValue.asBoolean(); - } - - public DotName getInputCodec() { - return getCodec("codec"); - } - - public DotName getOutputCodec() { - DotName output = getCodec("outputCodec"); - return output != null ? output : getInputCodec(); - } - - private DotName getCodec(String valueName) { - AnnotationValue codecValue = annotation.value(valueName); - if (codecValue != null) { - return codecValue.asClass().name(); - } - return null; - } - - public enum MessageType { - NONE, - PONG, - TEXT, - BINARY - } - - public ResultHandle[] generateArguments(ResultHandle endpointThis, BytecodeCreator bytecode, - TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) { - if (arguments.isEmpty()) { - return new ResultHandle[] {}; - } - ResultHandle[] resultHandles = new ResultHandle[arguments.size()]; - int idx = 0; - for (CallbackArgument argument : arguments) { - resultHandles[idx] = argument.get( - invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations, index, - endpointThis, bytecode)); - idx++; - } - return resultHandles; - } - - private List collectArguments(AnnotationInstance annotation, MethodInfo method, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - IndexView index) { - List parameters = method.parameters(); - if (parameters.isEmpty()) { - return List.of(); - } - List arguments = new ArrayList<>(parameters.size()); - for (MethodParameterInfo parameter : parameters) { - List found = callbackArguments - .findMatching(parameterContext(annotation, parameter, transformedAnnotations, index)); - if (found.isEmpty()) { - String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found", - DotNames.simpleName(annotation.name()), - parameter.name() != null ? parameter.name() : "#" + parameter.position(), - WebSocketServerProcessor.callbackToString(method)); - throw new WebSocketServerException(msg); - } else if (found.size() > 1 && (found.get(0).priotity() == found.get(1).priotity())) { - String msg = String.format( - "Unable to inject @%s callback parameter '%s' declared on %s: ambiguous injectors found: %s", - DotNames.simpleName(annotation.name()), - parameter.name() != null ? parameter.name() : "#" + parameter.position(), - WebSocketServerProcessor.callbackToString(method), - found.stream().map(p -> p.getClass().getSimpleName() + ":" + p.priotity())); - throw new WebSocketServerException(msg); - } - arguments.add(found.get(0)); - } - return List.copyOf(arguments); - } - - Type argumentType(Predicate filter) { - int idx = 0; - for (int i = 0; i < arguments.size(); i++) { - if (filter.test(arguments.get(idx))) { - return method.parameterType(i); - } - } - return null; - } - - private ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter, - TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) { - return new ParameterContext() { - - @Override - public MethodParameterInfo parameter() { - return parameter; - } - - @Override - public Set parameterAnnotations() { - return Annotations.getParameterAnnotations( - transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); - } - - @Override - public AnnotationInstance callbackAnnotation() { - return callbackAnnotation; - } - - @Override - public String endpointPath() { - return endpointPath; - } - - @Override - public IndexView index() { - return index; - } - - }; - } - - private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation, - MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, IndexView index, - ResultHandle endpointThis, BytecodeCreator bytecode) { - return new InvocationBytecodeContext() { - - @Override - public AnnotationInstance callbackAnnotation() { - return callbackAnnotation; - } - - @Override - public MethodParameterInfo parameter() { - return parameter; - } - - @Override - public Set parameterAnnotations() { - return Annotations.getParameterAnnotations( - transformedAnnotations::getAnnotations, parameter.method(), parameter.position()); - } - - @Override - public String endpointPath() { - return endpointPath; - } - - @Override - public IndexView index() { - return index; - } - - @Override - public BytecodeCreator bytecode() { - return bytecode; - } - - @Override - public ResultHandle getPayload() { - return acceptsMessage() || callbackAnnotation.name().equals(WebSocketDotNames.ON_ERROR) - ? bytecode.getMethodParam(0) - : null; - } - - @Override - public ResultHandle getDecodedMessage(Type parameterType) { - return acceptsMessage() - ? WebSocketServerProcessor.decodeMessage(endpointThis, bytecode, acceptsBinaryMessage(), - parameterType, - getPayload(), Callback.this) - : null; - } + public boolean isClient() { + return isClient; + } - @Override - public ResultHandle getConnection() { - return bytecode.readInstanceField( - FieldDescriptor.of(WebSocketEndpointBase.class, "connection", WebSocketConnection.class), - endpointThis); - } - }; - } + public boolean isServer() { + return !isClient; + } + public DotName beanClassName() { + return bean.getImplClazz().name(); } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java similarity index 74% rename from extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java rename to extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index c0216132ab1f1..ab446a8cd060f 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -36,16 +36,19 @@ import io.quarkus.arc.deployment.CustomScopeBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; -import io.quarkus.arc.deployment.UnremovableBeanBuildItem; +import io.quarkus.arc.deployment.ValidationPhaseBuildItem; +import io.quarkus.arc.deployment.ValidationPhaseBuildItem.ValidationErrorBuildItem; import io.quarkus.arc.processor.Annotations; import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.arc.processor.DotNames; +import io.quarkus.arc.processor.InjectionPointInfo; import io.quarkus.arc.processor.Types; import io.quarkus.deployment.GeneratedClassGizmoAdaptor; import io.quarkus.deployment.annotations.BuildProducer; import io.quarkus.deployment.annotations.BuildStep; import io.quarkus.deployment.annotations.Record; +import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.FeatureBuildItem; import io.quarkus.deployment.builditem.GeneratedClassBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; @@ -62,17 +65,24 @@ import io.quarkus.vertx.http.deployment.HttpRootPathBuildItem; import io.quarkus.vertx.http.deployment.RouteBuildItem; import io.quarkus.vertx.http.runtime.HandlerType; -import io.quarkus.websockets.next.TextMessageCodec; -import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.InboundProcessingMode; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketClientException; import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.WebSocketException; import io.quarkus.websockets.next.WebSocketServerException; -import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; -import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; -import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback.MessageType; +import io.quarkus.websockets.next.deployment.Callback.MessageType; +import io.quarkus.websockets.next.deployment.Callback.Target; +import io.quarkus.websockets.next.runtime.BasicWebSocketConnectorImpl; +import io.quarkus.websockets.next.runtime.ClientConnectionManager; import io.quarkus.websockets.next.runtime.Codecs; import io.quarkus.websockets.next.runtime.ConnectionManager; import io.quarkus.websockets.next.runtime.ContextSupport; import io.quarkus.websockets.next.runtime.JsonTextMessageCodec; +import io.quarkus.websockets.next.runtime.WebSocketClientRecorder; +import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpoint; +import io.quarkus.websockets.next.runtime.WebSocketConnectionBase; +import io.quarkus.websockets.next.runtime.WebSocketConnectorImpl; import io.quarkus.websockets.next.runtime.WebSocketEndpoint; import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel; import io.quarkus.websockets.next.runtime.WebSocketEndpointBase; @@ -87,9 +97,10 @@ import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; -public class WebSocketServerProcessor { +public class WebSocketProcessor { - static final String ENDPOINT_SUFFIX = "_WebSocketEndpoint"; + static final String SERVER_ENDPOINT_SUFFIX = "_WebSocketServerEndpoint"; + static final String CLIENT_ENDPOINT_SUFFIX = "_WebSocketClientEndpoint"; static final String NESTED_SEPARATOR = "$_"; // Parameter names consist of alphanumeric characters and underscore @@ -102,8 +113,10 @@ FeatureBuildItem feature() { } @BuildStep - BeanDefiningAnnotationBuildItem beanDefiningAnnotation() { - return new BeanDefiningAnnotationBuildItem(WebSocketDotNames.WEB_SOCKET, DotNames.SINGLETON); + void beanDefiningAnnotations(BuildProducer beanDefiningAnnotations) { + beanDefiningAnnotations.produce(new BeanDefiningAnnotationBuildItem(WebSocketDotNames.WEB_SOCKET, DotNames.SINGLETON)); + beanDefiningAnnotations + .produce(new BeanDefiningAnnotationBuildItem(WebSocketDotNames.WEB_SOCKET_CLIENT, DotNames.SINGLETON)); } @BuildStep @@ -115,11 +128,6 @@ AutoAddScopeBuildItem addScopeToGlobalErrorHandlers() { .defaultScope(BuiltinScope.SINGLETON).build(); } - @BuildStep - void unremovableBeans(BuildProducer unremovableBeans) { - unremovableBeans.produce(UnremovableBeanBuildItem.beanTypes(TextMessageCodec.class)); - } - @BuildStep ExecutionModelAnnotationsAllowedBuildItem executionModelAnnotations( TransformedAnnotationsBuildItem transformedAnnotations) { @@ -133,12 +141,63 @@ public boolean test(MethodInfo method) { } @BuildStep - public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, + CallbackArgumentsBuildItem collectCallbackArguments(List callbackArguments) { + List sorted = new ArrayList<>(); + for (CallbackArgumentBuildItem callbackArgument : callbackArguments) { + sorted.add(callbackArgument.getProvider()); + } + sorted.sort(Comparator.comparingInt(CallbackArgument::priotity).reversed()); + return new CallbackArgumentsBuildItem(sorted); + } + + @BuildStep + void additionalBeans(CombinedIndexBuildItem combinedIndex, BuildProducer additionalBeans) { + IndexView index = combinedIndex.getIndex(); + + // Always register the removable beans + AdditionalBeanBuildItem removable = AdditionalBeanBuildItem.builder() + .setRemovable() + .addBeanClasses(WebSocketConnectorImpl.class, JsonTextMessageCodec.class) + .build(); + additionalBeans.produce(removable); + + AdditionalBeanBuildItem.Builder unremovable = AdditionalBeanBuildItem.builder() + .setUnremovable() + .addBeanClasses(Codecs.class, ClientConnectionManager.class, BasicWebSocketConnectorImpl.class); + if (!index.getAnnotations(WebSocketDotNames.WEB_SOCKET).isEmpty()) { + unremovable.addBeanClasses(ConnectionManager.class, WebSocketHttpServerOptionsCustomizer.class); + } + additionalBeans.produce(unremovable.build()); + } + + @BuildStep + ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) { + return new ContextConfiguratorBuildItem(phase.getContext() + .configure(SessionScoped.class) + .normal() + .contextClass(WebSocketSessionContext.class)); + } + + @BuildStep + CustomScopeBuildItem registerSessionScope() { + return new CustomScopeBuildItem(DotName.createSimple(SessionScoped.class.getName())); + } + + @BuildStep + void builtinCallbackArguments(BuildProducer providers) { + providers.produce(new CallbackArgumentBuildItem(new MessageCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new ConnectionCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new PathParamCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument())); + } + + @BuildStep + void collectGlobalErrorHandlers(BeanArchiveIndexBuildItem beanArchiveIndex, BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + BuildProducer globalErrorHandlers, CallbackArgumentsBuildItem callbackArguments, - TransformedAnnotationsBuildItem transformedAnnotations, - BuildProducer endpoints, - BuildProducer globalErrorHandlers) { + TransformedAnnotationsBuildItem transformedAnnotations) { IndexView index = beanArchiveIndex.getIndex(); @@ -146,92 +205,156 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, Map globalErrors = new HashMap<>(); for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) { ClassInfo beanClass = bean.getTarget().get().asClass(); - if (beanClass.annotation(WebSocketDotNames.WEB_SOCKET) == null) { - for (Callback callback : findErrorHandlers(index, beanClass, callbackArguments, transformedAnnotations, null)) { + if (beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET) == null + && beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET_CLIENT) == null) { + for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, beanClass, callbackArguments, + transformedAnnotations, null)) { GlobalErrorHandler errorHandler = new GlobalErrorHandler(bean, callback); DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); if (globalErrors.containsKey(errorTypeName)) { - throw new WebSocketServerException(String.format( + throw new WebSocketException(String.format( "Multiple global @OnError callbacks may not accept the same error parameter: %s\n\t- %s\n\t- %s", errorTypeName, - callbackToString(callback.method), - callbackToString(globalErrors.get(errorTypeName).callback.method))); + callback.asString(), + globalErrors.get(errorTypeName).callback.asString())); } globalErrors.put(errorTypeName, errorHandler); } } } globalErrorHandlers.produce(new GlobalErrorHandlersBuildItem(List.copyOf(globalErrors.values()))); + } + + @BuildStep + public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, + BeanDiscoveryFinishedBuildItem beanDiscoveryFinished, + CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, + BuildProducer endpoints) { + + IndexView index = beanArchiveIndex.getIndex(); // Collect WebSocket endpoints - Map idToEndpoint = new HashMap<>(); - Map pathToEndpoint = new HashMap<>(); + Map serverIdToEndpoint = new HashMap<>(); + Map serverPathToEndpoint = new HashMap<>(); + Map clientIdToEndpoint = new HashMap<>(); + Map clientPathToEndpoint = new HashMap<>(); + for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) { ClassInfo beanClass = bean.getTarget().get().asClass(); AnnotationInstance webSocketAnnotation = beanClass.annotation(WebSocketDotNames.WEB_SOCKET); + AnnotationInstance webSocketClientAnnotation = beanClass.annotation(WebSocketDotNames.WEB_SOCKET_CLIENT); + + if (webSocketAnnotation == null && webSocketClientAnnotation == null) { + continue; + } else if (webSocketAnnotation != null && webSocketClientAnnotation != null) { + throw new WebSocketException( + "Endpoint class may not be annotated with both @WebSocket and @WebSocketClient: " + beanClass); + } + String path; + String id; + AnnotationValue inboundProcessingMode; + Target target; + if (webSocketAnnotation != null) { - String path = getPath(webSocketAnnotation.value("path").asString()); + target = Target.SERVER; + path = getPath(webSocketAnnotation.value("path").asString()); if (beanClass.nestingType() == NestingType.INNER) { // Sub-websocket - merge the path from the enclosing classes path = mergePath(getPathPrefix(index, beanClass.enclosingClass()), path); } - DotName prevPath = pathToEndpoint.put(path, beanClass.name()); + DotName prevPath = serverPathToEndpoint.put(path, beanClass.name()); if (prevPath != null) { throw new WebSocketServerException( String.format("Multiple endpoints [%s, %s] define the same path: %s", prevPath, beanClass, path)); } - String endpointId; AnnotationValue endpointIdValue = webSocketAnnotation.value("endpointId"); if (endpointIdValue == null) { - endpointId = beanClass.name().toString(); + id = beanClass.name().toString(); } else { - endpointId = endpointIdValue.asString(); + id = endpointIdValue.asString(); } - DotName prevId = idToEndpoint.put(endpointId, beanClass.name()); + DotName prevId = serverIdToEndpoint.put(id, beanClass.name()); if (prevId != null) { throw new WebSocketServerException( String.format("Multiple endpoints [%s, %s] define the same endpoint id: %s", prevId, beanClass, - endpointId)); + id)); + } + inboundProcessingMode = webSocketAnnotation.value("inboundProcessingMode"); + } else { + target = Target.CLIENT; + path = getPath(webSocketClientAnnotation.value("path").asString()); + DotName prevPath = clientPathToEndpoint.put(path, beanClass.name()); + if (prevPath != null) { + throw new WebSocketServerException( + String.format("Multiple client endpoints [%s, %s] define the same path: %s", prevPath, beanClass, + path)); } - Callback onOpen = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, - callbackArguments, transformedAnnotations, path); - Callback onTextMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_TEXT_MESSAGE, - callbackArguments, transformedAnnotations, path); - Callback onBinaryMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path); - Callback onPongMessage = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_PONG_MESSAGE, - callbackArguments, transformedAnnotations, path, - this::validateOnPongMessage); - Callback onClose = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, - callbackArguments, transformedAnnotations, path, - this::validateOnClose); - if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { + AnnotationValue clientIdValue = webSocketClientAnnotation.value("clientId"); + if (clientIdValue == null) { + id = beanClass.name().toString(); + } else { + id = clientIdValue.asString(); + } + DotName prevId = clientIdToEndpoint.put(id, beanClass.name()); + if (prevId != null) { throw new WebSocketServerException( - "The endpoint must declare at least one method annotated with @OnTextMessage, @OnBinaryMessage, @OnPongMessage or @OnOpen: " - + beanClass); + String.format("Multiple client endpoints [%s, %s] define the same endpoint id: %s", prevId, + beanClass, + id)); } - AnnotationValue executionMode = webSocketAnnotation.value("executionMode"); - endpoints.produce(new WebSocketEndpointBuildItem(bean, path, endpointId, - executionMode != null ? WebSocket.ExecutionMode.valueOf(executionMode.asEnum()) - : WebSocket.ExecutionMode.SERIAL, - onOpen, - onTextMessage, - onBinaryMessage, - onPongMessage, - onClose, - findErrorHandlers(index, beanClass, callbackArguments, transformedAnnotations, path))); + inboundProcessingMode = webSocketClientAnnotation.value("inboundProcessingMode"); + } + + Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, + callbackArguments, transformedAnnotations, path); + Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, + WebSocketDotNames.ON_TEXT_MESSAGE, + callbackArguments, transformedAnnotations, path); + Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, + WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path); + Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, + WebSocketDotNames.ON_PONG_MESSAGE, + callbackArguments, transformedAnnotations, path, + this::validateOnPongMessage); + Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, + callbackArguments, transformedAnnotations, path, + this::validateOnClose); + if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { + throw new WebSocketServerException( + "The endpoint must declare at least one method annotated with @OnTextMessage, @OnBinaryMessage, @OnPongMessage or @OnOpen: " + + beanClass); } + endpoints.produce(new WebSocketEndpointBuildItem(target == Target.CLIENT, bean, path, id, + inboundProcessingMode != null ? InboundProcessingMode.valueOf(inboundProcessingMode.asEnum()) + : InboundProcessingMode.SERIAL, + onOpen, + onTextMessage, + onBinaryMessage, + onPongMessage, + onClose, + findErrorHandlers(target, index, beanClass, callbackArguments, transformedAnnotations, path))); } } @BuildStep - CallbackArgumentsBuildItem collectCallbackArguments(List callbackArguments) { - List sorted = new ArrayList<>(); - for (CallbackArgumentBuildItem callbackArgument : callbackArguments) { - sorted.add(callbackArgument.getProvider()); + public void validateConnectorInjectionPoints(List endpoints, + ValidationPhaseBuildItem validationPhase, BuildProducer validationErrors) { + for (InjectionPointInfo injectionPoint : validationPhase.getContext().getInjectionPoints()) { + if (injectionPoint.getRequiredType().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTOR) + && injectionPoint.hasDefaultedQualifier()) { + Type clientEndpointType = injectionPoint.getRequiredType().asParameterizedType().arguments().get(0); + if (endpoints.stream() + .filter(WebSocketEndpointBuildItem::isClient) + .map(WebSocketEndpointBuildItem::beanClassName) + .noneMatch(clientEndpointType.name()::equals)) { + validationErrors.produce( + new ValidationErrorBuildItem(new WebSocketClientException(String.format( + "Type argument [%s] of the injected WebSocketConnector is not a @WebSocketClient endpoint: %s", + clientEndpointType, injectionPoint.getTargetInfo())))); + } + } } - sorted.sort(Comparator.comparingInt(CallbackArgument::priotity).reversed()); - return new CallbackArgumentsBuildItem(sorted); } @BuildStep @@ -245,7 +368,10 @@ public void generateEndpoints(BeanArchiveIndexBuildItem index, List() { @Override public String apply(String name) { - int idx = name.indexOf(ENDPOINT_SUFFIX); + int idx = name.indexOf(CLIENT_ENDPOINT_SUFFIX); + if (idx == -1) { + idx = name.indexOf(SERVER_ENDPOINT_SUFFIX); + } if (idx != -1) { name = name.substring(0, idx); } @@ -256,16 +382,17 @@ public String apply(String name) { } }); for (WebSocketEndpointBuildItem endpoint : endpoints) { - // For each WebSocket endpoint bean generate an implementation of WebSocketEndpoint + // For each WebSocket endpoint bean we generate an implementation of WebSocketEndpoint // A new instance of this generated endpoint is created for each client connection // The generated endpoint ensures the correct execution model is used // and delegates callback invocations to the endpoint bean String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, - index.getIndex(), classOutput, globalErrorHandlers); + index.getIndex(), classOutput, globalErrorHandlers, + endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints - .produce(new GeneratedEndpointBuildItem(endpoint.endpointId, endpoint.bean.getImplClazz().name().toString(), - generatedName, endpoint.path)); + .produce(new GeneratedEndpointBuildItem(endpoint.id, endpoint.bean.getImplClazz().name().toString(), + generatedName, endpoint.path, endpoint.isClient)); } } @@ -274,8 +401,8 @@ public String apply(String name) { public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildItem httpRootPath, List generatedEndpoints, BuildProducer routes) { - - for (GeneratedEndpointBuildItem endpoint : generatedEndpoints) { + for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer) + .toList()) { RouteBuildItem.Builder builder = RouteBuildItem.builder() .route(httpRootPath.relativePath(endpoint.path)) .displayOnNotFoundPage("WebSocket Endpoint") @@ -285,17 +412,15 @@ public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildIt } } - @BuildStep - AdditionalBeanBuildItem additionalBeans() { - return AdditionalBeanBuildItem.builder().setUnremovable() - .addBeanClasses(Codecs.class, JsonTextMessageCodec.class, ConnectionManager.class, - WebSocketHttpServerOptionsCustomizer.class) - .build(); - } - @BuildStep @Record(RUNTIME_INIT) - void syntheticBeans(WebSocketServerRecorder recorder, BuildProducer syntheticBeans) { + void serverSyntheticBeans(WebSocketServerRecorder recorder, List generatedEndpoints, + BuildProducer syntheticBeans) { + List serverEndpoints = generatedEndpoints.stream() + .filter(GeneratedEndpointBuildItem::isServer).toList(); + if (serverEndpoints.isEmpty()) { + return; + } syntheticBeans.produce(SyntheticBeanBuildItem.configure(WebSocketConnection.class) .scope(SessionScoped.class) .setRuntimeInit() @@ -305,25 +430,29 @@ void syntheticBeans(WebSocketServerRecorder recorder, BuildProducer providers) { - providers.produce(new CallbackArgumentBuildItem(new MessageCallbackArgument())); - providers.produce(new CallbackArgumentBuildItem(new ConnectionCallbackArgument())); - providers.produce(new CallbackArgumentBuildItem(new PathParamCallbackArgument())); - providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); - providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument())); + @Record(RUNTIME_INIT) + void clientSyntheticBeans(WebSocketClientRecorder recorder, List generatedEndpoints, + BuildProducer syntheticBeans) { + List clientEndpoints = generatedEndpoints.stream() + .filter(GeneratedEndpointBuildItem::isClient).toList(); + if (!clientEndpoints.isEmpty()) { + syntheticBeans.produce(SyntheticBeanBuildItem.configure(WebSocketClientConnection.class) + .scope(SessionScoped.class) + .setRuntimeInit() + .supplier(recorder.connectionSupplier()) + .unremovable() + .done()); + } + // ClientEndpointsContext is always registered but is removable + Map endpointMap = new HashMap<>(); + for (GeneratedEndpointBuildItem generatedEndpoint : clientEndpoints) { + endpointMap.put(generatedEndpoint.endpointClassName, new ClientEndpoint(generatedEndpoint.endpointId, + getOriginalPath(generatedEndpoint.path), generatedEndpoint.generatedClassName)); + } + syntheticBeans.produce(SyntheticBeanBuildItem.configure(WebSocketClientRecorder.ClientEndpointsContext.class) + .setRuntimeInit() + .supplier(recorder.createContext(endpointMap)) + .done()); } static String mergePath(String prefix, String path) { @@ -356,11 +485,19 @@ static String getPath(String path) { return path.startsWith("/") ? sb.toString() : "/" + sb.toString(); } - static String callbackToString(MethodInfo callback) { - return callback.declaringClass().name() + "#" + callback.name() + "()"; + public static String getOriginalPath(String path) { + StringBuilder sb = new StringBuilder(); + Matcher m = TRANSLATED_PATH_PARAM_PATTERN.matcher(path); + while (m.find()) { + // Replace :foo with {foo} + String match = m.group(); + m.appendReplacement(sb, "{" + match.subSequence(1, match.length()) + "}"); + } + m.appendTail(sb); + return sb.toString(); } - private String getPathPrefix(IndexView index, DotName enclosingClassName) { + static String getPathPrefix(IndexView index, DotName enclosingClassName) { ClassInfo enclosingClass = index.getClassByName(enclosingClassName); if (enclosingClass == null) { throw new WebSocketServerException("Enclosing class not found in index: " + enclosingClass); @@ -378,22 +515,22 @@ private String getPathPrefix(IndexView index, DotName enclosingClassName) { } private void validateOnPongMessage(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketServerProcessor.isUniVoid(callback.returnType())) { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { throw new WebSocketServerException( - "@OnPongMessage callback must return void or Uni: " + callbackToString(callback.method)); + "@OnPongMessage callback must return void or Uni: " + callback.asString()); } Type messageType = callback.argumentType(MessageCallbackArgument::isMessage); if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) { throw new WebSocketServerException( "@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: " - + callbackToString(callback.method)); + + callback.asString()); } } private void validateOnClose(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketServerProcessor.isUniVoid(callback.returnType())) { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { throw new WebSocketServerException( - "@OnClose callback must return void or Uni: " + callbackToString(callback.method)); + "@OnClose callback must return void or Uni: " + callback.asString()); } } @@ -456,12 +593,13 @@ private void validateOnClose(Callback callback) { * @param classOutput * @return the name of the generated class */ - private String generateEndpoint(WebSocketEndpointBuildItem endpoint, + static String generateEndpoint(WebSocketEndpointBuildItem endpoint, CallbackArgumentsBuildItem argumentProviders, TransformedAnnotationsBuildItem transformedAnnotations, IndexView index, ClassOutput classOutput, - GlobalErrorHandlersBuildItem globalErrorHandlers) { + GlobalErrorHandlersBuildItem globalErrorHandlers, + String endpointSuffix) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; if (implClazz.enclosingClass() != null) { @@ -471,23 +609,24 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, baseName = DotNames.simpleName(implClazz.name()); } String generatedName = DotNames.internalPackageNameWithTrailingSlash(implClazz.name()) + baseName - + ENDPOINT_SUFFIX; + + endpointSuffix; ClassCreator endpointCreator = ClassCreator.builder().classOutput(classOutput).className(generatedName) .superClass(WebSocketEndpointBase.class) .build(); - MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnection.class, - Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class); + MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnectionBase.class, + Codecs.class, ContextSupport.class); constructor.invokeSpecialMethod( - MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnection.class, - Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class), + MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnectionBase.class, + Codecs.class, ContextSupport.class), constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), - constructor.getMethodParam(2), constructor.getMethodParam(3)); + constructor.getMethodParam(2)); constructor.returnNull(); - MethodCreator executionMode = endpointCreator.getMethodCreator("executionMode", WebSocket.ExecutionMode.class); - executionMode.returnValue(executionMode.load(endpoint.executionMode)); + MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode", + InboundProcessingMode.class); + inboundProcessingMode.returnValue(inboundProcessingMode.load(endpoint.inboundProcessingMode)); MethodCreator beanIdentifier = endpointCreator.getMethodCreator("beanIdentifier", String.class); beanIdentifier.returnValue(beanIdentifier.load(endpoint.bean.getIdentifier())); @@ -539,7 +678,7 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint, return generatedName.replace('/', '.'); } - private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, + private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index) { @@ -548,15 +687,16 @@ private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuil for (Callback callback : endpoint.onErrors) { DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); if (errors.containsKey(errorTypeName)) { - throw new WebSocketServerException(String.format( + throw new WebSocketException(String.format( "Multiple @OnError callbacks may not accept the same error parameter: %s\n\t- %s\n\t- %s", errorTypeName, - callbackToString(callback.method), callbackToString(errors.get(errorTypeName).method))); + callback.asString(), errors.get(errorTypeName).asString())); } errors.put(errorTypeName, callback); throwableInfos.add(new ThrowableInfo(endpoint.bean, callback, throwableHierarchy(errorTypeName, index))); } - for (GlobalErrorHandler globalErrorHandler : globalErrorHandlers.handlers) { + for (GlobalErrorHandler globalErrorHandler : endpoint.isClient ? globalErrorHandlers.forClient() + : globalErrorHandlers.forServer()) { Callback callback = globalErrorHandler.callback; DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); if (!errors.containsKey(errorTypeName)) { @@ -609,14 +749,14 @@ private void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuil doOnError.getMethodParam(0))); } - private List throwableHierarchy(DotName throwableName, IndexView index) { + private static List throwableHierarchy(DotName throwableName, IndexView index) { // TextDecodeException -> [TextDecodeException, WebSocketServerException, RuntimeException, Exception, Throwable] List ret = new ArrayList<>(); addToThrowableHierarchy(throwableName, index, ret); return ret; } - private void addToThrowableHierarchy(DotName throwableName, IndexView index, List hierarchy) { + private static void addToThrowableHierarchy(DotName throwableName, IndexView index, List hierarchy) { hierarchy.add(throwableName); ClassInfo errorClass = index.getClassByName(throwableName); if (errorClass == null) { @@ -640,7 +780,7 @@ record GlobalErrorHandler(BeanInfo bean, Callback callback) { } - private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, + private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers) { if (callback == null) { @@ -695,7 +835,7 @@ private void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBu } } - private TryBlock uniFailureTryBlock(BytecodeCreator method) { + private static TryBlock uniFailureTryBlock(BytecodeCreator method) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); // return Uni.createFrom().failure(t); @@ -707,7 +847,7 @@ private TryBlock uniFailureTryBlock(BytecodeCreator method) { return tryBlock; } - private TryBlock onErrorTryBlock(BytecodeCreator method, ResultHandle endpointThis) { + private static TryBlock onErrorTryBlock(BytecodeCreator method, ResultHandle endpointThis) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); // return doOnError(t); @@ -728,7 +868,7 @@ static ResultHandle decodeMessage( // Binary message if (WebSocketDotNames.BUFFER.equals(valueType.name())) { return value; - } else if (WebSocketServerProcessor.isByteArray(valueType)) { + } else if (WebSocketProcessor.isByteArray(valueType)) { // byte[] message = buffer.getBytes(); return method.invokeInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "getBytes", byte[].class), value); @@ -771,7 +911,7 @@ static ResultHandle decodeMessage( // Buffer message = Buffer.buffer(string); return method.invokeStaticInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, String.class), value); - } else if (WebSocketServerProcessor.isByteArray(valueType)) { + } else if (WebSocketProcessor.isByteArray(valueType)) { // byte[] message = Buffer.buffer(string).getBytes(); ResultHandle buffer = method.invokeStaticInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); @@ -789,7 +929,7 @@ static ResultHandle decodeMessage( } } - private ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, BytecodeCreator method, Callback callback, + private static ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, BytecodeCreator method, Callback callback, ResultHandle uni, WebSocketEndpointBuildItem endpoint, GlobalErrorHandlersBuildItem globalErrorHandlers) { if (callback.isOnError() || (globalErrorHandlers.handlers.isEmpty() && (endpoint == null || endpoint.onErrors.isEmpty()))) { @@ -811,7 +951,7 @@ private ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, BytecodeCr uniOnFailure, fun.getInstance()); } - private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator method, Callback callback, + private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator method, Callback callback, GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint, ResultHandle value) { if (callback.acceptsBinaryMessage()) { @@ -938,12 +1078,12 @@ private ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator me } } - private ResultHandle encodeBuffer(BytecodeCreator method, Type messageType, ResultHandle value, + private static ResultHandle encodeBuffer(BytecodeCreator method, Type messageType, ResultHandle value, ResultHandle endpointThis, Callback callback) { ResultHandle buffer; if (messageType.name().equals(WebSocketDotNames.BUFFER)) { buffer = value; - } else if (WebSocketServerProcessor.isByteArray(messageType)) { + } else if (WebSocketProcessor.isByteArray(messageType)) { buffer = method.invokeStaticInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); } else if (messageType.name().equals(WebSocketDotNames.STRING)) { @@ -965,13 +1105,13 @@ private ResultHandle encodeBuffer(BytecodeCreator method, Type messageType, Resu return buffer; } - private ResultHandle encodeText(BytecodeCreator method, Type messageType, ResultHandle value, + private static ResultHandle encodeText(BytecodeCreator method, Type messageType, ResultHandle value, ResultHandle endpointThis, Callback callback) { ResultHandle text; if (messageType.name().equals(WebSocketDotNames.BUFFER)) { text = method.invokeInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "toString", String.class), value); - } else if (WebSocketServerProcessor.isByteArray(messageType)) { + } else if (WebSocketProcessor.isByteArray(messageType)) { ResultHandle buffer = method.invokeStaticInterfaceMethod( MethodDescriptor.ofMethod(Buffer.class, "buffer", Buffer.class, byte[].class), value); text = method.invokeInterfaceMethod( @@ -994,13 +1134,13 @@ private ResultHandle encodeText(BytecodeCreator method, Type messageType, Result return text; } - private ResultHandle uniVoid(BytecodeCreator method) { + private static ResultHandle uniVoid(BytecodeCreator method) { ResultHandle uniCreate = method .invokeStaticInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "createFrom", UniCreate.class)); return method.invokeVirtualMethod(MethodDescriptor.ofMethod(UniCreate.class, "voidItem", Uni.class), uniCreate); } - private void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCreator method, Callback callback, + private static void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCreator method, Callback callback, GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint, ResultHandle result) { // The result must be always Uni @@ -1015,7 +1155,8 @@ private void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCreator me } } - private List findErrorHandlers(IndexView index, ClassInfo beanClass, CallbackArgumentsBuildItem callbackArguments, + static List findErrorHandlers(Target target, IndexView index, ClassInfo beanClass, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { List annotations = findCallbackAnnotations(index, beanClass, WebSocketDotNames.ON_ERROR); @@ -1025,22 +1166,29 @@ private List findErrorHandlers(IndexView index, ClassInfo beanClass, C List errorHandlers = new ArrayList<>(); for (AnnotationInstance annotation : annotations) { MethodInfo method = annotation.target().asMethod(); - Callback callback = new Callback(annotation, method, executionModel(method, transformedAnnotations), - callbackArguments, - transformedAnnotations, endpointPath, index); + // If a global error handler accepts a connection param we change the target appropriately + if (method.parameterTypes().stream().map(Type::name).anyMatch(WebSocketDotNames.WEB_SOCKET_CONNECTION::equals)) { + target = Target.SERVER; + } else if (method.parameterTypes().stream().map(Type::name) + .anyMatch(WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION::equals)) { + target = Target.CLIENT; + } + Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), + callbackArguments, transformedAnnotations, endpointPath, index); long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count(); if (errorArguments != 1) { - throw new WebSocketServerException( + throw new WebSocketException( String.format("@OnError callback must accept exactly one error parameter; found %s: %s", errorArguments, - callbackToString(callback.method))); + callback.asString())); } errorHandlers.add(callback); } return errorHandlers; } - private List findCallbackAnnotations(IndexView index, ClassInfo beanClass, DotName annotationName) { + private static List findCallbackAnnotations(IndexView index, ClassInfo beanClass, + DotName annotationName) { ClassInfo aClass = beanClass; List annotations = new ArrayList<>(); while (aClass != null) { @@ -1056,13 +1204,14 @@ private List findCallbackAnnotations(IndexView index, ClassI return annotations; } - private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { - return findCallback(index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, null); + return findCallback(target, index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, + null); } - private Callback findCallback(IndexView index, ClassInfo beanClass, DotName annotationName, + private static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, Consumer validator) { @@ -1072,37 +1221,43 @@ private Callback findCallback(IndexView index, ClassInfo beanClass, DotName anno } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); - Callback callback = new Callback(annotation, method, executionModel(method, transformedAnnotations), + Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations, endpointPath, index); long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count(); if (callback.acceptsMessage()) { if (messageArguments > 1) { - throw new WebSocketServerException( + throw new WebSocketException( String.format("@%s callback may accept at most 1 message parameter; found %s: %s", DotNames.simpleName(callback.annotation.name()), messageArguments, - callbackToString(callback.method))); + callback.asString())); } } else { if (messageArguments != 0) { - throw new WebSocketServerException( + throw new WebSocketException( String.format("@%s callback must not accept a message parameter; found %s: %s", DotNames.simpleName(callback.annotation.name()), messageArguments, - callbackToString(callback.method))); + callback.asString())); } } + if (target == Target.CLIENT && callback.broadcast()) { + throw new WebSocketClientException( + String.format("@%s callback declared on a client endpoint must not broadcast messages: %s", + DotNames.simpleName(callback.annotation.name()), + callback.asString())); + } if (validator != null) { validator.accept(callback); } return callback; } - throw new WebSocketServerException( + throw new WebSocketException( String.format("There can be only one callback annotated with %s declared on %s", annotationName, beanClass)); } - ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) { + private static ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) { if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD)) { return ExecutionModel.VIRTUAL_THREAD; } else if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING)) { @@ -1114,7 +1269,7 @@ ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem } } - boolean hasBlockingSignature(MethodInfo method) { + static boolean hasBlockingSignature(MethodInfo method) { switch (method.returnType().kind()) { case VOID: case CLASS: @@ -1124,7 +1279,8 @@ boolean hasBlockingSignature(MethodInfo method) { DotName name = method.returnType().asParameterizedType().name(); return !name.equals(WebSocketDotNames.UNI) && !name.equals(WebSocketDotNames.MULTI); default: - throw new WebSocketServerException("Unsupported return type:" + callbackToString(method)); + throw new WebSocketServerException( + "Unsupported return type:" + methodToString(method)); } } @@ -1137,4 +1293,7 @@ static boolean isByteArray(Type type) { return type.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(type.asArrayType().constituent()); } + static String methodToString(MethodInfo method) { + return method.declaringClass().name() + "#" + method.name() + "()"; + } } diff --git a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java index ed2db7823cce3..5f89d5cfd0482 100644 --- a/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java +++ b/extensions/websockets-next/server/deployment/src/main/java/io/quarkus/websockets/next/deployment/devui/WebSocketServerDevUIProcessor.java @@ -6,7 +6,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.regex.Matcher; import java.util.stream.Collectors; import org.jboss.jandex.MethodInfo; @@ -21,10 +20,10 @@ import io.quarkus.devui.spi.JsonRPCProvidersBuildItem; import io.quarkus.devui.spi.page.CardPageBuildItem; import io.quarkus.devui.spi.page.Page; +import io.quarkus.websockets.next.deployment.Callback; import io.quarkus.websockets.next.deployment.GeneratedEndpointBuildItem; import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem; -import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback; -import io.quarkus.websockets.next.deployment.WebSocketServerProcessor; +import io.quarkus.websockets.next.deployment.WebSocketProcessor; import io.quarkus.websockets.next.runtime.devui.WebSocketNextJsonRPCService; public class WebSocketServerDevUIProcessor { @@ -38,7 +37,7 @@ public void pages(List endpoints, List> createEndpointsJson(List endpoints, List generatedEndpoints) { List> json = new ArrayList<>(); - for (WebSocketEndpointBuildItem endpoint : endpoints.stream().sorted(Comparator.comparing(e -> e.path)) + for (WebSocketEndpointBuildItem endpoint : endpoints.stream().filter(WebSocketEndpointBuildItem::isServer) + .sorted(Comparator.comparing(e -> e.path)) .collect(Collectors.toList())) { Map endpointJson = new HashMap<>(); String clazz = endpoint.bean.getImplClazz().name().toString(); @@ -62,8 +62,8 @@ private List> createEndpointsJson(List ge.endpointClassName.equals(clazz)).findFirst() .orElseThrow().generatedClassName); - endpointJson.put("path", getOriginalPath(endpoint.path)); - endpointJson.put("executionMode", endpoint.executionMode.toString()); + endpointJson.put("path", WebSocketProcessor.getOriginalPath(endpoint.path)); + endpointJson.put("executionMode", endpoint.inboundProcessingMode.toString()); List> callbacks = new ArrayList<>(); addCallback(endpoint.onOpen, callbacks); addCallback(endpoint.onBinaryMessage, callbacks); @@ -79,7 +79,7 @@ private List> createEndpointsJson(List> callbacks) { + private void addCallback(Callback callback, List> callbacks) { if (callback != null) { callbacks.add(Map.of("annotation", callback.annotation.toString(), "method", methodToString(callback.method))); } @@ -131,16 +131,4 @@ private String typeToString(Type type) { } } - static String getOriginalPath(String path) { - StringBuilder sb = new StringBuilder(); - Matcher m = WebSocketServerProcessor.TRANSLATED_PATH_PARAM_PATTERN.matcher(path); - while (m.find()) { - // Replace :foo with {foo} - String match = m.group(); - m.appendReplacement(sb, "{" + match.subSequence(1, match.length()) + "}"); - } - m.appendTail(sb); - return sb.toString(); - } - } diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketProcessorTest.java similarity index 54% rename from extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java rename to extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketProcessorTest.java index 44e37dc73011d..843312a7a6a2f 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketServerProcessorTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/deployment/WebSocketProcessorTest.java @@ -7,26 +7,26 @@ import io.quarkus.websockets.next.WebSocketServerException; -public class WebSocketServerProcessorTest { +public class WebSocketProcessorTest { @Test public void testGetPath() { - assertEquals("/foo/:id", WebSocketServerProcessor.getPath("/foo/{id}")); - assertEquals("/foo/:id/bar/:id2", WebSocketServerProcessor.getPath("/foo/{id}/bar/{id2}")); - assertEquals("/foo/:bar-:baz", WebSocketServerProcessor.getPath("/foo/{bar}-{baz}")); - assertEquals("/ws/v:version", WebSocketServerProcessor.getPath("/ws/v{version}")); + assertEquals("/foo/:id", WebSocketProcessor.getPath("/foo/{id}")); + assertEquals("/foo/:id/bar/:id2", WebSocketProcessor.getPath("/foo/{id}/bar/{id2}")); + assertEquals("/foo/:bar-:baz", WebSocketProcessor.getPath("/foo/{bar}-{baz}")); + assertEquals("/ws/v:version", WebSocketProcessor.getPath("/ws/v{version}")); WebSocketServerException e = assertThrows(WebSocketServerException.class, - () -> WebSocketServerProcessor.getPath("/foo/v{bar}/{baz}and{alpha_1}-{name}")); + () -> WebSocketProcessor.getPath("/foo/v{bar}/{baz}and{alpha_1}-{name}")); assertEquals( "Path parameter {baz} may not be followed by an alphanumeric character or underscore: /foo/v{bar}/{baz}and{alpha_1}-{name}", e.getMessage()); e = assertThrows(WebSocketServerException.class, - () -> WebSocketServerProcessor.getPath("/foo/v{bar}/{baz}_{alpha_1}-{name}")); + () -> WebSocketProcessor.getPath("/foo/v{bar}/{baz}_{alpha_1}-{name}")); assertEquals( "Path parameter {baz} may not be followed by an alphanumeric character or underscore: /foo/v{bar}/{baz}_{alpha_1}-{name}", e.getMessage()); e = assertThrows(WebSocketServerException.class, - () -> WebSocketServerProcessor.getPath("/foo/v{bar}/{baz}1-{name}")); + () -> WebSocketProcessor.getPath("/foo/v{bar}/{baz}1-{name}")); assertEquals( "Path parameter {baz} may not be followed by an alphanumeric character or underscore: /foo/v{bar}/{baz}1-{name}", e.getMessage()); @@ -34,9 +34,9 @@ public void testGetPath() { @Test public void testMergePath() { - assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo/", "/bar")); - assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo", "/bar")); - assertEquals("foo/bar", WebSocketServerProcessor.mergePath("foo/", "bar")); + assertEquals("foo/bar", WebSocketProcessor.mergePath("foo/", "/bar")); + assertEquals("foo/bar", WebSocketProcessor.mergePath("foo", "/bar")); + assertEquals("foo/bar", WebSocketProcessor.mergePath("foo/", "bar")); } } diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java index 5bcb9ac19f0ba..1a56b592c8fb1 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/HandshakeRequestArgumentTest.java @@ -11,9 +11,9 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.HandshakeRequest; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.quarkus.websockets.next.test.utils.WSClient; import io.vertx.core.Vertx; import io.vertx.core.http.WebSocketConnectOptions; diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java index c5a934eeb6cdf..967414a56e9b1 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnCloseInvalidArgumentTest.java @@ -10,7 +10,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnClose; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class OnCloseInvalidArgumentTest { @@ -19,7 +19,7 @@ public class OnCloseInvalidArgumentTest { .withApplicationRoot(root -> { root.addClasses(Endpoint.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void testInvalidArgument() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java index 5f3b9071cf546..68426d2132f06 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/OnOpenInvalidArgumentTest.java @@ -10,7 +10,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class OnOpenInvalidArgumentTest { @@ -19,7 +19,7 @@ public class OnOpenInvalidArgumentTest { .withApplicationRoot(root -> { root.addClasses(Endpoint.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void testInvalidArgument() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java index f23f0343cdf23..ea003d0ba40cc 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidNameTest.java @@ -9,7 +9,7 @@ import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.PathParam; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class PathParamArgumentInvalidNameTest { @@ -17,7 +17,7 @@ public class PathParamArgumentInvalidNameTest { public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { root.addClasses(MontyEcho.class); - }).setExpectedException(WebSocketServerException.class); + }).setExpectedException(WebSocketException.class); @Test void testInvalidArgument() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java index 31097c8bf7180..9925539538dea 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/args/PathParamArgumentInvalidTypeTest.java @@ -9,7 +9,7 @@ import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.PathParam; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class PathParamArgumentInvalidTypeTest { @@ -17,7 +17,7 @@ public class PathParamArgumentInvalidTypeTest { public static final QuarkusUnitTest test = new QuarkusUnitTest() .withApplicationRoot(root -> { root.addClasses(MontyEcho.class); - }).setExpectedException(WebSocketServerException.class); + }).setExpectedException(WebSocketException.class); @Test void testInvalidArgument() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorTest.java new file mode 100644 index 0000000000000..fc22881e6ced4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BasicConnectorTest.java @@ -0,0 +1,137 @@ +package io.quarkus.websockets.next.test.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BasicWebSocketConnector; +import io.quarkus.websockets.next.BasicWebSocketConnector.ExecutionModel; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.vertx.core.Context; + +public class BasicConnectorTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class); + }); + + @Inject + BasicWebSocketConnector connector; + + @TestHTTPResource("/end") + URI uri; + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final List MESSAGES = new CopyOnWriteArrayList<>(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @Test + void testClient() throws InterruptedException { + + assertThrows(NullPointerException.class, () -> connector.baseUri(null)); + assertThrows(NullPointerException.class, () -> connector.path(null)); + assertThrows(NullPointerException.class, () -> connector.addHeader(null, "foo")); + assertThrows(NullPointerException.class, () -> connector.addHeader("foo", null)); + assertThrows(NullPointerException.class, () -> connector.pathParam(null, "foo")); + assertThrows(NullPointerException.class, () -> connector.pathParam("foo", null)); + assertThrows(NullPointerException.class, () -> connector.addSubprotocol(null)); + assertThrows(NullPointerException.class, () -> connector.executionModel(null)); + assertThrows(NullPointerException.class, () -> connector.onBinaryMessage(null)); + assertThrows(NullPointerException.class, () -> connector.onTextMessage(null)); + assertThrows(NullPointerException.class, () -> connector.onOpen(null)); + assertThrows(NullPointerException.class, () -> connector.onClose(null)); + assertThrows(NullPointerException.class, () -> connector.onPong(null)); + assertThrows(NullPointerException.class, () -> connector.onError(null)); + + WebSocketClientConnection connection1 = connector + .baseUri(uri) + .path("/{name}") + .pathParam("name", "Lu") + .onTextMessage((c, m) -> { + assertTrue(Context.isOnWorkerThread()); + String name = c.pathParam("name"); + MESSAGE_LATCH.countDown(); + MESSAGES.add(name + ":" + m); + }) + .onClose((c, s) -> CLOSED_LATCH.countDown()) + .connectAndAwait(); + assertEquals("Lu", connection1.pathParam("name")); + connection1.sendTextAndAwait("Hi!"); + + assertTrue(MESSAGE_LATCH.await(5, TimeUnit.SECONDS)); + // Note that ordering is not guaranteed + assertThat(MESSAGES.get(0)).isIn("Lu:Hello Lu!", "Lu:Hi!"); + assertThat(MESSAGES.get(1)).isIn("Lu:Hello Lu!", "Lu:Hi!"); + + connection1.closeAndAwait(); + assertTrue(CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + + CountDownLatch CONN2_LATCH = new CountDownLatch(1); + WebSocketClientConnection connection2 = BasicWebSocketConnector + .create() + .baseUri(uri) + .path("/Cool") + .executionModel(ExecutionModel.NON_BLOCKING) + .addHeader("X-Test", "foo") + .onTextMessage((c, m) -> { + assertTrue(Context.isOnEventLoopThread()); + // Path params not set + assertNull(c.pathParam("name")); + assertTrue(c.handshakeRequest().path().endsWith("Cool")); + assertEquals("foo", c.handshakeRequest().header("X-Test")); + CONN2_LATCH.countDown(); + }) + .connectAndAwait(); + assertNotNull(connection2); + assertTrue(CONN2_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/end/{name}") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open(@PathParam String name) { + return "Hello " + name + "!"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BroadCastOnClientTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BroadCastOnClientTest.java new file mode 100644 index 0000000000000..a3e33c2ea3eaa --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/BroadCastOnClientTest.java @@ -0,0 +1,47 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientException; + +public class BroadCastOnClientTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }) + .setExpectedException(WebSocketClientException.class, true); + + @Test + void testInvalidBroadcast() { + fail(); + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + @OnOpen + void open() { + } + + } + + @WebSocketClient(path = "/end") + public static class ClientEndpoint { + + @OnTextMessage(broadcast = true) + String echo(String message) { + return message; + } + + } +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientAutoPingIntervalTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientAutoPingIntervalTest.java new file mode 100644 index 0000000000000..52b4734273670 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientAutoPingIntervalTest.java @@ -0,0 +1,66 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnPongMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketConnector; +import io.vertx.core.buffer.Buffer; + +public class ClientAutoPingIntervalTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }).overrideConfigKey("quarkus.websockets-next.client.auto-ping-interval", "200ms"); + + @TestHTTPResource("/") + URI uri; + + @Inject + WebSocketConnector connector; + + @Test + public void testPingPong() throws InterruptedException, ExecutionException { + connector.baseUri(uri).connectAndAwait(); + // Ping messages are sent automatically + assertTrue(ClientEndpoint.PONG.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + @OnOpen + void open() { + } + + } + + @WebSocketClient(path = "/end") + public static class ClientEndpoint { + + static final CountDownLatch PONG = new CountDownLatch(3); + + @OnPongMessage + void pong(Buffer data) { + PONG.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientEndpointTest.java new file mode 100644 index 0000000000000..37b0b71121d95 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientEndpointTest.java @@ -0,0 +1,107 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketConnector; + +public class ClientEndpointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }); + + @Inject + WebSocketConnector connector; + + @TestHTTPResource("/") + URI uri; + + @Test + void testClient() throws InterruptedException { + WebSocketClientConnection connection = connector + .baseUri(uri) + .pathParam("name", "Lu") + .connectAndAwait(); + assertEquals("Lu", connection.pathParam("name")); + connection.sendTextAndAwait("Hi!"); + + assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS)); + assertEquals("Lu:Hello Lu!", ClientEndpoint.MESSAGES.get(0)); + assertEquals("Lu:Hi!", ClientEndpoint.MESSAGES.get(1)); + + connection.closeAndAwait(); + assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/endpoint/{name}") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open(@PathParam String name) { + return "Hello " + name + "!"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + + @WebSocketClient(path = "/endpoint/{name}") + public static class ClientEndpoint { + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final List MESSAGES = new CopyOnWriteArrayList<>(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnTextMessage + void onMessage(@PathParam String name, String message, WebSocketClientConnection connection) { + if (!name.equals(connection.pathParam("name"))) { + throw new IllegalArgumentException(); + } + MESSAGE_LATCH.countDown(); + MESSAGES.add(name + ":" + message); + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientIdConfigBaseUriTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientIdConfigBaseUriTest.java new file mode 100644 index 0000000000000..3cae6ad74ada5 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/ClientIdConfigBaseUriTest.java @@ -0,0 +1,77 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketConnector; + +public class ClientIdConfigBaseUriTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }); + + @TestHTTPResource("/") + URI uri; + + @Inject + WebSocketConnector connector; + + @Test + public void testConfiguredBaseUri() throws InterruptedException, ExecutionException { + String key = "c1.base-uri"; + String prev = System.getProperty(key); + System.setProperty(key, uri.toString()); + try { + // No need to pass baseUri + connector.connectAndAwait(); + assertTrue(ServerEndpoint.OPEN_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ClientEndpoint.OPEN_LATCH.await(5, TimeUnit.SECONDS)); + } finally { + if (prev != null) { + System.setProperty(key, prev); + } + } + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + static final CountDownLatch OPEN_LATCH = new CountDownLatch(1); + + @OnOpen + void open() { + OPEN_LATCH.countDown(); + } + + } + + @WebSocketClient(path = "/end", clientId = "c1") + public static class ClientEndpoint { + + static final CountDownLatch OPEN_LATCH = new CountDownLatch(1); + + @OnOpen + void open() { + OPEN_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/InvalidConnectorInjectionPointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/InvalidConnectorInjectionPointTest.java new file mode 100644 index 0000000000000..d85ad6b0f9250 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/InvalidConnectorInjectionPointTest.java @@ -0,0 +1,39 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.fail; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Unremovable; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.websockets.next.WebSocketClientException; +import io.quarkus.websockets.next.WebSocketConnector; + +public class InvalidConnectorInjectionPointTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(Service.class); + }) + .setExpectedException(WebSocketClientException.class, true); + + @Test + void testInvalidInjectionPoint() { + fail(); + } + + @Unremovable + @Singleton + public static class Service { + + @Inject + WebSocketConnector invalid; + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/OpenClientConnectionsTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/OpenClientConnectionsTest.java new file mode 100644 index 0000000000000..7e9199f0e5dc4 --- /dev/null +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/client/OpenClientConnectionsTest.java @@ -0,0 +1,121 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.BasicWebSocketConnector; +import io.quarkus.websockets.next.HandshakeRequest; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OpenClientConnections; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketConnector; + +public class OpenClientConnectionsTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(ServerEndpoint.class, ClientEndpoint.class); + }); + + @Inject + OpenClientConnections connections; + + @Inject + WebSocketConnector connector; + + @Inject + BasicWebSocketConnector basicConnector; + + @TestHTTPResource("/") + URI uri; + + @Test + void testClient() throws InterruptedException { + for (WebSocketClientConnection c : connections) { + fail("No connection should be found: " + c); + } + + WebSocketClientConnection connection1 = connector + .baseUri(uri) + .addHeader("X-Test", "foo") + .connectAndAwait(); + + WebSocketClientConnection connection2 = connector + .baseUri(uri) + .addHeader("X-Test", "bar") + .connectAndAwait(); + + CountDownLatch CONN3_OPEN_LATCH = new CountDownLatch(1); + WebSocketClientConnection connection3 = basicConnector + .baseUri(uri) + .onOpen(c -> CONN3_OPEN_LATCH.countDown()) + .path("end") + .connectAndAwait(); + + assertTrue(ServerEndpoint.OPEN_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ClientEndpoint.OPEN_LATCH.await(5, TimeUnit.SECONDS)); + + assertNotNull(connections.findByConnectionId(connection1.id())); + assertNotNull(connections.findByConnectionId(connection2.id())); + assertNotNull(connections.findByConnectionId(connection3.id())); + assertEquals(3, connections.listAll().size()); + assertEquals(2, connections.findByClientId("client").size()); + assertEquals(1, connections.findByClientId(BasicWebSocketConnector.class.getName()).size()); + + connection2.closeAndAwait(); + assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertEquals(2, connections.stream().toList().size()); + } + + @WebSocket(path = "/end") + public static class ServerEndpoint { + + static final CountDownLatch OPEN_LATCH = new CountDownLatch(2); + + @OnOpen + void open(HandshakeRequest handshakeRequest) { + if (handshakeRequest.header("X-Test") != null) { + OPEN_LATCH.countDown(); + } + } + + } + + @WebSocketClient(path = "/end", clientId = "client") + public static class ClientEndpoint { + + static final CountDownLatch OPEN_LATCH = new CountDownLatch(2); + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + void open(HandshakeRequest handshakeRequest) { + if (handshakeRequest.header("X-Test") != null) { + OPEN_LATCH.countDown(); + } + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + +} diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java index 51221b33afb02..a8933b1f10d3c 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseInSubEndpointTest.java @@ -7,7 +7,7 @@ import io.quarkus.websockets.next.OnClose; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnCloseInSubEndpointTest { @@ -16,7 +16,7 @@ public class TooManyOnCloseInSubEndpointTest { .withApplicationRoot(root -> { root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnClose.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatSubEndpointWithoutTooManyOnCloseFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java index 275c70c6230b7..55ba2dcb7d9c0 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnCloseTest.java @@ -7,7 +7,7 @@ import io.quarkus.websockets.next.OnClose; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnCloseTest { @@ -16,7 +16,7 @@ public class TooManyOnCloseTest { .withApplicationRoot(root -> { root.addClasses(TooManyOnClose.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatEndpointWithMultipleOnCloseMethodsFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java index b15900524e29e..0f25cd5f87ebb 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageInSubEndpointTest.java @@ -6,7 +6,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnMessageInSubEndpointTest { @@ -15,7 +15,7 @@ public class TooManyOnMessageInSubEndpointTest { .withApplicationRoot(root -> { root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnMessage.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatSubEndpointWithoutTooManyOnMessageFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java index 4b7d8dc826f58..adc76a40ff909 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnMessageTest.java @@ -6,7 +6,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnMessageTest { @@ -15,7 +15,7 @@ public class TooManyOnMessageTest { .withApplicationRoot(root -> { root.addClasses(TooManyOnMessage.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatEndpointWithMultipleOnMessageMethodsFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java index c729625534bd2..90b5778d9deb6 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenInSubEndpointTest.java @@ -7,7 +7,7 @@ import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnOpenInSubEndpointTest { @@ -16,7 +16,7 @@ public class TooManyOnOpenInSubEndpointTest { .withApplicationRoot(root -> { root.addClasses(ParentEndpoint.class, ParentEndpoint.SubEndpointWithTooManyOnOpen.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatSubEndpointWithoutTooManyOnOpenFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java index 5fd41b498f817..8d8b2b4da1e27 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/endpoints/TooManyOnOpenTest.java @@ -6,7 +6,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class TooManyOnOpenTest { @@ -15,7 +15,7 @@ public class TooManyOnOpenTest { .withApplicationRoot(root -> { root.addClasses(TooManyOnOpen.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void verifyThatEndpointWithMultipleOnOpenMethodsFailsToDeploy() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/GlobalErrorHandlerWithPathParamTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/GlobalErrorHandlerWithPathParamTest.java index a4ca1f3684317..be8c7655416c9 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/GlobalErrorHandlerWithPathParamTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/GlobalErrorHandlerWithPathParamTest.java @@ -10,7 +10,7 @@ import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnError; import io.quarkus.websockets.next.PathParam; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class GlobalErrorHandlerWithPathParamTest { @@ -19,7 +19,7 @@ public class GlobalErrorHandlerWithPathParamTest { .withApplicationRoot(root -> { root.addClasses(GlobalErrorHandlers.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void testMultipleAmbiguousErrorHandlers() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java index cbd863f0babcb..2abbf73a27e87 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousErrorHandlersTest.java @@ -9,7 +9,7 @@ import io.quarkus.websockets.next.OnError; import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.WebSocket; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class MultipleAmbiguousErrorHandlersTest { @@ -18,7 +18,7 @@ public class MultipleAmbiguousErrorHandlersTest { .withApplicationRoot(root -> { root.addClasses(Endpoint.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void testMultipleAmbiguousErrorHandlers() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousGlobalErrorHandlersTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousGlobalErrorHandlersTest.java index 08751fc65f512..3a653229c1ab7 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousGlobalErrorHandlersTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/errors/MultipleAmbiguousGlobalErrorHandlersTest.java @@ -10,7 +10,7 @@ import io.quarkus.arc.Unremovable; import io.quarkus.test.QuarkusUnitTest; import io.quarkus.websockets.next.OnError; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; public class MultipleAmbiguousGlobalErrorHandlersTest { @@ -19,7 +19,7 @@ public class MultipleAmbiguousGlobalErrorHandlersTest { .withApplicationRoot(root -> { root.addClasses(GlobalErrorHandlers.class); }) - .setExpectedException(WebSocketServerException.class); + .setExpectedException(WebSocketException.class); @Test void testMultipleAmbiguousErrorHandlers() { diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/ConcurrentInboundProcessingTest.java similarity index 87% rename from extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java rename to extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/ConcurrentInboundProcessingTest.java index 6b4c0659d5c5c..50c27fec38a65 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/ConcurrentExecutionModeTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/ConcurrentInboundProcessingTest.java @@ -1,6 +1,6 @@ -package io.quarkus.websockets.next.test.executionmode; +package io.quarkus.websockets.next.test.inboundprocessing; -import static io.quarkus.websockets.next.WebSocket.ExecutionMode.CONCURRENT; +import static io.quarkus.websockets.next.InboundProcessingMode.CONCURRENT; import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; @@ -19,7 +19,7 @@ import io.quarkus.websockets.next.test.utils.WSClient; import io.vertx.core.Vertx; -public class ConcurrentExecutionModeTest { +public class ConcurrentInboundProcessingTest { @RegisterExtension public static final QuarkusUnitTest test = new QuarkusUnitTest() @@ -46,7 +46,7 @@ void testSimultaneousExecution() { } } - @WebSocket(path = "/sim", executionMode = CONCURRENT) + @WebSocket(path = "/sim", inboundProcessingMode = CONCURRENT) public static class Sim { private final CountDownLatch latch = new CountDownLatch(4); diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/SerialExecutionModeErrorTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/SerialInboundProcessingErrorTest.java similarity index 87% rename from extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/SerialExecutionModeErrorTest.java rename to extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/SerialInboundProcessingErrorTest.java index 327febe64060b..79d2cafb63af1 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/executionmode/SerialExecutionModeErrorTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/inboundprocessing/SerialInboundProcessingErrorTest.java @@ -1,6 +1,6 @@ -package io.quarkus.websockets.next.test.executionmode; +package io.quarkus.websockets.next.test.inboundprocessing; -import static io.quarkus.websockets.next.WebSocket.ExecutionMode.SERIAL; +import static io.quarkus.websockets.next.InboundProcessingMode.SERIAL; import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; @@ -18,7 +18,7 @@ import io.quarkus.websockets.next.test.utils.WSClient; import io.vertx.core.Vertx; -public class SerialExecutionModeErrorTest { +public class SerialInboundProcessingErrorTest { @RegisterExtension public static final QuarkusUnitTest test = new QuarkusUnitTest() @@ -45,7 +45,7 @@ void testSerialExecution() { } } - @WebSocket(path = "/sim", executionMode = SERIAL) + @WebSocket(path = "/sim", inboundProcessingMode = SERIAL) public static class Sim { @OnTextMessage diff --git a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java index b922be4955450..f8363d3a56632 100644 --- a/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java +++ b/extensions/websockets-next/server/deployment/src/test/java/io/quarkus/websockets/next/test/subprotocol/SubprotocolSelectedTest.java @@ -1,6 +1,6 @@ package io.quarkus.websockets.next.test.subprotocol; -import static io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest.SEC_WEBSOCKET_PROTOCOL; +import static io.quarkus.websockets.next.HandshakeRequest.SEC_WEBSOCKET_PROTOCOL; import static org.junit.jupiter.api.Assertions.assertEquals; import java.net.URI; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BasicWebSocketConnector.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BasicWebSocketConnector.java new file mode 100644 index 0000000000000..56d6704085116 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BasicWebSocketConnector.java @@ -0,0 +1,181 @@ +package io.quarkus.websockets.next; + +import java.net.URI; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import io.quarkus.arc.Arc; +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.common.annotation.Experimental; +import io.smallrye.mutiny.Uni; +import io.vertx.core.buffer.Buffer; + +/** + * This basic connector can be used to configure and open new client connections. Unlike with {@link WebSocketConnector} a + * client endpoint class is not needed. + *

+ * This construct is not thread-safe and should not be used concurrently. + * + * @see WebSocketClientConnection + */ +@Experimental("This API is experimental and may change in the future") +public interface BasicWebSocketConnector { + + /** + * Obtains a new basic connector. An alternative to {@code @Inject BasicWebSocketConnector}. + * + * @return a new basic connector + */ + static BasicWebSocketConnector create() { + return Arc.container().instance(BasicWebSocketConnector.class).get(); + } + + /** + * Set the base URI. + * + * @param uri + * @return self + */ + BasicWebSocketConnector baseUri(URI uri); + + /** + * Set the path that should be appended to the path of the URI set by {@link #baseUri(URI)}. + *

+ * The path may contain path parameters as defined by {@link WebSocketClient#path()}. In this case, the + * {@link #pathParam(String, String)} method must be used to pass path param values. + * + * @param path + * @return self + */ + BasicWebSocketConnector path(String path); + + /** + * Set the path param. + * + * @param name + * @param value + * @return self + * @throws IllegalArgumentException If the path set by {@link #path(String)} does not contain a parameter with the given + * name + */ + BasicWebSocketConnector pathParam(String name, String value); + + /** + * Add a header used during the initial handshake request. + * + * @param name + * @param value + * @return self + * @see HandshakeRequest + */ + BasicWebSocketConnector addHeader(String name, String value); + + /** + * Add the subprotocol. + * + * @param name + * @param value + * @return self + */ + BasicWebSocketConnector addSubprotocol(String value); + + /** + * Set the execution model for callback handlers. + *

+ * By default, {@link ExecutionModel#BLOCKING} is used. + * + * @return self + * @see #onTextMessage(BiConsumer) + * @see #onBinaryMessage(BiConsumer) + * @see #onPong(BiConsumer) + * @see #onOpen(Consumer) + * @see #onClose(BiConsumer) + * @see #onError(BiConsumer) + */ + BasicWebSocketConnector executionModel(ExecutionModel model); + + /** + * Set a callback to be invoked when a connection to the server is open. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onOpen(Consumer consumer); + + /** + * Set a callback to be invoked when a text message is received from the server. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onTextMessage(BiConsumer consumer); + + /** + * Set a callback to be invoked when a binary message is received from the server. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onBinaryMessage(BiConsumer consumer); + + /** + * Set a callback to be invoked when a pong message is received from the server. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onPong(BiConsumer consumer); + + /** + * Set a callback to be invoked when a connection to the server is closed. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onClose(BiConsumer consumer); + + /** + * Set a callback to be invoked when an error occurs. + * + * @param consumer + * @return self + * @see #executionModel(ExecutionModel) + */ + BasicWebSocketConnector onError(BiConsumer consumer); + + /** + * + * @return a new {@link Uni} with a {@link WebSocketClientConnection} item + */ + @CheckReturnValue + Uni connect(); + + /** + * + * @return the client connection + */ + default WebSocketClientConnection connectAndAwait() { + return connect().await().indefinitely(); + } + + enum ExecutionModel { + /** + * Callback may block the current thread. + */ + BLOCKING, + /** + * Callback is executed on the event loop and may not block the current thread. + */ + NON_BLOCKING, + /** + * Callback is executed on a virtual thread. + */ + VIRTUAL_THREAD, + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java index 897092042bfb1..fad7edcffcc57 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryDecodeException.java @@ -8,7 +8,7 @@ * @see BinaryMessageCodec */ @Experimental("This API is experimental and may change in the future") -public class BinaryDecodeException extends WebSocketServerException { +public class BinaryDecodeException extends WebSocketException { private static final long serialVersionUID = 6814319993301938091L; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java index 74eb0a425a7a9..153c72fba8a3e 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/BinaryEncodeException.java @@ -7,7 +7,7 @@ * @see BinaryMessageCodec */ @Experimental("This API is experimental and may change in the future") -public class BinaryEncodeException extends WebSocketServerException { +public class BinaryEncodeException extends WebSocketException { private static final long serialVersionUID = -8042792962717461873L; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/HandshakeRequest.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/HandshakeRequest.java new file mode 100644 index 0000000000000..052dda407a11b --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/HandshakeRequest.java @@ -0,0 +1,89 @@ +package io.quarkus.websockets.next; + +import java.util.List; +import java.util.Map; + +/** + * Provides some useful information about the initial handshake request. + */ +public interface HandshakeRequest { + + /** + * The name is case insensitive. + * + * @param name + * @return the first header value for the given header name, or {@code null} + */ + String header(String name); + + /** + * The name is case insensitive. + * + * @param name + * @return an immutable list of header values for the given header name, never {@code null} + */ + List headers(String name); + + /** + * Returned header names are lower case. + * + * @return an immutable map of header names to header values + */ + Map> headers(); + + /** + * + * @return the scheme + */ + String scheme(); + + /** + * + * @return the host + */ + String host(); + + /** + * + * @return the port + */ + int port(); + + /** + * + * @return the path + */ + String path(); + + /** + * + * @return the query string + */ + String query(); + + /** + * See The WebSocket Protocol. + */ + public static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + + /** + * See The WebSocket Protocol. + */ + public static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; + + /** + * See The WebSocket Protocol. + */ + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + /** + * See The WebSocket Protocol. + */ + public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + + /** + * See The WebSocket Protocol. + */ + public static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + +} \ No newline at end of file diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/InboundProcessingMode.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/InboundProcessingMode.java new file mode 100644 index 0000000000000..dc206f4a23743 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/InboundProcessingMode.java @@ -0,0 +1,24 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +/** + * Defines the mode used to process incoming messages for a specific connection. + * + * @see WebSocketConnection + * @see WebSocketClientConnection + */ +@Experimental("This API is experimental and may change in the future") +public enum InboundProcessingMode { + + /** + * Messages are processed serially, ordering is guaranteed. + */ + SERIAL, + + /** + * Messages are processed concurrently, there are no ordering guarantees. + */ + CONCURRENT, + +} \ No newline at end of file diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java index ee8957a0632f6..2eb5f62221ce4 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnBinaryMessage.java @@ -6,11 +6,10 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation consumes binary messages. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation consume binary messages. *

* The method must accept exactly one message parameter. A binary message is always represented as a * {@link io.vertx.core.buffer.Buffer}. Therefore, the following conversion rules @@ -31,7 +30,7 @@ *

* The method may also accept the following parameters: *

    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
@@ -44,6 +43,7 @@ public @interface OnBinaryMessage { /** + * Broadcasting is only supported for server endpoints annotated with {@link WebSocket}. * * @return {@code true} if all the connected clients should receive the objects returned by the annotated method * @see WebSocketConnection#broadcast() diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java index 636a6ec3fb723..39ed396ec24ba 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnClose.java @@ -6,17 +6,16 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation is invoked when the client disconnects from the - * socket. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation are invoked when a connection + * is closed. *

* The method must return {@code void} or {@code io.smallrye.mutiny.Uni}. * The method may accept the following parameters: *

    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java index 890219ab6450e..ee277a3373b71 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnError.java @@ -6,11 +6,11 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation is invoked when an error occurs. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation are invoked when an error + * occurs. *

* It is used when an endpoint callback throws a runtime error, or when a conversion errors occurs, or when a returned * {@link io.smallrye.mutiny.Uni} receives a failure. @@ -18,7 +18,7 @@ * The method must accept exactly one "error" parameter, i.e. a parameter that is assignable from {@link java.lang.Throwable}. * The method may also accept the following parameters: *

    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
@@ -26,9 +26,12 @@ * An endpoint may declare multiple methods annotated with this annotation. However, each method must declare a different error * parameter. The method that declares a most-specific supertype of the actual exception is selected. *

- * This annotation can be also used to declare a global error handler, i.e. a method that is not declared on a {@link WebSocket} - * endpoint. Such a method may not accept {@link PathParam} paremeters. Error handlers declared on an endpoint take - * precedence over the global error handlers. + * This annotation can be also used to declare a global error handler, i.e. a method that is not declared on a + * {@link WebSocket}/{@link WebSocketClient} endpoint. Such a method may not accept {@link PathParam} paremeters. If a global + * error handler accepts {@link WebSocketConnection} then it's only applied to server-side errors. If a global error + * handler accepts {@link WebSocketClientConnection} then it's only applied to client-side errors. + * + * Error handlers declared on an endpoint take precedence over the global error handlers. */ @Retention(RUNTIME) @Target(METHOD) diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java index 6aff4281fbdab..7489fb35c89ce 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnOpen.java @@ -6,16 +6,15 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation is invoked when the client connects to a web socket - * endpoint. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation are invoked when a new + * connection is opened. *

* The method may accept the following parameters: *

    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
@@ -28,6 +27,8 @@ public @interface OnOpen { /** + * Broadcasting is only supported for server endpoints annotated with {@link WebSocket}. + * * @return {@code true} if all the connected clients should receive the objects emitted by the annotated method * @see WebSocketConnection#broadcast() */ diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java index e628954dafec4..a1b077704fb94 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnPongMessage.java @@ -6,16 +6,15 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation consumes pong messages. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation consume pong messages. * * The method must accept exactly one pong message parameter represented as a {@link io.vertx.core.buffer.Buffer}. The method * may also accept the following parameters: *
    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java index 922101c3dda0f..ac7cd56b27a42 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OnTextMessage.java @@ -6,11 +6,10 @@ import java.lang.annotation.Retention; import java.lang.annotation.Target; -import io.quarkus.websockets.next.WebSocketConnection.HandshakeRequest; import io.smallrye.common.annotation.Experimental; /** - * A {@link WebSocket} endpoint method annotated with this annotation consumes text messages. + * {@link WebSocket} and {@link WebSocketClient} endpoint methods annotated with this annotation consume text messages. *

* The method must accept exactly one message parameter. A text message is always represented as a {@link String}. Therefore, * the following conversion rules apply. The types listed @@ -30,7 +29,7 @@ *

* The method may also accept the following parameters: *

    - *
  • {@link WebSocketConnection}
  • + *
  • {@link WebSocketConnection}/{@link WebSocketClientConnection}; depending on the endpoint type
  • *
  • {@link HandshakeRequest}
  • *
  • {@link String} parameters annotated with {@link PathParam}
  • *
@@ -43,6 +42,7 @@ public @interface OnTextMessage { /** + * Broadcasting is only supported for server endpoints annotated with {@link WebSocket}. * * @return {@code true} if all the connected clients should receive the objects returned by the annotated method * @see WebSocketConnection#broadcast() diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenClientConnections.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenClientConnections.java new file mode 100644 index 0000000000000..e4270cc8b54ae --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenClientConnections.java @@ -0,0 +1,55 @@ +package io.quarkus.websockets.next; + +import java.util.Collection; +import java.util.Optional; +import java.util.stream.Stream; + +import io.smallrye.common.annotation.Experimental; + +/** + * Provides convenient access to all open client connections. + *

+ * Quarkus provides a built-in CDI bean with the {@link jakarta.inject.Singleton} scope that implements this interface. + */ +@Experimental("This API is experimental and may change in the future") +public interface OpenClientConnections extends Iterable { + + /** + * Returns an immutable snapshot of all open connections at the given time. + * + * @return an immutable collection of all open connections + */ + default Collection listAll() { + return stream().toList(); + } + + /** + * Returns an immutable snapshot of all open connections for the given client id. + * + * @param endpointId + * @return an immutable collection of all open connections for the given client id + * @see WebSocketClient#clientId() + */ + default Collection findByClientId(String clientId) { + return stream().filter(c -> c.clientId().equals(clientId)).toList(); + } + + /** + * Returns the open connection with the given id. + * + * @param connectionId + * @return the open connection or empty {@link Optional} if no open connection with the given id exists + * @see WebSocketConnection#id() + */ + default Optional findByConnectionId(String connectionId) { + return stream().filter(c -> c.id().equals(connectionId)).findFirst(); + } + + /** + * Returns the stream of all open connections at the given time. + * + * @return the stream of open connections + */ + Stream stream(); + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenConnections.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenConnections.java index b5a6859daafd0..c8a5c797289c7 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenConnections.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/OpenConnections.java @@ -7,7 +7,7 @@ import io.smallrye.common.annotation.Experimental; /** - * Provides convenient access to all open connections from clients to {@link WebSocket} endpoints on the server. + * Provides convenient access to all open connections. *

* Quarkus provides a built-in CDI bean with the {@link jakarta.inject.Singleton} scope that implements this interface. */ diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java index 62e49f246946f..c3b565741a868 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextDecodeException.java @@ -7,7 +7,7 @@ * @see TextMessageCodec */ @Experimental("This API is experimental and may change in the future") -public class TextDecodeException extends WebSocketServerException { +public class TextDecodeException extends WebSocketException { private static final long serialVersionUID = 6814319993301938091L; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java index 47a74133f1779..d41aaf9615922 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/TextEncodeException.java @@ -7,7 +7,7 @@ * @see TextMessageCodec */ @Experimental("This API is experimental and may change in the future") -public class TextEncodeException extends WebSocketServerException { +public class TextEncodeException extends WebSocketException { private static final long serialVersionUID = 837621296462089705L; diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java index 736c080cc92a6..819bce0cd4064 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocket.java @@ -48,32 +48,13 @@ public String endpointId() default FCQN_NAME; /** - * The execution mode used to process incoming messages for a specific connection. + * The mode used to process incoming messages for a specific connection. */ - public ExecutionMode executionMode() default ExecutionMode.SERIAL; + public InboundProcessingMode inboundProcessingMode() default InboundProcessingMode.SERIAL; /** * Constant value for {@link #endpointId()} indicating that the fully qualified name of the annotated class should be used. */ String FCQN_NAME = "<>"; - /** - * Defines the execution mode used to process incoming messages for a specific connection. - * - * @see WebSocketConnection - */ - enum ExecutionMode { - - /** - * Messages are processed serially, ordering is guaranteed. - */ - SERIAL, - - /** - * Messages are processed concurrently, there are no ordering guarantees. - */ - CONCURRENT, - - } - } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClient.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClient.java new file mode 100644 index 0000000000000..92c960b5fc4ae --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClient.java @@ -0,0 +1,60 @@ +package io.quarkus.websockets.next; + +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Singleton; + +import io.smallrye.common.annotation.Experimental; + +/** + * Denotes a WebSocket client endpoint. + *

+ * An endpoint must declare a method annotated with {@link OnTextMessage}, {@link OnBinaryMessage}, {@link OnPongMessage} or + * {@link OnOpen}. An endpoint may declare a method annotated with {@link OnClose}. + * + *

Lifecycle and concurrency

+ * Client endpoint implementation class must be a CDI bean. If no scope annotation is defined then {@link Singleton} is used. + * {@link ApplicationScoped} and {@link Singleton} client endpoints are shared accross all WebSocket client connections. + * Therefore, implementations should be either stateless or thread-safe. + */ +@Retention(RUNTIME) +@Target(TYPE) +@Experimental("This API is experimental and may change in the future") +public @interface WebSocketClient { + + /** + * The path of the endpoint on the server. + *

+ * It is possible to match path parameters. The placeholder of a path parameter consists of the parameter name surrounded by + * curly brackets. The actual value of a path parameter can be obtained using + * {@link WebSocketClientConnection#pathParam(String)}. For example, the path /foo/{bar} defines the path + * parameter {@code bar}. + * + * @see WebSocketConnection#pathParam(String) + */ + public String path(); + + /** + * By default, the fully qualified name of the annotated class is used. + * + * @return the endpoint id + * @see WebSocketClientConnection#clientId() + */ + public String clientId() default FCQN_NAME; + + /** + * The execution mode used to process incoming messages for a specific connection. + */ + public InboundProcessingMode inboundProcessingMode() default InboundProcessingMode.SERIAL; + + /** + * Constant value for {@link #clientId()} indicating that the fully qualified name of the annotated class should be used. + */ + String FCQN_NAME = "<>"; + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java new file mode 100644 index 0000000000000..9a987ed0004c9 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java @@ -0,0 +1,74 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.common.annotation.Experimental; +import io.smallrye.mutiny.Uni; + +/** + * This interface represents a client connection to a WebSocket endpoint. + *

+ * Quarkus provides a built-in CDI bean that implements this interface and can be injected in a {@link WebSocketClient} + * endpoint and used to interact with the connected server. + */ +@Experimental("This API is experimental and may change in the future") +public interface WebSocketClientConnection extends Sender, BlockingSender { + + /** + * + * @return the unique identifier assigned to this connection + */ + String id(); + + /* + * @return the client id + */ + String clientId(); + + /** + * + * @param name + * @return the actual value of the path parameter or null + * @see WebSocketClient#path() + */ + String pathParam(String name); + + /** + * @return {@code true} if the HTTP connection is encrypted via SSL/TLS + */ + boolean isSecure(); + + /** + * @return {@code true} if the WebSocket is closed + */ + boolean isClosed(); + + /** + * + * @return {@code true} if the WebSocket is open + */ + default boolean isOpen() { + return !isClosed(); + } + + /** + * Close the connection. + * + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + Uni close(); + + /** + * Close the connection. + */ + default void closeAndAwait() { + close().await().indefinitely(); + } + + /** + * + * @return the handshake request + */ + HandshakeRequest handshakeRequest(); + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientException.java new file mode 100644 index 0000000000000..50dd541d7c3b3 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientException.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +@Experimental("This API is experimental and may change in the future") +public class WebSocketClientException extends WebSocketException { + + private static final long serialVersionUID = -4213710383874397185L; + + public WebSocketClientException(String message, Throwable cause) { + super(message, cause); + } + + public WebSocketClientException(String message) { + super(message); + } + + public WebSocketClientException(Throwable cause) { + super(cause); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java index 63a3b66058697..e59e5cb0dbea0 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java @@ -1,8 +1,6 @@ package io.quarkus.websockets.next; import java.time.Instant; -import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.Predicate; @@ -127,89 +125,4 @@ interface BroadcastSender extends Sender, BlockingSender { } - /** - * Provides some useful information about the initial handshake request. - */ - interface HandshakeRequest { - - /** - * The name is case insensitive. - * - * @param name - * @return the first header value for the given header name, or {@code null} - */ - String header(String name); - - /** - * The name is case insensitive. - * - * @param name - * @return an immutable list of header values for the given header name, never {@code null} - */ - List headers(String name); - - /** - * Returned header names are lower case. - * - * @return an immutable map of header names to header values - */ - Map> headers(); - - /** - * - * @return the scheme - */ - String scheme(); - - /** - * - * @return the host - */ - String host(); - - /** - * - * @return the port - */ - int port(); - - /** - * - * @return the path - */ - String path(); - - /** - * - * @return the query string - */ - String query(); - - /** - * See The WebSocket Protocol. - */ - public static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; - - /** - * See The WebSocket Protocol. - */ - public static final String SEC_WEBSOCKET_EXTENSIONS = "Sec-WebSocket-Extensions"; - - /** - * See The WebSocket Protocol. - */ - public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; - - /** - * See The WebSocket Protocol. - */ - public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; - - /** - * See The WebSocket Protocol. - */ - public static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; - - } - } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnector.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnector.java new file mode 100644 index 0000000000000..4b771a66c7833 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnector.java @@ -0,0 +1,74 @@ +package io.quarkus.websockets.next; + +import java.net.URI; + +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.common.annotation.Experimental; +import io.smallrye.mutiny.Uni; + +/** + * This connector can be used to configure and open new client connections using a client endpoint class. + *

+ * This construct is not thread-safe and should not be used concurrently. + * + * @param The client endpoint class + * @see WebSocketClient + * @see WebSocketClientConnection + */ +@Experimental("This API is experimental and may change in the future") +public interface WebSocketConnector { + + /** + * Set the base URI. + * + * @param baseUri + * @return self + */ + WebSocketConnector baseUri(URI baseUri); + + /** + * Set the path param. + * + * @param name + * @param value + * @return self + * @throws IllegalArgumentException If the client endpoint path does not contain a parameter with the given name + * @see WebSocketClient#path() + */ + WebSocketConnector pathParam(String name, String value); + + /** + * Add a header used during the initial handshake request. + * + * @param name + * @param value + * @return self + * @see HandshakeRequest + */ + WebSocketConnector addHeader(String name, String value); + + /** + * Add the subprotocol. + * + * @param name + * @param value + * @return self + */ + WebSocketConnector addSubprotocol(String value); + + /** + * + * @return a new {@link Uni} with a {@link WebSocketClientConnection} item + */ + @CheckReturnValue + Uni connect(); + + /** + * + * @return the client connection + */ + default WebSocketClientConnection connectAndAwait() { + return connect().await().indefinitely(); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketException.java new file mode 100644 index 0000000000000..340ff835c1b6d --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketException.java @@ -0,0 +1,22 @@ +package io.quarkus.websockets.next; + +import io.smallrye.common.annotation.Experimental; + +@Experimental("This API is experimental and may change in the future") +public class WebSocketException extends RuntimeException { + + private static final long serialVersionUID = 903932032264812404L; + + public WebSocketException(String message, Throwable cause) { + super(message, cause); + } + + public WebSocketException(String message) { + super(message); + } + + public WebSocketException(Throwable cause) { + super(cause); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java index c226d78983d47..b4cd21e1c423b 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketServerException.java @@ -3,9 +3,9 @@ import io.smallrye.common.annotation.Experimental; @Experimental("This API is experimental and may change in the future") -public class WebSocketServerException extends RuntimeException { +public class WebSocketServerException extends WebSocketException { - private static final long serialVersionUID = 903932032264812404L; + private static final long serialVersionUID = 815788270725783535L; public WebSocketServerException(String message, Throwable cause) { super(message, cause); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java new file mode 100644 index 0000000000000..dff4780aa45c7 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java @@ -0,0 +1,43 @@ +package io.quarkus.websockets.next; + +import java.time.Duration; +import java.util.Optional; +import java.util.OptionalInt; + +import io.quarkus.runtime.annotations.ConfigPhase; +import io.quarkus.runtime.annotations.ConfigRoot; +import io.smallrye.config.ConfigMapping; +import io.smallrye.config.WithDefault; + +@ConfigMapping(prefix = "quarkus.websockets-next.client") +@ConfigRoot(phase = ConfigPhase.RUN_TIME) +public interface WebSocketsClientRuntimeConfig { + + /** + * Compression Extensions for WebSocket are supported by default. + *

+ * See also RFC 7692 + */ + @WithDefault("false") + boolean offerPerMessageCompression(); + + /** + * The compression level must be a value between 0 and 9. The default value is + * {@value io.vertx.core.http.HttpClientOptions#DEFAULT_WEBSOCKET_COMPRESSION_LEVEL}. + */ + OptionalInt compressionLevel(); + + /** + * The maximum size of a message in bytes. The default values is + * {@value io.vertx.core.http.HttpClientOptions#DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE}. + */ + OptionalInt maxMessageSize(); + + /** + * The interval after which, when set, the client sends a ping message to a connected server automatically. + *

+ * Ping messages are not sent automatically by default. + */ + Optional autoPingInterval(); + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java index 9566afca3ea5a..28e9d284c2fce 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsServerRuntimeConfig.java @@ -9,7 +9,6 @@ import io.quarkus.runtime.annotations.ConfigRoot; import io.smallrye.config.ConfigMapping; import io.smallrye.config.WithDefault; -import io.vertx.core.http.HttpServerOptions; @ConfigMapping(prefix = "quarkus.websockets-next.server") @ConfigRoot(phase = ConfigPhase.RUN_TIME) @@ -30,13 +29,13 @@ public interface WebSocketsServerRuntimeConfig { /** * The compression level must be a value between 0 and 9. The default value is - * {@value HttpServerOptions#DEFAULT_WEBSOCKET_COMPRESSION_LEVEL}. + * {@value io.vertx.core.http.HttpServerOptions#DEFAULT_WEBSOCKET_COMPRESSION_LEVEL}. */ OptionalInt compressionLevel(); /** * The maximum size of a message in bytes. The default values is - * {@value HttpServerOptions#DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE}. + * {@value io.vertx.core.http.HttpServerOptions#DEFAULT_MAX_WEBSOCKET_MESSAGE_SIZE}. */ OptionalInt maxMessageSize(); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java new file mode 100644 index 0000000000000..e059f5e12c6b9 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java @@ -0,0 +1,292 @@ +package io.quarkus.websockets.next.runtime; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.inject.Typed; + +import org.jboss.logging.Logger; + +import io.quarkus.virtual.threads.VirtualThreadsRecorder; +import io.quarkus.websockets.next.BasicWebSocketConnector; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketClientException; +import io.quarkus.websockets.next.WebSocketsClientRuntimeConfig; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.vertx.UniHelper; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.http.WebSocketClientOptions; +import io.vertx.core.http.WebSocketConnectOptions; + +@Typed(BasicWebSocketConnector.class) +@Dependent +public class BasicWebSocketConnectorImpl extends WebSocketConnectorBase + implements BasicWebSocketConnector { + + private static final Logger LOG = Logger.getLogger(BasicWebSocketConnectorImpl.class); + + // mutable state + + private ExecutionModel executionModel = ExecutionModel.BLOCKING; + + private Consumer openHandler; + + private BiConsumer textMessageHandler; + + private BiConsumer binaryMessageHandler; + + private BiConsumer pongMessageHandler; + + private BiConsumer closeHandler; + + private BiConsumer errorHandler; + + BasicWebSocketConnectorImpl(Vertx vertx, Codecs codecs, ClientConnectionManager connectionManager, + WebSocketsClientRuntimeConfig config) { + super(vertx, codecs, connectionManager, config); + } + + @Override + public BasicWebSocketConnector executionModel(ExecutionModel model) { + this.executionModel = Objects.requireNonNull(model); + return self(); + } + + @Override + public BasicWebSocketConnector path(String path) { + setPath(Objects.requireNonNull(path)); + return self(); + } + + @Override + public BasicWebSocketConnector onOpen(Consumer consumer) { + this.openHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public BasicWebSocketConnector onTextMessage(BiConsumer consumer) { + this.textMessageHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public BasicWebSocketConnector onBinaryMessage(BiConsumer consumer) { + this.binaryMessageHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public BasicWebSocketConnector onPong(BiConsumer consumer) { + this.pongMessageHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public BasicWebSocketConnector onClose(BiConsumer consumer) { + this.closeHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public BasicWebSocketConnector onError(BiConsumer consumer) { + this.errorHandler = Objects.requireNonNull(consumer); + return self(); + } + + @Override + public Uni connect() { + if (baseUri == null) { + throw new WebSocketClientException("Endpoint URI not set!"); + } + + // Currently we create a new client for each connection + // The client is closed when the connection is closed + // TODO would it make sense to share clients? + WebSocketClientOptions clientOptions = new WebSocketClientOptions(); + if (config.offerPerMessageCompression()) { + clientOptions.setTryUsePerMessageCompression(true); + if (config.compressionLevel().isPresent()) { + clientOptions.setCompressionLevel(config.compressionLevel().getAsInt()); + } + } + if (config.maxMessageSize().isPresent()) { + clientOptions.setMaxMessageSize(config.maxMessageSize().getAsInt()); + } + + WebSocketClient client = vertx.createWebSocketClient(); + + WebSocketConnectOptions connectOptions = new WebSocketConnectOptions() + .setSsl(baseUri.getScheme().equals("https")) + .setHost(baseUri.getHost()) + .setPort(baseUri.getPort()); + StringBuilder requestUri = new StringBuilder(); + String mergedPath = mergePath(baseUri.getPath(), replacePathParameters(path)); + requestUri.append(mergedPath); + if (baseUri.getQuery() != null) { + requestUri.append("?").append(baseUri.getQuery()); + } + connectOptions.setURI(requestUri.toString()); + for (Entry> e : headers.entrySet()) { + for (String val : e.getValue()) { + connectOptions.addHeader(e.getKey(), val); + } + } + subprotocols.forEach(connectOptions::addSubProtocol); + + URI serverEndpointUri; + try { + serverEndpointUri = new URI(baseUri.getScheme(), baseUri.getUserInfo(), baseUri.getHost(), baseUri.getPort(), + mergedPath, + baseUri.getQuery(), baseUri.getFragment()); + } catch (URISyntaxException e) { + throw new WebSocketClientException(e); + } + + return UniHelper.toUni(client.connect(connectOptions)) + .map(ws -> { + String clientId = BasicWebSocketConnector.class.getName(); + WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws, + codecs, + pathParams, + serverEndpointUri, + headers); + LOG.debugf("Client connection created: %s", connection); + connectionManager.add(BasicWebSocketConnectorImpl.class.getName(), connection); + + if (openHandler != null) { + doExecute(connection, null, (c, ignored) -> openHandler.accept(c)); + } + + if (textMessageHandler != null) { + ws.textMessageHandler(new Handler() { + @Override + public void handle(String event) { + doExecute(connection, event, textMessageHandler); + } + }); + } + + if (binaryMessageHandler != null) { + ws.binaryMessageHandler(new Handler() { + + @Override + public void handle(Buffer event) { + doExecute(connection, event, binaryMessageHandler); + } + }); + } + + if (pongMessageHandler != null) { + ws.pongHandler(new Handler() { + + @Override + public void handle(Buffer event) { + doExecute(connection, event, pongMessageHandler); + } + }); + } + + if (errorHandler != null) { + ws.exceptionHandler(new Handler() { + + @Override + public void handle(Throwable event) { + doExecute(connection, event, errorHandler); + } + }); + } + + ws.closeHandler(new Handler() { + + @Override + public void handle(Void event) { + if (closeHandler != null) { + doExecute(connection, ws.closeStatusCode(), closeHandler); + } + connectionManager.remove(BasicWebSocketConnectorImpl.class.getName(), connection); + client.close(); + } + + }); + + return connection; + }); + } + + private void doExecute(WebSocketClientConnectionImpl connection, MESSAGE message, + BiConsumer consumer) { + // We always invoke callbacks on a new duplicated context and offload if blocking/virtualThread is needed + Context context = vertx.getOrCreateContext(); + ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { + @Override + public void handle(Void event) { + if (executionModel == ExecutionModel.VIRTUAL_THREAD) { + VirtualThreadsRecorder.getCurrent().execute(new Runnable() { + public void run() { + try { + consumer.accept(connection, message); + } catch (Exception e) { + LOG.errorf(e, "Unable to call handler: " + connection); + } + } + }); + } else if (executionModel == ExecutionModel.BLOCKING) { + vertx.executeBlocking(new Callable() { + @Override + public Void call() { + try { + consumer.accept(connection, message); + } catch (Exception e) { + LOG.errorf(e, "Unable to call handler: " + connection); + } + return null; + } + }, false); + } else { + // Non-blocking -> event loop + try { + consumer.accept(connection, message); + } catch (Exception e) { + LOG.errorf(e, "Unable to call handler: " + connection); + } + } + } + }); + } + + private String mergePath(String path1, String path2) { + StringBuilder path = new StringBuilder(); + if (path1 != null) { + path.append(path1); + } + if (path2 != null) { + if (path1.endsWith("/")) { + if (path2.startsWith("/")) { + path.append(path2.substring(1)); + } else { + path.append(path2); + } + } else { + if (path2.startsWith("/")) { + path.append(path2); + } else { + path.append(path2.substring(1)); + } + } + } + return path.toString(); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ClientConnectionManager.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ClientConnectionManager.java new file mode 100644 index 0000000000000..a81eef5786d64 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ClientConnectionManager.java @@ -0,0 +1,102 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.stream.Stream; + +import jakarta.annotation.PreDestroy; +import jakarta.inject.Singleton; + +import org.jboss.logging.Logger; + +import io.quarkus.websockets.next.OpenClientConnections; +import io.quarkus.websockets.next.WebSocketClientConnection; + +@Singleton +public class ClientConnectionManager implements OpenClientConnections { + + private static final Logger LOG = Logger.getLogger(ClientConnectionManager.class); + + private final ConcurrentMap> endpointToConnections = new ConcurrentHashMap<>(); + + private final List listeners = new CopyOnWriteArrayList<>(); + + @Override + public Iterator iterator() { + return stream().iterator(); + } + + @Override + public Stream stream() { + return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen); + } + + void add(String endpoint, WebSocketClientConnection connection) { + LOG.debugf("Add client connection: %s", connection); + if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) { + if (!listeners.isEmpty()) { + for (ClientConnectionListener listener : listeners) { + try { + listener.connectionAdded(endpoint, connection); + } catch (Exception e) { + LOG.warnf("Unable to call listener#connectionAdded() on [%s]: %s", listener.getClass(), + e.toString()); + } + } + } + } + } + + void remove(String endpoint, WebSocketClientConnection connection) { + LOG.debugf("Remove client connection: %s", connection); + Set connections = endpointToConnections.get(endpoint); + if (connections != null) { + if (connections.remove(connection)) { + if (!listeners.isEmpty()) { + for (ClientConnectionListener listener : listeners) { + try { + listener.connectionRemoved(endpoint, connection.id()); + } catch (Exception e) { + LOG.warnf("Unable to call listener#connectionRemoved() on [%s]: %s", listener.getClass(), + e.toString()); + } + } + } + } + } + } + + /** + * + * @param endpoint + * @return the connections for the given client endpoint, never {@code null} + */ + public Set getConnections(String endpoint) { + Set ret = endpointToConnections.get(endpoint); + if (ret == null) { + return Set.of(); + } + return ret; + } + + public void addListener(ClientConnectionListener listener) { + this.listeners.add(listener); + } + + @PreDestroy + void destroy() { + endpointToConnections.clear(); + } + + public interface ClientConnectionListener { + + void connectionAdded(String endpoint, WebSocketClientConnection connection); + + void connectionRemoved(String endpoint, String connectionId); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java index 0b3b7f3abaec1..d4bd9ad1424b3 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Codecs.java @@ -13,7 +13,7 @@ import io.quarkus.websockets.next.TextDecodeException; import io.quarkus.websockets.next.TextEncodeException; import io.quarkus.websockets.next.TextMessageCodec; -import io.quarkus.websockets.next.WebSocketServerException; +import io.quarkus.websockets.next.WebSocketException; import io.vertx.core.buffer.Buffer; @Singleton @@ -139,7 +139,7 @@ public Buffer binaryEncode(T message, Class codecBeanClass) { throw noCodecToEncode(true, message, type); } - WebSocketServerException noCodecToDecode(String text, Buffer bytes, Type type) { + WebSocketException noCodecToDecode(String text, Buffer bytes, Type type) { String message = String.format("No %s codec handles the type %s", bytes != null ? "binary" : "text", type); if (bytes != null) { return new BinaryDecodeException(bytes, message); @@ -148,7 +148,7 @@ WebSocketServerException noCodecToDecode(String text, Buffer bytes, Type type) { } } - WebSocketServerException noCodecToEncode(boolean binary, Object encodedObject, Type type) { + WebSocketException noCodecToEncode(boolean binary, Object encodedObject, Type type) { String message = String.format("No %s codec handles the type %s", binary ? "binary" : "text", type); if (binary) { return new BinaryEncodeException(encodedObject, message); @@ -157,7 +157,7 @@ WebSocketServerException noCodecToEncode(boolean binary, Object encodedObject, T } } - WebSocketServerException unableToEncode(boolean binary, MessageCodec codec, Object encodedObject, Exception e) { + WebSocketException unableToEncode(boolean binary, MessageCodec codec, Object encodedObject, Exception e) { String message = String.format("Unable to encode %s message with %s", binary ? "binary" : "text", codec.getClass().getName()); if (binary) { @@ -167,7 +167,7 @@ WebSocketServerException unableToEncode(boolean binary, MessageCodec codec } } - WebSocketServerException unableToDecode(String text, Buffer bytes, MessageCodec codec, Exception e) { + WebSocketException unableToDecode(String text, Buffer bytes, MessageCodec codec, Exception e) { String message = String.format("Unable to decode %s message with %s", bytes != null ? "binary" : "text", codec.getClass().getName()); if (bytes != null) { @@ -177,7 +177,7 @@ WebSocketServerException unableToDecode(String text, Buffer bytes, MessageCodec< } } - WebSocketServerException forcedCannotEncode(boolean binary, MessageCodec codec, Object encodedObject) { + WebSocketException forcedCannotEncode(boolean binary, MessageCodec codec, Object encodedObject) { String message = String.format("Forced %s codec [%s] cannot handle the type %s", binary ? "binary" : "text", codec.getClass().getName(), encodedObject.getClass()); if (binary) { @@ -187,7 +187,7 @@ WebSocketServerException forcedCannotEncode(boolean binary, MessageCodec c } } - WebSocketServerException forcedCannotDecode(String text, Buffer bytes, MessageCodec codec, Type type) { + WebSocketException forcedCannotDecode(String text, Buffer bytes, MessageCodec codec, Type type) { String message = String.format("Forced %s codec [%s] cannot decode the type %s", bytes != null ? "binary" : "text", codec.getClass().getName(), type); if (bytes != null) { diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java index 9f13ca901e8e2..8a690d793ce5d 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ConcurrencyLimiter.java @@ -5,7 +5,6 @@ import org.jboss.logging.Logger; -import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.mutiny.helpers.queues.Queues; import io.vertx.core.Context; import io.vertx.core.Handler; @@ -18,12 +17,12 @@ class ConcurrencyLimiter { private static final Logger LOG = Logger.getLogger(ConcurrencyLimiter.class); - private final WebSocketConnection connection; + private final WebSocketConnectionBase connection; private final Queue queue; private final AtomicLong uncompleted; private final AtomicLong queueCounter; - ConcurrencyLimiter(WebSocketConnection connection) { + ConcurrencyLimiter(WebSocketConnectionBase connection) { this.connection = connection; this.uncompleted = new AtomicLong(); this.queueCounter = new AtomicLong(); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java index 9e4b7a4e81525..0b018b6fe2eaf 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/ContextSupport.java @@ -5,7 +5,6 @@ import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.arc.ManagedContext; import io.quarkus.vertx.core.runtime.context.VertxContextSafetyToggle; -import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; @@ -14,12 +13,14 @@ public class ContextSupport { private static final Logger LOG = Logger.getLogger(ContextSupport.class); - private final WebSocketConnection connection; + static final String WEB_SOCKET_CONN_KEY = WebSocketConnectionBase.class.getName(); + + private final WebSocketConnectionBase connection; private final SessionContextState sessionContextState; private final WebSocketSessionContext sessionContext; private final ManagedContext requestContext; - ContextSupport(WebSocketConnection connection, SessionContextState sessionContextState, + ContextSupport(WebSocketConnectionBase connection, SessionContextState sessionContextState, WebSocketSessionContext sessionContext, ManagedContext requestContext) { this.connection = connection; @@ -72,12 +73,12 @@ ContextState currentRequestContextState() { return requestContext.getStateIfActive(); } - static Context createNewDuplicatedContext(Context context, WebSocketConnection connection) { + static Context createNewDuplicatedContext(Context context, WebSocketConnectionBase connection) { Context duplicated = VertxContext.createNewDuplicatedContext(context); VertxContextSafetyToggle.setContextSafe(duplicated, true); // We need to store the connection in the duplicated context // It's used to initialize the synthetic bean later on - duplicated.putLocal(WebSocketServerRecorder.WEB_SOCKET_CONN_KEY, connection); + duplicated.putLocal(ContextSupport.WEB_SOCKET_CONN_KEY, connection); LOG.debugf("New vertx duplicated context [%s] created: %s", duplicated, connection); return duplicated; } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java new file mode 100644 index 0000000000000..15c2933c5feca --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -0,0 +1,303 @@ +package io.quarkus.websockets.next.runtime; + +import java.time.Duration; +import java.util.Optional; +import java.util.function.Consumer; + +import jakarta.enterprise.context.SessionScoped; + +import org.jboss.logging.Logger; + +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.InjectableContext; +import io.quarkus.websockets.next.WebSocketException; +import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; +import io.smallrye.mutiny.Multi; +import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocketBase; + +class Endpoints { + + private static final Logger LOG = Logger.getLogger(Endpoints.class); + + static void initialize(Vertx vertx, ArcContainer container, Codecs codecs, WebSocketConnectionBase connection, + WebSocketBase ws, String generatedEndpointClass, Optional autoPingInterval, Runnable onClose) { + + Context context = vertx.getOrCreateContext(); + + // Initialize and capture the session context state that will be activated + // during message processing + WebSocketSessionContext sessionContext = sessionContext(container); + SessionContextState sessionContextState = sessionContext.initializeContextState(); + ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, + sessionContext(container), + container.requestContext()); + + // Create an endpoint that delegates callbacks to the endpoint bean + WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, contextSupport); + + // A broadcast processor is only needed if Multi is consumed by the callback + BroadcastProcessor textBroadcastProcessor = endpoint.consumedTextMultiType() != null + ? BroadcastProcessor.create() + : null; + BroadcastProcessor binaryBroadcastProcessor = endpoint.consumedBinaryMultiType() != null + ? BroadcastProcessor.create() + : null; + + // NOTE: We always invoke callbacks on a new duplicated context + // and the endpoint is responsible to make the switch if blocking/virtualThread + + Context onOpenContext = ContextSupport.createNewDuplicatedContext(context, connection); + onOpenContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onOpen().onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnOpen callback completed: %s", connection); + // If Multi is consumed we need to invoke the callback eagerly + // but after @OnOpen completes + if (textBroadcastProcessor != null) { + Multi multi = textBroadcastProcessor.onCancellation().call(connection::close); + onOpenContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onTextMessage(multi).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnTextMessage callback consuming Multi completed: %s", + connection); + } else { + LOG.errorf(r.cause(), + "Unable to complete @OnTextMessage callback consuming Multi: %s", + connection); + } + }); + } + }); + } + if (binaryBroadcastProcessor != null) { + Multi multi = binaryBroadcastProcessor.onCancellation().call(connection::close); + onOpenContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onBinaryMessage(multi).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnBinaryMessage callback consuming Multi completed: %s", + connection); + } else { + LOG.errorf(r.cause(), + "Unable to complete @OnBinaryMessage callback consuming Multi: %s", + connection); + } + }); + } + }); + } + } else { + LOG.errorf(r.cause(), "Unable to complete @OnOpen callback: %s", connection); + } + }); + } + }); + + if (textBroadcastProcessor == null) { + // Multi not consumed - invoke @OnTextMessage callback for each message received + textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { + endpoint.onTextMessage(m).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnTextMessage callback consumed text message: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to consume text message in @OnTextMessage callback: %s", + connection); + } + }); + }, true); + } else { + textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { + contextSupport.start(); + try { + textBroadcastProcessor.onNext(endpoint.decodeTextMultiItem(m)); + LOG.debugf("Text message >> Multi: %s", connection); + } catch (Throwable throwable) { + endpoint.doOnError(throwable).subscribe().with( + v -> LOG.debugf("Text message >> Multi: %s", connection), + t -> LOG.errorf(t, "Unable to send text message to Multi: %s", connection)); + } finally { + contextSupport.end(false); + } + }, false); + } + + if (binaryBroadcastProcessor == null) { + // Multi not consumed - invoke @OnBinaryMessage callback for each message received + binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { + endpoint.onBinaryMessage(m).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnBinaryMessage callback consumed text message: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to consume text message in @OnBinaryMessage callback: %s", + connection); + } + }); + }, true); + } else { + binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { + contextSupport.start(); + try { + binaryBroadcastProcessor.onNext(endpoint.decodeBinaryMultiItem(m)); + LOG.debugf("Binary message >> Multi: %s", connection); + } catch (Throwable throwable) { + endpoint.doOnError(throwable).subscribe().with( + v -> LOG.debugf("Binary message >> Multi: %s", connection), + t -> LOG.errorf(t, "Unable to send binary message to Multi: %s", connection)); + } finally { + contextSupport.end(false); + } + }, false); + } + + pongMessageHandler(connection, endpoint, ws, onOpenContext, m -> { + endpoint.onPongMessage(m).onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnPongMessage callback consumed text message: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to consume text message in @OnPongMessage callback: %s", + connection); + } + }); + }); + + Long timerId; + if (autoPingInterval.isPresent()) { + timerId = vertx.setPeriodic(autoPingInterval.get().toMillis(), new Handler() { + @Override + public void handle(Long timerId) { + connection.sendAutoPing(); + } + }); + } else { + timerId = null; + } + + ws.closeHandler(new Handler() { + @Override + public void handle(Void event) { + ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.onClose().onComplete(r -> { + if (r.succeeded()) { + LOG.debugf("@OnClose callback completed: %s", connection); + } else { + LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection); + } + onClose.run(); + if (timerId != null) { + vertx.cancelTimer(timerId); + } + }); + } + }); + } + }); + + ws.exceptionHandler(new Handler() { + @Override + public void handle(Throwable t) { + ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { + @Override + public void handle(Void event) { + endpoint.doOnError(t).subscribe().with( + v -> LOG.debugf("Error [%s] processed: %s", t.getClass(), connection), + t -> LOG.errorf(t, "Unhandled error occured: %s", t.toString(), + connection)); + } + }); + } + }); + } + + private static void textMessageHandler(WebSocketConnectionBase connection, WebSocketEndpoint endpoint, WebSocketBase ws, + Context context, Consumer textAction, boolean newDuplicatedContext) { + ws.textMessageHandler(new Handler() { + @Override + public void handle(String message) { + Context duplicatedContext = newDuplicatedContext + ? ContextSupport.createNewDuplicatedContext(context, connection) + : context; + duplicatedContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + textAction.accept(message); + } + }); + } + }); + } + + private static void binaryMessageHandler(WebSocketConnectionBase connection, WebSocketEndpoint endpoint, WebSocketBase ws, + Context context, Consumer binaryAction, boolean newDuplicatedContext) { + ws.binaryMessageHandler(new Handler() { + @Override + public void handle(Buffer message) { + Context duplicatedContext = newDuplicatedContext + ? ContextSupport.createNewDuplicatedContext(context, connection) + : context; + duplicatedContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + binaryAction.accept(message); + } + }); + } + }); + } + + private static void pongMessageHandler(WebSocketConnectionBase connection, WebSocketEndpoint endpoint, WebSocketBase ws, + Context context, Consumer pongAction) { + ws.pongHandler(new Handler() { + @Override + public void handle(Buffer message) { + Context duplicatedContext = ContextSupport.createNewDuplicatedContext(context, connection); + duplicatedContext.runOnContext(new Handler() { + @Override + public void handle(Void event) { + pongAction.accept(message); + } + }); + } + }); + } + + private static WebSocketEndpoint createEndpoint(String endpointClassName, Context context, + WebSocketConnectionBase connection, + Codecs codecs, ContextSupport contextSupport) { + try { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = WebSocketServerRecorder.class.getClassLoader(); + } + @SuppressWarnings("unchecked") + Class endpointClazz = (Class) cl + .loadClass(endpointClassName); + WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz + .getDeclaredConstructor(WebSocketConnectionBase.class, Codecs.class, ContextSupport.class) + .newInstance(connection, codecs, contextSupport); + return endpoint; + } catch (Exception e) { + throw new WebSocketException("Unable to create endpoint instance: " + endpointClassName, e); + } + } + + private static WebSocketSessionContext sessionContext(ArcContainer container) { + for (InjectableContext injectableContext : container.getContexts(SessionScoped.class)) { + if (WebSocketSessionContext.class.equals(injectableContext.getClass())) { + return (WebSocketSessionContext) injectableContext; + } + } + throw new WebSocketException("CDI session context not registered"); + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java new file mode 100644 index 0000000000000..83b9745ab7cf5 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientConnectionImpl.java @@ -0,0 +1,117 @@ +package io.quarkus.websockets.next.runtime; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; + +import io.quarkus.websockets.next.HandshakeRequest; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.vertx.core.http.WebSocket; +import io.vertx.core.http.WebSocketBase; + +class WebSocketClientConnectionImpl extends WebSocketConnectionBase implements WebSocketClientConnection { + + private final String clientId; + + private final WebSocket webSocket; + + WebSocketClientConnectionImpl(String clientId, WebSocket webSocket, Codecs codecs, + Map pathParams, URI serverEndpointUri, Map> headers) { + super(Map.copyOf(pathParams), codecs, new ClientHandshakeRequestImpl(serverEndpointUri, headers)); + this.clientId = clientId; + this.webSocket = Objects.requireNonNull(webSocket); + } + + @Override + WebSocketBase webSocket() { + return webSocket; + } + + @Override + public String clientId() { + return clientId; + } + + @Override + public String toString() { + return "WebSocket client connection [id=" + identifier + ", clientId=" + clientId + "]"; + } + + @Override + public int hashCode() { + return Objects.hash(identifier); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + WebSocketClientConnectionImpl other = (WebSocketClientConnectionImpl) obj; + return Objects.equals(identifier, other.identifier); + } + + private static class ClientHandshakeRequestImpl implements HandshakeRequest { + + private final URI serverEndpointUrl; + private final Map> headers; + + ClientHandshakeRequestImpl(URI serverEndpointUrl, Map> headers) { + this.serverEndpointUrl = serverEndpointUrl; + Map> copy = new HashMap<>(); + for (Entry> e : headers.entrySet()) { + copy.put(e.getKey().toLowerCase(), List.copyOf(e.getValue())); + } + this.headers = copy; + } + + @Override + public String header(String name) { + List values = headers(name); + return values.isEmpty() ? null : values.get(0); + } + + @Override + public List headers(String name) { + return headers.getOrDefault(Objects.requireNonNull(name).toLowerCase(), List.of()); + } + + @Override + public Map> headers() { + return headers; + } + + @Override + public String scheme() { + return serverEndpointUrl.getScheme(); + } + + @Override + public String host() { + return serverEndpointUrl.getHost(); + } + + @Override + public int port() { + return serverEndpointUrl.getPort(); + } + + @Override + public String path() { + return serverEndpointUrl.getPath(); + } + + @Override + public String query() { + return serverEndpointUrl.getQuery(); + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientRecorder.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientRecorder.java new file mode 100644 index 0000000000000..fa47dc2e7309c --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketClientRecorder.java @@ -0,0 +1,72 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.Map; +import java.util.function.Supplier; + +import io.quarkus.runtime.annotations.RecordableConstructor; +import io.quarkus.runtime.annotations.Recorder; +import io.quarkus.websockets.next.WebSocketClientException; +import io.smallrye.common.vertx.VertxContext; +import io.vertx.core.Context; +import io.vertx.core.Vertx; + +@Recorder +public class WebSocketClientRecorder { + + public Supplier connectionSupplier() { + return new Supplier() { + + @Override + public Object get() { + Context context = Vertx.currentContext(); + if (context != null && VertxContext.isDuplicatedContext(context)) { + Object connection = context.getLocal(ContextSupport.WEB_SOCKET_CONN_KEY); + if (connection != null) { + return connection; + } + } + throw new WebSocketClientException("Unable to obtain the connection from the Vert.x duplicated context"); + } + }; + } + + public Supplier createContext(Map endpointMap) { + return new Supplier() { + @Override + public Object get() { + return new ClientEndpointsContext() { + + @Override + public ClientEndpoint endpoint(String endpointClass) { + return endpointMap.get(endpointClass); + } + + }; + } + }; + } + + public interface ClientEndpointsContext { + + ClientEndpoint endpoint(String endpointClass); + + } + + public static class ClientEndpoint { + + public final String clientId; + + public final String path; + + public final String generatedEndpointClass; + + @RecordableConstructor + public ClientEndpoint(String clientId, String path, String generatedEndpointClass) { + this.clientId = clientId; + this.path = path; + this.generatedEndpointClass = generatedEndpointClass; + } + + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java new file mode 100644 index 0000000000000..fc3d07727d7bb --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java @@ -0,0 +1,113 @@ +package io.quarkus.websockets.next.runtime; + +import java.time.Instant; +import java.util.Map; +import java.util.UUID; + +import org.jboss.logging.Logger; + +import io.quarkus.vertx.core.runtime.VertxBufferImpl; +import io.quarkus.websockets.next.HandshakeRequest; +import io.quarkus.websockets.next.WebSocketConnection.BroadcastSender; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.vertx.UniHelper; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.buffer.impl.BufferImpl; +import io.vertx.core.http.WebSocketBase; +import io.vertx.core.json.JsonArray; +import io.vertx.core.json.JsonObject; + +public abstract class WebSocketConnectionBase { + + private static final Logger LOG = Logger.getLogger(WebSocketConnectionBase.class); + + protected final String identifier; + + protected final Map pathParams; + + protected final Codecs codecs; + + protected final HandshakeRequest handshakeRequest; + + protected final Instant creationTime; + + WebSocketConnectionBase(Map pathParams, Codecs codecs, HandshakeRequest handshakeRequest) { + this.identifier = UUID.randomUUID().toString(); + this.pathParams = pathParams; + this.codecs = codecs; + this.handshakeRequest = handshakeRequest; + this.creationTime = Instant.now(); + } + + abstract WebSocketBase webSocket(); + + public String id() { + return identifier; + } + + public String pathParam(String name) { + return pathParams.get(name); + } + + public Uni sendText(String message) { + return UniHelper.toUni(webSocket().writeTextMessage(message)); + } + + public Uni sendBinary(Buffer message) { + return UniHelper.toUni(webSocket().writeBinaryMessage(message)); + } + + public Uni sendText(M message) { + String text; + // Use the same conversion rules as defined for the OnTextMessage + if (message instanceof JsonObject || message instanceof JsonArray || message instanceof BufferImpl + || message instanceof VertxBufferImpl) { + text = message.toString(); + } else if (message.getClass().isArray() && message.getClass().arrayType().equals(byte.class)) { + text = Buffer.buffer((byte[]) message).toString(); + } else { + text = codecs.textEncode(message, null); + } + return sendText(text); + } + + public Uni sendPing(Buffer data) { + return UniHelper.toUni(webSocket().writePing(data)); + } + + void sendAutoPing() { + webSocket().writePing(Buffer.buffer("ping")).onComplete(r -> { + if (r.failed()) { + LOG.warnf("Unable to send auto-ping for %s: %s", this, r.cause().toString()); + } + }); + } + + public Uni sendPong(Buffer data) { + return UniHelper.toUni(webSocket().writePong(data)); + } + + public Uni close() { + return UniHelper.toUni(webSocket().close()); + } + + public boolean isSecure() { + return webSocket().isSsl(); + } + + public boolean isClosed() { + return webSocket().isClosed(); + } + + public HandshakeRequest handshakeRequest() { + return handshakeRequest; + } + + public Instant creationTime() { + return creationTime; + } + + public BroadcastSender broadcast() { + throw new UnsupportedOperationException(); + } +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java index 9de3d34d70efd..124fc48bdab6b 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionImpl.java @@ -1,6 +1,5 @@ package io.quarkus.websockets.next.runtime; -import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -8,66 +7,44 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Set; -import java.util.UUID; import java.util.function.BiFunction; import java.util.function.Predicate; import java.util.stream.Collectors; -import org.jboss.logging.Logger; - -import io.quarkus.vertx.core.runtime.VertxBufferImpl; +import io.quarkus.websockets.next.HandshakeRequest; import io.quarkus.websockets.next.WebSocketConnection; import io.smallrye.mutiny.Uni; -import io.smallrye.mutiny.vertx.UniHelper; import io.vertx.core.buffer.Buffer; -import io.vertx.core.buffer.impl.BufferImpl; import io.vertx.core.http.ServerWebSocket; -import io.vertx.core.json.JsonArray; -import io.vertx.core.json.JsonObject; +import io.vertx.core.http.WebSocketBase; import io.vertx.ext.web.RoutingContext; -class WebSocketConnectionImpl implements WebSocketConnection { - - private static final Logger LOG = Logger.getLogger(WebSocketConnectionImpl.class); +class WebSocketConnectionImpl extends WebSocketConnectionBase implements WebSocketConnection { private final String generatedEndpointClass; private final String endpointId; - private final String identifier; - private final ServerWebSocket webSocket; private final ConnectionManager connectionManager; - private final Codecs codecs; - - private final Map pathParams; - - private final HandshakeRequest handshakeRequest; - private final BroadcastSender defaultBroadcast; - private final Instant creationTime; - WebSocketConnectionImpl(String generatedEndpointClass, String endpointClass, ServerWebSocket webSocket, ConnectionManager connectionManager, Codecs codecs, RoutingContext ctx) { + super(Map.copyOf(ctx.pathParams()), codecs, new HandshakeRequestImpl(webSocket, ctx)); this.generatedEndpointClass = generatedEndpointClass; this.endpointId = endpointClass; - this.identifier = UUID.randomUUID().toString(); this.webSocket = Objects.requireNonNull(webSocket); this.connectionManager = Objects.requireNonNull(connectionManager); - this.pathParams = Map.copyOf(ctx.pathParams()); this.defaultBroadcast = new BroadcastImpl(null); - this.codecs = codecs; - this.handshakeRequest = new HandshakeRequestImpl(ctx); - this.creationTime = Instant.now(); } @Override - public String id() { - return identifier; + WebSocketBase webSocket() { + return webSocket; } @Override @@ -75,95 +52,22 @@ public String endpointId() { return endpointId; } - @Override - public String pathParam(String name) { - return pathParams.get(name); - } - - @Override - public Uni sendText(String message) { - return UniHelper.toUni(webSocket.writeTextMessage(message)); - } - - @Override - public Uni sendBinary(Buffer message) { - return UniHelper.toUni(webSocket.writeBinaryMessage(message)); - } - - @Override - public Uni sendText(M message) { - String text; - // Use the same conversion rules as defined for the OnTextMessage - if (message instanceof JsonObject || message instanceof JsonArray || message instanceof BufferImpl - || message instanceof VertxBufferImpl) { - text = message.toString(); - } else if (message.getClass().isArray() && message.getClass().arrayType().equals(byte.class)) { - text = Buffer.buffer((byte[]) message).toString(); - } else { - text = codecs.textEncode(message, null); - } - return sendText(text); - } - - @Override - public Uni sendPing(Buffer data) { - return UniHelper.toUni(webSocket.writePing(data)); - } - - void sendAutoPing() { - webSocket.writePing(Buffer.buffer("ping")).onComplete(r -> { - if (r.failed()) { - LOG.warnf("Unable to send auto-ping for %s: %s", this, r.cause().toString()); - } - }); - } - - @Override - public Uni sendPong(Buffer data) { - return UniHelper.toUni(webSocket.writePong(data)); - } - @Override public BroadcastSender broadcast() { return defaultBroadcast; } - @Override - public Uni close() { - return UniHelper.toUni(webSocket.close()); - } - - @Override - public boolean isSecure() { - return webSocket.isSsl(); - } - - @Override - public boolean isClosed() { - return webSocket.isClosed(); - } - @Override public Set getOpenConnections() { return connectionManager.getConnections(generatedEndpointClass).stream().filter(WebSocketConnection::isOpen) .collect(Collectors.toUnmodifiableSet()); } - @Override - public HandshakeRequest handshakeRequest() { - return handshakeRequest; - } - @Override public String subprotocol() { return webSocket.subProtocol(); } - @Override - public Instant creationTime() { - return creationTime; - } - @Override public String toString() { return "WebSocket connection [id=" + identifier + ", path=" + webSocket.path() + "]"; @@ -186,11 +90,14 @@ public boolean equals(Object obj) { return Objects.equals(identifier, other.identifier); } - private class HandshakeRequestImpl implements HandshakeRequest { + private static class HandshakeRequestImpl implements HandshakeRequest { + + private final ServerWebSocket webSocket; private final Map> headers; - HandshakeRequestImpl(RoutingContext ctx) { + HandshakeRequestImpl(ServerWebSocket webSocket, RoutingContext ctx) { + this.webSocket = webSocket; this.headers = initHeaders(ctx); } diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java new file mode 100644 index 0000000000000..dc53643b588e7 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java @@ -0,0 +1,129 @@ +package io.quarkus.websockets.next.runtime; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.quarkus.websockets.next.WebSocketClientException; +import io.quarkus.websockets.next.WebSocketsClientRuntimeConfig; +import io.vertx.core.Vertx; + +abstract class WebSocketConnectorBase> { + + protected static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}"); + + // mutable state + + protected URI baseUri; + + protected final Map pathParams; + + protected final Map> headers; + + protected final Set subprotocols; + + protected String path; + + protected Set pathParamNames; + + // injected dependencies + + protected final Vertx vertx; + + protected final Codecs codecs; + + protected final ClientConnectionManager connectionManager; + + protected final WebSocketsClientRuntimeConfig config; + + WebSocketConnectorBase(Vertx vertx, Codecs codecs, + ClientConnectionManager connectionManager, WebSocketsClientRuntimeConfig config) { + this.headers = new HashMap<>(); + this.subprotocols = new HashSet<>(); + this.pathParams = new HashMap<>(); + this.vertx = vertx; + this.codecs = codecs; + this.connectionManager = connectionManager; + this.config = config; + this.pathParamNames = Set.of(); + } + + public THIS baseUri(URI baseUri) { + this.baseUri = Objects.requireNonNull(baseUri); + return self(); + } + + public THIS addHeader(String name, String value) { + Objects.requireNonNull(name); + Objects.requireNonNull(value); + List values = headers.get(name); + if (values == null) { + values = new ArrayList<>(); + headers.put(name, values); + } + values.add(value); + return self(); + } + + public THIS pathParam(String name, String value) { + Objects.requireNonNull(name); + Objects.requireNonNull(value); + if (!pathParamNames.contains(name)) { + throw new IllegalArgumentException( + String.format("[%s] is not a valid path parameter in the path %s", name, path)); + } + pathParams.put(name, value); + return self(); + } + + public THIS addSubprotocol(String value) { + subprotocols.add(Objects.requireNonNull(value)); + return self(); + } + + void setPath(String path) { + this.path = path; + this.pathParamNames = getPathParamNames(path); + } + + @SuppressWarnings("unchecked") + protected THIS self() { + return (THIS) this; + } + + Set getPathParamNames(String path) { + Set names = new HashSet<>(); + Matcher m = PATH_PARAM_PATTERN.matcher(path); + while (m.find()) { + String match = m.group(); + String paramName = match.substring(1, match.length() - 1); + names.add(paramName); + } + return names; + } + + String replacePathParameters(String path) { + StringBuilder sb = new StringBuilder(); + Matcher m = PATH_PARAM_PATTERN.matcher(path); + while (m.find()) { + // Replace {foo} with the param value + String match = m.group(); + String paramName = match.substring(1, match.length() - 1); + String val = pathParams.get(paramName); + if (val == null) { + throw new WebSocketClientException("Unable to obtain the path param for: " + paramName); + } + m.appendReplacement(sb, val); + } + m.appendTail(sb); + return path.startsWith("/") ? sb.toString() : "/" + sb.toString(); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java new file mode 100644 index 0000000000000..a4abe65f42162 --- /dev/null +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -0,0 +1,135 @@ +package io.quarkus.websockets.next.runtime; + +import java.lang.reflect.ParameterizedType; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; + +import jakarta.enterprise.context.Dependent; +import jakarta.enterprise.inject.Typed; +import jakarta.enterprise.inject.spi.InjectionPoint; + +import org.eclipse.microprofile.config.ConfigProvider; +import org.jboss.logging.Logger; + +import io.quarkus.arc.Arc; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketClientException; +import io.quarkus.websockets.next.WebSocketConnector; +import io.quarkus.websockets.next.WebSocketsClientRuntimeConfig; +import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpoint; +import io.quarkus.websockets.next.runtime.WebSocketClientRecorder.ClientEndpointsContext; +import io.smallrye.mutiny.Uni; +import io.smallrye.mutiny.vertx.UniHelper; +import io.vertx.core.Vertx; +import io.vertx.core.http.WebSocketClient; +import io.vertx.core.http.WebSocketClientOptions; +import io.vertx.core.http.WebSocketConnectOptions; + +@Typed(WebSocketConnector.class) +@Dependent +public class WebSocketConnectorImpl extends WebSocketConnectorBase> + implements WebSocketConnector { + + private static final Logger LOG = Logger.getLogger(WebSocketConnectorImpl.class); + + // derived properties + + private final ClientEndpoint clientEndpoint; + + WebSocketConnectorImpl(InjectionPoint injectionPoint, Codecs codecs, Vertx vertx, ClientConnectionManager connectionManager, + ClientEndpointsContext endpointsContext, WebSocketsClientRuntimeConfig config) { + super(vertx, codecs, connectionManager, config); + this.clientEndpoint = Objects.requireNonNull(endpointsContext.endpoint(getEndpointClass(injectionPoint))); + setPath(clientEndpoint.path); + } + + @Override + public Uni connect() { + // Currently we create a new client for each connection + // The client is closed when the connection is closed + // TODO would it make sense to share clients? + WebSocketClientOptions clientOptions = new WebSocketClientOptions(); + if (config.offerPerMessageCompression()) { + clientOptions.setTryUsePerMessageCompression(true); + if (config.compressionLevel().isPresent()) { + clientOptions.setCompressionLevel(config.compressionLevel().getAsInt()); + } + } + if (config.maxMessageSize().isPresent()) { + clientOptions.setMaxMessageSize(config.maxMessageSize().getAsInt()); + } + + WebSocketClient client = vertx.createWebSocketClient(); + + StringBuilder serverEndpoint = new StringBuilder(); + if (baseUri != null) { + serverEndpoint.append(baseUri.toString()); + } else { + // Obtain the base URI from the config + String key = clientEndpoint.clientId + ".base-uri"; + Optional maybeBaseUri = ConfigProvider.getConfig().getOptionalValue(key, String.class); + if (maybeBaseUri.isEmpty()) { + throw new WebSocketClientException("Unable to obtain the config value for: " + key); + } + serverEndpoint.append(maybeBaseUri.get()); + } + serverEndpoint.append(replacePathParameters(clientEndpoint.path)); + + URI serverEndpointUri; + try { + serverEndpointUri = new URI(serverEndpoint.toString()); + } catch (URISyntaxException e) { + throw new WebSocketClientException(e); + } + + WebSocketConnectOptions connectOptions = new WebSocketConnectOptions() + .setSsl(serverEndpointUri.getScheme().equals("https")) + .setHost(serverEndpointUri.getHost()) + .setPort(serverEndpointUri.getPort()); + StringBuilder uri = new StringBuilder(); + if (serverEndpointUri.getPath() != null) { + uri.append(serverEndpointUri.getPath()); + } + if (serverEndpointUri.getQuery() != null) { + uri.append("?").append(serverEndpointUri.getQuery()); + } + connectOptions.setURI(uri.toString()); + for (Entry> e : headers.entrySet()) { + for (String val : e.getValue()) { + connectOptions.addHeader(e.getKey(), val); + } + } + subprotocols.forEach(connectOptions::addSubProtocol); + + return UniHelper.toUni(client.connect(connectOptions)) + .map(ws -> { + WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientEndpoint.clientId, ws, + codecs, + pathParams, + serverEndpointUri, headers); + LOG.debugf("Client connection created: %s", connection); + connectionManager.add(clientEndpoint.generatedEndpointClass, connection); + + Endpoints.initialize(vertx, Arc.container(), codecs, connection, ws, + clientEndpoint.generatedEndpointClass, config.autoPingInterval(), + () -> { + connectionManager.remove(clientEndpoint.generatedEndpointClass, connection); + client.close(); + }); + + return connection; + }); + } + + String getEndpointClass(InjectionPoint injectionPoint) { + // The type is validated during build - if it does not represent a client endpoint the build fails + // WebSocketConnectorImpl -> org.acme.Foo + ParameterizedType parameterizedType = (ParameterizedType) injectionPoint.getType(); + return parameterizedType.getActualTypeArguments()[0].getTypeName(); + } + +} diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java index a5dfb70a86076..b811cfc22b049 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpoint.java @@ -2,7 +2,7 @@ import java.lang.reflect.Type; -import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.InboundProcessingMode; import io.smallrye.mutiny.Uni; import io.vertx.core.Future; import io.vertx.core.buffer.Buffer; @@ -10,16 +10,15 @@ /** * Internal representation of a WebSocket endpoint. *

- * A new instance is created for each client connection. + * A new instance is created for each connection. */ public interface WebSocketEndpoint { /** * - * @see WebSocket#executionMode() - * @return the execution mode + * @return the inbound processing mode */ - WebSocket.ExecutionMode executionMode(); + InboundProcessingMode inboundProcessingMode(); // @OnOpen diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java index 3a7b4da9dfa33..ed453f59a97c9 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketEndpointBase.java @@ -15,9 +15,7 @@ import io.quarkus.arc.InjectableBean; import io.quarkus.arc.InjectableContext.ContextState; import io.quarkus.virtual.threads.VirtualThreadsRecorder; -import io.quarkus.websockets.next.WebSocket.ExecutionMode; -import io.quarkus.websockets.next.WebSocketConnection; -import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; +import io.quarkus.websockets.next.InboundProcessingMode; import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; @@ -34,15 +32,12 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private static final Logger LOG = Logger.getLogger(WebSocketEndpointBase.class); // Keep this field public - there's a problem with ConnectionArgumentProvider reading the protected field in the test mode - public final WebSocketConnection connection; + public final WebSocketConnectionBase connection; protected final Codecs codecs; private final ConcurrencyLimiter limiter; - @SuppressWarnings("unused") - private final WebSocketsServerRuntimeConfig config; - private final ArcContainer container; private final ContextSupport contextSupport; @@ -50,12 +45,10 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint { private final InjectableBean bean; private final Object beanInstance; - public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs, - WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) { + public WebSocketEndpointBase(WebSocketConnectionBase connection, Codecs codecs, ContextSupport contextSupport) { this.connection = connection; this.codecs = codecs; - this.limiter = executionMode() == ExecutionMode.SERIAL ? new ConcurrencyLimiter(connection) : null; - this.config = config; + this.limiter = inboundProcessingMode() == InboundProcessingMode.SERIAL ? new ConcurrencyLimiter(connection) : null; this.container = Arc.container(); this.contextSupport = contextSupport; InjectableBean bean = container.bean(beanIdentifier()); diff --git a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index bca955cc02b4c..e580cf85791e7 100644 --- a/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/server/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -1,29 +1,20 @@ package io.quarkus.websockets.next.runtime; -import java.util.function.Consumer; import java.util.function.Supplier; -import jakarta.enterprise.context.SessionScoped; - import org.jboss.logging.Logger; import io.quarkus.arc.Arc; import io.quarkus.arc.ArcContainer; -import io.quarkus.arc.InjectableContext; import io.quarkus.runtime.annotations.Recorder; import io.quarkus.vertx.core.runtime.VertxCoreRecorder; -import io.quarkus.websockets.next.WebSocketConnection; import io.quarkus.websockets.next.WebSocketServerException; import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig; -import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState; import io.smallrye.common.vertx.VertxContext; -import io.smallrye.mutiny.Multi; -import io.smallrye.mutiny.operators.multi.processors.BroadcastProcessor; import io.vertx.core.Context; import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Vertx; -import io.vertx.core.buffer.Buffer; import io.vertx.core.http.ServerWebSocket; import io.vertx.ext.web.RoutingContext; @@ -32,8 +23,6 @@ public class WebSocketServerRecorder { private static final Logger LOG = Logger.getLogger(WebSocketServerRecorder.class); - static final String WEB_SOCKET_CONN_KEY = WebSocketConnection.class.getName(); - private final WebSocketsServerRuntimeConfig config; public WebSocketServerRecorder(WebSocketsServerRuntimeConfig config) { @@ -47,7 +36,7 @@ public Supplier connectionSupplier() { public Object get() { Context context = Vertx.currentContext(); if (context != null && VertxContext.isDuplicatedContext(context)) { - Object connection = context.getLocal(WEB_SOCKET_CONN_KEY); + Object connection = context.getLocal(ContextSupport.WEB_SOCKET_CONN_KEY); if (connection != null) { return connection; } @@ -68,288 +57,17 @@ public void handle(RoutingContext ctx) { Future future = ctx.request().toWebSocket(); future.onSuccess(ws -> { Vertx vertx = VertxCoreRecorder.getVertx().get(); - Context context = vertx.getOrCreateContext(); WebSocketConnectionImpl connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws, connectionManager, codecs, ctx); connectionManager.add(generatedEndpointClass, connection); LOG.debugf("Connection created: %s", connection); - // Initialize and capture the session context state that will be activated - // during message processing - WebSocketSessionContext sessionContext = sessionContext(container); - SessionContextState sessionContextState = sessionContext.initializeContextState(); - ContextSupport contextSupport = new ContextSupport(connection, sessionContextState, - sessionContext(container), - container.requestContext()); - - // Create an endpoint that delegates callbacks to the @WebSocket bean - WebSocketEndpoint endpoint = createEndpoint(generatedEndpointClass, context, connection, codecs, config, - contextSupport); - - // A broadcast processor is only needed if Multi is consumed by the callback - BroadcastProcessor textBroadcastProcessor = endpoint.consumedTextMultiType() != null - ? BroadcastProcessor.create() - : null; - BroadcastProcessor binaryBroadcastProcessor = endpoint.consumedBinaryMultiType() != null - ? BroadcastProcessor.create() - : null; - - // NOTE: We always invoke callbacks on a new duplicated context - // and the endpoint is responsible to make the switch if blocking/virtualThread - - Context onOpenContext = ContextSupport.createNewDuplicatedContext(context, connection); - onOpenContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - endpoint.onOpen().onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnOpen callback completed: %s", connection); - // If Multi is consumed we need to invoke the callback eagerly - // but after @OnOpen completes - if (textBroadcastProcessor != null) { - Multi multi = textBroadcastProcessor.onCancellation().call(connection::close); - onOpenContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - endpoint.onTextMessage(multi).onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnTextMessage callback consuming Multi completed: %s", - connection); - } else { - LOG.errorf(r.cause(), - "Unable to complete @OnTextMessage callback consuming Multi: %s", - connection); - } - }); - } - }); - } - if (binaryBroadcastProcessor != null) { - Multi multi = binaryBroadcastProcessor.onCancellation().call(connection::close); - onOpenContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - endpoint.onBinaryMessage(multi).onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnBinaryMessage callback consuming Multi completed: %s", - connection); - } else { - LOG.errorf(r.cause(), - "Unable to complete @OnBinaryMessage callback consuming Multi: %s", - connection); - } - }); - } - }); - } - } else { - LOG.errorf(r.cause(), "Unable to complete @OnOpen callback: %s", connection); - } - }); - } - }); - - if (textBroadcastProcessor == null) { - // Multi not consumed - invoke @OnTextMessage callback for each message received - textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { - endpoint.onTextMessage(m).onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnTextMessage callback consumed text message: %s", connection); - } else { - LOG.errorf(r.cause(), "Unable to consume text message in @OnTextMessage callback: %s", - connection); - } - }); - }, true); - } else { - textMessageHandler(connection, endpoint, ws, onOpenContext, m -> { - contextSupport.start(); - try { - textBroadcastProcessor.onNext(endpoint.decodeTextMultiItem(m)); - LOG.debugf("Text message >> Multi: %s", connection); - } catch (Throwable throwable) { - endpoint.doOnError(throwable).subscribe().with( - v -> LOG.debugf("Text message >> Multi: %s", connection), - t -> LOG.errorf(t, "Unable to send text message to Multi: %s", connection)); - } finally { - contextSupport.end(false); - } - }, false); - } - - if (binaryBroadcastProcessor == null) { - // Multi not consumed - invoke @OnBinaryMessage callback for each message received - binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { - endpoint.onBinaryMessage(m).onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnBinaryMessage callback consumed text message: %s", connection); - } else { - LOG.errorf(r.cause(), "Unable to consume text message in @OnBinaryMessage callback: %s", - connection); - } - }); - }, true); - } else { - binaryMessageHandler(connection, endpoint, ws, onOpenContext, m -> { - contextSupport.start(); - try { - binaryBroadcastProcessor.onNext(endpoint.decodeBinaryMultiItem(m)); - LOG.debugf("Binary message >> Multi: %s", connection); - } catch (Throwable throwable) { - endpoint.doOnError(throwable).subscribe().with( - v -> LOG.debugf("Binary message >> Multi: %s", connection), - t -> LOG.errorf(t, "Unable to send binary message to Multi: %s", connection)); - } finally { - contextSupport.end(false); - } - }, false); - } - - pongMessageHandler(connection, endpoint, ws, onOpenContext, m -> { - endpoint.onPongMessage(m).onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnPongMessage callback consumed text message: %s", connection); - } else { - LOG.errorf(r.cause(), "Unable to consume text message in @OnPongMessage callback: %s", - connection); - } - }); - }); - - Long timerId; - if (config.autoPingInterval().isPresent()) { - timerId = vertx.setPeriodic(config.autoPingInterval().get().toMillis(), new Handler() { - @Override - public void handle(Long timerId) { - connection.sendAutoPing(); - } - }); - } else { - timerId = null; - } - - ws.closeHandler(new Handler() { - @Override - public void handle(Void event) { - ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { - @Override - public void handle(Void event) { - endpoint.onClose().onComplete(r -> { - if (r.succeeded()) { - LOG.debugf("@OnClose callback completed: %s", connection); - } else { - LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection); - } - connectionManager.remove(generatedEndpointClass, connection); - if (timerId != null) { - vertx.cancelTimer(timerId); - } - }); - } - }); - } - }); - - ws.exceptionHandler(new Handler() { - @Override - public void handle(Throwable t) { - ContextSupport.createNewDuplicatedContext(context, connection).runOnContext(new Handler() { - @Override - public void handle(Void event) { - endpoint.doOnError(t).subscribe().with( - v -> LOG.debugf("Error [%s] processed: %s", t.getClass(), connection), - t -> LOG.errorf(t, "Unhandled error occured: %s", t.toString(), - connection)); - } - }); - } - }); - + Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, + config.autoPingInterval(), () -> connectionManager.remove(generatedEndpointClass, connection)); }); } }; } - private void textMessageHandler(WebSocketConnection connection, WebSocketEndpoint endpoint, ServerWebSocket ws, - Context context, Consumer textAction, boolean newDuplicatedContext) { - ws.textMessageHandler(new Handler() { - @Override - public void handle(String message) { - Context duplicatedContext = newDuplicatedContext - ? ContextSupport.createNewDuplicatedContext(context, connection) - : context; - duplicatedContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - textAction.accept(message); - } - }); - } - }); - } - - private void binaryMessageHandler(WebSocketConnection connection, WebSocketEndpoint endpoint, ServerWebSocket ws, - Context context, Consumer binaryAction, boolean newDuplicatedContext) { - ws.binaryMessageHandler(new Handler() { - @Override - public void handle(Buffer message) { - Context duplicatedContext = newDuplicatedContext - ? ContextSupport.createNewDuplicatedContext(context, connection) - : context; - duplicatedContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - binaryAction.accept(message); - } - }); - } - }); - } - - private void pongMessageHandler(WebSocketConnection connection, WebSocketEndpoint endpoint, ServerWebSocket ws, - Context context, Consumer pongAction) { - ws.pongHandler(new Handler() { - @Override - public void handle(Buffer message) { - Context duplicatedContext = ContextSupport.createNewDuplicatedContext(context, connection); - duplicatedContext.runOnContext(new Handler() { - @Override - public void handle(Void event) { - pongAction.accept(message); - } - }); - } - }); - } - - private WebSocketEndpoint createEndpoint(String endpointClassName, Context context, WebSocketConnection connection, - Codecs codecs, WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) { - try { - ClassLoader cl = Thread.currentThread().getContextClassLoader(); - if (cl == null) { - cl = WebSocketServerRecorder.class.getClassLoader(); - } - @SuppressWarnings("unchecked") - Class endpointClazz = (Class) cl - .loadClass(endpointClassName); - WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz - .getDeclaredConstructor(WebSocketConnection.class, Codecs.class, - WebSocketsServerRuntimeConfig.class, ContextSupport.class) - .newInstance(connection, codecs, config, contextSupport); - return endpoint; - } catch (Exception e) { - throw new WebSocketServerException("Unable to create endpoint instance: " + endpointClassName, e); - } - } - - private static WebSocketSessionContext sessionContext(ArcContainer container) { - for (InjectableContext injectableContext : container.getContexts(SessionScoped.class)) { - if (WebSocketSessionContext.class.equals(injectableContext.getClass())) { - return (WebSocketSessionContext) injectableContext; - } - } - throw new WebSocketServerException("CDI session context not registered"); - } - }