diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExecutorSwitchingWrappers.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExecutorSwitchingWrappers.java index cb729960..f326cbb6 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExecutorSwitchingWrappers.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExecutorSwitchingWrappers.java @@ -6,7 +6,6 @@ import dev.restate.sdk.core.TypeTag; import dev.restate.sdk.core.syscalls.*; import io.grpc.MethodDescriptor; -import io.grpc.ServerCall; import java.time.Duration; import java.util.Map; import java.util.concurrent.Executor; @@ -17,8 +16,8 @@ class ExecutorSwitchingWrappers { private ExecutorSwitchingWrappers() {} - static ServerCall.Listener serverCallListener( - ServerCall.Listener sc, Executor userExecutor) { + static RestateServerCallListener serverCallListener( + RestateServerCallListener sc, Executor userExecutor) { return new ExecutorSwitchingServerCallListener(sc, userExecutor); } @@ -27,25 +26,20 @@ static SyscallsInternal syscalls(SyscallsInternal sc, Executor syscallsExecutor) } private static class ExecutorSwitchingServerCallListener - extends ServerCall.Listener { + implements RestateServerCallListener { - private final ServerCall.Listener listener; + private final RestateServerCallListener listener; private final Executor userExecutor; private ExecutorSwitchingServerCallListener( - ServerCall.Listener listener, Executor userExecutor) { + RestateServerCallListener listener, Executor userExecutor) { this.listener = listener; this.userExecutor = userExecutor; } @Override - public void onMessage(MessageLite message) { - userExecutor.execute(() -> listener.onMessage(message)); - } - - @Override - public void onHalfClose() { - userExecutor.execute(listener::onHalfClose); + public void onMessageAndHalfClose(MessageLite message) { + userExecutor.execute(() -> listener.onMessageAndHalfClose(message)); } @Override diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExceptionCatchingServerCallListener.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java similarity index 62% rename from sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExceptionCatchingServerCallListener.java rename to sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java index 9426567e..6540b776 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ExceptionCatchingServerCallListener.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/GrpcServerCallListenerAdaptor.java @@ -1,37 +1,35 @@ package dev.restate.sdk.core.impl; -import io.grpc.ForwardingServerCallListener; import io.grpc.Metadata; import io.grpc.ServerCall; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -class ExceptionCatchingServerCallListener - extends ForwardingServerCallListener.SimpleForwardingServerCallListener { +/** + * Adapts a {@link ServerCall.Listener} to a {@link RestateServerCallListener}. + * + * @param type of the request + * @param type of the response + */ +class GrpcServerCallListenerAdaptor implements RestateServerCallListener { - private static final Logger LOG = LogManager.getLogger(ExceptionCatchingServerCallListener.class); + private static final Logger LOG = LogManager.getLogger(GrpcServerCallListenerAdaptor.class); private final ServerCall serverCall; - ExceptionCatchingServerCallListener( + private final ServerCall.Listener delegate; + + GrpcServerCallListenerAdaptor( ServerCall.Listener delegate, ServerCall serverCall) { - super(delegate); + this.delegate = delegate; this.serverCall = serverCall; } @Override - public void onMessage(ReqT message) { - try { - super.onMessage(message); - } catch (Throwable e) { - closeWithException(e); - } - } - - @Override - public void onHalfClose() { + public void onMessageAndHalfClose(ReqT message) { try { - super.onHalfClose(); + delegate.onMessage(message); + delegate.onHalfClose(); } catch (Throwable e) { closeWithException(e); } @@ -40,7 +38,7 @@ public void onHalfClose() { @Override public void onCancel() { try { - super.onCancel(); + delegate.onCancel(); } catch (Throwable e) { closeWithException(e); } @@ -49,7 +47,7 @@ public void onCancel() { @Override public void onComplete() { try { - super.onComplete(); + delegate.onComplete(); } catch (Throwable e) { closeWithException(e); } @@ -58,7 +56,7 @@ public void onComplete() { @Override public void onReady() { try { - super.onReady(); + delegate.onReady(); } catch (Throwable e) { closeWithException(e); } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateGrpcServer.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateGrpcServer.java index a5800535..ffd2d6e5 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateGrpcServer.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateGrpcServer.java @@ -101,7 +101,7 @@ public void start() { stateMachine.start( invocationId -> { // Create the listener and create the decorators chain - ServerCall.Listener listener = + ServerCall.Listener grpcListener = Contexts.interceptCall( Context.current() .withValue(InvocationId.INVOCATION_ID_KEY, invocationId) @@ -109,14 +109,15 @@ public void start() { bridge, new Metadata(), method.getServerCallHandler()); - listener = new ExceptionCatchingServerCallListener<>(listener, bridge); + RestateServerCallListener restateListener = + new GrpcServerCallListenerAdaptor<>(grpcListener, bridge); if (serverCallListenerExecutor != null) { - listener = + restateListener = ExecutorSwitchingWrappers.serverCallListener( - listener, serverCallListenerExecutor); + restateListener, serverCallListenerExecutor); } - bridge.setListener(listener); + bridge.setListener(restateListener); }); } }; diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java index c6c3314c..1f63d37f 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCall.java @@ -30,7 +30,7 @@ class RestateServerCall extends ServerCall { // // The listener reference is volatile in order to guarantee its visibility when the ownership of // this object is transferred through threads. - private volatile ServerCall.Listener listener; + private volatile RestateServerCallListener listener; // These variables don't need to be volatile as they're accessed and mutated only by // #setListener() and #request() @@ -45,7 +45,7 @@ class RestateServerCall extends ServerCall { // --- Invoked in the State machine thread - void setListener(Listener listener) { + void setListener(RestateServerCallListener listener) { this.listener = listener; this.listener.onReady(); @@ -158,8 +158,7 @@ private void pollInput() { MessageLite message = deferredValue.toReadyResult().getResult(); LOG.trace("Read input message:\n{}", message); - listener.onMessage(message); - listener.onHalfClose(); + listener.onMessageAndHalfClose(message); }, this::onError)), this::onError)); diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCallListener.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCallListener.java new file mode 100644 index 00000000..23824dad --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/RestateServerCallListener.java @@ -0,0 +1,18 @@ +package dev.restate.sdk.core.impl; + +/** + * Callbacks for incoming rpc messages. + * + *

This interface is strongly inspired by {@link io.grpc.ServerCall.Listener}. + * + * @param type of the incoming message + */ +public interface RestateServerCallListener { + void onMessageAndHalfClose(M message); + + void onCancel(); + + void onComplete(); + + void onReady(); +}