Skip to content

Commit

Permalink
Update TestkitRequestProcessorHandler to adhere to req/res pattern (#…
Browse files Browse the repository at this point in the history
…1500)

This update aims to make sure the Teskit backend does not break the request / response pattern. For instance, by sending 2 responses together in case of internal callbacks.
  • Loading branch information
injectives authored Nov 1, 2023
1 parent 30a5123 commit 99d32cb
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package neo4j.org.testkit.backend;

import java.util.ArrayDeque;
import java.util.Queue;
import java.util.function.Consumer;
import neo4j.org.testkit.backend.messages.responses.TestkitResponse;

public class ResponseQueueHanlder {
private final Consumer<TestkitResponse> responseWriter;
private final Queue<TestkitResponse> responseQueue = new ArrayDeque<>();
private boolean responseReady;

ResponseQueueHanlder(Consumer<TestkitResponse> responseWriter) {
this.responseWriter = responseWriter;
}

public synchronized void setResponseReadyAndDispatchFirst() {
responseReady = true;
dispatchFirst();
}

public synchronized void offerAndDispatchFirst(TestkitResponse response) {
responseQueue.offer(response);
if (responseReady) {
dispatchFirst();
}
}

private synchronized void dispatchFirst() {
var response = responseQueue.poll();
if (response != null) {
responseReady = false;
responseWriter.accept(response);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ public static void main(String[] args) throws InterruptedException {
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel channel) {
var responseQueueHanlder = new ResponseQueueHanlder(channel::writeAndFlush);
channel.pipeline().addLast(new TestkitMessageInboundHandler());
channel.pipeline().addLast(new TestkitMessageOutboundHandler());
channel.pipeline().addLast(new TestkitRequestResponseMapperHandler(logging));
channel.pipeline().addLast(new TestkitRequestProcessorHandler(backendMode, logging));
channel.pipeline()
.addLast(new TestkitRequestResponseMapperHandler(logging, responseQueueHanlder));
channel.pipeline()
.addLast(new TestkitRequestProcessorHandler(
backendMode, logging, responseQueueHanlder));
}
});
var server = bootstrap.bind().sync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.function.BiFunction;
import neo4j.org.testkit.backend.CustomDriverError;
import neo4j.org.testkit.backend.FrontendError;
import neo4j.org.testkit.backend.ResponseQueueHanlder;
import neo4j.org.testkit.backend.TestkitState;
import neo4j.org.testkit.backend.messages.requests.TestkitRequest;
import neo4j.org.testkit.backend.messages.responses.BackendError;
Expand All @@ -47,9 +48,11 @@ public class TestkitRequestProcessorHandler extends ChannelInboundHandlerAdapter
private final BiFunction<TestkitRequest, TestkitState, CompletionStage<TestkitResponse>> processorImpl;
// Some requests require multiple threads
private final Executor requestExecutorService = Executors.newFixedThreadPool(10);
private final ResponseQueueHanlder responseQueueHanlder;
private Channel channel;

public TestkitRequestProcessorHandler(BackendMode backendMode, Logging logging) {
public TestkitRequestProcessorHandler(
BackendMode backendMode, Logging logging, ResponseQueueHanlder responseQueueHanlder) {
switch (backendMode) {
case ASYNC -> processorImpl = TestkitRequest::processAsync;
case REACTIVE_LEGACY -> processorImpl =
Expand All @@ -59,6 +62,7 @@ public TestkitRequestProcessorHandler(BackendMode backendMode, Logging logging)
default -> processorImpl = TestkitRequestProcessorHandler::wrapSyncRequest;
}
testkitState = new TestkitState(this::writeAndFlush, logging);
this.responseQueueHanlder = responseQueueHanlder;
}

@Override
Expand All @@ -74,14 +78,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
requestExecutorService.execute(() -> {
try {
var request = (TestkitRequest) msg;
var responseStage = processorImpl.apply(request, testkitState);
responseStage.whenComplete((response, throwable) -> {
if (throwable != null) {
ctx.writeAndFlush(createErrorResponse(throwable));
} else if (response != null) {
ctx.writeAndFlush(response);
}
});
processorImpl
.apply(request, testkitState)
.exceptionally(this::createErrorResponse)
.whenComplete((response, ignored) -> {
if (response != null) {
responseQueueHanlder.offerAndDispatchFirst(response);
}
});
} catch (Throwable throwable) {
exceptionCaught(ctx, throwable);
}
Expand All @@ -101,7 +105,8 @@ private static CompletionStage<TestkitResponse> wrapSyncRequest(

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
ctx.writeAndFlush(createErrorResponse(cause));
var response = createErrorResponse(cause);
responseQueueHanlder.offerAndDispatchFirst(response);
}

private TestkitResponse createErrorResponse(Throwable throwable) {
Expand Down Expand Up @@ -165,7 +170,7 @@ private void writeAndFlush(TestkitResponse response) {
if (channel == null) {
throw new IllegalStateException("Called before channel is initialized");
}
channel.writeAndFlush(response);
responseQueueHanlder.offerAndDispatchFirst(response);
}

public enum BackendMode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import neo4j.org.testkit.backend.ResponseQueueHanlder;
import neo4j.org.testkit.backend.messages.TestkitModule;
import neo4j.org.testkit.backend.messages.requests.TestkitRequest;
import neo4j.org.testkit.backend.messages.responses.TestkitResponse;
Expand All @@ -32,17 +33,19 @@
public class TestkitRequestResponseMapperHandler extends ChannelDuplexHandler {
private final Logger log;
private final ObjectMapper objectMapper = newObjectMapper();
private final ResponseQueueHanlder responseQueueHanlder;

public TestkitRequestResponseMapperHandler(Logging logging) {
public TestkitRequestResponseMapperHandler(Logging logging, ResponseQueueHanlder responseQueueHanlder) {
log = logging.getLog(getClass());
this.responseQueueHanlder = responseQueueHanlder;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
var testkitMessage = (String) msg;
log.debug("Inbound Testkit message '%s'", testkitMessage.trim());
TestkitRequest testkitRequest;
testkitRequest = objectMapper.readValue(testkitMessage, TestkitRequest.class);
responseQueueHanlder.setResponseReadyAndDispatchFirst();
var testkitRequest = objectMapper.readValue(testkitMessage, TestkitRequest.class);
ctx.fireChannelRead(testkitRequest);
}

Expand Down

0 comments on commit 99d32cb

Please sign in to comment.