diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index cb94195cce15..3714b29a6e5b 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -1060,6 +1060,7 @@ public void messagesAvailable(final MessageProducer producer) { checkState( savedState.winningSubstream != null, "Headers should be received prior to messages."); if (savedState.winningSubstream != substream) { + GrpcUtil.closeQuietly(producer); return; } listenerSerializeExecutor.execute( diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index f20e772e92b3..2d5dca2969f2 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -60,6 +60,8 @@ import io.grpc.internal.StreamListener.MessageProducer; import java.io.InputStream; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.concurrent.Executor; @@ -991,6 +993,78 @@ public void messageAvailable() { verify(masterListener).messagesAvailable(messageProducer); } + @Test + public void inboundMessagesClosedOnCancel() throws Exception { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + retriableStream.request(1); + retriableStream.cancel(Status.CANCELLED.withDescription("on purpose")); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener = sublistenerCaptor1.getValue(); + listener.headersRead(new Metadata()); + InputStream is = mock(InputStream.class); + listener.messagesAvailable(new FakeMessageProducer(is)); + verify(masterListener, never()).messagesAvailable(any(MessageProducer.class)); + verify(is).close(); + } + + @Test + public void notAdd0PrevRetryAttemptsToRespHeaders() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor.capture()); + + sublistenerCaptor.getValue().headersRead(new Metadata()); + + ArgumentCaptor metadataCaptor = + ArgumentCaptor.forClass(Metadata.class); + verify(masterListener).headersRead(metadataCaptor.capture()); + assertEquals(null, metadataCaptor.getValue().get(GRPC_PREVIOUS_RPC_ATTEMPTS)); + } + + @Test + public void addPrevRetryAttemptsToRespHeaders() { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry + ClientStream mockStream2 = mock(ClientStream.class); + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream2).start(sublistenerCaptor2.capture()); + Metadata headers = new Metadata(); + headers.put(GRPC_PREVIOUS_RPC_ATTEMPTS, "3"); + sublistenerCaptor2.getValue().headersRead(headers); + + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(masterListener).headersRead(metadataCaptor.capture()); + Iterable iterable = metadataCaptor.getValue().getAll(GRPC_PREVIOUS_RPC_ATTEMPTS); + assertEquals(1, Iterables.size(iterable)); + assertEquals("1", iterable.iterator().next()); + } + @Test public void closedWhileDraining() { ClientStream mockStream1 = mock(ClientStream.class); @@ -2718,4 +2792,22 @@ private interface RetriableStreamRecorder { Status prestart(); } + + private static final class FakeMessageProducer implements MessageProducer { + private final Iterator iterator; + + public FakeMessageProducer(InputStream... iss) { + this.iterator = Arrays.asList(iss).iterator(); + } + + @Override + @Nullable + public InputStream next() { + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } + } }