Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSockets Next: client endpoints #40309

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
package io.quarkus.websockets.next.deployment;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;
import org.jboss.jandex.Type.Kind;

import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem;
import io.quarkus.arc.processor.Annotations;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.FieldDescriptor;
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.WebSocketException;
import io.quarkus.websockets.next.deployment.CallbackArgument.InvocationBytecodeContext;
import io.quarkus.websockets.next.deployment.CallbackArgument.ParameterContext;
import io.quarkus.websockets.next.runtime.WebSocketConnectionBase;
import io.quarkus.websockets.next.runtime.WebSocketEndpoint.ExecutionModel;
import io.quarkus.websockets.next.runtime.WebSocketEndpointBase;

/**
* Represents either an endpoint callback or a global error handler.
*/
public class Callback {

public final Target target;
public final String endpointPath;
public final AnnotationInstance annotation;
public final MethodInfo method;
public final ExecutionModel executionModel;
public final MessageType messageType;
public final List<CallbackArgument> arguments;

public Callback(Target target, AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
String endpointPath, IndexView index) {
this.target = target;
this.method = method;
this.annotation = annotation;
this.executionModel = executionModel;
if (WebSocketDotNames.ON_BINARY_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.BINARY;
} else if (WebSocketDotNames.ON_TEXT_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.TEXT;
} else if (WebSocketDotNames.ON_PONG_MESSAGE.equals(annotation.name())) {
this.messageType = MessageType.PONG;
} else {
this.messageType = MessageType.NONE;
}
this.endpointPath = endpointPath;
this.arguments = collectArguments(annotation, method, callbackArguments, transformedAnnotations, index);
}

public boolean isGlobal() {
return endpointPath == null;
}

public boolean isClient() {
return target == Target.CLIENT;
}

public boolean isServer() {
return target == Target.SERVER;
}

public boolean isOnOpen() {
return annotation.name().equals(WebSocketDotNames.ON_OPEN);
}

public boolean isOnClose() {
return annotation.name().equals(WebSocketDotNames.ON_CLOSE);
}

public boolean isOnError() {
return annotation.name().equals(WebSocketDotNames.ON_ERROR);
}

public Type returnType() {
return method.returnType();
}

public Type messageParamType() {
return acceptsMessage() ? method.parameterType(0) : null;
}

public boolean isReturnTypeVoid() {
return returnType().kind() == Kind.VOID;
}

public boolean isReturnTypeUni() {
return WebSocketDotNames.UNI.equals(returnType().name());
}

public boolean isReturnTypeMulti() {
return WebSocketDotNames.MULTI.equals(returnType().name());
}

public boolean acceptsMessage() {
return messageType != MessageType.NONE;
}

public boolean acceptsBinaryMessage() {
return messageType == MessageType.BINARY || messageType == MessageType.PONG;
}

public boolean acceptsMulti() {
return acceptsMessage() && method.parameterType(0).name().equals(WebSocketDotNames.MULTI);
}

public Callback.MessageType messageType() {
return messageType;
}

public boolean broadcast() {
AnnotationValue broadcastValue = annotation.value("broadcast");
return broadcastValue != null && broadcastValue.asBoolean();
}

public DotName getInputCodec() {
return getCodec("codec");
}

public DotName getOutputCodec() {
DotName output = getCodec("outputCodec");
return output != null ? output : getInputCodec();
}

public String asString() {
return method.declaringClass().name() + "#" + method.name() + "()";
}

private DotName getCodec(String valueName) {
AnnotationValue codecValue = annotation.value(valueName);
if (codecValue != null) {
return codecValue.asClass().name();
}
return null;
}

public enum MessageType {
NONE,
PONG,
TEXT,
BINARY
}

public enum Target {
CLIENT,
SERVER,
UNDEFINED
}

public ResultHandle[] generateArguments(ResultHandle endpointThis, BytecodeCreator bytecode,
TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) {
if (arguments.isEmpty()) {
return new ResultHandle[] {};
}
ResultHandle[] resultHandles = new ResultHandle[arguments.size()];
int idx = 0;
for (CallbackArgument argument : arguments) {
resultHandles[idx] = argument.get(
invocationBytecodeContext(annotation, method.parameters().get(idx), transformedAnnotations, index,
endpointThis, bytecode));
idx++;
}
return resultHandles;
}

private List<CallbackArgument> collectArguments(AnnotationInstance annotation, MethodInfo method,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
IndexView index) {
List<MethodParameterInfo> parameters = method.parameters();
if (parameters.isEmpty()) {
return List.of();
}
List<CallbackArgument> arguments = new ArrayList<>(parameters.size());
for (MethodParameterInfo parameter : parameters) {
List<CallbackArgument> found = callbackArguments
.findMatching(parameterContext(annotation, parameter, transformedAnnotations, index));
if (found.isEmpty()) {
String msg = String.format("Unable to inject @%s callback parameter '%s' declared on %s: no injector found",
DotNames.simpleName(annotation.name()),
parameter.name() != null ? parameter.name() : "#" + parameter.position(),
asString());
throw new WebSocketException(msg);
} else if (found.size() > 1 && (found.get(0).priotity() == found.get(1).priotity())) {
String msg = String.format(
"Unable to inject @%s callback parameter '%s' declared on %s: ambiguous injectors found: %s",
DotNames.simpleName(annotation.name()),
parameter.name() != null ? parameter.name() : "#" + parameter.position(),
asString(),
found.stream().map(p -> p.getClass().getSimpleName() + ":" + p.priotity()));
throw new WebSocketException(msg);
}
arguments.add(found.get(0));
}
return List.copyOf(arguments);
}

Type argumentType(Predicate<CallbackArgument> filter) {
for (int i = 0; i < arguments.size(); i++) {
if (filter.test(arguments.get(i))) {
return method.parameterType(i);
}
}
return null;
}

private ParameterContext parameterContext(AnnotationInstance callbackAnnotation, MethodParameterInfo parameter,
TransformedAnnotationsBuildItem transformedAnnotations, IndexView index) {
return new ParameterContext() {

@Override
public Target callbackTarget() {
return target;
}

@Override
public MethodParameterInfo parameter() {
return parameter;
}

@Override
public Set<AnnotationInstance> parameterAnnotations() {
return Annotations.getParameterAnnotations(
transformedAnnotations::getAnnotations, parameter.method(), parameter.position());
}

@Override
public AnnotationInstance callbackAnnotation() {
return callbackAnnotation;
}

@Override
public String endpointPath() {
return endpointPath;
}

@Override
public IndexView index() {
return index;
}

};
}

private InvocationBytecodeContext invocationBytecodeContext(AnnotationInstance callbackAnnotation,
MethodParameterInfo parameter, TransformedAnnotationsBuildItem transformedAnnotations, IndexView index,
ResultHandle endpointThis, BytecodeCreator bytecode) {
return new InvocationBytecodeContext() {

@Override
public Target callbackTarget() {
return target;
}

@Override
public AnnotationInstance callbackAnnotation() {
return callbackAnnotation;
}

@Override
public MethodParameterInfo parameter() {
return parameter;
}

@Override
public Set<AnnotationInstance> parameterAnnotations() {
return Annotations.getParameterAnnotations(
transformedAnnotations::getAnnotations, parameter.method(), parameter.position());
}

@Override
public String endpointPath() {
return endpointPath;
}

@Override
public IndexView index() {
return index;
}

@Override
public BytecodeCreator bytecode() {
return bytecode;
}

@Override
public ResultHandle getPayload() {
return acceptsMessage() || callbackAnnotation.name().equals(WebSocketDotNames.ON_ERROR)
? bytecode.getMethodParam(0)
: null;
}

@Override
public ResultHandle getDecodedMessage(Type parameterType) {
return acceptsMessage()
? WebSocketProcessor.decodeMessage(endpointThis, bytecode, acceptsBinaryMessage(),
parameterType,
getPayload(), Callback.this)
: null;
}

@Override
public ResultHandle getConnection() {
return bytecode.readInstanceField(
FieldDescriptor.of(WebSocketEndpointBase.class, "connection", WebSocketConnectionBase.class),
endpointThis);
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import io.quarkus.websockets.next.OnError;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketException;
import io.quarkus.websockets.next.deployment.Callback.Target;

/**
* Provides arguments for method parameters of a callback method declared on a WebSocket endpoint.
Expand All @@ -24,7 +25,7 @@ interface CallbackArgument {
*
* @param context
* @return {@code true} if this provider matches the given parameter context, {@code false} otherwise
* @throws WebSocketServerException If an invalid parameter is detected
* @throws WebSocketException If an invalid parameter is detected
*/
boolean matches(ParameterContext context);

Expand All @@ -49,6 +50,12 @@ default int priotity() {

interface ParameterContext {

/**
*
* @return the callback target
*/
Target callbackTarget();

/**
*
* @return the endpoint path or {@code null} for global error handlers
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
package io.quarkus.websockets.next.deployment;

import org.jboss.jandex.DotName;

import io.quarkus.gizmo.ResultHandle;
import io.quarkus.websockets.next.WebSocketException;
import io.quarkus.websockets.next.deployment.Callback.Target;

class ConnectionCallbackArgument implements CallbackArgument {

@Override
public boolean matches(ParameterContext context) {
return context.parameter().type().name().equals(WebSocketDotNames.WEB_SOCKET_CONNECTION);
DotName paramTypeName = context.parameter().type().name();
if (context.callbackTarget() == Target.SERVER) {
if (WebSocketDotNames.WEB_SOCKET_CONNECTION.equals(paramTypeName)) {
return true;
} else if (WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION.equals(paramTypeName)) {
throw new WebSocketException("@WebSocket callback method may not accept WebSocketClientConnection");
}
} else if (context.callbackTarget() == Target.CLIENT) {
if (WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION.equals(paramTypeName)) {
return true;
} else if (WebSocketDotNames.WEB_SOCKET_CONNECTION.equals(paramTypeName)) {
throw new WebSocketException("@WebSocketClient callback method may not accept WebSocketConnection");
}
}
return false;
}

@Override
Expand Down
Loading
Loading