Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release connection on reactive beginTransaction cancellation #1342

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.neo4j.driver.reactive.RxTransactionWork;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class InternalRxSession extends AbstractRxQueryRunner implements RxSession {
private final NetworkSession session;
Expand Down Expand Up @@ -69,7 +70,8 @@ public Publisher<RxTransaction> beginTransaction(TransactionConfig config) {
return txFuture;
},
() -> new IllegalStateException(
"Unexpected condition, begin transaction call has completed successfully with transaction being null"));
"Unexpected condition, begin transaction call has completed successfully with transaction being null"),
(tx) -> Mono.fromDirect(tx.close()).subscribe());
}

private Publisher<InternalRxTransaction> beginTransaction(AccessMode mode, TransactionConfig config) {
Expand All @@ -86,7 +88,8 @@ private Publisher<InternalRxTransaction> beginTransaction(AccessMode mode, Trans
return txFuture;
},
() -> new IllegalStateException(
"Unexpected condition, begin transaction call has completed successfully with transaction being null"));
"Unexpected condition, begin transaction call has completed successfully with transaction being null"),
(tx) -> Mono.fromDirect(tx.close()).subscribe());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
*/
package org.neo4j.driver.internal.reactive;

import static java.util.Objects.requireNonNull;

import java.util.Optional;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.neo4j.driver.internal.util.Futures;
import org.reactivestreams.Publisher;
Expand All @@ -28,6 +31,7 @@
public class RxUtils {
/**
* The publisher created by this method will either succeed without publishing anything or fail with an error.
*
* @param supplier supplies a {@link CompletionStage<Void>}.
* @return A publisher that publishes nothing on completion or fails with an error.
*/
Expand All @@ -48,23 +52,79 @@ public static <T> Publisher<T> createEmptyPublisher(Supplier<CompletionStage<Voi
* @param supplier supplies a {@link CompletionStage<T>} that MUST produce a non-null result when completed successfully.
* @param nullResultThrowableSupplier supplies a {@link Throwable} that is used as an error when the supplied completion stage completes successfully with
* null.
* @param cancellationHandler handles cancellation, may be used to release associated resources
* @param <T> the type of the item to publish.
* @return A publisher that succeeds exactly one item or fails with an error.
*/
public static <T> Publisher<T> createSingleItemPublisher(
Supplier<CompletionStage<T>> supplier, Supplier<Throwable> nullResultThrowableSupplier) {
return Mono.create(sink -> supplier.get().whenComplete((item, completionError) -> {
if (completionError == null) {
if (item != null) {
sink.success(item);
} else {
sink.error(nullResultThrowableSupplier.get());
Supplier<CompletionStage<T>> supplier,
Supplier<Throwable> nullResultThrowableSupplier,
Consumer<T> cancellationHandler) {
requireNonNull(supplier, "supplier must not be null");
requireNonNull(nullResultThrowableSupplier, "nullResultThrowableSupplier must not be null");
requireNonNull(cancellationHandler, "cancellationHandler must not be null");
return Mono.create(sink -> {
SinkState state = new SinkState<T>();
sink.onRequest(ignored -> {
CompletionStage<T> stage;
synchronized (state) {
if (state.isCancelled()) {
return;
}
if (state.getStage() != null) {
return;
}
stage = supplier.get();
state.setStage(stage);
}
} else {
Throwable error = Optional.ofNullable(Futures.completionExceptionCause(completionError))
.orElse(completionError);
sink.error(error);
}
}));
stage.whenComplete((item, completionError) -> {
if (completionError == null) {
if (item != null) {
sink.success(item);
} else {
sink.error(nullResultThrowableSupplier.get());
}
} else {
Throwable error = Optional.ofNullable(Futures.completionExceptionCause(completionError))
.orElse(completionError);
sink.error(error);
}
});
});
sink.onCancel(() -> {
CompletionStage<T> stage;
synchronized (state) {
if (state.isCancelled()) {
return;
}
state.setCancelled(true);
stage = state.getStage();
}
if (stage != null) {
stage.whenComplete((value, ignored) -> cancellationHandler.accept(value));
}
});
});
}

private static class SinkState<T> {
private CompletionStage<T> stage;
private boolean cancelled;

public CompletionStage<T> getStage() {
return stage;
}

public void setStage(CompletionStage<T> stage) {
this.stage = stage;
}

public boolean isCancelled() {
return cancelled;
}

public void setCancelled(boolean cancelled) {
this.cancelled = cancelled;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
*/
package org.neo4j.driver.internal.reactive;

import static org.mockito.BDDMockito.then;
import static org.mockito.Mockito.mock;
import static org.neo4j.driver.internal.reactive.RxUtils.createEmptyPublisher;
import static org.neo4j.driver.internal.reactive.RxUtils.createSingleItemPublisher;
import static org.neo4j.driver.internal.util.Futures.failedFuture;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.internal.util.Futures;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.test.StepVerifier;

class RxUtilsTest {
Expand All @@ -47,24 +53,58 @@ void emptyPublisherShouldErrorWhenSupplierErrors() {

@Test
void singleItemPublisherShouldCompleteWithValue() {
Publisher<String> publisher =
createSingleItemPublisher(() -> CompletableFuture.completedFuture("One"), () -> mock(Throwable.class));
Publisher<String> publisher = createSingleItemPublisher(
() -> CompletableFuture.completedFuture("One"), () -> mock(Throwable.class), (ignored) -> {});
StepVerifier.create(publisher).expectNext("One").verifyComplete();
}

@Test
void singleItemPublisherShouldErrorWhenFutureCompletesWithNull() {
Throwable error = mock(Throwable.class);
Publisher<String> publisher = createSingleItemPublisher(Futures::completedWithNull, () -> error);
Publisher<String> publisher =
createSingleItemPublisher(Futures::completedWithNull, () -> error, (ignored) -> {});

StepVerifier.create(publisher).verifyErrorMatches(actualError -> error == actualError);
}

@Test
void singleItemPublisherShouldErrorWhenSupplierErrors() {
RuntimeException error = mock(RuntimeException.class);
Publisher<String> publisher = createSingleItemPublisher(() -> failedFuture(error), () -> mock(Throwable.class));
Publisher<String> publisher =
createSingleItemPublisher(() -> failedFuture(error), () -> mock(Throwable.class), (ignored) -> {});

StepVerifier.create(publisher).verifyErrorMatches(actualError -> error == actualError);
}

@Test
void singleItemPublisherShouldHandleCancellationAfterRequestProcessingBegins() {
// GIVEN
String value = "value";
CompletableFuture<String> valueFuture = new CompletableFuture<>();
CompletableFuture<Void> supplierInvokedFuture = new CompletableFuture<>();
Supplier<CompletionStage<String>> valueFutureSupplier = () -> {
supplierInvokedFuture.complete(null);
return valueFuture;
};
@SuppressWarnings("unchecked")
Consumer<String> cancellationHandler = mock(Consumer.class);
Publisher<String> publisher =
createSingleItemPublisher(valueFutureSupplier, () -> mock(Throwable.class), cancellationHandler);

// WHEN
publisher.subscribe(new BaseSubscriber<String>() {
@Override
protected void hookOnSubscribe(Subscription subscription) {
subscription.request(1);
supplierInvokedFuture.thenAccept(ignored -> {
subscription.cancel();
valueFuture.complete(value);
});
}
});

// THEN
valueFuture.join();
then(cancellationHandler).should().accept(value);
}
}