diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/TimeoutPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/TimeoutPublisher.java index 07180a6724..c8bd93e578 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/TimeoutPublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/TimeoutPublisher.java @@ -125,7 +125,8 @@ private TimeoutSubscriber(TimeoutPublisher parent, Subscriber target, AsyncContextProvider contextProvider) { this.parent = parent; - this.target = new ConcurrentTerminalSubscriber<>(target); + // Concurrent onSubscribe is protected by subscriptionUpdater, no need to double protect. + this.target = new ConcurrentTerminalSubscriber<>(target, false); this.contextProvider = contextProvider; } @@ -278,14 +279,22 @@ private void offloadTimeout(Throwable cause) { private void processTimeout(Throwable cause) { final Subscription subscription = subscriptionUpdater.getAndSet(this, EMPTY_SUBSCRIPTION); - // The timer is started before onSubscribe so the subscription may actually be null at this time. - if (subscription != null) { - subscription.cancel(); - // onErrorFromTimeout will protect against concurrent access on the Subscriber. - } else { - target.onSubscribe(EMPTY_SUBSCRIPTION); + // We need to deliver cancel upstream first (clear state for Publishers that + // allow sequential resubscribe) but we always want to force a TimeoutException downstream (because this is + // the source of the error, despite what any upstream operators/publishers may deliver). + final Subscriber localTarget = target.unwrapMarkTerminated(); + try { + // The timer is started before onSubscribe so the subscription may actually be null at this time. + if (subscription != null) { + subscription.cancel(); + } else if (localTarget != null) { + localTarget.onSubscribe(EMPTY_SUBSCRIPTION); + } + } finally { + if (localTarget != null) { + localTarget.onError(cause); + } } - target.processOnError(cause); } private void stopTimer() { diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/completable/TimeoutCompletableTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/completable/TimeoutCompletableTest.java index 2abcee7069..1e3e2c4734 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/completable/TimeoutCompletableTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/completable/TimeoutCompletableTest.java @@ -22,8 +22,8 @@ import io.servicetalk.concurrent.api.DelegatingExecutor; import io.servicetalk.concurrent.api.ExecutorExtension; import io.servicetalk.concurrent.api.TestCancellable; +import io.servicetalk.concurrent.api.TestCompletable; import io.servicetalk.concurrent.api.TestExecutor; -import io.servicetalk.concurrent.api.TestSingle; import io.servicetalk.concurrent.test.internal.TestCompletableSubscriber; import org.junit.jupiter.api.BeforeEach; @@ -39,6 +39,7 @@ import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static java.time.Duration.ofNanos; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -53,9 +54,9 @@ class TimeoutCompletableTest { @RegisterExtension - final ExecutorExtension executorExtension = ExecutorExtension.withTestExecutor(); - private final TestCompletableSubscriber listener = new TestCompletableSubscriber(); - private final TestSingle source = new TestSingle<>(); + static final ExecutorExtension executorExtension = ExecutorExtension.withTestExecutor(); + private final TestCompletableSubscriber subscriber = new TestCompletableSubscriber(); + private final TestCompletable source = new TestCompletable(); private TestExecutor testExecutor; @BeforeEach @@ -63,16 +64,38 @@ void setup() { testExecutor = executorExtension.executor(); } + @Test + void timeoutExceptionDeliveredBeforeUpstreamException() { + toSource(new Completable() { + @Override + protected void handleSubscribe(final CompletableSource.Subscriber subscriber) { + subscriber.onSubscribe(new Cancellable() { + private boolean terminated; + @Override + public void cancel() { + if (!terminated) { + terminated = true; + subscriber.onError(new AssertionError("unexpected error, should have seen timeout")); + } + } + }); + } + }.timeout(ofNanos(1), testExecutor)) + .subscribe(subscriber); + testExecutor.advanceTimeBy(1, NANOSECONDS); + assertThat(subscriber.awaitOnError(), instanceOf(TimeoutException.class)); + } + @Test void executorScheduleThrows() { - toSource(source.ignoreElement().timeout(1, NANOSECONDS, new DelegatingExecutor(testExecutor) { + toSource(source.timeout(1, NANOSECONDS, new DelegatingExecutor(testExecutor) { @Override public Cancellable schedule(final Runnable task, final long delay, final TimeUnit unit) { throw DELIBERATE_EXCEPTION; } - })).subscribe(listener); + })).subscribe(subscriber); - assertThat(listener.awaitOnError(), is(DELIBERATE_EXCEPTION)); + assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); TestCancellable cancellable = new TestCancellable(); source.onSubscribe(cancellable); assertTrue(cancellable.isCancelled()); @@ -82,8 +105,8 @@ public Cancellable schedule(final Runnable task, final long delay, final TimeUni void noDataOnCompletionNoTimeout() { init(); - source.onSuccess(1); - listener.awaitOnComplete(); + source.onComplete(); + subscriber.awaitOnComplete(); assertThat(testExecutor.scheduledTasksPending(), is(0)); assertThat(testExecutor.scheduledTasksExecuted(), is(0)); @@ -94,7 +117,7 @@ void noDataOnErrorNoTimeout() { init(); source.onError(DELIBERATE_EXCEPTION); - assertThat(listener.awaitOnError(), is(DELIBERATE_EXCEPTION)); + assertThat(subscriber.awaitOnError(), is(DELIBERATE_EXCEPTION)); assertThat(testExecutor.scheduledTasksPending(), is(0)); assertThat(testExecutor.scheduledTasksExecuted(), is(0)); @@ -104,7 +127,7 @@ void noDataOnErrorNoTimeout() { void subscriptionCancelAlsoCancelsTimer() { init(); - listener.awaitSubscription().cancel(); + subscriber.awaitSubscription().cancel(); assertThat(testExecutor.scheduledTasksPending(), is(0)); assertThat(testExecutor.scheduledTasksExecuted(), is(0)); @@ -115,7 +138,7 @@ void noDataAndTimeout() { init(); testExecutor.advanceTimeBy(1, NANOSECONDS); - assertThat(listener.awaitOnError(), instanceOf(TimeoutException.class)); + assertThat(subscriber.awaitOnError(), instanceOf(TimeoutException.class)); assertThat(testExecutor.scheduledTasksPending(), is(0)); assertThat(testExecutor.scheduledTasksExecuted(), is(1)); @@ -136,7 +159,7 @@ void justSubscribeTimeout() { assertNotNull(subscriber); subscriber.onSubscribe(mockCancellable); verify(mockCancellable).cancel(); - assertThat(listener.awaitOnError(), instanceOf(TimeoutException.class)); + assertThat(this.subscriber.awaitOnError(), instanceOf(TimeoutException.class)); } @Test @@ -150,12 +173,12 @@ void defaultExecutorSubscribeTimeout() { // TODO(dariusz): Replace all executors created with the test instance // Executors.setFactory(AllExecutorFactory.create(() -> testExecutor)); - toSource(operationThatInternallyTimesOut).subscribe(listener); + toSource(operationThatInternallyTimesOut).subscribe(subscriber); testExecutor.advanceTimeBy(1, DAYS); CompletableSource.Subscriber subscriber = delayedCompletable.subscriber; assertNotNull(subscriber); - assertThat(listener.awaitOnError(), instanceOf(TimeoutException.class)); + assertThat(this.subscriber.awaitOnError(), instanceOf(TimeoutException.class)); } @Test @@ -173,19 +196,19 @@ void cancelDoesOnError() throws Exception { assertThat(testExecutor.scheduledTasksPending(), is(0)); assertThat(testExecutor.scheduledTasksExecuted(), is(1)); cancelLatch.await(); - Throwable error = listener.awaitOnError(); + Throwable error = this.subscriber.awaitOnError(); assertThat(error, instanceOf(TimeoutException.class)); } private void init() { - init(source.ignoreElement(), true); + init(source, true); } private void init(Completable source, boolean expectOnSubscribe) { - toSource(source.timeout(1, NANOSECONDS, testExecutor)).subscribe(listener); + toSource(source.timeout(1, NANOSECONDS, testExecutor)).subscribe(subscriber); assertThat(testExecutor.scheduledTasksPending(), is(1)); if (expectOnSubscribe) { - assertThat(listener.pollTerminal(10, MILLISECONDS), is(nullValue())); + assertThat(subscriber.pollTerminal(10, MILLISECONDS), is(nullValue())); } } diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/publisher/TimeoutPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/publisher/TimeoutPublisherTest.java index 6e1783c0c7..afb9206e04 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/publisher/TimeoutPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/publisher/TimeoutPublisherTest.java @@ -47,6 +47,7 @@ import static io.servicetalk.concurrent.Cancellable.IGNORE_CANCEL; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static java.time.Duration.ofNanos; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.hamcrest.MatcherAssert.assertThat; @@ -95,6 +96,32 @@ void setup() { testExecutor = executorExtension.executor(); } + @Test + void timeoutExceptionDeliveredBeforeUpstreamException() { + toSource(new Publisher() { + @Override + protected void handleSubscribe(final Subscriber subscriber) { + subscriber.onSubscribe(new Subscription() { + private boolean terminated; + @Override + public void request(final long n) { + } + + @Override + public void cancel() { + if (!terminated) { + terminated = true; + subscriber.onError(new AssertionError("unexpected error, should have seen timeout")); + } + } + }); + } + }.timeout(ofNanos(1), testExecutor)) + .subscribe(subscriber); + testExecutor.advanceTimeBy(1, NANOSECONDS); + assertThat(subscriber.awaitOnError(), instanceOf(TimeoutException.class)); + } + @Test void executorScheduleThrowsTerminalTimeout() { toSource(publisher.timeoutTerminal(1, NANOSECONDS, new DelegatingExecutor(testExecutor) { @@ -260,7 +287,7 @@ void dataAndTimeout(TimerBehaviorParam params) throws Exception { void justSubscribeTimeout(TimerBehaviorParam params) { DelayedOnSubscribePublisher delayedPublisher = new DelayedOnSubscribePublisher<>(); - init(delayedPublisher, params, Duration.ofNanos(1), false); + init(delayedPublisher, params, ofNanos(1), false); testExecutor.advanceTimeBy(1, NANOSECONDS); assertThat(testExecutor.scheduledTasksPending(), is(0)); @@ -344,11 +371,11 @@ public Cancellable execute(final Runnable task) throws RejectedExecutionExceptio } private void init(TimerBehaviorParam params) { - init(params, Duration.ofNanos(1)); + init(params, ofNanos(1)); } private void init(TimerBehaviorParam params, Duration duration) { - init(publisher, params, duration, true); + init(publisher, params, duration, true); } private void init(Publisher publisher, TimerBehaviorParam params, diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/TimeoutSingleTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/TimeoutSingleTest.java index 6ff685cd73..872b4019f7 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/TimeoutSingleTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/TimeoutSingleTest.java @@ -16,13 +16,14 @@ package io.servicetalk.concurrent.api.single; import io.servicetalk.concurrent.Cancellable; +import io.servicetalk.concurrent.SingleSource; import io.servicetalk.concurrent.SingleSource.Subscriber; import io.servicetalk.concurrent.api.DelegatingExecutor; import io.servicetalk.concurrent.api.ExecutorExtension; -import io.servicetalk.concurrent.api.LegacyTestSingle; import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.api.TestCancellable; import io.servicetalk.concurrent.api.TestExecutor; +import io.servicetalk.concurrent.api.TestSingle; import io.servicetalk.concurrent.test.internal.TestSingleSubscriber; import org.junit.jupiter.api.BeforeEach; @@ -35,6 +36,7 @@ import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static java.time.Duration.ofNanos; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.hamcrest.MatcherAssert.assertThat; @@ -49,10 +51,9 @@ class TimeoutSingleTest { @RegisterExtension - final ExecutorExtension executorExtension = ExecutorExtension.withTestExecutor(); - - private LegacyTestSingle source = new LegacyTestSingle<>(false, false); - final TestSingleSubscriber subscriber = new TestSingleSubscriber<>(); + static final ExecutorExtension executorExtension = ExecutorExtension.withTestExecutor(); + private final TestSingle source = new TestSingle<>(); + private final TestSingleSubscriber subscriber = new TestSingleSubscriber<>(); private TestExecutor testExecutor; @BeforeEach @@ -60,6 +61,28 @@ void setup() { testExecutor = executorExtension.executor(); } + @Test + void timeoutExceptionDeliveredBeforeUpstreamException() { + toSource(new Single() { + @Override + protected void handleSubscribe(final SingleSource.Subscriber subscriber) { + subscriber.onSubscribe(new Cancellable() { + private boolean terminated; + @Override + public void cancel() { + if (!terminated) { + terminated = true; + subscriber.onError(new AssertionError("unexpected error, should have seen timeout")); + } + } + }); + } + }.timeout(ofNanos(1), testExecutor)) + .subscribe(subscriber); + testExecutor.advanceTimeBy(1, NANOSECONDS); + assertThat(subscriber.awaitOnError(), instanceOf(TimeoutException.class)); + } + @Test void executorScheduleThrows() { toSource(source.timeout(1, NANOSECONDS, new DelegatingExecutor(testExecutor) { diff --git a/servicetalk-concurrent-internal/src/main/java/io/servicetalk/concurrent/internal/ConcurrentTerminalSubscriber.java b/servicetalk-concurrent-internal/src/main/java/io/servicetalk/concurrent/internal/ConcurrentTerminalSubscriber.java index 6f3a32d303..bd9c3a9efc 100644 --- a/servicetalk-concurrent-internal/src/main/java/io/servicetalk/concurrent/internal/ConcurrentTerminalSubscriber.java +++ b/servicetalk-concurrent-internal/src/main/java/io/servicetalk/concurrent/internal/ConcurrentTerminalSubscriber.java @@ -37,6 +37,7 @@ public final class ConcurrentTerminalSubscriber implements Subscriber { private static final int SUBSCRIBER_STATE_TERMINATING = 2; private static final int SUBSCRIBER_STATE_TERMINATED = 3; + @SuppressWarnings("rawtypes") private static final AtomicIntegerFieldUpdater stateUpdater = AtomicIntegerFieldUpdater.newUpdater(ConcurrentTerminalSubscriber.class, "state"); @@ -210,4 +211,16 @@ public boolean processOnComplete() { } } } + + /** + * Used to terminate the delegate {@link Subscriber} managed by this class externally. This method will mark the + * internal state of this class as terminated so no more signals are propagated by this class. + * @return the delegate {@link Subscriber} managed by this class if not already terminated, otherwise {@code null}. + */ + @Nullable + public Subscriber unwrapMarkTerminated() { + final int localState = stateUpdater.getAndSet(this, SUBSCRIBER_STATE_TERMINATED); + return localState == SUBSCRIBER_STATE_TERMINATED || localState == SUBSCRIBER_STATE_TERMINATING ? + null : delegate; + } }