Skip to content

Commit

Permalink
Revert "Revert "WriteStreamSubscriber: respect termination of the p…
Browse files Browse the repository at this point in the history
…ublisher (#2387)""

This reverts commit 1a33f6c.
  • Loading branch information
idelpivnitskiy committed Oct 13, 2022
1 parent 1c1be8a commit 7fb6875
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ final class WriteStreamSubscriber implements PublisherSource.Subscriber<Object>,
private static final byte CHANNEL_CLOSED = 1 << 1;
private static final byte CLOSE_OUTBOUND_ON_SUBSCRIBER_TERMINATION = 1 << 2;
private static final byte SUBSCRIBER_TERMINATED = 1 << 3;
private static final byte SOURCE_OUTBOUND_CLOSED = 1 << 4;
private static final byte SUBSCRIBER_OR_SOURCE_TERMINATED = SOURCE_TERMINATED | SUBSCRIBER_TERMINATED;
private static final Subscription CANCELLED = newEmptySubscription();
private static final AtomicReferenceFieldUpdater<WriteStreamSubscriber, Subscription> subscriptionUpdater =
Expand Down Expand Up @@ -177,8 +178,8 @@ void doWrite(Object msg) {
demandEstimator.onItemWrite(msg, capacityBefore, capacityAfter);
// Client-side always starts a request with request(1) to probe a Channel with meta-data before continuing
// to write the payload body, see https://github.com/apple/servicetalk/pull/1644.
// Requests that await feedback from the remote peer should not request more until they receive
// continueWriting() signal.
// Requests that await feedback from the remote peer should not request more data from the publisher until
// they receive continueWriting() signal.
if (!isClient || !(shouldWaitFlag = shouldWait.test(msg))) {
requestMoreIfRequired(subscription, capacityAfter);
}
Expand Down Expand Up @@ -246,15 +247,20 @@ public void channelOutboundClosed() {
// we may deadlock if we don't request enough onNext signals to see the terminal signal.
sub.request(Long.MAX_VALUE);
}
promise.sourceTerminated(null, true);
promise.outboundClosed();
}

@Override
public void terminateSource() {
assert eventLoop.inEventLoop();
// Terminate the source only if it awaits continuation.
if (shouldWaitFlag) {
assert promise.activeWrites == 0; // We never start sending payload body until we receive 100 (Continue)
assert promise.activeWrites == 0 :
channel + " Unexpected activeWrites=" + promise.activeWrites + " while waiting for continuation";
// Cancel the passed write Publisher to signal transport is not interested in more data.
final Subscription sub = this.subscription;
assert sub != null : channel + " Unexpected subscription=null while waiting for continuation";
sub.cancel();
promise.sourceTerminated(null, true);
}
}
Expand Down Expand Up @@ -408,6 +414,8 @@ boolean isWritable() {

void writeNext(Object msg) {
assert eventLoop.inEventLoop();
assert isWritable() : channel + " Unexpected writeNext: " + msg + " during non-writable state=" +
Integer.toString(state, 2);
activeWrites++;
listenersOnWriteBoundaries.addLast(WRITE_BOUNDARY);
channel.write(msg, this);
Expand All @@ -416,6 +424,16 @@ void writeNext(Object msg) {
}
}

void outboundClosed() {
assert eventLoop.inEventLoop();
if (isAnySet(state, SUBSCRIBER_OR_SOURCE_TERMINATED)) {
// We have terminated prematurely perhaps due to write failure.
return;
}
state = set(state, SOURCE_OUTBOUND_CLOSED); // Assign a state to mark the promise as not writable.
markCancelled();
}

void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
assert eventLoop.inEventLoop();
if (isAnySet(state, SUBSCRIBER_OR_SOURCE_TERMINATED)) {
Expand All @@ -425,15 +443,10 @@ void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
this.failureCause = cause;
state = set(state, SOURCE_TERMINATED);
if (markCancelled) {
// When we know that the source is effectively terminated and won't emit any new items, mark the
// subscription as CANCELLED to prevent any further interactions with it, like propagating `cancel` from
// `channelClosed(Throwable)` or `request(MAX_VALUE)` from `channelOutboundClosed()`. At this point we
// always have a non-null subscription because this is reachable only if publisher emitted some signals.
WriteStreamSubscriber.this.subscription = CANCELLED;
markCancelled();
}
if (activeWrites == 0) {
try {
state = set(state, SUBSCRIBER_TERMINATED);
terminateSubscriber(cause);
} catch (Throwable t) {
tryFailureOrLog(t);
Expand All @@ -449,6 +462,14 @@ void sourceTerminated(@Nullable Throwable cause, boolean markCancelled) {
}
}

void markCancelled() {
// When we know that the source is effectively terminated and won't emit any new items, mark the
// subscription as CANCELLED to prevent any further interactions with it, like propagating `cancel` from
// `channelClosed(Throwable)` or `request(MAX_VALUE)` from `channelOutboundClosed()`. At this point we
// always have a non-null subscription because this is reachable only if publisher emitted some signals.
WriteStreamSubscriber.this.subscription = CANCELLED;
}

void close(Throwable cause, boolean closeOutboundIfIdle) {
assert eventLoop.inEventLoop();
if (isAllSet(state, CHANNEL_CLOSED)) {
Expand All @@ -461,7 +482,7 @@ void close(Throwable cause, boolean closeOutboundIfIdle) {
// just close the channel now.
closeHandler.closeChannelOutbound(channel);
}
} else if (activeWrites > 0) {
} else if (activeWrites > 0 || isAllSet(state, SOURCE_OUTBOUND_CLOSED)) {
// Writes are pending, we will close the channel once writes are done.
state = set(state, CLOSE_OUTBOUND_ON_SUBSCRIBER_TERMINATION);
} else {
Expand Down Expand Up @@ -507,7 +528,6 @@ private boolean setSuccess0() {
}
observer.itemFlushed();
if (--activeWrites == 0 && isAllSet(state, SOURCE_TERMINATED)) {
state = set(state, SUBSCRIBER_TERMINATED);
try {
terminateSubscriber(failureCause);
} catch (Throwable t) {
Expand All @@ -531,7 +551,6 @@ private boolean setFailure0(Throwable cause) {
if (isAllSet(state, SUBSCRIBER_TERMINATED)) {
return nettySharedPromiseTryStatus();
}
state = set(state, SUBSCRIBER_TERMINATED);
Subscription oldVal = subscriptionUpdater.getAndSet(WriteStreamSubscriber.this, CANCELLED);
if (oldVal != null && !isAllSet(state, SOURCE_TERMINATED)) {
oldVal.cancel();
Expand All @@ -554,6 +573,7 @@ private boolean nettySharedPromiseTryStatus() {
}

private void terminateSubscriber(@Nullable Throwable cause) {
state = set(state, SUBSCRIBER_TERMINATED);
if (cause == null) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} Terminate subscriber, state: {}", channel, Integer.toString(state, 2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;

import javax.annotation.Nullable;

import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -36,18 +37,10 @@

public abstract class AbstractWriteTest {

protected EmbeddedChannel channel;
protected WriteDemandEstimator demandEstimator;
protected CompletableSource.Subscriber completableSubscriber;
protected FailingWriteHandler failingWriteHandler;

@BeforeEach
public void setUp() throws Exception {
completableSubscriber = mock(CompletableSource.Subscriber.class);
failingWriteHandler = new FailingWriteHandler();
channel = new EmbeddedChannel(failingWriteHandler);
demandEstimator = mock(WriteDemandEstimator.class);
}
protected final WriteDemandEstimator demandEstimator = mock(WriteDemandEstimator.class);
protected final CompletableSource.Subscriber completableSubscriber = mock(CompletableSource.Subscriber.class);
protected final FailingWriteHandler failingWriteHandler = new FailingWriteHandler();
protected final EmbeddedChannel channel = new EmbeddedChannel(failingWriteHandler);

@AfterEach
public void tearDown() throws Exception {
Expand All @@ -59,6 +52,9 @@ protected void verifyWriteSuccessful(String... items) {
channel.flushOutbound();
if (items.length > 0) {
assertThat("Message not written.", channel.outboundMessages(), contains((String[]) items));
for (int i = 0; i < items.length; ++i) {
channel.readOutbound(); // discard written items
}
} else {
assertThat("Unexpected message(s) written.", channel.outboundMessages(), is(empty()));
}
Expand All @@ -71,6 +67,14 @@ protected void verifyListenerSuccessful() {
verifyNoMoreInteractions(completableSubscriber);
}

protected void verifyListenerFailed(@Nullable Throwable t) {
channel.flushOutbound();
assertThat("Unexpected message(s) written.", channel.outboundMessages(), is(empty()));
verify(completableSubscriber).onSubscribe(any());
verify(completableSubscriber).onError(t != null ? t : any());
verifyNoMoreInteractions(completableSubscriber);
}

static final class FailingWriteHandler extends ChannelDuplexHandler {
private volatile boolean failNextWritePromise;
private volatile boolean throwFromNextWrite;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private void sendRequest() {
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
})
.expectComplete()
.verify();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,24 @@ void afterExchangeIdleConnection() {
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
closeNotifyAndVerifyClosing();
})
.expectComplete()
.verify();
}

@Test
void afterRequestAndWritingResponseButBeforeCompletingWrite() {
receiveRequest();
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)))
.then(() -> {
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
closeNotifyAndVerifyClosing();
writeSource.onComplete();
})
.expectErrorConsumed(cause -> {
assertThat("Unexpected write failure cause", cause, instanceOf(CloseEventObservedException.class));
CloseEventObservedException ceoe = (CloseEventObservedException) cause;
Expand All @@ -55,7 +71,7 @@ void afterExchangeIdleConnection() {
}

@Test
void afterRequestBeforeSendingResponse() {
void afterRequestBeforeWritingResponse() {
receiveRequest();

PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
Expand All @@ -66,7 +82,7 @@ void afterRequestBeforeSendingResponse() {
}

@Test
void afterRequestWhileSendingResponse() {
void afterRequestWhileWritingResponse() {
receiveRequest();

PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
Expand All @@ -80,7 +96,7 @@ void afterRequestWhileSendingResponse() {
}

@Test
void whileReadingRequestBeforeSendingResponse() {
void whileReadingRequestBeforeWritingResponse() {
StepVerifiers.create(conn.write(fromSource(newPublisherProcessor())).merge(conn.read()))
.then(() -> {
// Start reading request
Expand All @@ -93,7 +109,7 @@ void whileReadingRequestBeforeSendingResponse() {
}

@Test
void whileReadingRequestAndSendingResponse() {
void whileReadingRequestAndWritingResponse() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
Expand All @@ -109,7 +125,7 @@ void whileReadingRequestAndSendingResponse() {
}

@Test
void whileReadingRequestAfterSendingResponse() {
void whileReadingRequestAfterWritingResponse() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
Expand All @@ -118,13 +134,34 @@ void whileReadingRequestAfterSendingResponse() {
// Send response
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
writeSource.onComplete();
})
.expectNext(BEGIN)
.then(this::closeNotifyAndVerifyClosing)
.expectError(ClosedChannelException.class)
.verify();
}

@Test
void whileReadingRequestAfterWritingResponseButBeforeCompletingWrite() {
PublisherSource.Processor<String, String> writeSource = newPublisherProcessor();
StepVerifiers.create(conn.write(fromSource(writeSource)).merge(conn.read()))
.then(() -> {
// Start reading request
channel.writeInbound(BEGIN);
// Send response
writeMsg(writeSource, BEGIN);
writeMsg(writeSource, END);
})
.expectNext(BEGIN)
.then(() -> {
closeNotifyAndVerifyClosing();
writeSource.onComplete();
})
.expectError(ClosedChannelException.class)
.verify();
}

private void receiveRequest() {
StepVerifiers.create(conn.read())
.then(() -> channel.writeInbound(BEGIN))
Expand Down
Loading

0 comments on commit 7fb6875

Please sign in to comment.