Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WriteStreamSubscriber: respect termination of the publisher #2387

Merged
merged 10 commits into from
Oct 12, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ final class WriteStreamSubscriber implements PublisherSource.Subscriber<Object>,
private static final byte CHANNEL_CLOSED = 1 << 1;
private static final byte CLOSE_OUTBOUND_ON_SUBSCRIBER_TERMINATION = 1 << 2;
private static final byte SUBSCRIBER_TERMINATED = 1 << 3;
private static final byte SOURCE_OUTBOUND_CLOSED = 1 << 4;
private static final byte SUBSCRIBER_OR_SOURCE_TERMINATED = SOURCE_TERMINATED | SUBSCRIBER_TERMINATED;
private static final Subscription CANCELLED = newEmptySubscription();
private static final AtomicReferenceFieldUpdater<WriteStreamSubscriber, Subscription> subscriptionUpdater =
Expand Down Expand Up @@ -177,8 +178,8 @@ void doWrite(Object msg) {
demandEstimator.onItemWrite(msg, capacityBefore, capacityAfter);
// Client-side always starts a request with request(1) to probe a Channel with meta-data before continuing
// to write the payload body, see https://github.com/apple/servicetalk/pull/1644.
// Requests that await feedback from the remote peer should not request more until they receive
// continueWriting() signal.
// Requests that await feedback from the remote peer should not request more data from the publisher until
// they receive continueWriting() signal.
if (!isClient || !(shouldWaitFlag = shouldWait.test(msg))) {
requestMoreIfRequired(subscription, capacityAfter);
}
Expand Down Expand Up @@ -246,15 +247,20 @@ public void channelOutboundClosed() {
// we may deadlock if we don't request enough onNext signals to see the terminal signal.
sub.request(Long.MAX_VALUE);
}
promise.sourceTerminated(null, true);
promise.outboundClosed();
}

@Override
public void terminateSource() {
assert eventLoop.inEventLoop();
// Terminate the source only if it awaits continuation.
if (shouldWaitFlag) {
assert promise.activeWrites == 0; // We never start sending payload body until we receive 100 (Continue)
assert promise.activeWrites == 0 :
channel + " Unexpected activeWrites=" + promise.activeWrites + " while waiting for continuation";
// Cancel the passed write Publisher to signal transport is not interested in more data.
final Subscription sub = this.subscription;
assert sub != null : channel + " Unexpected subscription=null while waiting for continuation";
sub.cancel();
promise.sourceTerminated(null, true);
}
}
Expand Down Expand Up @@ -408,6 +414,8 @@ boolean isWritable() {

void writeNext(Object msg) {
assert eventLoop.inEventLoop();
assert isWritable() : channel + " Unexpected writeNext: " + msg + " during non-writable state=" +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this assertion necessary? wondering if we need to prevent any write-after-close/terminated scenarios as we may want to pass on the object anyways (reference counted, ordered promos completion, ..)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, invocation of this method is currently guarded by isWritable() check:

    void doWrite(Object msg) {
        // Ignore onNext if the channel is already closed.
        if (promise.isWritable()) {
            long capacityBefore = channel.bytesBeforeUnwritable();
            promise.writeNext(msg);

Added this assection just in case the contract changes, to make sure we handle writes correctly as it will likely will also required changes in AllWritesPromise.

Integer.toString(state, 2);
activeWrites++;
listenersOnWriteBoundaries.addLast(WRITE_BOUNDARY);
channel.write(msg, this);
Expand All @@ -416,6 +424,16 @@ void writeNext(Object msg) {
}
}

void outboundClosed() {
assert eventLoop.inEventLoop();
if (isAnySet(state, SUBSCRIBER_OR_SOURCE_TERMINATED)) {
// We have terminated prematurely perhaps due to write failure.
return;
}
state = set(state, SOURCE_OUTBOUND_CLOSED); // Assign a state to mark the promise as not writable.
markCancelled();
}

void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
assert eventLoop.inEventLoop();
if (isAnySet(state, SUBSCRIBER_OR_SOURCE_TERMINATED)) {
Expand All @@ -425,15 +443,10 @@ void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
this.failureCause = cause;
state = set(state, SOURCE_TERMINATED);
if (markCancelled) {
// When we know that the source is effectively terminated and won't emit any new items, mark the
// subscription as CANCELLED to prevent any further interactions with it, like propagating `cancel` from
// `channelClosed(Throwable)` or `request(MAX_VALUE)` from `channelOutboundClosed()`. At this point we
// always have a non-null subscription because this is reachable only if publisher emitted some signals.
WriteStreamSubscriber.this.subscription = CANCELLED;
markCancelled();
}
if (activeWrites == 0) {
try {
state = set(state, SUBSCRIBER_TERMINATED);
terminateSubscriber(cause);
} catch (Throwable t) {
tryFailureOrLog(t);
Expand All @@ -449,6 +462,14 @@ void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
}
}

void markCancelled() {
// When we know that the source is effectively terminated and won't emit any new items, mark the
// subscription as CANCELLED to prevent any further interactions with it, like propagating `cancel` from
// `channelClosed(Throwable)` or `request(MAX_VALUE)` from `channelOutboundClosed()`. At this point we
// always have a non-null subscription because this is reachable only if publisher emitted some signals.
WriteStreamSubscriber.this.subscription = CANCELLED;
}

void close(Throwable cause, boolean closeOutboundIfIdle) {
assert eventLoop.inEventLoop();
if (isAllSet(state, CHANNEL_CLOSED)) {
Expand All @@ -461,7 +482,7 @@ void close(Throwable cause, boolean closeOutboundIfIdle) {
// just close the channel now.
closeHandler.closeChannelOutbound(channel);
}
} else if (activeWrites > 0) {
} else if (activeWrites > 0 || isAllSet(state, SOURCE_OUTBOUND_CLOSED)) {
// Writes are pending, we will close the channel once writes are done.
state = set(state, CLOSE_OUTBOUND_ON_SUBSCRIBER_TERMINATION);
} else {
Expand Down Expand Up @@ -507,7 +528,6 @@ private boolean setSuccess0() {
}
observer.itemFlushed();
if (--activeWrites == 0 && isAllSet(state, SOURCE_TERMINATED)) {
state = set(state, SUBSCRIBER_TERMINATED);
try {
terminateSubscriber(failureCause);
} catch (Throwable t) {
Expand All @@ -531,7 +551,6 @@ private boolean setFailure0(Throwable cause) {
if (isAllSet(state, SUBSCRIBER_TERMINATED)) {
return nettySharedPromiseTryStatus();
}
state = set(state, SUBSCRIBER_TERMINATED);
Subscription oldVal = subscriptionUpdater.getAndSet(WriteStreamSubscriber.this, CANCELLED);
if (oldVal != null && !isAllSet(state, SOURCE_TERMINATED)) {
oldVal.cancel();
Expand All @@ -554,6 +573,7 @@ private boolean nettySharedPromiseTryStatus() {
}

private void terminateSubscriber(@Nullable Throwable cause) {
state = set(state, SUBSCRIBER_TERMINATED);
if (cause == null) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} Terminate subscriber, state: {}", channel, Integer.toString(state, 2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;

import javax.annotation.Nullable;

import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -36,18 +37,10 @@

public abstract class AbstractWriteTest {

protected EmbeddedChannel channel;
protected WriteDemandEstimator demandEstimator;
protected CompletableSource.Subscriber completableSubscriber;
protected FailingWriteHandler failingWriteHandler;

@BeforeEach
public void setUp() throws Exception {
completableSubscriber = mock(CompletableSource.Subscriber.class);
failingWriteHandler = new FailingWriteHandler();
channel = new EmbeddedChannel(failingWriteHandler);
demandEstimator = mock(WriteDemandEstimator.class);
}
protected final WriteDemandEstimator demandEstimator = mock(WriteDemandEstimator.class);
protected final CompletableSource.Subscriber completableSubscriber = mock(CompletableSource.Subscriber.class);
protected final FailingWriteHandler failingWriteHandler = new FailingWriteHandler();
protected final EmbeddedChannel channel = new EmbeddedChannel(failingWriteHandler);

@AfterEach
public void tearDown() throws Exception {
Expand All @@ -59,6 +52,9 @@ protected void verifyWriteSuccessful(String... items) {
channel.flushOutbound();
if (items.length > 0) {
assertThat("Message not written.", channel.outboundMessages(), contains((String[]) items));
for (int i = 0; i < items.length; ++i) {
channel.readOutbound(); // discard written items
}
} else {
assertThat("Unexpected message(s) written.", channel.outboundMessages(), is(empty()));
}
Expand All @@ -71,6 +67,14 @@ protected void verifyListenerSuccessful() {
verifyNoMoreInteractions(completableSubscriber);
}

protected void verifyListenerFailed(@Nullable Throwable t) {
channel.flushOutbound();
assertThat("Unexpected message(s) written.", channel.outboundMessages(), is(empty()));
verify(completableSubscriber).onSubscribe(any());
verify(completableSubscriber).onError(t != null ? t : any());
verifyNoMoreInteractions(completableSubscriber);
}

static final class FailingWriteHandler extends ChannelDuplexHandler {
private volatile boolean failNextWritePromise;
private volatile boolean throwFromNextWrite;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private void sendRequest() {
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
})
.expectComplete()
.verify();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,24 @@ void afterExchangeIdleConnection() {
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
closeNotifyAndVerifyClosing();
})
.expectComplete()
.verify();
}

@Test
void afterRequestAndWritingResponseButBeforeCompletingWrite() {
receiveRequest();
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)))
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
closeNotifyAndVerifyClosing();
writeSource.onComplete();
})
.expectErrorConsumed(cause -> {
assertThat("Unexpected write failure cause", cause, instanceOf(CloseEventObservedException.class));
CloseEventObservedException ceoe = (CloseEventObservedException) cause;
Expand All @@ -55,7 +71,7 @@ void afterExchangeIdleConnection() {
}

@Test
void afterRequestBeforeSendingResponse() {
void afterRequestBeforeWritingResponse() {
receiveRequest();

PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
Expand All @@ -66,7 +82,7 @@ void afterRequestBeforeSendingResponse() {
}

@Test
void afterRequestWhileSendingResponse() {
void afterRequestWhileWritingResponse() {
receiveRequest();

PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
Expand All @@ -80,7 +96,7 @@ void afterRequestWhileSendingResponse() {
}

@Test
void whileReadingRequestBeforeSendingResponse() {
void whileReadingRequestBeforeWritingResponse() {
StepVerifiers.create(conn.write(fromSource(newPublisherProcessor())).merge(conn.read()))
.then(() -> {
// Start reading request
Expand All @@ -93,7 +109,7 @@ void whileReadingRequestBeforeSendingResponse() {
}

@Test
void whileReadingRequestAndSendingResponse() {
void whileReadingRequestAndWritingResponse() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
Expand All @@ -109,7 +125,7 @@ void whileReadingRequestAndSendingResponse() {
}

@Test
void whileReadingRequestAfterSendingResponse() {
void whileReadingRequestAfterWritingResponse() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
Expand All @@ -118,13 +134,34 @@ void whileReadingRequestAfterSendingResponse() {
// Send response
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
})
.expectNext(BEGIN)
.then(this::closeNotifyAndVerifyClosing)
.expectError(ClosedChannelException.class)
.verify();
}

@Test
void whileReadingRequestAfterWritingResponseButBeforeCompletingWrite() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
// Start reading request
channel.writeInbound(BEGIN);
// Send response
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
})
.expectNext(BEGIN)
.then(() -> {
closeNotifyAndVerifyClosing();
writeSource.onComplete();
})
.expectError(ClosedChannelException.class)
.verify();
}

private void receiveRequest() {
StepVerifiers.create(conn.read())
.then(() -> channel.writeInbound(BEGIN))
Expand Down
Loading