Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine onMessage and onHalfClose callbacks into single call to avoid race condition #75

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import dev.restate.sdk.core.TypeTag;
import dev.restate.sdk.core.syscalls.*;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.Executor;
Expand All @@ -17,8 +16,8 @@ class ExecutorSwitchingWrappers {

private ExecutorSwitchingWrappers() {}

static ServerCall.Listener<MessageLite> serverCallListener(
ServerCall.Listener<MessageLite> sc, Executor userExecutor) {
static RestateServerCallListener<MessageLite> serverCallListener(
RestateServerCallListener<MessageLite> sc, Executor userExecutor) {
return new ExecutorSwitchingServerCallListener(sc, userExecutor);
}

Expand All @@ -27,25 +26,20 @@ static SyscallsInternal syscalls(SyscallsInternal sc, Executor syscallsExecutor)
}

private static class ExecutorSwitchingServerCallListener
extends ServerCall.Listener<MessageLite> {
implements RestateServerCallListener<MessageLite> {

private final ServerCall.Listener<MessageLite> listener;
private final RestateServerCallListener<MessageLite> listener;
private final Executor userExecutor;

private ExecutorSwitchingServerCallListener(
ServerCall.Listener<MessageLite> listener, Executor userExecutor) {
RestateServerCallListener<MessageLite> listener, Executor userExecutor) {
this.listener = listener;
this.userExecutor = userExecutor;
}

@Override
public void onMessage(MessageLite message) {
userExecutor.execute(() -> listener.onMessage(message));
}

@Override
public void onHalfClose() {
userExecutor.execute(listener::onHalfClose);
public void onMessageAndHalfClose(MessageLite message) {
userExecutor.execute(() -> listener.onMessageAndHalfClose(message));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
package dev.restate.sdk.core.impl;

import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

class ExceptionCatchingServerCallListener<ReqT, RespT>
extends ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {
/**
* Adapts a {@link ServerCall.Listener} to a {@link RestateServerCallListener}.
*
* @param <ReqT> type of the request
* @param <RespT> type of the response
*/
class GrpcServerCallListenerAdaptor<ReqT, RespT> implements RestateServerCallListener<ReqT> {

private static final Logger LOG = LogManager.getLogger(ExceptionCatchingServerCallListener.class);
private static final Logger LOG = LogManager.getLogger(GrpcServerCallListenerAdaptor.class);

private final ServerCall<ReqT, RespT> serverCall;

ExceptionCatchingServerCallListener(
private final ServerCall.Listener<ReqT> delegate;

GrpcServerCallListenerAdaptor(
ServerCall.Listener<ReqT> delegate, ServerCall<ReqT, RespT> serverCall) {
super(delegate);
this.delegate = delegate;
this.serverCall = serverCall;
}

@Override
public void onMessage(ReqT message) {
try {
super.onMessage(message);
} catch (Throwable e) {
closeWithException(e);
}
}

@Override
public void onHalfClose() {
public void onMessageAndHalfClose(ReqT message) {
try {
super.onHalfClose();
delegate.onMessage(message);
delegate.onHalfClose();
} catch (Throwable e) {
closeWithException(e);
}
Expand All @@ -40,7 +38,7 @@ public void onHalfClose() {
@Override
public void onCancel() {
try {
super.onCancel();
delegate.onCancel();
} catch (Throwable e) {
closeWithException(e);
}
Expand All @@ -49,7 +47,7 @@ public void onCancel() {
@Override
public void onComplete() {
try {
super.onComplete();
delegate.onComplete();
} catch (Throwable e) {
closeWithException(e);
}
Expand All @@ -58,7 +56,7 @@ public void onComplete() {
@Override
public void onReady() {
try {
super.onReady();
delegate.onReady();
} catch (Throwable e) {
closeWithException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,23 @@ public void start() {
stateMachine.start(
invocationId -> {
// Create the listener and create the decorators chain
ServerCall.Listener<MessageLite> listener =
ServerCall.Listener<MessageLite> grpcListener =
Contexts.interceptCall(
Context.current()
.withValue(InvocationId.INVOCATION_ID_KEY, invocationId)
.withValue(Syscalls.SYSCALLS_KEY, syscalls),
bridge,
new Metadata(),
method.getServerCallHandler());
listener = new ExceptionCatchingServerCallListener<>(listener, bridge);
RestateServerCallListener<MessageLite> restateListener =
new GrpcServerCallListenerAdaptor<>(grpcListener, bridge);
if (serverCallListenerExecutor != null) {
listener =
restateListener =
ExecutorSwitchingWrappers.serverCallListener(
listener, serverCallListenerExecutor);
restateListener, serverCallListenerExecutor);
}

bridge.setListener(listener);
bridge.setListener(restateListener);
});
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class RestateServerCall extends ServerCall<MessageLite, MessageLite> {
//
// The listener reference is volatile in order to guarantee its visibility when the ownership of
// this object is transferred through threads.
private volatile ServerCall.Listener<MessageLite> listener;
private volatile RestateServerCallListener<MessageLite> listener;

// These variables don't need to be volatile as they're accessed and mutated only by
// #setListener() and #request()
Expand All @@ -45,7 +45,7 @@ class RestateServerCall extends ServerCall<MessageLite, MessageLite> {

// --- Invoked in the State machine thread

void setListener(Listener<MessageLite> listener) {
void setListener(RestateServerCallListener<MessageLite> listener) {
this.listener = listener;
this.listener.onReady();

Expand Down Expand Up @@ -158,8 +158,7 @@ private void pollInput() {
MessageLite message = deferredValue.toReadyResult().getResult();

LOG.trace("Read input message:\n{}", message);
listener.onMessage(message);
listener.onHalfClose();
listener.onMessageAndHalfClose(message);
},
this::onError)),
this::onError));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package dev.restate.sdk.core.impl;

/**
* Callbacks for incoming rpc messages.
*
* <p>This interface is strongly inspired by {@link io.grpc.ServerCall.Listener}.
*
* @param <M> type of the incoming message
*/
public interface RestateServerCallListener<M> {
void onMessageAndHalfClose(M message);
tillrohrmann marked this conversation as resolved.
Show resolved Hide resolved

void onCancel();

void onComplete();

void onReady();
}