From 95da1eddbc979f4ce78c9d1ac15bc4c1faba6dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 21 Aug 2024 07:01:48 +0200 Subject: [PATCH] feat: add option for cancelling queries when closing client (#3276) * feat: add option for cancelling queries when closing client Adds an option to keep track of all running queries and cancel these when the client is closed. * test: unflake test cases * fix: cancel query by default for batch transactions * chore: cleanup --- .../cloud/spanner/AbstractReadContext.java | 14 +- .../google/cloud/spanner/BatchClientImpl.java | 2 + .../cloud/spanner/GrpcStreamIterator.java | 21 ++- .../spanner/IsChannelShutdownException.java | 50 ++++++ .../spanner/SpannerExceptionFactory.java | 7 +- .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 45 ++++- .../cloud/spanner/spi/v1/SpannerRpc.java | 9 + .../CloseSpannerWithOpenResultSetTest.java | 164 ++++++++++++++++++ .../cloud/spanner/GrpcResultSetTest.java | 2 +- .../cloud/spanner/MockSpannerServiceImpl.java | 14 +- .../cloud/spanner/ReadFormatTestRunner.java | 2 +- 11 files changed, 318 insertions(+), 12 deletions(-) create mode 100644 google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java create mode 100644 google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java index 3c56bf0a738..caf0e06379e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/AbstractReadContext.java @@ -69,6 +69,7 @@ abstract class AbstractReadContext abstract static class Builder, T extends AbstractReadContext> { private SessionImpl session; + private boolean cancelQueryWhenClientIsClosed; private SpannerRpc rpc; private ISpan span; private TraceWrapper tracer; @@ -91,6 +92,11 @@ B setSession(SessionImpl session) { return self(); } + B setCancelQueryWhenClientIsClosed(boolean cancelQueryWhenClientIsClosed) { + this.cancelQueryWhenClientIsClosed = cancelQueryWhenClientIsClosed; + return self(); + } + B setRpc(SpannerRpc rpc) { this.rpc = rpc; return self(); @@ -450,6 +456,7 @@ void initTransaction() { final Object lock = new Object(); final SessionImpl session; + final boolean cancelQueryWhenClientIsClosed; final SpannerRpc rpc; final ExecutorProvider executorProvider; ISpan span; @@ -479,6 +486,7 @@ void initTransaction() { AbstractReadContext(Builder builder) { this.session = builder.session; + this.cancelQueryWhenClientIsClosed = builder.cancelQueryWhenClientIsClosed; this.rpc = builder.rpc; this.defaultPrefetchChunks = builder.defaultPrefetchChunks; this.defaultQueryOptions = builder.defaultQueryOptions; @@ -760,7 +768,8 @@ ResultSet executeQueryInternalWithOptions( rpc.getExecuteQueryRetryableCodes()) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks); + GrpcStreamIterator stream = + new GrpcStreamIterator(statement, prefetchChunks, cancelQueryWhenClientIsClosed); if (partitionToken != null) { request.setPartitionToken(partitionToken); } @@ -943,7 +952,8 @@ ResultSet readInternalWithOptions( rpc.getReadRetryableCodes()) { @Override CloseableIterator startStream(@Nullable ByteString resumeToken) { - GrpcStreamIterator stream = new GrpcStreamIterator(prefetchChunks); + GrpcStreamIterator stream = + new GrpcStreamIterator(prefetchChunks, cancelQueryWhenClientIsClosed); TransactionSelector selector = null; if (resumeToken != null) { builder.setResumeToken(resumeToken); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java index 22fb9f710c1..3d886dd383b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/BatchClientImpl.java @@ -54,6 +54,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(TimestampBound bound) { return new BatchReadOnlyTransactionImpl( MultiUseReadOnlyTransaction.newBuilder() .setSession(session) + .setCancelQueryWhenClientIsClosed(true) .setRpc(sessionClient.getSpanner().getRpc()) .setTimestampBound(bound) .setDefaultQueryOptions( @@ -75,6 +76,7 @@ public BatchReadOnlyTransaction batchReadOnlyTransaction(BatchTransactionId batc return new BatchReadOnlyTransactionImpl( MultiUseReadOnlyTransaction.newBuilder() .setSession(session) + .setCancelQueryWhenClientIsClosed(true) .setRpc(sessionClient.getSpanner().getRpc()) .setTransactionId(batchTransactionId.getTransactionId()) .setTimestamp(batchTransactionId.getTimestamp()) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java index dde6b69c461..af6b5683502 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/GrpcStreamIterator.java @@ -38,7 +38,7 @@ class GrpcStreamIterator extends AbstractIterator private static final Logger logger = Logger.getLogger(GrpcStreamIterator.class.getName()); private static final PartialResultSet END_OF_STREAM = PartialResultSet.newBuilder().build(); - private final ConsumerImpl consumer = new ConsumerImpl(); + private final ConsumerImpl consumer; private final BlockingQueue stream; private final Statement statement; @@ -49,13 +49,15 @@ class GrpcStreamIterator extends AbstractIterator private SpannerException error; @VisibleForTesting - GrpcStreamIterator(int prefetchChunks) { - this(null, prefetchChunks); + GrpcStreamIterator(int prefetchChunks, boolean cancelQueryWhenClientIsClosed) { + this(null, prefetchChunks, cancelQueryWhenClientIsClosed); } @VisibleForTesting - GrpcStreamIterator(Statement statement, int prefetchChunks) { + GrpcStreamIterator( + Statement statement, int prefetchChunks, boolean cancelQueryWhenClientIsClosed) { this.statement = statement; + this.consumer = new ConsumerImpl(cancelQueryWhenClientIsClosed); // One extra to allow for END_OF_STREAM message. this.stream = new LinkedBlockingQueue<>(prefetchChunks + 1); } @@ -136,6 +138,12 @@ private void addToStream(PartialResultSet results) { } private class ConsumerImpl implements SpannerRpc.ResultStreamConsumer { + private final boolean cancelQueryWhenClientIsClosed; + + ConsumerImpl(boolean cancelQueryWhenClientIsClosed) { + this.cancelQueryWhenClientIsClosed = cancelQueryWhenClientIsClosed; + } + @Override public void onPartialResultSet(PartialResultSet results) { addToStream(results); @@ -168,5 +176,10 @@ public void onError(SpannerException e) { error = e; addToStream(END_OF_STREAM); } + + @Override + public boolean cancelQueryWhenClientIsClosed() { + return this.cancelQueryWhenClientIsClosed; + } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java new file mode 100644 index 00000000000..367d75a13cb --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/IsChannelShutdownException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import com.google.api.gax.rpc.UnavailableException; +import com.google.common.base.Predicate; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; + +/** + * Predicate that checks whether an exception is a ChannelShutdownException. This exception is + * thrown by gRPC if the underlying gRPC stub has been shut down and uses the UNAVAILABLE error + * code. This means that it would normally be retried by the Spanner client, but this specific + * UNAVAILABLE error should not be retried, as it would otherwise directly return the same error. + */ +class IsChannelShutdownException implements Predicate { + + @Override + public boolean apply(Throwable input) { + Throwable cause = input; + do { + if (isUnavailableError(cause) + && (cause.getMessage().contains("Channel shutdown invoked") + || cause.getMessage().contains("Channel shutdownNow invoked"))) { + return true; + } + } while ((cause = cause.getCause()) != null); + return false; + } + + private boolean isUnavailableError(Throwable cause) { + return (cause instanceof UnavailableException) + || (cause instanceof StatusRuntimeException + && ((StatusRuntimeException) cause).getStatus().getCode() == Code.UNAVAILABLE); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java index 11c2715b03c..39b254fe997 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java @@ -333,7 +333,9 @@ private static boolean isRetryable(ErrorCode code, @Nullable Throwable cause) { case UNAVAILABLE: // SSLHandshakeException is (probably) not retryable, as it is an indication that the server // certificate was not accepted by the client. - return !hasCauseMatching(cause, Matchers.isSSLHandshakeException); + // Channel shutdown is also not a retryable exception. + return !(hasCauseMatching(cause, Matchers.isSSLHandshakeException) + || hasCauseMatching(cause, Matchers.IS_CHANNEL_SHUTDOWN_EXCEPTION)); case RESOURCE_EXHAUSTED: return SpannerException.extractRetryDelay(cause) > 0; default: @@ -356,5 +358,8 @@ private static class Matchers { static final Predicate isRetryableInternalError = new IsRetryableInternalError(); static final Predicate isSSLHandshakeException = new IsSslHandshakeException(); + + static final Predicate IS_CHANNEL_SHUTDOWN_EXCEPTION = + new IsChannelShutdownException(); } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 00ae72f169a..e1e15b851b4 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -201,6 +201,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -262,6 +263,9 @@ public class GapicSpannerRpc implements SpannerRpc { private final ScheduledExecutorService spannerWatchdog; + private final ConcurrentLinkedDeque responseObservers = + new ConcurrentLinkedDeque<>(); + private final boolean throttleAdministrativeRequests; private final RetrySettings retryAdministrativeRequestsSettings; private static final double ADMINISTRATIVE_REQUESTS_RATE_LIMIT = 1.0D; @@ -2004,9 +2008,29 @@ GrpcCallContext newCallContext( return (GrpcCallContext) context.merge(apiCallContextFromContext); } + void registerResponseObserver(SpannerResponseObserver responseObserver) { + responseObservers.add(responseObserver); + } + + void unregisterResponseObserver(SpannerResponseObserver responseObserver) { + responseObservers.remove(responseObserver); + } + + void closeResponseObservers() { + responseObservers.forEach(SpannerResponseObserver::close); + responseObservers.clear(); + } + + @InternalApi + @VisibleForTesting + public int getNumActiveResponseObservers() { + return responseObservers.size(); + } + @Override public void shutdown() { this.rpcIsClosed = true; + closeResponseObservers(); if (this.spannerStub != null) { this.spannerStub.close(); this.partitionedDmlStub.close(); @@ -2028,6 +2052,7 @@ public void shutdown() { public void shutdownNow() { this.rpcIsClosed = true; + closeResponseObservers(); this.spannerStub.close(); this.partitionedDmlStub.close(); this.instanceAdminStub.close(); @@ -2085,7 +2110,7 @@ public void cancel(@Nullable String message) { * A {@code ResponseObserver} that exposes the {@code StreamController} and delegates callbacks to * the {@link ResultStreamConsumer}. */ - private static class SpannerResponseObserver implements ResponseObserver { + private class SpannerResponseObserver implements ResponseObserver { private StreamController controller; private final ResultStreamConsumer consumer; @@ -2094,13 +2119,21 @@ public SpannerResponseObserver(ResultStreamConsumer consumer) { this.consumer = consumer; } + void close() { + if (this.controller != null) { + this.controller.cancel(); + } + } + @Override public void onStart(StreamController controller) { - // Disable the auto flow control to allow client library // set the number of messages it prefers to request controller.disableAutoInboundFlowControl(); this.controller = controller; + if (this.consumer.cancelQueryWhenClientIsClosed()) { + registerResponseObserver(this); + } } @Override @@ -2110,11 +2143,19 @@ public void onResponse(PartialResultSet response) { @Override public void onError(Throwable t) { + // Unregister the response observer when the query has completed with an error. + if (this.consumer.cancelQueryWhenClientIsClosed()) { + unregisterResponseObserver(this); + } consumer.onError(newSpannerException(t)); } @Override public void onComplete() { + // Unregister the response observer when the query has completed normally. + if (this.consumer.cancelQueryWhenClientIsClosed()) { + unregisterResponseObserver(this); + } consumer.onCompleted(); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 8f9f10ddb57..0b040df4197 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -153,6 +153,15 @@ interface ResultStreamConsumer { void onCompleted(); void onError(SpannerException e); + + /** + * Returns true if the stream should be cancelled when the Spanner client is closed. This + * returns true for {@link com.google.cloud.spanner.BatchReadOnlyTransaction}, as these use a + * non-pooled session. Pooled sessions are deleted when the Spanner client is closed, and this + * automatically also cancels any query that uses the session, which means that we don't need to + * explicitly cancel those queries when the Spanner client is closed. + */ + boolean cancelQueryWhenClientIsClosed(); } /** Handle for cancellation of a streaming read or query call. */ diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java new file mode 100644 index 00000000000..67b14f60a4e --- /dev/null +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/CloseSpannerWithOpenResultSetTest.java @@ -0,0 +1,164 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assume.assumeFalse; + +import com.google.cloud.NoCredentials; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.cloud.spanner.connection.AbstractMockServerTest; +import com.google.cloud.spanner.spi.v1.GapicSpannerRpc; +import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteSqlRequest; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.threeten.bp.Duration; + +@RunWith(JUnit4.class) +public class CloseSpannerWithOpenResultSetTest extends AbstractMockServerTest { + + Spanner createSpanner() { + return SpannerOptions.newBuilder() + .setProjectId("p") + .setHost(String.format("http://localhost:%d", getPort())) + .setChannelConfigurator(ManagedChannelBuilder::usePlaintext) + .setCredentials(NoCredentials.getInstance()) + .setSessionPoolOption( + SessionPoolOptions.newBuilder().setWaitForMinSessions(Duration.ofSeconds(5L)).build()) + .build() + .getService(); + } + + @After + public void cleanup() { + mockSpanner.unfreeze(); + mockSpanner.clearRequests(); + } + + @Test + public void testBatchClient_closedSpannerWithOpenResultSet_streamsAreCancelled() { + Spanner spanner = createSpanner(); + assumeFalse(spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSession()); + + BatchClient client = spanner.getBatchClient(DatabaseId.of("p", "i", "d")); + try (BatchReadOnlyTransaction transaction = + client.batchReadOnlyTransaction(TimestampBound.strong()); + ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) { + mockSpanner.freezeAfterReturningNumRows(1); + assertTrue(resultSet.next()); + ((SpannerImpl) spanner).close(1, TimeUnit.MILLISECONDS); + // This should return an error as the stream is cancelled. + SpannerException exception = assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.CANCELLED, exception.getErrorCode()); + } + } + + @Test + public void testNormalDatabaseClient_closedSpannerWithOpenResultSet_sessionsAreDeleted() + throws Exception { + Spanner spanner = createSpanner(); + assumeFalse(spanner.getOptions().getSessionPoolOptions().getUseMultiplexedSession()); + + DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d")); + try (ReadOnlyTransaction transaction = client.readOnlyTransaction(TimestampBound.strong()); + ResultSet resultSet = transaction.executeQuery(SELECT_RANDOM_STATEMENT)) { + mockSpanner.freezeAfterReturningNumRows(1); + assertTrue(resultSet.next()); + List executeSqlRequests = + mockSpanner.getRequestsOfType(ExecuteSqlRequest.class).stream() + .filter(request -> request.getSql().equals(SELECT_RANDOM_STATEMENT.getSql())) + .collect(Collectors.toList()); + assertEquals(1, executeSqlRequests.size()); + ExecutorService service = Executors.newSingleThreadExecutor(); + service.submit(spanner::close); + // Verify that the session that is used by this transaction is deleted. + // That will automatically cancel the query. + mockSpanner.waitForRequestsToContain( + request -> + request instanceof DeleteSessionRequest + && ((DeleteSessionRequest) request) + .getName() + .equals(executeSqlRequests.get(0).getSession()), + /*timeoutMillis=*/ 1000L); + service.shutdownNow(); + } + } + + @Test + public void testStreamsAreCleanedUp() throws Exception { + String invalidSql = "select * from foo"; + Statement invalidStatement = Statement.of(invalidSql); + mockSpanner.putStatementResult( + StatementResult.exception( + invalidStatement, + Status.NOT_FOUND.withDescription("Table not found: foo").asRuntimeException())); + int numThreads = 16; + int numQueries = 32; + try (Spanner spanner = createSpanner()) { + BatchClient client = spanner.getBatchClient(DatabaseId.of("p", "i", "d")); + ExecutorService service = Executors.newFixedThreadPool(numThreads); + List> futures = new ArrayList<>(numQueries); + for (int n = 0; n < numQueries; n++) { + futures.add( + service.submit( + () -> { + try (BatchReadOnlyTransaction transaction = + client.batchReadOnlyTransaction(TimestampBound.strong())) { + if (ThreadLocalRandom.current().nextInt(10) < 2) { + try (ResultSet resultSet = transaction.executeQuery(invalidStatement)) { + SpannerException exception = + assertThrows(SpannerException.class, resultSet::next); + assertEquals(ErrorCode.NOT_FOUND, exception.getErrorCode()); + } + } else { + try (ResultSet resultSet = + transaction.executeQuery(SELECT_RANDOM_STATEMENT)) { + while (resultSet.next()) { + assertNotNull(resultSet.getCurrentRowAsStruct()); + } + } + } + } + })); + } + service.shutdown(); + for (Future fut : futures) { + fut.get(); + } + assertTrue(service.awaitTermination(1L, TimeUnit.MINUTES)); + // Verify that all response observers have been unregistered. + assertEquals( + 0, ((GapicSpannerRpc) ((SpannerImpl) spanner).getRpc()).getNumActiveResponseObservers()); + } + } +} diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java index 2051e006d81..62336163eaf 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/GrpcResultSetTest.java @@ -81,7 +81,7 @@ public void onDone(boolean withBeginTransaction) {} @Before public void setUp() { - stream = new GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10, /*cancelQueryWhenClientIsClosed=*/ false); stream.setCall( new SpannerRpc.StreamingCall() { @Override diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java index 5266ecad7c8..9f0a2822d87 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/MockSpannerServiceImpl.java @@ -578,6 +578,7 @@ private static void checkStreamException( private final Object lock = new Object(); private Deque requests = new ConcurrentLinkedDeque<>(); private volatile CountDownLatch freezeLock = new CountDownLatch(0); + private final AtomicInteger freezeAfterReturningNumRows = new AtomicInteger(); private Queue exceptions = new ConcurrentLinkedQueue<>(); private boolean stickyGlobalExceptions = false; private ConcurrentMap statementResults = new ConcurrentHashMap<>(); @@ -784,6 +785,10 @@ public void unfreeze() { freezeLock.countDown(); } + public void freezeAfterReturningNumRows(int numRows) { + freezeAfterReturningNumRows.set(numRows); + } + public void setMaxSessionsInOneBatch(int max) { this.maxNumSessionsInOneBatch = max; } @@ -1678,7 +1683,8 @@ private void returnPartialResultSet( ByteString transactionId, TransactionSelector transactionSelector, StreamObserver responseObserver, - SimulatedExecutionTime executionTime) { + SimulatedExecutionTime executionTime) + throws Exception { ResultSetMetadata metadata = resultSet.getMetadata(); if (transactionId == null) { Transaction transaction = getTemporaryTransactionOrNull(transactionSelector); @@ -1700,6 +1706,12 @@ private void returnPartialResultSet( SimulatedExecutionTime.checkStreamException( index, executionTime.exceptions, executionTime.streamIndices); responseObserver.onNext(iterator.next()); + if (freezeAfterReturningNumRows.get() > 0) { + if (freezeAfterReturningNumRows.decrementAndGet() == 0) { + freeze(); + freezeLock.await(); + } + } index++; } responseObserver.onCompleted(); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java index 8d97d9d894b..c973b7e471e 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/ReadFormatTestRunner.java @@ -114,7 +114,7 @@ private static class TestCaseRunner { } private void run() throws Exception { - stream = new GrpcStreamIterator(10); + stream = new GrpcStreamIterator(10, /*cancelQueryWhenClientIsClosed=*/ false); stream.setCall( new SpannerRpc.StreamingCall() { @Override