From 7d5802bba77d0d0741c385f18721c5e5e9f29bb0 Mon Sep 17 00:00:00 2001 From: m-trieu Date: Tue, 15 Oct 2024 10:26:32 -0700 Subject: [PATCH 01/23] address PR comments --- .../worker/streaming/harness/GlobalDataStreamSender.java | 1 + 1 file changed, 1 insertion(+) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index ce5f3a7b6bfc..1fbf59892e61 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import com.google.common.base.Suppliers; import java.io.Closeable; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; From a2d5c8434c76c69023b0e6e499b8c682086bfd3f Mon Sep 17 00:00:00 2001 From: m-trieu Date: Mon, 14 Oct 2024 15:16:39 -0700 Subject: [PATCH 02/23] simplify budget distribution logic moving it into GrpcDirectGetWorkStream, simplify worker metadata consumption --- .../worker/streaming/harness/GlobalDataStreamSender.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index 1fbf59892e61..4dec7ead19f2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; -import com.google.common.base.Suppliers; import java.io.Closeable; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; @@ -47,7 +46,6 @@ public GetDataStream get() { if (!started) { started = true; } - return delegate.get(); } From d53310c6b13c22949ba96a4b222ab3ab25e2473c Mon Sep 17 00:00:00 2001 From: m-trieu Date: Mon, 14 Oct 2024 14:17:02 -0700 Subject: [PATCH 03/23] add shutdown and start mechanics to WindmillStreams --- .../worker/WorkItemCancelledException.java | 8 +- .../FanOutStreamingEngineWorkerHarness.java | 6 +- .../harness/WindmillStreamSender.java | 107 +++--- .../client/AbstractWindmillStream.java | 350 ++++++++++++------ .../windmill/client/WindmillStream.java | 13 +- .../client/grpc/AppendableInputStream.java | 8 +- .../client/grpc/GrpcCommitWorkStream.java | 201 ++++++---- .../client/grpc/GrpcDirectGetWorkStream.java | 47 +-- .../client/grpc/GrpcGetDataStream.java | 120 ++++-- .../grpc/GrpcGetDataStreamRequests.java | 74 +++- .../client/grpc/GrpcGetWorkStream.java | 52 +-- .../grpc/GrpcGetWorkerMetadataStream.java | 53 ++- .../client/grpc/GrpcWindmillServer.java | 35 +- .../grpc/GrpcWindmillStreamFactory.java | 88 ++++- .../grpc/WindmillStreamShutdownException.java | 35 ++ .../grpc/observers/DirectStreamObserver.java | 67 +++- .../dataflow/worker/FakeWindmillServer.java | 9 + ...anOutStreamingEngineWorkerHarnessTest.java | 4 +- .../harness/WindmillStreamSenderTest.java | 50 +-- .../client/WindmillStreamPoolTest.java | 3 + .../StreamingEngineWorkCommitterTest.java | 4 + .../client/grpc/GrpcCommitWorkStreamTest.java | 235 ++++++++++++ .../client/grpc/GrpcGetDataStreamTest.java | 202 ++++++++++ .../grpc/GrpcGetWorkerMetadataStreamTest.java | 44 +-- 24 files changed, 1350 insertions(+), 465 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index ec5122a8732a..a12a5075c5ee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -26,8 +26,12 @@ public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } - public WorkItemCancelledException(Throwable e) { - super(e); + public WorkItemCancelledException(String message, Throwable cause) { + super(message, cause); + } + + public WorkItemCancelledException(Throwable cause) { + super(cause); } /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 458cf57ca8e7..5ed39b43dafe 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -111,7 +111,6 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker @GuardedBy("this") private boolean started; - @GuardedBy("this") private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; private FanOutStreamingEngineWorkerHarness( @@ -204,9 +203,10 @@ public synchronized void start() { Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); getWorkerMetadataStream = streamFactory.createGetWorkerMetadataStream( - dispatcherClient.getWindmillMetadataServiceStubBlocking(), + dispatcherClient::getWindmillMetadataServiceStubBlocking, getWorkerMetadataThrottleTimer, this::consumeWorkerMetadata); + getWorkerMetadataStream.start(); started = true; } @@ -416,7 +416,7 @@ private WindmillStreamSender createAndStartWindmillStreamSender(Endpoint endpoin StreamGetDataClient.create( getDataStream, this::getGlobalDataStream, getDataMetricTracker), workCommitterFactory); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); return windmillStreamSender; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 744c3d74445f..5641ef6f648c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -18,10 +18,12 @@ package org.apache.beam.runners.dataflow.worker.streaming.harness; import java.io.Closeable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.WindmillConnection; @@ -37,20 +39,13 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudgetSpender; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.FixedStreamHeartbeatSender; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; /** * Owns and maintains a set of streams used to communicate with a specific Windmill worker. - * Underlying streams are "cached" in a threadsafe manner so that once {@link Supplier#get} is - * called, a stream that is already started is returned. - * - *

Holds references to {@link - * Supplier} because - * initializing the streams automatically start them, and we want to do so lazily here once the - * {@link GetWorkBudget} is set. * *

Once started, the underlying streams are "alive" until they are manually closed via {@link - * #close()} ()}. + * #close()}. * *

If closed, it means that the backend endpoint is no longer in the worker set. Once closed, * these instances are not reused. @@ -61,13 +56,15 @@ @Internal @ThreadSafe final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { + private static final String STREAM_STARTER_THREAD_NAME = "StartWindmillStreamThread-%d"; private final AtomicBoolean started; private final AtomicReference getWorkBudget; - private final Supplier getWorkStream; - private final Supplier getDataStream; - private final Supplier commitWorkStream; - private final Supplier workCommitter; + private final GetWorkStream getWorkStream; + private final GetDataStream getDataStream; + private final CommitWorkStream commitWorkStream; + private final WorkCommitter workCommitter; private final StreamingEngineThrottleTimers streamingEngineThrottleTimers; + private final ExecutorService streamStarter; private WindmillStreamSender( WindmillConnection connection, @@ -81,33 +78,28 @@ private WindmillStreamSender( this.getWorkBudget = getWorkBudget; this.streamingEngineThrottleTimers = StreamingEngineThrottleTimers.create(); - // All streams are memoized/cached since they are expensive to create and some implementations - // perform side effects on construction (i.e. sending initial requests to the stream server to - // initiate the streaming RPC connection). Stream instances connect/reconnect internally, so we - // can reuse the same instance through the entire lifecycle of WindmillStreamSender. + // Stream instances connect/reconnect internally, so we can reuse the same instance through the + // entire lifecycle of WindmillStreamSender. this.getDataStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createGetDataStream( - connection.stub(), streamingEngineThrottleTimers.getDataThrottleTimer())); + streamingEngineStreamFactory.createDirectGetDataStream( + connection, streamingEngineThrottleTimers.getDataThrottleTimer()); this.commitWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createCommitWorkStream( - connection.stub(), streamingEngineThrottleTimers.commitWorkThrottleTimer())); - this.workCommitter = - Suppliers.memoize(() -> workCommitterFactory.apply(commitWorkStream.get())); + streamingEngineStreamFactory.createDirectCommitWorkStream( + connection, streamingEngineThrottleTimers.commitWorkThrottleTimer()); + this.workCommitter = workCommitterFactory.apply(commitWorkStream); this.getWorkStream = - Suppliers.memoize( - () -> - streamingEngineStreamFactory.createDirectGetWorkStream( - connection, - withRequestBudget(getWorkRequest, getWorkBudget.get()), - streamingEngineThrottleTimers.getWorkThrottleTimer(), - FixedStreamHeartbeatSender.create(getDataStream.get()), - getDataClientFactory.apply(getDataStream.get()), - workCommitter.get(), - workItemScheduler)); + streamingEngineStreamFactory.createDirectGetWorkStream( + connection, + withRequestBudget(getWorkRequest, getWorkBudget.get()), + streamingEngineThrottleTimers.getWorkThrottleTimer(), + FixedStreamHeartbeatSender.create(getDataStream), + getDataClientFactory.apply(getDataStream), + workCommitter, + workItemScheduler); + // 3 threads, 1 for each stream type (GetWork, GetData, CommitWork). + this.streamStarter = + Executors.newFixedThreadPool( + 3, new ThreadFactoryBuilder().setNameFormat(STREAM_STARTER_THREAD_NAME).build()); } static WindmillStreamSender create( @@ -132,34 +124,35 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB return request.toBuilder().setMaxItems(budget.items()).setMaxBytes(budget.bytes()).build(); } - @SuppressWarnings("ReturnValueIgnored") - void startStreams() { - getWorkStream.get(); - getDataStream.get(); - commitWorkStream.get(); - workCommitter.get().start(); - // *stream.get() is all memoized in a threadsafe manner. - started.set(true); + synchronized void start() { + if (!started.get()) { + // Start these 3 streams in parallel since they each may perform blocking IO. + CompletableFuture.allOf( + CompletableFuture.runAsync(getWorkStream::start, streamStarter), + CompletableFuture.runAsync(getDataStream::start, streamStarter), + CompletableFuture.runAsync(commitWorkStream::start, streamStarter)) + .join(); + workCommitter.start(); + started.set(true); + } } @Override - public void close() { - // Supplier.get() starts the stream which is an expensive operation as it initiates the - // streaming RPCs by possibly making calls over the network. Do not close the streams unless - // they have already been started. + public synchronized void close() { if (started.get()) { - getWorkStream.get().shutdown(); - getDataStream.get().shutdown(); - workCommitter.get().stop(); - commitWorkStream.get().shutdown(); + getWorkStream.shutdown(); + getDataStream.shutdown(); + workCommitter.stop(); + commitWorkStream.shutdown(); } } @Override public void setBudget(long items, long bytes) { - getWorkBudget.set(getWorkBudget.get().apply(items, bytes)); + GetWorkBudget adjustment = GetWorkBudget.builder().setItems(items).setBytes(bytes).build(); + getWorkBudget.set(adjustment); if (started.get()) { - getWorkStream.get().setBudget(items, bytes); + getWorkStream.setBudget(adjustment); } } @@ -168,6 +161,6 @@ long getAndResetThrottleTime() { } long getCurrentActiveCommitBytes() { - return started.get() ? workCommitter.get().currentActiveCommitBytes() : 0; + return workCommitter.currentActiveCommitBytes(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 58aecfc71e00..aa5a50db13b3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -21,8 +21,9 @@ import java.io.PrintWriter; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -30,17 +31,20 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.StatusRuntimeException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Base class for persistent streams connecting to Windmill. @@ -49,11 +53,12 @@ * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on a * broken stream. * - *

Subclasses should override onResponse to handle responses from the server, and onNewStream to - * perform any work that must be done when a new stream is created, such as sending headers or - * retrying requests. + *

Subclasses should override {@link #onResponse(ResponseT)} to handle responses from the server, + * and {@link #onNewStream()} to perform any work that must be done when a new stream is created, + * such as sending headers or retrying requests. * - *

send and startStream should not be called from onResponse; use executor() instead. + *

{@link #send(RequestT)} and {@link #startStream()} should not be called from {@link + * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead. * *

Synchronization on this is used to synchronize the gRpc stream state and internal data * structures. Since grpc channel operations may block, synchronization on this stream may also @@ -63,32 +68,42 @@ */ public abstract class AbstractWindmillStream implements WindmillStream { - public static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; - private static final Logger LOG = LoggerFactory.getLogger(AbstractWindmillStream.class); + private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); + protected final AtomicBoolean clientClosed; - private final AtomicBoolean isShutdown; + protected final Sleeper sleeper; private final AtomicLong lastSendTimeMs; - private final Executor executor; + private final ExecutorService executor; private final BackOff backoff; private final AtomicLong startTimeMs; private final AtomicLong lastResponseTimeMs; + private final AtomicInteger restartCount; private final AtomicInteger errorCount; - private final AtomicReference lastError; - private final AtomicReference lastErrorTime; + private final AtomicReference lastRestartReason; + private final AtomicReference lastRestartTime; private final AtomicLong sleepUntil; private final CountDownLatch finishLatch; private final Set> streamRegistry; private final int logEveryNStreamFailures; - private final Supplier> requestObserverSupplier; - // Indicates if the current stream in requestObserver is closed by calling close() method - private final AtomicBoolean streamClosed; private final String backendWorkerToken; - private @Nullable StreamObserver requestObserver; + private final ResettableRequestObserver requestObserver; + private final AtomicBoolean isShutdown; + private final AtomicBoolean started; + private final AtomicReference shutdownTime; + + /** + * Indicates if the current {@link ResettableRequestObserver} was closed by calling {@link + * #halfClose()}. + */ + private final AtomicBoolean streamClosed; + + private final Logger logger; protected AbstractWindmillStream( + Logger logger, String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, @@ -107,20 +122,27 @@ protected AbstractWindmillStream( this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; this.clientClosed = new AtomicBoolean(); - this.streamClosed = new AtomicBoolean(); + this.isShutdown = new AtomicBoolean(false); + this.started = new AtomicBoolean(false); + this.streamClosed = new AtomicBoolean(false); this.startTimeMs = new AtomicLong(); this.lastSendTimeMs = new AtomicLong(); this.lastResponseTimeMs = new AtomicLong(); + this.restartCount = new AtomicInteger(); this.errorCount = new AtomicInteger(); - this.lastError = new AtomicReference<>(); - this.lastErrorTime = new AtomicReference<>(); + this.lastRestartReason = new AtomicReference<>(); + this.lastRestartTime = new AtomicReference<>(); this.sleepUntil = new AtomicLong(); this.finishLatch = new CountDownLatch(1); - this.isShutdown = new AtomicBoolean(false); - this.requestObserverSupplier = - () -> - streamObserverFactory.from( - clientFactory, new AbstractWindmillStream.ResponseObserver()); + this.requestObserver = + new ResettableRequestObserver<>( + () -> + streamObserverFactory.from( + clientFactory, + new AbstractWindmillStream.ResponseObserver())); + this.sleeper = Sleeper.DEFAULT; + this.logger = logger; + this.shutdownTime = new AtomicReference<>(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -130,10 +152,7 @@ private static String createThreadName(String streamType, String backendWorkerTo } private static long debugDuration(long nowMs, long startMs) { - if (startMs <= 0) { - return -1; - } - return Math.max(0, nowMs - startMs); + return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); } /** Called on each response from the server. */ @@ -160,7 +179,7 @@ protected boolean isShutdown() { private StreamObserver requestObserver() { if (requestObserver == null) { throw new NullPointerException( - "requestObserver cannot be null. Missing a call to startStream() to initialize."); + "requestObserver cannot be null. Missing a call to start() to initialize stream."); } return requestObserver; @@ -168,28 +187,52 @@ private StreamObserver requestObserver() { /** Send a request to the server. */ protected final void send(RequestT request) { - lastSendTimeMs.set(Instant.now().getMillis()); synchronized (this) { + if (isShutdown()) { + return; + } + if (streamClosed.get()) { throw new IllegalStateException("Send called on a client closed stream."); } - requestObserver().onNext(request); + try { + lastSendTimeMs.set(Instant.now().getMillis()); + requestObserver.onNext(request); + } catch (StreamObserverCancelledException e) { + if (isShutdown()) { + logger.debug("Stream was closed or shutdown during send.", e); + return; + } + + requestObserver.onError(e); + } + } + } + + @Override + public final void start() { + if (!isShutdown.get() && started.compareAndSet(false, true)) { + // start() should only be executed once during the lifetime of the stream for idempotency and + // when shutdown() has not been called. + startStream(); } } /** Starts the underlying stream. */ - protected final void startStream() { + private void startStream() { // Add the stream to the registry after it has been fully constructed. streamRegistry.add(this); while (true) { try { synchronized (this) { + if (isShutdown.get()) { + break; + } startTimeMs.set(Instant.now().getMillis()); lastResponseTimeMs.set(0); streamClosed.set(false); - // lazily initialize the requestObserver. Gets reset whenever the stream is reopened. - requestObserver = requestObserverSupplier.get(); + requestObserver.reset(); onNewStream(); if (clientClosed.get()) { halfClose(); @@ -197,43 +240,62 @@ protected final void startStream() { return; } } catch (Exception e) { - LOG.error("Failed to create new stream, retrying: ", e); + logger.error("Failed to create new stream, retrying: ", e); try { long sleep = backoff.nextBackOffMillis(); sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException | IOException i) { + sleeper.sleep(sleep); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + logger.info( + "Interrupted during stream creation backoff. The stream will not be created."); + break; + } catch (IOException ioe) { // Keep trying to create the stream. } } } + + // We were never able to start the stream, remove it from the stream registry. Otherwise, it is + // removed when closed. + streamRegistry.remove(this); } - protected final Executor executor() { - return executor; + /** + * Execute the runnable using the {@link #executor} handling the executor being in a shutdown + * state. + */ + protected final void executeSafely(Runnable runnable) { + try { + executor.execute(runnable); + } catch (RejectedExecutionException e) { + logger.debug("{}-{} has been shutdown.", getClass(), backendWorkerToken); + } } - public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { - if (lastSendTimeMs.get() < lastSendThreshold.getMillis() && !clientClosed.get()) { + public final void maybeSendHealthCheck(Instant lastSendThreshold) { + if (!clientClosed.get() && lastSendTimeMs.get() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); } catch (RuntimeException e) { - LOG.debug("Received exception sending health check.", e); + logger.debug("Received exception sending health check.", e); } } } protected abstract void sendHealthCheck(); - // Care is taken that synchronization on this is unnecessary for all status page information. - // Blocking sends are made beneath this stream object's lock which could block status page - // rendering. + /** + * @implNote Care is taken that synchronization on this is unnecessary for all status page + * information. Blocking sends are made beneath this stream object's lock which could block + * status page rendering. + */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - if (errorCount.get() > 0) { + if (restartCount.get() > 0) { writer.format( - ", %d errors, last error [ %s ] at [%s]", - errorCount.get(), lastError.get(), lastErrorTime.get()); + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + restartCount.get(), lastRestartReason.get(), lastRestartTime.get(), errorCount.get()); } if (clientClosed.get()) { writer.write(", client closed"); @@ -244,21 +306,26 @@ public final void appendSummaryHtml(PrintWriter writer) { writer.format(", %dms backoff remaining", sleepLeft); } writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s", + ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " + + "isShutdown: %s, shutdown time: %s", debugDuration(nowMs, startTimeMs.get()), debugDuration(nowMs, lastSendTimeMs.get()), debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get()); + streamClosed.get(), + isShutdown.get(), + shutdownTime.get()); } - // Don't require synchronization on stream, see the appendSummaryHtml comment. + /** + * @implNote Don't require synchronization on stream, see the {@link + * #appendSummaryHtml(PrintWriter)} comment. + */ protected abstract void appendSpecificHtml(PrintWriter writer); @Override public final synchronized void halfClose() { - // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. clientClosed.set(true); - requestObserver().onCompleted(); + requestObserver.onCompleted(); streamClosed.set(true); } @@ -278,24 +345,73 @@ public String backendWorkerToken() { } @Override - public void shutdown() { + public final void shutdown() { + // Don't lock here as isShutdown checks are used in the stream to free blocked + // threads or as exit conditions to loops. if (isShutdown.compareAndSet(false, true)) { requestObserver() .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + shutdownInternal(); + shutdownTime.set(DateTime.now()); } } - private void setLastError(String error) { - lastError.set(error); - lastErrorTime.set(DateTime.now()); + private void recordRestartReason(String error) { + lastRestartReason.set(error); + lastRestartTime.set(DateTime.now()); } + protected abstract void shutdownInternal(); + public static class WindmillStreamShutdownException extends RuntimeException { public WindmillStreamShutdownException(String message) { super(message); } } + /** + * Request observer that allows resetting its internal delegate using the given {@link + * #requestObserverSupplier}. + */ + @ThreadSafe + private static class ResettableRequestObserver implements StreamObserver { + + private final Supplier> requestObserverSupplier; + + @GuardedBy("this") + private @Nullable StreamObserver delegateRequestObserver; + + private ResettableRequestObserver(Supplier> requestObserverSupplier) { + this.requestObserverSupplier = requestObserverSupplier; + this.delegateRequestObserver = null; + } + + private synchronized StreamObserver delegate() { + return Preconditions.checkNotNull( + delegateRequestObserver, + "requestObserver cannot be null. Missing a call to startStream() to initialize."); + } + + private synchronized void reset() { + delegateRequestObserver = requestObserverSupplier.get(); + } + + @Override + public void onNext(RequestT requestT) { + delegate().onNext(requestT); + } + + @Override + public void onError(Throwable throwable) { + delegate().onError(throwable); + } + + @Override + public void onCompleted() { + delegate().onCompleted(); + } + } + private class ResponseObserver implements StreamObserver { @Override @@ -311,71 +427,91 @@ public void onNext(ResponseT response) { @Override public void onError(Throwable t) { - onStreamFinished(t); + if (maybeTeardownStream()) { + return; + } + + recordStreamStatus(Status.fromThrowable(t)); + + try { + long sleep = backoff.nextBackOffMillis(); + sleepUntil.set(Instant.now().getMillis() + sleep); + sleeper.sleep(sleep); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } catch (IOException e) { + // Ignore. + } + + executeSafely(AbstractWindmillStream.this::startStream); } @Override public void onCompleted() { - onStreamFinished(null); + if (maybeTeardownStream()) { + return; + } + recordStreamStatus(OK_STATUS); + executeSafely(AbstractWindmillStream.this::startStream); } - private void onStreamFinished(@Nullable Throwable t) { - synchronized (this) { - if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - return; - } - } - if (t != null) { - Status status = null; - if (t instanceof StatusRuntimeException) { - status = ((StatusRuntimeException) t).getStatus(); - } - String statusError = status == null ? "" : status.toString(); - setLastError(statusError); - if (errorCount.getAndIncrement() % logEveryNStreamFailures == 0) { + private void recordStreamStatus(Status status) { + int currentRestartCount = restartCount.incrementAndGet(); + if (status.isOk()) { + String restartReason = + "Stream completed successfully but did not complete requested operations, " + + "recreating"; + logger.warn(restartReason); + recordRestartReason(restartReason); + } else { + int currentErrorCount = errorCount.incrementAndGet(); + recordRestartReason(status.toString()); + Throwable t = status.getCause(); + if (t instanceof StreamObserverCancelledException) { + logger.error( + "StreamObserver was unexpectedly cancelled for stream={}, worker={}. stacktrace={}", + getClass(), + backendWorkerToken, + t.getStackTrace(), + t); + } else if (currentRestartCount % logEveryNStreamFailures == 0) { + // Don't log every restart since it will get noisy, and many errors transient. long nowMillis = Instant.now().getMillis(); - String responseDebug; - if (lastResponseTimeMs.get() == 0) { - responseDebug = "never received response"; - } else { - responseDebug = - "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - } - LOG.debug( - "{} streaming Windmill RPC errors for {}, last was: {} with status {}." - + " created {}ms ago, {}. This is normal with autoscaling.", + String responseDebug = + lastResponseTimeMs.get() == 0 + ? "never received response" + : "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; + + logger.debug( + "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", AbstractWindmillStream.this.getClass(), - errorCount.get(), + currentRestartCount, + currentErrorCount, t, - statusError, + status, nowMillis - startTimeMs.get(), responseDebug); } + // If the stream was stopped due to a resource exhausted error then we are throttled. - if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { + if (status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { startThrottleTimer(); } + } + } - try { - long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); - Thread.sleep(sleep); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } catch (IOException e) { - // Ignore. - } - } else { - errorCount.incrementAndGet(); - String error = - "Stream completed successfully but did not complete requested operations, " - + "recreating"; - LOG.warn(error); - setLastError(error); + /** Returns true if the stream was torn down and should not be restarted internally. */ + private synchronized boolean maybeTeardownStream() { + if (isShutdown() || (clientClosed.get() && !hasPendingRequests())) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); + return true; } - executor.execute(AbstractWindmillStream.this::startStream); + + return false; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index f26c56b14ec2..361531ce4f2d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -34,6 +34,12 @@ @ThreadSafe public interface WindmillStream { + /** + * Start the stream, opening a connection to the backend server. A call to start() is required for + * any further interactions on the stream. + */ + void start(); + /** An identifier for the backend worker where the stream is sending/receiving RPCs. */ String backendWorkerToken(); @@ -47,8 +53,9 @@ public interface WindmillStream { Instant startTime(); /** - * Shutdown the stream. There should be no further interactions with the stream once this has been - * called. + * Shuts down the stream. No further interactions should be made with the stream, and the stream + * will no longer try to connect internally. Any pending retries or in-flight requests will be + * cancelled and all responses dropped and considered invalid. */ void shutdown(); @@ -86,7 +93,7 @@ interface CommitWorkStream extends WindmillStream { * Returns a builder that can be used for sending requests. Each builder is not thread-safe but * different builders for the same stream may be used simultaneously. */ - CommitWorkStream.RequestBatcher batcher(); + RequestBatcher batcher(); @NotThreadSafe interface RequestBatcher extends Closeable { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java index 98545a429461..b15f73645dee 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/AppendableInputStream.java @@ -134,6 +134,12 @@ public void close() throws IOException { stream.close(); } + static class InvalidInputStreamStateException extends IllegalStateException { + public InvalidInputStreamStateException() { + super("Got poison pill or timeout but stream is not done."); + } + } + @SuppressWarnings("NullableProblems") private class InputStreamEnumeration implements Enumeration { // The first stream is eagerly read on SequenceInputStream creation. For this reason @@ -159,7 +165,7 @@ public boolean hasMoreElements() { if (complete.get()) { return false; } - throw new IllegalStateException("Got poison pill or timeout but stream is not done."); + throw new InvalidInputStreamStateException(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new CancellationException(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 053843a8af25..7fcbb80631a8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -19,14 +19,18 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import com.google.auto.value.AutoValue; import java.io.PrintWriter; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; @@ -40,17 +44,19 @@ import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.EvictingQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public final class GrpcCommitWorkStream +final class GrpcCommitWorkStream extends AbstractWindmillStream implements CommitWorkStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStream.class); private static final long HEARTBEAT_REQUEST_ID = Long.MAX_VALUE; - private final Map pending; + private final ConcurrentMap pending; private final AtomicLong idGenerator; private final JobHeader jobHeader; private final ThrottleTimer commitWorkThrottleTimer; @@ -69,6 +75,7 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( + LOG, "CommitWorkStream", startCommitWorkRpcFn, backoff, @@ -83,7 +90,7 @@ private GrpcCommitWorkStream( this.streamingRpcBatchLimit = streamingRpcBatchLimit; } - public static GrpcCommitWorkStream create( + static GrpcCommitWorkStream create( String backendWorkerToken, Function, StreamObserver> startCommitWorkRpcFn, @@ -95,20 +102,17 @@ public static GrpcCommitWorkStream create( JobHeader jobHeader, AtomicLong idGenerator, int streamingRpcBatchLimit) { - GrpcCommitWorkStream commitWorkStream = - new GrpcCommitWorkStream( - backendWorkerToken, - startCommitWorkRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - commitWorkThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit); - commitWorkStream.startStream(); - return commitWorkStream; + return new GrpcCommitWorkStream( + backendWorkerToken, + startCommitWorkRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit); } @Override @@ -156,29 +160,44 @@ public void sendHealthCheck() { protected void onResponse(StreamingCommitResponse response) { commitWorkThrottleTimer.stop(); - RuntimeException finalException = null; + CommitCompletionException failures = new CommitCompletionException(); for (int i = 0; i < response.getRequestIdCount(); ++i) { long requestId = response.getRequestId(i); if (requestId == HEARTBEAT_REQUEST_ID) { continue; } - PendingRequest done = pending.remove(requestId); - if (done == null) { - LOG.error("Got unknown commit request ID: {}", requestId); + PendingRequest pendingRequest = pending.remove(requestId); + CommitStatus commitStatus = + i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; + if (pendingRequest == null) { + if (!isShutdown()) { + // Skip responses when the stream is shutdown since they are now invalid. + LOG.error("Got unknown commit request ID: {}", requestId); + } } else { try { - done.onDone.accept( - (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); + pendingRequest.completeWithStatus(commitStatus); } catch (RuntimeException e) { // Catch possible exceptions to ensure that an exception for one commit does not prevent - // other commits from being processed. + // other commits from being processed. Aggregate all the failures to throw after + // processing the response if they exist. LOG.warn("Exception while processing commit response.", e); - finalException = e; + failures.recordError(commitStatus, e); } } } - if (finalException != null) { - throw finalException; + if (failures.hasErrors()) { + throw failures; + } + } + + @Override + protected void shutdownInternal() { + Iterator pendingRequests = pending.values().iterator(); + while (pendingRequests.hasNext()) { + PendingRequest pendingRequest = pendingRequests.next(); + pendingRequest.completeWithStatus(CommitStatus.ABORTED); + pendingRequests.remove(); } } @@ -187,13 +206,14 @@ protected void startThrottleTimer() { commitWorkThrottleTimer.start(); } - private void flushInternal(Map requests) { + private void flushInternal(Map requests) throws InterruptedException { if (requests.isEmpty()) { return; } + if (requests.size() == 1) { Map.Entry elem = requests.entrySet().iterator().next(); - if (elem.getValue().request.getSerializedSize() + if (elem.getValue().request().getSerializedSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { issueMultiChunkRequest(elem.getKey(), elem.getValue()); } else { @@ -204,14 +224,14 @@ private void flushInternal(Map requests) { } } - private void issueSingleRequest(final long id, PendingRequest pendingRequest) { + private void issueSingleRequest(long id, PendingRequest pendingRequest) { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() - .setComputationId(pendingRequest.computation) + .setComputationId(pendingRequest.computationId()) .setRequestId(id) - .setShardingKey(pendingRequest.request.getShardingKey()) - .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { pending.put(id, pendingRequest); @@ -229,13 +249,14 @@ private void issueBatchedRequest(Map requests) { for (Map.Entry entry : requests.entrySet()) { PendingRequest request = entry.getValue(); StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computation)) { - chunkBuilder.setComputationId(request.computation); - lastComputation = request.computation; + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); } - chunkBuilder.setRequestId(entry.getKey()); - chunkBuilder.setShardingKey(request.request.getShardingKey()); - chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { @@ -249,9 +270,8 @@ private void issueBatchedRequest(Map requests) { } private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { - checkNotNull(pendingRequest.computation); - final ByteString serializedCommit = pendingRequest.request.toByteString(); - + checkNotNull(pendingRequest.computationId()); + final ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { pending.put(id, pendingRequest); for (int i = 0; @@ -264,8 +284,8 @@ private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest StreamingCommitRequestChunk.newBuilder() .setRequestId(id) .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computation) - .setShardingKey(pendingRequest.request.getShardingKey()); + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); int remaining = serializedCommit.size() - end; if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); @@ -283,21 +303,34 @@ private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest } } - private static class PendingRequest { + @AutoValue + abstract static class PendingRequest { + + private static PendingRequest create( + String computationId, WorkItemCommitRequest request, Consumer onDone) { + return new AutoValue_GrpcCommitWorkStream_PendingRequest(computationId, request, onDone); + } - private final String computation; - private final WorkItemCommitRequest request; - private final Consumer onDone; + abstract String computationId(); + + abstract WorkItemCommitRequest request(); + + abstract Consumer onDone(); + + private long getBytes() { + return (long) request().getSerializedSize() + computationId().length(); + } - PendingRequest( - String computation, WorkItemCommitRequest request, Consumer onDone) { - this.computation = computation; - this.request = request; - this.onDone = onDone; + private ByteString serializedCommit() { + return request().toByteString(); } - long getBytes() { - return (long) request.getSerializedSize() + computation.length(); + private void completeWithStatus(CommitStatus commitStatus) { + onDone().accept(commitStatus); + } + + private long shardingKey() { + return request().getShardingKey(); } } @@ -317,7 +350,8 @@ public boolean commitWorkItem( if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { return false; } - PendingRequest request = new PendingRequest(computation, commitRequest, onDone); + + PendingRequest request = PendingRequest.create(computation, commitRequest, onDone); add(idGenerator.incrementAndGet(), request); return true; } @@ -325,21 +359,62 @@ public boolean commitWorkItem( /** Flushes any pending work items to the wire. */ @Override public void flush() { - flushInternal(queue); - queuedBytes = 0; - queue.clear(); + try { + if (!isShutdown()) { + flushInternal(queue); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + queuedBytes = 0; + queue.clear(); + } } void add(long id, PendingRequest request) { - assert (canAccept(request.getBytes())); + Preconditions.checkState(canAccept(request.getBytes())); queuedBytes += request.getBytes(); queue.put(id, request); } private boolean canAccept(long requestBytes) { - return queue.isEmpty() - || (queue.size() < streamingRpcBatchLimit - && (requestBytes + queuedBytes) < AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE); + return !isShutdown() + && (queue.isEmpty() + || (queue.size() < streamingRpcBatchLimit + && (requestBytes + queuedBytes) < AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE)); + } + } + + private static class CommitCompletionException extends RuntimeException { + private static final int MAX_PRINTABLE_ERRORS = 10; + private final Map>, Integer> errorCounter; + private final EvictingQueue detailedErrors; + + private CommitCompletionException() { + super("Exception while processing commit response."); + this.errorCounter = new HashMap<>(); + this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); + } + + private void recordError(CommitStatus commitStatus, Throwable error) { + errorCounter.compute( + Pair.of(commitStatus, error.getClass()), + (ignored, current) -> current == null ? 1 : current + 1); + detailedErrors.add(error); + } + + private boolean hasErrors() { + return !errorCounter.isEmpty(); + } + + @Override + public final String getMessage() { + return "CommitCompletionException{" + + "errorCounter=" + + errorCounter + + ", detailedErrors=" + + detailedErrors + + '}'; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index b27ebc8e9eee..8fbcb1479d50 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -21,7 +21,6 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import javax.annotation.concurrent.GuardedBy; @@ -105,6 +104,7 @@ private GrpcDirectGetWorkStream( WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -144,22 +144,19 @@ static GrpcDirectGetWorkStream create( GetDataClient getDataClient, WorkCommitter workCommitter, WorkItemScheduler workItemScheduler) { - GrpcDirectGetWorkStream getWorkStream = - new GrpcDirectGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - heartbeatSender, - getDataClient, - workCommitter, - workItemScheduler); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcDirectGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + heartbeatSender, + getDataClient, + workCommitter, + workItemScheduler); } private static Watermarks createWatermarks( @@ -207,8 +204,7 @@ protected synchronized void onNewStream() { StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( - requestHeader - .toBuilder() + requestHeader.toBuilder() .setMaxItems(initialGetWorkBudget.items()) .setMaxBytes(initialGetWorkBudget.bytes()) .build()) @@ -238,6 +234,11 @@ public void sendHealthCheck() { send(HEALTH_CHECK_REQUEST); } + @Override + protected void shutdownInternal() { + workItemAssemblers.clear(); + } + @Override protected void onResponse(StreamingGetWorkResponseChunk chunk) { getWorkThrottleTimer.stop(); @@ -277,14 +278,6 @@ public void setBudget(GetWorkBudget newBudget) { maybeSendRequestExtension(extension); } - private void executeSafely(Runnable runnable) { - try { - executor().execute(runnable); - } catch (RejectedExecutionException e) { - LOG.debug("{} has been shutdown.", getClass()); - } - } - /** * Tracks sent, received, max {@link GetWorkBudget} and uses this information to generate request * extensions. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index c99e05a77074..d2fe8395b109 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -31,10 +31,11 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; @@ -55,6 +56,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.VerifyException; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,6 +65,8 @@ final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStream.class); + private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = + StreamingGetDataRequest.newBuilder().build(); private final Deque batches; private final Map pending; @@ -90,6 +94,7 @@ private GrpcGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { super( + LOG, "GetDataStream", startGetDataRpcFn, backoff, @@ -107,7 +112,7 @@ private GrpcGetDataStream( this.processHeartbeatResponses = processHeartbeatResponses; } - public static GrpcGetDataStream create( + static GrpcGetDataStream create( String backendWorkerToken, Function, StreamObserver> startGetDataRpcFn, @@ -121,28 +126,29 @@ public static GrpcGetDataStream create( int streamingRpcBatchLimit, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { - GrpcGetDataStream getDataStream = - new GrpcGetDataStream( - backendWorkerToken, - startGetDataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getDataThrottleTimer, - jobHeader, - idGenerator, - streamingRpcBatchLimit, - sendKeyedGetDataRequests, - processHeartbeatResponses); - getDataStream.startStream(); - return getDataStream; + return new GrpcGetDataStream( + backendWorkerToken, + startGetDataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + idGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); } @Override protected synchronized void onNewStream() { + if (isShutdown()) { + return; + } + send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed.get()) { + if (clientClosed.get() && !isShutdown()) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. @@ -188,14 +194,23 @@ private long uniqueId() { @Override public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { - return issueRequest( - QueuedRequest.forComputation(uniqueId(), computation, request), - KeyedGetDataResponse::parseFrom); + try { + return issueRequest( + QueuedRequest.forComputation(uniqueId(), computation, request), + KeyedGetDataResponse::parseFrom); + } catch (WindmillStreamShutdownException e) { + throw new WorkItemCancelledException(request.getShardingKey()); + } } @Override public GlobalData requestGlobalData(GlobalDataRequest request) { - return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); + try { + return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); + } catch (WindmillStreamShutdownException e) { + throw new WorkItemCancelledException( + "SideInput fetch failed for request due to stream shutdown: " + request, e); + } } @Override @@ -273,10 +288,24 @@ public void onHeartbeatResponse(List resp @Override public void sendHealthCheck() { if (hasPendingRequests()) { - send(StreamingGetDataRequest.newBuilder().build()); + send(HEALTH_CHECK_REQUEST); } } + @Override + protected void shutdownInternal() { + // Stream has been explicitly closed. Drain pending input streams and request batches. + // Future calls to send RPCs will fail. + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); + } + @Override public void appendSpecificHtml(PrintWriter writer) { writer.format( @@ -302,30 +331,49 @@ public void appendSpecificHtml(PrintWriter writer) { } private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { - while (true) { + while (!isShutdown()) { request.resetResponseStream(); try { queueRequestAndWait(request); return parseFn.parse(request.getResponseStream()); - } catch (CancellationException e) { - // Retry issuing the request since the response stream was cancelled. - continue; + } catch (AppendableInputStream.InvalidInputStreamStateException + | VerifyException + | CancellationException e) { + handleShutdown(request, e); + if (!(e instanceof CancellationException)) { + throw e; + } } catch (IOException e) { LOG.error("Parsing GetData response failed: ", e); - continue; } catch (InterruptedException e) { Thread.currentThread().interrupt(); + handleShutdown(request, e); throw new RuntimeException(e); } finally { pending.remove(request.id()); } } + + // If we have exited the loop here, the stream has been shutdown. Cancel the response stream. + request.getResponseStream().cancel(); + throw new WindmillStreamShutdownException( + "Cannot send request=[" + request + "] on closed stream."); + } + + private void handleShutdown(QueuedRequest request, Throwable cause) { + if (isShutdown()) { + WindmillStreamShutdownException shutdownException = + new WindmillStreamShutdownException( + "Cannot send request=[" + request + "] on closed stream."); + shutdownException.addSuppressed(cause); + throw shutdownException; + } } private void queueRequestAndWait(QueuedRequest request) throws InterruptedException { QueuedBatch batch; boolean responsibleForSend = false; - CountDownLatch waitForSendLatch = null; + @Nullable QueuedBatch prevBatch = null; synchronized (batches) { batch = batches.isEmpty() ? null : batches.getLast(); if (batch == null @@ -333,7 +381,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept || batch.requests().size() >= streamingRpcBatchLimit || batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { if (batch != null) { - waitForSendLatch = batch.getLatch(); + prevBatch = batch; } batch = new QueuedBatch(); batches.addLast(batch); @@ -342,12 +390,12 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept batch.addRequest(request); } if (responsibleForSend) { - if (waitForSendLatch == null) { + if (prevBatch == null) { // If there was not a previous batch wait a little while to improve // batching. - Thread.sleep(1); + sleeper.sleep(1); } else { - waitForSendLatch.await(); + prevBatch.waitForSendOrFailNotification(); } // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for its completion. @@ -361,10 +409,10 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept } // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). - batch.countDown(); + batch.notifySent(); } else { // Wait for this batch to be sent before parsing the response. - batch.await(); + batch.waitForSendOrFailNotification(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index cda9537127d9..53dd4a5b5294 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -17,18 +17,27 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; + import com.google.auto.value.AutoOneOf; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.stream.Stream; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** Utility data classes for {@link GrpcGetDataStream}. */ final class GrpcGetDataStreamRequests { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetDataStreamRequests.class); + private static final int STREAM_CANCELLED_ERROR_LOG_LIMIT = 3; + private GrpcGetDataStreamRequests() {} static class QueuedRequest { @@ -81,6 +90,10 @@ void resetResponseStream() { this.responseStream = new AppendableInputStream(); } + public ComputationOrGlobalDataRequest getDataRequest() { + return dataRequest; + } + void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder builder) { builder.addRequestId(id); if (dataRequest.isForComputation()) { @@ -95,11 +108,8 @@ static class QueuedBatch { private final List requests = new ArrayList<>(); private final CountDownLatch sent = new CountDownLatch(1); private long byteSize = 0; - private boolean finalized = false; - - CountDownLatch getLatch() { - return sent; - } + private volatile boolean finalized = false; + private volatile boolean failed = false; List requests() { return requests; @@ -122,12 +132,62 @@ void addRequest(QueuedRequest request) { byteSize += request.byteSize(); } - void countDown() { + /** Let waiting for threads know that the request has been successfully sent. */ + synchronized void notifySent() { sent.countDown(); } - void await() throws InterruptedException { + /** Let waiting for threads know that a failure occurred. */ + synchronized void notifyFailed() { + failed = true; + sent.countDown(); + } + + /** + * Block until notified of a successful send via {@link #notifySent()} or a non-retryable + * failure via {@link #notifyFailed()}. On failure, throw an exception to on calling threads. + */ + void waitForSendOrFailNotification() throws InterruptedException { sent.await(); + if (failed) { + ImmutableList cancelledRequests = createStreamCancelledErrorMessage(); + LOG.error("Requests failed for the following batches: {}", cancelledRequests); + throw new WindmillStreamShutdownException( + "Requests failed for batch containing " + + String.join(", ", cancelledRequests) + + " ... requests. This is most likely due to the stream being explicitly closed" + + " which happens when the work is marked as invalid on the streaming" + + " backend when key ranges shuffle around. This is transient and corresponding" + + " work will eventually be retried."); + } + } + + ImmutableList createStreamCancelledErrorMessage() { + return requests.stream() + .flatMap( + request -> { + switch (request.getDataRequest().getKind()) { + case GLOBAL: + return Stream.of("GetSideInput=" + request.getDataRequest().global()); + case COMPUTATION: + return request.getDataRequest().computation().getRequestsList().stream() + .map( + keyedRequest -> + "KeyedGetState=[" + + "shardingKey=" + + keyedRequest.getShardingKey() + + "cacheToken=" + + keyedRequest.getCacheToken() + + "workToken" + + keyedRequest.getWorkToken() + + "]"); + default: + // Will never happen switch is exhaustive. + throw new IllegalStateException(); + } + }) + .limit(STREAM_CANCELLED_ERROR_LOG_LIMIT) + .collect(toImmutableList()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index a368f3fec235..fc38343afb98 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -36,11 +36,15 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; final class GrpcGetWorkStream extends AbstractWindmillStream implements GetWorkStream { + private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class); + private final GetWorkRequest request; private final WorkItemReceiver receiver; private final ThrottleTimer getWorkThrottleTimer; @@ -62,6 +66,7 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( + LOG, "GetWorkStream", startGetWorkRpcFn, backoff, @@ -90,19 +95,16 @@ public static GrpcGetWorkStream create( int logEveryNStreamFailures, ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { - GrpcGetWorkStream getWorkStream = - new GrpcGetWorkStream( - backendWorkerToken, - startGetWorkRpcFn, - request, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - getWorkThrottleTimer, - receiver); - getWorkStream.startStream(); - return getWorkStream; + return new GrpcGetWorkStream( + backendWorkerToken, + startGetWorkRpcFn, + request, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + getWorkThrottleTimer, + receiver); } private void sendRequestExtension(long moreItems, long moreBytes) { @@ -114,15 +116,14 @@ private void sendRequestExtension(long moreItems, long moreBytes) { .setMaxBytes(moreBytes)) .build(); - executor() - .execute( - () -> { - try { - send(extension); - } catch (IllegalStateException e) { - // Stream was closed. - } - }); + executeSafely( + () -> { + try { + send(extension); + } catch (IllegalStateException e) { + // Stream was closed. + } + }); } @Override @@ -133,6 +134,9 @@ protected synchronized void onNewStream() { send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -194,7 +198,5 @@ protected void startThrottleTimer() { } @Override - public void setBudget(GetWorkBudget newBudget) { - // no-op - } + public void setBudget(GetWorkBudget newBudget) {} } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 44e21a9b18ed..7fd1f011a4bb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -47,9 +47,6 @@ public final class GrpcGetWorkerMetadataStream private final Consumer serverMappingConsumer; private final Object metadataLock; - @GuardedBy("metadataLock") - private long metadataVersion; - @GuardedBy("metadataLock") private WorkerMetadataResponse latestResponse; @@ -61,10 +58,10 @@ private GrpcGetWorkerMetadataStream( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - long metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingConsumer) { super( + LOG, "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, @@ -73,7 +70,6 @@ private GrpcGetWorkerMetadataStream( logEveryNStreamFailures, ""); this.workerMetadataRequest = WorkerMetadataRequest.newBuilder().setHeader(jobHeader).build(); - this.metadataVersion = metadataVersion; this.getWorkerMetadataThrottleTimer = getWorkerMetadataThrottleTimer; this.serverMappingConsumer = serverMappingConsumer; this.latestResponse = WorkerMetadataResponse.getDefaultInstance(); @@ -88,23 +84,17 @@ public static GrpcGetWorkerMetadataStream create( Set> streamRegistry, int logEveryNStreamFailures, JobHeader jobHeader, - int metadataVersion, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingUpdater) { - GrpcGetWorkerMetadataStream getWorkerMetadataStream = - new GrpcGetWorkerMetadataStream( - startGetWorkerMetadataRpcFn, - backoff, - streamObserverFactory, - streamRegistry, - logEveryNStreamFailures, - jobHeader, - metadataVersion, - getWorkerMetadataThrottleTimer, - serverMappingUpdater); - LOG.info("Started GetWorkerMetadataStream. {}", getWorkerMetadataStream); - getWorkerMetadataStream.startStream(); - return getWorkerMetadataStream; + return new GrpcGetWorkerMetadataStream( + startGetWorkerMetadataRpcFn, + backoff, + streamObserverFactory, + streamRegistry, + logEveryNStreamFailures, + jobHeader, + getWorkerMetadataThrottleTimer, + serverMappingUpdater); } /** @@ -118,25 +108,23 @@ protected void onResponse(WorkerMetadataResponse response) { /** * Acquires the {@link #metadataLock} Returns {@link Optional} if the - * metadataVersion in the response is not stale (older or equal to {@link #metadataVersion}), else - * returns empty {@link Optional}. + * metadataVersion in the response is not stale (older or equal to current {@link + * WorkerMetadataResponse#getMetadataVersion()}), else returns empty {@link Optional}. */ private Optional extractWindmillEndpointsFrom( WorkerMetadataResponse response) { synchronized (metadataLock) { - if (response.getMetadataVersion() > this.metadataVersion) { - this.metadataVersion = response.getMetadataVersion(); + if (response.getMetadataVersion() > latestResponse.getMetadataVersion()) { this.latestResponse = response; return Optional.of(WindmillEndpoints.from(response)); } else { // If the currentMetadataVersion is greater than or equal to one in the response, the // response data is stale, and we do not want to do anything. - LOG.info( - "Received WorkerMetadataResponse={}; Received metadata version={}; Current metadata version={}. " + LOG.debug( + "Received metadata version={}; Current metadata version={}. " + "Skipping update because received stale metadata", - response, response.getMetadataVersion(), - this.metadataVersion); + latestResponse.getMetadataVersion()); } } @@ -144,10 +132,13 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected synchronized void onNewStream() { + protected void onNewStream() { send(workerMetadataRequest); } + @Override + protected void shutdownInternal() {} + @Override protected boolean hasPendingRequests() { return false; @@ -167,8 +158,8 @@ protected void sendHealthCheck() { protected void appendSpecificHtml(PrintWriter writer) { synchronized (metadataLock) { writer.format( - "GetWorkerMetadataStream: version=[%d] , job_header=[%s], latest_response=[%s]", - this.metadataVersion, workerMetadataRequest.getHeader(), this.latestResponse); + "GetWorkerMetadataStream: job_header=[%s], current_metadata=[%s]", + workerMetadataRequest.getHeader(), latestResponse); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 310495982679..1a105185b8a9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -332,27 +332,36 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { @Override public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { - return windmillStreamFactory.createGetWorkStream( - dispatcherClient.getWindmillServiceStub(), - GetWorkRequest.newBuilder(request) - .setJobId(options.getJobId()) - .setProjectId(options.getProject()) - .setWorkerId(options.getWorkerId()) - .build(), - throttleTimers.getWorkThrottleTimer(), - receiver); + GetWorkStream getWorkStream = + windmillStreamFactory.createGetWorkStream( + dispatcherClient.getWindmillServiceStub(), + GetWorkRequest.newBuilder(request) + .setJobId(options.getJobId()) + .setProjectId(options.getProject()) + .setWorkerId(options.getWorkerId()) + .build(), + throttleTimers.getWorkThrottleTimer(), + receiver); + getWorkStream.start(); + return getWorkStream; } @Override public GetDataStream getDataStream() { - return windmillStreamFactory.createGetDataStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + GetDataStream getDataStream = + windmillStreamFactory.createGetDataStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.getDataThrottleTimer()); + getDataStream.start(); + return getDataStream; } @Override public CommitWorkStream commitWorkStream() { - return windmillStreamFactory.createCommitWorkStream( - dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + CommitWorkStream commitWorkStream = + windmillStreamFactory.createCommitWorkStream( + dispatcherClient.getWindmillServiceStub(), throttleTimers.commitWorkThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 9e6a02d135e2..02c238fa3825 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -17,10 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; import com.google.auto.value.AutoBuilder; import java.io.PrintWriter; +import java.util.Collection; import java.util.List; import java.util.Set; import java.util.Timer; @@ -29,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; @@ -66,6 +68,8 @@ @ThreadSafe @Internal public class GrpcWindmillStreamFactory implements StatusDataProvider { + + private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; private static final Duration MIN_BACKOFF = Duration.millis(1); private static final Duration DEFAULT_MAX_BACKOFF = Duration.standardSeconds(30); private static final int DEFAULT_LOG_EVERY_N_STREAM_FAILURES = 1; @@ -73,6 +77,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider { private static final int DEFAULT_WINDMILL_MESSAGES_BETWEEN_IS_READY_CHECKS = 1; private static final int NO_HEALTH_CHECKS = -1; private static final String NO_BACKEND_WORKER_TOKEN = ""; + private static final String DISPATCHER_DEBUG_NAME = "Dispatcher"; private final JobHeader jobHeader; private final int logEveryNStreamFailures; @@ -93,7 +98,8 @@ private GrpcWindmillStreamFactory( int windmillMessagesBetweenIsReadyChecks, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, - Supplier maxBackOffSupplier) { + Supplier maxBackOffSupplier, + Set> streamRegistry) { this.jobHeader = jobHeader; this.logEveryNStreamFailures = logEveryNStreamFailures; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -106,7 +112,7 @@ private GrpcWindmillStreamFactory( .withInitialBackoff(MIN_BACKOFF) .withMaxBackoff(maxBackOffSupplier.get()) .backoff()); - this.streamRegistry = ConcurrentHashMap.newKeySet(); + this.streamRegistry = streamRegistry; this.sendKeyedGetDataRequests = sendKeyedGetDataRequests; this.processHeartbeatResponses = processHeartbeatResponses; this.streamIdGenerator = new AtomicLong(); @@ -121,7 +127,8 @@ static GrpcWindmillStreamFactory create( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, Supplier maxBackOffSupplier, - int healthCheckIntervalMillis) { + int healthCheckIntervalMillis, + Set> streamRegistry) { GrpcWindmillStreamFactory streamFactory = new GrpcWindmillStreamFactory( jobHeader, @@ -130,7 +137,8 @@ static GrpcWindmillStreamFactory create( windmillMessagesBetweenIsReadyChecks, sendKeyedGetDataRequests, processHeartbeatResponses, - maxBackOffSupplier); + maxBackOffSupplier, + streamRegistry); if (healthCheckIntervalMillis >= 0) { // Health checks are run on background daemon thread, which will only be cleaned up on JVM @@ -167,14 +175,27 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { .setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT) .setHealthCheckIntervalMillis(NO_HEALTH_CHECKS) .setSendKeyedGetDataRequests(true) - .setProcessHeartbeatResponses(ignored -> {}); + .setProcessHeartbeatResponses(ignored -> {}) + .setStreamRegistry(ConcurrentHashMap.newKeySet()); } private static > T withDefaultDeadline(T stub) { // Deadlines are absolute points in time, so generate a new one everytime this function is // called. - return stub.withDeadlineAfter( - AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + return stub.withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS); + } + + private static void printSummaryHtmlForWorker( + String workerToken, Collection> streams, PrintWriter writer) { + writer.write( + "" + (workerToken.isEmpty() ? DISPATCHER_DEBUG_NAME : workerToken) + ""); + writer.write("
"); + streams.forEach( + stream -> { + stream.appendSummaryHtml(writer); + writer.write("
"); + }); + writer.write("
"); } public GetWorkStream createGetWorkStream( @@ -204,7 +225,7 @@ public GetWorkStream createDirectGetWorkStream( WorkItemScheduler workItemScheduler) { return GrpcDirectGetWorkStream.create( connection.backendWorkerToken(), - responseObserver -> withDefaultDeadline(connection.stub()).getWorkStream(responseObserver), + responseObserver -> connection.stub().getWorkStream(responseObserver), request, grpcBackOff.get(), newStreamObserverFactory(), @@ -234,6 +255,23 @@ public GetDataStream createGetDataStream( processHeartbeatResponses); } + public GetDataStream createDirectGetDataStream( + WindmillConnection connection, ThrottleTimer getDataThrottleTimer) { + return GrpcGetDataStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().getDataStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + getDataThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit, + sendKeyedGetDataRequests, + processHeartbeatResponses); + } + public CommitWorkStream createCommitWorkStream( CloudWindmillServiceV1Alpha1Stub stub, ThrottleTimer commitWorkThrottleTimer) { return GrpcCommitWorkStream.create( @@ -249,18 +287,32 @@ public CommitWorkStream createCommitWorkStream( streamingRpcBatchLimit); } + public CommitWorkStream createDirectCommitWorkStream( + WindmillConnection connection, ThrottleTimer commitWorkThrottleTimer) { + return GrpcCommitWorkStream.create( + connection.backendWorkerToken(), + responseObserver -> connection.stub().commitWorkStream(responseObserver), + grpcBackOff.get(), + newStreamObserverFactory(), + streamRegistry, + logEveryNStreamFailures, + commitWorkThrottleTimer, + jobHeader, + streamIdGenerator, + streamingRpcBatchLimit); + } + public GetWorkerMetadataStream createGetWorkerMetadataStream( - CloudWindmillMetadataServiceV1Alpha1Stub stub, + Supplier stub, ThrottleTimer getWorkerMetadataThrottleTimer, Consumer onNewWindmillEndpoints) { return GrpcGetWorkerMetadataStream.create( - responseObserver -> withDefaultDeadline(stub).getWorkerMetadata(responseObserver), + responseObserver -> withDefaultDeadline(stub.get()).getWorkerMetadata(responseObserver), grpcBackOff.get(), newStreamObserverFactory(), streamRegistry, logEveryNStreamFailures, jobHeader, - 0, getWorkerMetadataThrottleTimer, onNewWindmillEndpoints); } @@ -273,10 +325,12 @@ private StreamObserverFactory newStreamObserverFactory() { @Override public void appendSummaryHtml(PrintWriter writer) { writer.write("Active Streams:
"); - for (AbstractWindmillStream stream : streamRegistry) { - stream.appendSummaryHtml(writer); - writer.write("
"); - } + streamRegistry.stream() + .collect( + toImmutableListMultimap( + AbstractWindmillStream::backendWorkerToken, Function.identity())) + .asMap() + .forEach((workerToken, streams) -> printSummaryHtmlForWorker(workerToken, streams, writer)); } @Internal @@ -299,6 +353,8 @@ Builder setProcessHeartbeatResponses( Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis); + Builder setStreamRegistry(Set> streamRegistry); + GrpcWindmillStreamFactory build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java new file mode 100644 index 000000000000..0c146e298be0 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.sdk.annotations.Internal; + +/** + * Indicates that a {@link WindmillStream#shutdown()} was called while waiting for some internal + * operation to complete. Most common use of this exception should be conversion to a {@link + * org.apache.beam.runners.dataflow.worker.WorkItemCancelledException} as the {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} being processed by {@link + * WindmillStream}. + */ +@Internal +final class WindmillStreamShutdownException extends RuntimeException { + WindmillStreamShutdownException(String message) { + super(message); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 9d57df1af317..4d798e8d18ea 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -39,7 +39,9 @@ @ThreadSafe public final class DirectStreamObserver implements StreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); - private final Phaser phaser; + private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; + + private final Phaser isReadyNotifier; private final Object lock = new Object(); @@ -53,11 +55,11 @@ public final class DirectStreamObserver implements StreamObserver { private int messagesSinceReady = 0; public DirectStreamObserver( - Phaser phaser, + Phaser isReadyNotifier, CallStreamObserver outboundObserver, long deadlineSeconds, int messagesBetweenIsReadyChecks) { - this.phaser = phaser; + this.isReadyNotifier = isReadyNotifier; this.outboundObserver = outboundObserver; this.deadlineSeconds = deadlineSeconds; // We always let the first message pass through without blocking because it is performed under @@ -74,6 +76,16 @@ public void onNext(T value) { while (true) { try { synchronized (lock) { + // If we awaited previously and timed out, wait for the same phase. Otherwise we're + // careful to observe the phase before observing isReady. + if (awaitPhase < 0) { + awaitPhase = isReadyNotifier.getPhase(); + // If getPhase() returns a value less than 0, the phaser has been terminated. + if (awaitPhase < 0) { + return; + } + } + // We only check isReady periodically to effectively allow for increasing the outbound // buffer periodically. This reduces the overhead of blocking while still restricting // memory because there is a limited # of streams, and we have a max messages size of 2MB. @@ -81,53 +93,63 @@ public void onNext(T value) { outboundObserver.onNext(value); return; } - // If we awaited previously and timed out, wait for the same phase. Otherwise we're - // careful to observe the phase before observing isReady. - if (awaitPhase < 0) { - awaitPhase = phaser.getPhase(); - } + if (outboundObserver.isReady()) { messagesSinceReady = 0; outboundObserver.onNext(value); return; } } + // A callback has been registered to advance the phaser whenever the observer // transitions to is ready. Since we are waiting for a phase observed before the // outboundObserver.isReady() returned false, we expect it to advance after the // channel has become ready. This doesn't always seem to be the case (despite // documentation stating otherwise) so we poll periodically and enforce an overall // timeout related to the stream deadline. - phaser.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + int nextPhase = + isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); + // If nextPhase is a value less than 0, the phaser has been terminated. + if (nextPhase < 0) { + return; + } + synchronized (lock) { messagesSinceReady = 0; outboundObserver.onNext(value); return; } } catch (TimeoutException e) { + if (isReadyNotifier.isTerminated()) { + return; + } + totalSecondsWaited += waitSeconds; if (totalSecondsWaited > deadlineSeconds) { - LOG.error( - "Exceeded timeout waiting for the outboundObserver to become ready meaning " - + "that the stream deadline was not respected."); - throw new RuntimeException(e); + String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); + LOG.error(errorMessage); + throw new StreamObserverCancelledException(errorMessage, e); } - if (totalSecondsWaited > 30) { + + if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) { LOG.info( "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, Thread.currentThread().getName()); } + waitSeconds = waitSeconds * 2; } catch (InterruptedException e) { Thread.currentThread().interrupt(); - throw new RuntimeException(e); + throw new StreamObserverCancelledException(e); } } } @Override public void onError(Throwable t) { + // Free the blocked threads in onNext(). + isReadyNotifier.forceTermination(); synchronized (lock) { outboundObserver.onError(t); } @@ -135,8 +157,23 @@ public void onError(Throwable t) { @Override public void onCompleted() { + // Free the blocked threads in onNext(). + isReadyNotifier.forceTermination(); synchronized (lock) { outboundObserver.onCompleted(); } } + + private String constructStreamCancelledErrorMessage(long totalSecondsWaited) { + return deadlineSeconds > 0 + ? "Waited " + + totalSecondsWaited + + "s which exceeds given deadline of " + + deadlineSeconds + + "s for the outboundObserver to become ready meaning " + + "that the stream deadline was not respected." + : "Output channel has been blocked for " + + totalSecondsWaited + + "s. Restarting stream internally."; + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 90ffb3d3fbcf..1da48bd2b7ce 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -236,6 +236,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -299,6 +302,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} @@ -380,6 +386,9 @@ public String backendWorkerToken() { return ""; } + @Override + public void start() {} + @Override public void shutdown() {} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 0092fcc7bcd1..7d0658534080 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -241,8 +241,8 @@ public void testStreamsStartCorrectly() throws InterruptedException { any(), eq(noOpProcessWorkItemFn())); - verify(streamFactory, times(2)).createGetDataStream(any(), any()); - verify(streamFactory, times(2)).createCommitWorkStream(any(), any()); + verify(streamFactory, times(2)).createDirectGetDataStream(any(), any()); + verify(streamFactory, times(2)).createDirectCommitWorkStream(any(), any()); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index 32d1f5738086..1aaad7c3c95d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -23,7 +23,6 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; @@ -96,14 +95,13 @@ public void testStartStream_startsAllStreams() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); verify(streamFactory) .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST - .toBuilder() + GET_WORK_REQUEST.toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), @@ -113,8 +111,8 @@ public void testStartStream_startsAllStreams() { any(), eq(workItemScheduler)); - verify(streamFactory).createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); - verify(streamFactory).createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + verify(streamFactory).createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); + verify(streamFactory).createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -126,16 +124,15 @@ public void testStartStream_onlyStartsStreamsOnce() { newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); + windmillStreamSender.start(); + windmillStreamSender.start(); verify(streamFactory, times(1)) .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST - .toBuilder() + GET_WORK_REQUEST.toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), @@ -146,9 +143,9 @@ public void testStartStream_onlyStartsStreamsOnce() { eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -160,10 +157,10 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted newWindmillStreamSender( GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); - Thread startStreamThread = new Thread(windmillStreamSender::startStreams); + Thread startStreamThread = new Thread(windmillStreamSender::start); startStreamThread.start(); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); startStreamThread.join(); @@ -171,8 +168,7 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST - .toBuilder() + GET_WORK_REQUEST.toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), @@ -183,19 +179,9 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted eq(workItemScheduler)); verify(streamFactory, times(1)) - .createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class)); + .createDirectGetDataStream(eq(connection), any(ThrottleTimer.class)); verify(streamFactory, times(1)) - .createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class)); - } - - @Test - public void testCloseAllStreams_doesNotCloseUnstartedStreams() { - WindmillStreamSender windmillStreamSender = - newWindmillStreamSender(GetWorkBudget.builder().setBytes(1L).setItems(1L).build()); - - windmillStreamSender.close(); - - verifyNoInteractions(streamFactory); + .createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class)); } @Test @@ -219,9 +205,9 @@ public void testCloseAllStreams_closesAllStreams() { eq(workItemScheduler))) .thenReturn(mockGetWorkStream); - when(mockStreamFactory.createGetDataStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockGetDataStream); - when(mockStreamFactory.createCommitWorkStream(eq(connection.stub()), any(ThrottleTimer.class))) + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) .thenReturn(mockCommitWorkStream); WindmillStreamSender windmillStreamSender = @@ -229,7 +215,7 @@ public void testCloseAllStreams_closesAllStreams() { GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), mockStreamFactory); - windmillStreamSender.startStreams(); + windmillStreamSender.start(); windmillStreamSender.close(); verify(mockGetWorkStream).shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index bdad382c9af2..fdd213223987 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -260,5 +260,8 @@ public String backendWorkerToken() { public void shutdown() { halfClose(); } + + @Override + public void start() {} } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 546a2883e3b2..e7317553af02 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -261,6 +261,10 @@ public void testStop_drainsCommitQueue() { Supplier fakeCommitWorkStream = () -> new CommitWorkStream() { + + @Override + public void start() {} + @Override public RequestBatcher batcher() { return new RequestBatcher() { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java new file mode 100644 index 000000000000..0563d69c0d9f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcCommitWorkStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetDataStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + private static final String COMPUTATION_ID = "computationId"; + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + private static Windmill.WorkItemCommitRequest workItemCommitRequest(long value) { + return Windmill.WorkItemCommitRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(value) + .setWorkToken(value) + .setCacheToken(value) + .build(); + } + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcCommitWorkStream commitWorkStream = + (GrpcCommitWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createCommitWorkStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + commitWorkStream.start(); + return commitWorkStream; + } + + @Test + public void testShutdown_abortsQueuedCommits() throws InterruptedException { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + Set onDone = new HashSet<>(); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + CommitWorkStreamStreamTestStub testStub = new CommitWorkStreamStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> { + onDone.add(commitStatus); + commitProcessed.countDown(); + }); + } + } + + // Verify that we sent the commits above in a request + the initial header. + verify(requestObserver, times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class)); + // We won't get responses so we will have some pending requests. + assertTrue(commitWorkStream.hasPendingRequests()); + + commitWorkStream.shutdown(); + commitProcessed.await(); + + assertThat(onDone).containsExactly(Windmill.CommitStatus.ABORTED); + } + + @Test + public void testCommitWorkItem_afterShutdownFalse() { + int numCommits = 5; + + CommitWorkStreamStreamTestStub testStub = + new CommitWorkStreamStreamTestStub(new TestCommitWorkStreamRequestObserver()); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue(batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + } + } + commitWorkStream.shutdown(); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertFalse( + batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + } + } + } + + @Test + public void testSend_notCalledAfterShutdown() { + int numCommits = 5; + CountDownLatch commitProcessed = new CountDownLatch(numCommits); + + TestCommitWorkStreamRequestObserver requestObserver = + spy(new TestCommitWorkStreamRequestObserver()); + CommitWorkStreamStreamTestStub testStub = new CommitWorkStreamStreamTestStub(requestObserver); + GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + + try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { + for (int i = 0; i < numCommits; i++) { + assertTrue( + batcher.commitWorkItem( + COMPUTATION_ID, + workItemCommitRequest(i), + commitStatus -> commitProcessed.countDown())); + } + commitWorkStream.shutdown(); + } + + // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends + // the header, which happens before we shutdown. + verify(requestObserver, times(1)).onNext(any(Windmill.StreamingCommitWorkRequest.class)); + } + + private static class TestCommitWorkStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class CommitWorkStreamStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestCommitWorkStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private CommitWorkStreamStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java new file mode 100644 index 000000000000..e4570f5afc9a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.ServerCallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamTest { + private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetDataStreamTest"; + private static final Windmill.JobHeader TEST_JOB_HEADER = + Windmill.JobHeader.newBuilder() + .setJobId("test_job") + .setWorkerId("test_worker") + .setProjectId("test_project") + .build(); + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); + private ManagedChannel inProcessChannel; + + @Before + public void setUp() throws IOException { + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); + } + + private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { + serviceRegistry.addService(testStub); + GrpcGetDataStream getDataStream = + (GrpcGetDataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setSendKeyedGetDataRequests(false) + .build() + .createGetDataStream( + CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer()); + getDataStream.start(); + return getDataStream; + } + + @Test + public void testRequestKeyedData_sendOnShutdownStreamThrowsWorkItemCancelledException() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + int numSendThreads = 5; + ExecutorService getDataStreamSenders = Executors.newFixedThreadPool(numSendThreads); + CountDownLatch waitForSendAttempt = new CountDownLatch(1); + // These will block until they are successfully sent. + List> sendFutures = + IntStream.range(0, 5) + .sequential() + .mapToObj( + i -> + (Runnable) + () -> { + // Prevent some threads from sending until we close the stream. + if (i % 2 == 0) { + try { + waitForSendAttempt.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(i) + .setCacheToken(i) + .setWorkToken(i) + .build()); + }) + // Run the code above on multiple threads. + .map(runnable -> CompletableFuture.runAsync(runnable, getDataStreamSenders)) + .collect(Collectors.toList()); + + getDataStream.shutdown(); + + // Free up waiting threads so that they can try to send on a closed stream. + waitForSendAttempt.countDown(); + + for (int i = 0; i < numSendThreads; i++) { + CompletableFuture sendFuture = sendFutures.get(i); + try { + // Wait for future to complete. + sendFuture.join(); + } catch (Exception ignored) { + } + if (i % 2 == 0) { + assertTrue(sendFuture.isCompletedExceptionally()); + ExecutionException e = assertThrows(ExecutionException.class, sendFuture::get); + assertThat(e).hasCauseThat().isInstanceOf(WorkItemCancelledException.class); + } + } + } + + private static class TestGetDataStreamRequestObserver + implements StreamObserver { + private @Nullable StreamObserver responseObserver; + + @Override + public void onNext(Windmill.StreamingGetDataRequest streamingGetDataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() { + if (responseObserver != null) { + responseObserver.onCompleted(); + } + } + } + + private static class GetDataStreamTestStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + private final TestGetDataStreamRequestObserver requestObserver; + private @Nullable StreamObserver responseObserver; + + private GetDataStreamTestStub(TestGetDataStreamRequestObserver requestObserver) { + this.requestObserver = requestObserver; + } + + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + if (this.responseObserver == null) { + ((ServerCallStreamObserver) responseObserver) + .setOnCancelHandler(() -> {}); + this.responseObserver = responseObserver; + requestObserver.responseObserver = this.responseObserver; + } + + return requestObserver; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 4439c409b32f..a8b828905e5e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream.DEFAULT_STREAM_RPC_DEADLINE_SECONDS; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.verify; @@ -30,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -39,9 +39,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; -import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; @@ -82,28 +80,26 @@ public class GrpcGetWorkerMetadataStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); - private final Set> streamRegistry = new HashSet<>(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; + private Set> streamRegistry; private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( GetWorkerMetadataTestStub getWorkerMetadataTestStub, - int metadataVersion, Consumer endpointsConsumer) { serviceRegistry.addService(getWorkerMetadataTestStub); - return GrpcGetWorkerMetadataStream.create( - responseObserver -> - CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel) - .getWorkerMetadata(responseObserver), - FluentBackoff.DEFAULT.backoff(), - StreamObserverFactory.direct(DEFAULT_STREAM_RPC_DEADLINE_SECONDS * 2, 1), - streamRegistry, - 1, // logEveryNStreamFailures - TEST_JOB_HEADER, - metadataVersion, - new ThrottleTimer(), - endpointsConsumer); + GrpcGetWorkerMetadataStream getWorkerMetadataStream = + (GrpcGetWorkerMetadataStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .setStreamRegistry(streamRegistry) + .build() + .createGetWorkerMetadataStream( + () -> CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer(), + endpointsConsumer); + getWorkerMetadataStream.start(); + return getWorkerMetadataStream; } @Before @@ -126,6 +122,7 @@ public void setUp() throws IOException { .setDirectEndpoint(IPV6_ADDRESS_1) .setBackendWorkerToken("worker_token") .build()); + streamRegistry = ConcurrentHashMap.newKeySet(); } @After @@ -146,8 +143,7 @@ public void testGetWorkerMetadata() { new TestWindmillEndpointsConsumer(); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = -1; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(mockResponse); assertThat(testWindmillEndpointsConsumer.globalDataEndpoints.keySet()) @@ -175,8 +171,7 @@ public void testGetWorkerMetadata_consumesSubsequentResponseMetadata() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(initialResponse); List newDirectPathEndpoints = @@ -222,8 +217,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { Mockito.spy(new TestWindmillEndpointsConsumer()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - int metadataVersion = 0; - stream = getWorkerMetadataTestStream(testStub, metadataVersion, testWindmillEndpointsConsumer); + stream = getWorkerMetadataTestStream(testStub, testWindmillEndpointsConsumer); testStub.injectWorkerMetadata(freshEndpoints); List staleDirectPathEndpoints = @@ -252,7 +246,7 @@ public void testGetWorkerMetadata_doesNotConsumeResponseIfMetadataStale() { public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(new TestGetWorkMetadataRequestObserver()); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); testStub.injectWorkerMetadata( WorkerMetadataResponse.newBuilder() .setMetadataVersion(1) @@ -270,7 +264,7 @@ public void testSendHealthCheck() { TestGetWorkMetadataRequestObserver requestObserver = Mockito.spy(new TestGetWorkMetadataRequestObserver()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); - stream = getWorkerMetadataTestStream(testStub, 0, new TestWindmillEndpointsConsumer()); + stream = getWorkerMetadataTestStream(testStub, new TestWindmillEndpointsConsumer()); stream.sendHealthCheck(); verify(requestObserver).onNext(WorkerMetadataRequest.getDefaultInstance()); From f30455bf960ba9eef4f3d241519bf34dc839f260 Mon Sep 17 00:00:00 2001 From: m-trieu Date: Tue, 15 Oct 2024 10:09:12 -0700 Subject: [PATCH 04/23] address PR comments 1 --- .../FanOutStreamingEngineWorkerHarness.java | 25 +++++++------ .../harness/GlobalDataStreamSender.java | 25 +++++++------ .../client/AbstractWindmillStream.java | 37 +++++++++++++------ .../client/grpc/GrpcCommitWorkStream.java | 4 +- .../client/grpc/GrpcGetDataStream.java | 17 +++++++-- .../client/grpc/GrpcGetWorkStream.java | 4 +- .../harness/WindmillStreamSenderTest.java | 9 +++-- 7 files changed, 76 insertions(+), 45 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index 5ed39b43dafe..a43f03f6ff53 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -32,10 +32,10 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; -import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -102,6 +102,8 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker /** Writes are guarded by synchronization, reads are lock free. */ private final AtomicReference backends; + private final GetWorkerMetadataStream getWorkerMetadataStream; + @GuardedBy("this") private long activeMetadataVersion; @@ -111,8 +113,6 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker @GuardedBy("this") private boolean started; - private @Nullable GetWorkerMetadataStream getWorkerMetadataStream; - private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -142,7 +142,14 @@ private FanOutStreamingEngineWorkerHarness( this.totalGetWorkBudget = totalGetWorkBudget; this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - this.getWorkerMetadataStream = null; + // To satisfy CheckerFramework complaining about reference to "this" in constructor. + @SuppressWarnings("methodref.receiver.bound") + Consumer newEndpointsConsumer = this::consumeWorkerMetadata; + this.getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient::getWindmillMetadataServiceStubBlocking, + getWorkerMetadataThrottleTimer, + newEndpointsConsumer); } /** @@ -201,11 +208,6 @@ static FanOutStreamingEngineWorkerHarness forTesting( @Override public synchronized void start() { Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); - getWorkerMetadataStream = - streamFactory.createGetWorkerMetadataStream( - dispatcherClient::getWindmillMetadataServiceStubBlocking, - getWorkerMetadataThrottleTimer, - this::consumeWorkerMetadata); getWorkerMetadataStream.start(); started = true; } @@ -393,9 +395,8 @@ private GlobalDataStreamSender getOrCreateGlobalDataSteam( .orElseGet( () -> new GlobalDataStreamSender( - () -> - streamFactory.createGetDataStream( - createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), + streamFactory.createGetDataStream( + createWindmillStub(keyedEndpoint.getValue()), new ThrottleTimer()), keyedEndpoint.getValue())); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index 4dec7ead19f2..71418773de47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -23,37 +23,40 @@ import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; @Internal @ThreadSafe -// TODO (m-trieu): replace Supplier with Stream after github.com/apache/beam/pull/32774/ is -// merged final class GlobalDataStreamSender implements Closeable, Supplier { private final Endpoint endpoint; - private final Supplier delegate; + private final GetDataStream delegate; private volatile boolean started; - GlobalDataStreamSender(Supplier delegate, Endpoint endpoint) { - // Ensures that the Supplier is thread-safe - this.delegate = Suppliers.memoize(delegate::get); + GlobalDataStreamSender(GetDataStream delegate, Endpoint endpoint) { + this.delegate = delegate; this.started = false; this.endpoint = endpoint; } @Override public GetDataStream get() { + if (!started) { + startStream(); + } + + return delegate; + } + + private synchronized void startStream() { + // Check started again after we acquire the lock. if (!started) { started = true; + delegate.start(); } - return delegate.get(); } @Override public void close() { - if (started) { - delegate.get().shutdown(); - } + delegate.shutdown(); } Endpoint endpoint() { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index aa5a50db13b3..d00dff70d171 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -90,7 +90,13 @@ public abstract class AbstractWindmillStream implements Win private final int logEveryNStreamFailures; private final String backendWorkerToken; private final ResettableRequestObserver requestObserver; + + /** Guards {@link #start()} and {@link #shutdown()} methods. */ + private final Object shutdownLock = new Object(); + + /** Reads are lock free, writes are guarded by shutdownLock. */ private final AtomicBoolean isShutdown; + private final AtomicBoolean started; private final AtomicReference shutdownTime; @@ -193,6 +199,7 @@ protected final void send(RequestT request) { } if (streamClosed.get()) { + // TODO(m-trieu): throw a more specific exception here (i.e StreamClosedException) throw new IllegalStateException("Send called on a client closed stream."); } @@ -201,7 +208,7 @@ protected final void send(RequestT request) { requestObserver.onNext(request); } catch (StreamObserverCancelledException e) { if (isShutdown()) { - logger.debug("Stream was closed or shutdown during send.", e); + logger.debug("Stream was shutdown during send.", e); return; } @@ -212,10 +219,12 @@ protected final void send(RequestT request) { @Override public final void start() { - if (!isShutdown.get() && started.compareAndSet(false, true)) { - // start() should only be executed once during the lifetime of the stream for idempotency and - // when shutdown() has not been called. - startStream(); + synchronized (shutdownLock) { + if (!isShutdown.get() && started.compareAndSet(false, true)) { + // start() should only be executed once during the lifetime of the stream for idempotency + // and when shutdown() has not been called. + startStream(); + } } } @@ -248,7 +257,8 @@ private void startStream() { } catch (InterruptedException ie) { Thread.currentThread().interrupt(); logger.info( - "Interrupted during stream creation backoff. The stream will not be created."); + "Interrupted during {} creation backoff. The stream will not be created.", + getClass()); break; } catch (IOException ioe) { // Keep trying to create the stream. @@ -324,6 +334,7 @@ public final void appendSummaryHtml(PrintWriter writer) { @Override public final synchronized void halfClose() { + // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. clientClosed.set(true); requestObserver.onCompleted(); streamClosed.set(true); @@ -346,13 +357,15 @@ public String backendWorkerToken() { @Override public final void shutdown() { - // Don't lock here as isShutdown checks are used in the stream to free blocked + // Don't lock on "this" as isShutdown checks are used in the stream to free blocked // threads or as exit conditions to loops. - if (isShutdown.compareAndSet(false, true)) { - requestObserver() - .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); - shutdownInternal(); - shutdownTime.set(DateTime.now()); + synchronized (shutdownLock) { + if (isShutdown.compareAndSet(false, true)) { + shutdownTime.set(DateTime.now()); + requestObserver() + .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + shutdownInternal(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 7fcbb80631a8..eeee29829154 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -206,7 +206,7 @@ protected void startThrottleTimer() { commitWorkThrottleTimer.start(); } - private void flushInternal(Map requests) throws InterruptedException { + private void flushInternal(Map requests) { if (requests.isEmpty()) { return; } @@ -363,8 +363,6 @@ public void flush() { if (!isShutdown()) { flushInternal(queue); } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); } finally { queuedBytes = 0; queue.clear(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index d2fe8395b109..fb19289ce77e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -56,6 +56,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.VerifyException; import org.joda.time.Instant; import org.slf4j.Logger; @@ -403,6 +404,15 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept verify(batch == batches.peekFirst()); batch.markFinalized(); } + trySendBatch(batch); + } else { + // Wait for this batch to be sent before parsing the response. + batch.waitForSendOrFailNotification(); + } + } + + void trySendBatch(QueuedBatch batch) { + try { sendBatch(batch.requests()); synchronized (batches) { verify(batch == batches.pollFirst()); @@ -410,9 +420,9 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). batch.notifySent(); - } else { - // Wait for this batch to be sent before parsing the response. - batch.waitForSendOrFailNotification(); + } catch (Exception e) { + LOG.error("Error occurred sending batch.", e); + batch.notifyFailed(); } } @@ -423,6 +433,7 @@ private void sendBatch(List requests) { // Synchronization of pending inserts is necessary with send to ensure duplicates are not // sent on stream reconnect. for (QueuedRequest request : requests) { + Preconditions.checkState(!isShutdown(), "Cannot send on shutdown stream."); // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. verify(pending.put(request.id(), request.getResponseStream()) == null); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index fc38343afb98..511b4a6b07bd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -198,5 +198,7 @@ protected void startThrottleTimer() { } @Override - public void setBudget(GetWorkBudget newBudget) {} + public void setBudget(GetWorkBudget newBudget) { + // no-op + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index 1aaad7c3c95d..647d45adebb4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -101,7 +101,8 @@ public void testStartStream_startsAllStreams() { .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST.toBuilder() + GET_WORK_REQUEST + .toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), @@ -132,7 +133,8 @@ public void testStartStream_onlyStartsStreamsOnce() { .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST.toBuilder() + GET_WORK_REQUEST + .toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), @@ -168,7 +170,8 @@ public void testStartStream_onlyStartsStreamsOnceConcurrent() throws Interrupted .createDirectGetWorkStream( eq(connection), eq( - GET_WORK_REQUEST.toBuilder() + GET_WORK_REQUEST + .toBuilder() .setMaxItems(itemBudget) .setMaxBytes(byteBudget) .build()), From 219c880d469f48c3d76bf1b96efa77bdd090a367 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Fri, 18 Oct 2024 03:00:09 -0500 Subject: [PATCH 05/23] address PR comments --- .../FanOutStreamingEngineWorkerHarness.java | 40 +++++++++---------- .../harness/GlobalDataStreamSender.java | 9 ++--- .../streaming/harness/StreamSender.java | 22 ++++++++++ .../harness/WindmillStreamSender.java | 9 ++--- .../client/AbstractWindmillStream.java | 28 ++++++------- .../client/grpc/GrpcDirectGetWorkStream.java | 3 +- .../client/grpc/GrpcGetDataStream.java | 11 ++++- 7 files changed, 74 insertions(+), 48 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index a43f03f6ff53..b42223844b26 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,7 +20,6 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; -import java.io.Closeable; import java.util.HashSet; import java.util.Map.Entry; import java.util.NoSuchElementException; @@ -32,10 +31,10 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; @@ -102,8 +101,6 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker /** Writes are guarded by synchronization, reads are lock free. */ private final AtomicReference backends; - private final GetWorkerMetadataStream getWorkerMetadataStream; - @GuardedBy("this") private long activeMetadataVersion; @@ -113,6 +110,9 @@ public final class FanOutStreamingEngineWorkerHarness implements StreamingWorker @GuardedBy("this") private boolean started; + @GuardedBy("this") + private @Nullable GetWorkerMetadataStream getWorkerMetadataStream = null; + private FanOutStreamingEngineWorkerHarness( JobHeader jobHeader, GetWorkBudget totalGetWorkBudget, @@ -142,14 +142,6 @@ private FanOutStreamingEngineWorkerHarness( this.totalGetWorkBudget = totalGetWorkBudget; this.activeMetadataVersion = Long.MIN_VALUE; this.workCommitterFactory = workCommitterFactory; - // To satisfy CheckerFramework complaining about reference to "this" in constructor. - @SuppressWarnings("methodref.receiver.bound") - Consumer newEndpointsConsumer = this::consumeWorkerMetadata; - this.getWorkerMetadataStream = - streamFactory.createGetWorkerMetadataStream( - dispatcherClient::getWindmillMetadataServiceStubBlocking, - getWorkerMetadataThrottleTimer, - newEndpointsConsumer); } /** @@ -208,6 +200,11 @@ static FanOutStreamingEngineWorkerHarness forTesting( @Override public synchronized void start() { Preconditions.checkState(!started, "FanOutStreamingEngineWorkerHarness cannot start twice."); + getWorkerMetadataStream = + streamFactory.createGetWorkerMetadataStream( + dispatcherClient::getWindmillMetadataServiceStubBlocking, + getWorkerMetadataThrottleTimer, + this::consumeWorkerMetadata); getWorkerMetadataStream.start(); started = true; } @@ -227,7 +224,7 @@ public ImmutableSet currentWindmillEndpoints() { */ private GetDataStream getGlobalDataStream(String globalDataKey) { return Optional.ofNullable(backends.get().globalDataStreams().get(globalDataKey)) - .map(GlobalDataStreamSender::get) + .map(GlobalDataStreamSender::stream) .orElseThrow( () -> new NoSuchElementException("No endpoint for global data tag: " + globalDataKey)); } @@ -322,7 +319,7 @@ private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); } - private void closeStreamSender(Endpoint endpoint, Closeable sender) { + private void closeStreamSender(Endpoint endpoint, StreamSender sender) { LOG.debug("Closing streams to endpoint={}, sender={}", endpoint, sender); try { sender.close(); @@ -348,13 +345,14 @@ private void closeStreamSender(Endpoint endpoint, Closeable sender) { private CompletionStage> getOrCreateWindmillStreamSenderFuture( Endpoint endpoint, ImmutableMap currentStreams) { - return MoreFutures.supplyAsync( - () -> - Pair.of( - endpoint, - Optional.ofNullable(currentStreams.get(endpoint)) - .orElseGet(() -> createAndStartWindmillStreamSender(endpoint))), - windmillStreamManager); + return Optional.ofNullable(currentStreams.get(endpoint)) + .map(backend -> CompletableFuture.completedFuture(Pair.of(endpoint, backend))) + .orElseGet( + () -> + MoreFutures.supplyAsync( + () -> Pair.of(endpoint, createAndStartWindmillStreamSender(endpoint)), + windmillStreamManager) + .toCompletableFuture()); } /** Add up all the throttle times of all streams including GetWorkerMetadataStream. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index 71418773de47..635482345807 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; -import java.io.Closeable; -import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints.Endpoint; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; @@ -26,7 +24,7 @@ @Internal @ThreadSafe -final class GlobalDataStreamSender implements Closeable, Supplier { +final class GlobalDataStreamSender implements StreamSender { private final Endpoint endpoint; private final GetDataStream delegate; private volatile boolean started; @@ -37,9 +35,10 @@ final class GlobalDataStreamSender implements Closeable, Supplier this.endpoint = endpoint; } - @Override - public GetDataStream get() { + GetDataStream stream() { if (!started) { + // Starting the stream possibly perform IO. Start the stream lazily since not all pipeline + // implementations need to fetch global/side input data. startStream(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java new file mode 100644 index 000000000000..40a63571620f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/StreamSender.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming.harness; + +interface StreamSender { + void close(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 5641ef6f648c..2a363885c7c5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; -import java.io.Closeable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -55,7 +54,7 @@ */ @Internal @ThreadSafe -final class WindmillStreamSender implements GetWorkBudgetSpender, Closeable { +final class WindmillStreamSender implements GetWorkBudgetSpender, StreamSender { private static final String STREAM_STARTER_THREAD_NAME = "StartWindmillStreamThread-%d"; private final AtomicBoolean started; private final AtomicReference getWorkBudget; @@ -149,10 +148,10 @@ public synchronized void close() { @Override public void setBudget(long items, long bytes) { - GetWorkBudget adjustment = GetWorkBudget.builder().setItems(items).setBytes(bytes).build(); - getWorkBudget.set(adjustment); + GetWorkBudget budget = GetWorkBudget.builder().setItems(items).setBytes(bytes).build(); + getWorkBudget.set(budget); if (started.get()) { - getWorkStream.setBudget(adjustment); + getWorkStream.setBudget(budget); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index d00dff70d171..5573844a571d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -94,10 +94,6 @@ public abstract class AbstractWindmillStream implements Win /** Guards {@link #start()} and {@link #shutdown()} methods. */ private final Object shutdownLock = new Object(); - /** Reads are lock free, writes are guarded by shutdownLock. */ - private final AtomicBoolean isShutdown; - - private final AtomicBoolean started; private final AtomicReference shutdownTime; /** @@ -107,6 +103,8 @@ public abstract class AbstractWindmillStream implements Win private final AtomicBoolean streamClosed; private final Logger logger; + private volatile boolean isShutdown; + private volatile boolean started; protected AbstractWindmillStream( Logger logger, @@ -128,8 +126,8 @@ protected AbstractWindmillStream( this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; this.clientClosed = new AtomicBoolean(); - this.isShutdown = new AtomicBoolean(false); - this.started = new AtomicBoolean(false); + this.isShutdown = false; + this.started = false; this.streamClosed = new AtomicBoolean(false); this.startTimeMs = new AtomicLong(); this.lastSendTimeMs = new AtomicLong(); @@ -179,7 +177,7 @@ private static long debugDuration(long nowMs, long startMs) { /** Reflects that {@link #shutdown()} was explicitly called. */ protected boolean isShutdown() { - return isShutdown.get(); + return isShutdown; } private StreamObserver requestObserver() { @@ -194,7 +192,7 @@ private StreamObserver requestObserver() { /** Send a request to the server. */ protected final void send(RequestT request) { synchronized (this) { - if (isShutdown()) { + if (isShutdown) { return; } @@ -207,7 +205,7 @@ protected final void send(RequestT request) { lastSendTimeMs.set(Instant.now().getMillis()); requestObserver.onNext(request); } catch (StreamObserverCancelledException e) { - if (isShutdown()) { + if (isShutdown) { logger.debug("Stream was shutdown during send.", e); return; } @@ -220,10 +218,11 @@ protected final void send(RequestT request) { @Override public final void start() { synchronized (shutdownLock) { - if (!isShutdown.get() && started.compareAndSet(false, true)) { + if (!isShutdown && !started) { // start() should only be executed once during the lifetime of the stream for idempotency // and when shutdown() has not been called. startStream(); + started = true; } } } @@ -235,7 +234,7 @@ private void startStream() { while (true) { try { synchronized (this) { - if (isShutdown.get()) { + if (isShutdown) { break; } startTimeMs.set(Instant.now().getMillis()); @@ -322,7 +321,7 @@ public final void appendSummaryHtml(PrintWriter writer) { debugDuration(nowMs, lastSendTimeMs.get()), debugDuration(nowMs, lastResponseTimeMs.get()), streamClosed.get(), - isShutdown.get(), + isShutdown, shutdownTime.get()); } @@ -360,7 +359,8 @@ public final void shutdown() { // Don't lock on "this" as isShutdown checks are used in the stream to free blocked // threads or as exit conditions to loops. synchronized (shutdownLock) { - if (isShutdown.compareAndSet(false, true)) { + if (!isShutdown) { + isShutdown = true; shutdownTime.set(DateTime.now()); requestObserver() .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); @@ -517,7 +517,7 @@ private void recordStreamStatus(Status status) { /** Returns true if the stream was torn down and should not be restarted internally. */ private synchronized boolean maybeTeardownStream() { - if (isShutdown() || (clientClosed.get() && !hasPendingRequests())) { + if (isShutdown || (clientClosed.get() && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 8fbcb1479d50..be9d6c6d06d6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -204,7 +204,8 @@ protected synchronized void onNewStream() { StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( - requestHeader.toBuilder() + requestHeader + .toBuilder() .setMaxItems(initialGetWorkBudget.items()) .setMaxBytes(initialGetWorkBudget.bytes()) .build()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index fb19289ce77e..cd0c67a2cb26 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -355,8 +355,6 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn requests) { } catch (IllegalStateException e) { // The stream broke before this call went through; onNewStream will retry the fetch. LOG.warn("GetData stream broke before call started.", e); + } finally { + if (isShutdown()) { + // Stream was shutdown during send, clear all the pending requests. + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + } } } } From 6b79fad409774f9659fb17c02062ffdbea7379e8 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Mon, 21 Oct 2024 16:03:35 -0700 Subject: [PATCH 06/23] address PR comments --- .../client/grpc/GrpcCommitWorkStream.java | 93 ++++++++++--------- .../client/grpc/GrpcGetDataStream.java | 37 ++++---- .../StreamingEngineWorkCommitterTest.java | 31 ++++++- .../client/grpc/GrpcCommitWorkStreamTest.java | 16 ++-- 4 files changed, 107 insertions(+), 70 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index eeee29829154..5d37c98f11cf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -196,7 +196,7 @@ protected void shutdownInternal() { Iterator pendingRequests = pending.values().iterator(); while (pendingRequests.hasNext()) { PendingRequest pendingRequest = pendingRequests.next(); - pendingRequest.completeWithStatus(CommitStatus.ABORTED); + pendingRequest.abort(); pendingRequests.remove(); } } @@ -332,6 +332,43 @@ private void completeWithStatus(CommitStatus commitStatus) { private long shardingKey() { return request().getShardingKey(); } + + private void abort() { + completeWithStatus(CommitStatus.ABORTED); + } + } + + private static class CommitCompletionException extends RuntimeException { + private static final int MAX_PRINTABLE_ERRORS = 10; + private final Map>, Integer> errorCounter; + private final EvictingQueue detailedErrors; + + private CommitCompletionException() { + super("Exception while processing commit response."); + this.errorCounter = new HashMap<>(); + this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); + } + + private void recordError(CommitStatus commitStatus, Throwable error) { + errorCounter.compute( + Pair.of(commitStatus, error.getClass()), + (ignored, current) -> current == null ? 1 : current + 1); + detailedErrors.add(error); + } + + private boolean hasErrors() { + return !errorCounter.isEmpty(); + } + + @Override + public final String getMessage() { + return "CommitCompletionException{" + + "errorCounter=" + + errorCounter + + ", detailedErrors=" + + detailedErrors + + '}'; + } } private class Batcher implements CommitWorkStream.RequestBatcher { @@ -347,7 +384,7 @@ private Batcher() { @Override public boolean commitWorkItem( String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { - if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length()) || isShutdown()) { return false; } @@ -362,6 +399,8 @@ public void flush() { try { if (!isShutdown()) { flushInternal(queue); + } else { + queue.forEach((ignored, request) -> request.onDone().accept(CommitStatus.ABORTED)); } } finally { queuedBytes = 0; @@ -370,49 +409,19 @@ public void flush() { } void add(long id, PendingRequest request) { - Preconditions.checkState(canAccept(request.getBytes())); - queuedBytes += request.getBytes(); - queue.put(id, request); + if (isShutdown()) { + request.abort(); + } else { + Preconditions.checkState(canAccept(request.getBytes())); + queuedBytes += request.getBytes(); + queue.put(id, request); + } } private boolean canAccept(long requestBytes) { - return !isShutdown() - && (queue.isEmpty() - || (queue.size() < streamingRpcBatchLimit - && (requestBytes + queuedBytes) < AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE)); - } - } - - private static class CommitCompletionException extends RuntimeException { - private static final int MAX_PRINTABLE_ERRORS = 10; - private final Map>, Integer> errorCounter; - private final EvictingQueue detailedErrors; - - private CommitCompletionException() { - super("Exception while processing commit response."); - this.errorCounter = new HashMap<>(); - this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); - } - - private void recordError(CommitStatus commitStatus, Throwable error) { - errorCounter.compute( - Pair.of(commitStatus, error.getClass()), - (ignored, current) -> current == null ? 1 : current + 1); - detailedErrors.add(error); - } - - private boolean hasErrors() { - return !errorCounter.isEmpty(); - } - - @Override - public final String getMessage() { - return "CommitCompletionException{" - + "errorCounter=" - + errorCounter - + ", detailedErrors=" - + detailedErrors - + '}'; + return queue.isEmpty() + || (queue.size() < streamingRpcBatchLimit + && (requestBytes + queuedBytes) < AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index cd0c67a2cb26..04ac533b585b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -79,6 +79,7 @@ final class GrpcGetDataStream // newer ComputationHeartbeatRequests. private final boolean sendKeyedGetDataRequests; private final Consumer> processHeartbeatResponses; + private final Object shutdownLock = new Object(); private GrpcGetDataStream( String backendWorkerToken, @@ -297,14 +298,16 @@ public void sendHealthCheck() { protected void shutdownInternal() { // Stream has been explicitly closed. Drain pending input streams and request batches. // Future calls to send RPCs will fail. - pending.values().forEach(AppendableInputStream::cancel); - pending.clear(); - batches.forEach( - batch -> { - batch.markFinalized(); - batch.notifyFailed(); - }); - batches.clear(); + synchronized (shutdownLock) { + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); + } } @Override @@ -433,23 +436,19 @@ private void sendBatch(List requests) { synchronized (this) { // Synchronization of pending inserts is necessary with send to ensure duplicates are not // sent on stream reconnect. - for (QueuedRequest request : requests) { - Preconditions.checkState(!isShutdown(), "Cannot send on shutdown stream."); - // Map#put returns null if there was no previous mapping for the key, meaning we have not - // seen it before. - verify(pending.put(request.id(), request.getResponseStream()) == null); + synchronized (shutdownLock) { + for (QueuedRequest request : requests) { + Preconditions.checkState(!isShutdown(), "Cannot send on shutdown stream."); + // Map#put returns null if there was no previous mapping for the key, meaning we have not + // seen it before. + verify(pending.put(request.id(), request.getResponseStream()) == null); + } } try { send(batchedRequest); } catch (IllegalStateException e) { // The stream broke before this call went through; onNewStream will retry the fetch. LOG.warn("GetData stream broke before call started.", e); - } finally { - if (isShutdown()) { - // Stream was shutdown during send, clear all the pending requests. - pending.values().forEach(AppendableInputStream::cancel); - pending.clear(); - } } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index e7317553af02..6bd08706cd50 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,23 +53,35 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { - + private static final String FAKE_SERVER_NAME = "Fake server for StreamingEngineWorkCommitterTest"; + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); + @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; + private ManagedChannel inProcessChannel; private static Work createMockWork(long workToken) { return Work.create( @@ -117,6 +129,23 @@ public void setUp() throws IOException { WindmillStreamPool.create( 1, Duration.standardMinutes(1), fakeWindmillServer::commitWorkStream) ::getCloseableStream; + Server server = + InProcessServerBuilder.forName(FAKE_SERVER_NAME) + .fallbackHandlerRegistry(serviceRegistry) + .directExecutor() + .build() + .start(); + + inProcessChannel = + grpcCleanup.register( + InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); + grpcCleanup.register(server); + grpcCleanup.register(inProcessChannel); + } + + @After + public void cleanUp() { + inProcessChannel.shutdownNow(); } private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 0563d69c0d9f..5a281a06dc28 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -53,7 +53,7 @@ @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { - private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetDataStreamTest"; + private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() .setJobId("test_job") @@ -97,7 +97,7 @@ public void cleanUp() { inProcessChannel.shutdownNow(); } - private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamStreamTestStub testStub) { + private GrpcCommitWorkStream createCommitWorkStream(CommitWorkStreamTestStub testStub) { serviceRegistry.addService(testStub); GrpcCommitWorkStream commitWorkStream = (GrpcCommitWorkStream) @@ -118,7 +118,7 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { TestCommitWorkStreamRequestObserver requestObserver = spy(new TestCommitWorkStreamRequestObserver()); - CommitWorkStreamStreamTestStub testStub = new CommitWorkStreamStreamTestStub(requestObserver); + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { @@ -147,8 +147,8 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { public void testCommitWorkItem_afterShutdownFalse() { int numCommits = 5; - CommitWorkStreamStreamTestStub testStub = - new CommitWorkStreamStreamTestStub(new TestCommitWorkStreamRequestObserver()); + CommitWorkStreamTestStub testStub = + new CommitWorkStreamTestStub(new TestCommitWorkStreamRequestObserver()); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { @@ -173,7 +173,7 @@ public void testSend_notCalledAfterShutdown() { TestCommitWorkStreamRequestObserver requestObserver = spy(new TestCommitWorkStreamRequestObserver()); - CommitWorkStreamStreamTestStub testStub = new CommitWorkStreamStreamTestStub(requestObserver); + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { @@ -210,12 +210,12 @@ public void onCompleted() { } } - private static class CommitWorkStreamStreamTestStub + private static class CommitWorkStreamTestStub extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { private final TestCommitWorkStreamRequestObserver requestObserver; private @Nullable StreamObserver responseObserver; - private CommitWorkStreamStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { + private CommitWorkStreamTestStub(TestCommitWorkStreamRequestObserver requestObserver) { this.requestObserver = requestObserver; } From 2426a6bbe47e62ff1c2d169a1b22f225dc6c1ce5 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Mon, 21 Oct 2024 17:20:35 -0700 Subject: [PATCH 07/23] guard stream shutdown behavior with shutdown lock --- .../client/AbstractWindmillStream.java | 12 ++++-- .../client/grpc/GrpcCommitWorkStream.java | 24 +++++++++-- .../client/grpc/GrpcGetDataStream.java | 30 +++++++------ .../grpc/GrpcDirectGetWorkStreamTest.java | 43 ++++++++++--------- 4 files changed, 69 insertions(+), 40 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 5573844a571d..57331d749267 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -75,6 +75,14 @@ public abstract class AbstractWindmillStream implements Win protected final AtomicBoolean clientClosed; protected final Sleeper sleeper; + + /** + * Used to guard {@link #start()} and {@link #shutdown()} behavior. + * + * @implNote Should not be held when performing IO. + */ + protected final Object shutdownLock = new Object(); + private final AtomicLong lastSendTimeMs; private final ExecutorService executor; private final BackOff backoff; @@ -90,10 +98,6 @@ public abstract class AbstractWindmillStream implements Win private final int logEveryNStreamFailures; private final String backendWorkerToken; private final ResettableRequestObserver requestObserver; - - /** Guards {@link #start()} and {@link #shutdown()} methods. */ - private final Object shutdownLock = new Object(); - private final AtomicReference shutdownTime; /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 5d37c98f11cf..826fb13b4db9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -234,7 +234,13 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { - pending.put(id, pendingRequest); + synchronized (shutdownLock) { + if (!isShutdown()) { + pending.put(id, pendingRequest); + } else { + return; + } + } try { send(chunk); } catch (IllegalStateException e) { @@ -260,7 +266,13 @@ private void issueBatchedRequest(Map requests) { } StreamingCommitWorkRequest request = requestBuilder.build(); synchronized (this) { - pending.putAll(requests); + synchronized (shutdownLock) { + if (!isShutdown()) { + pending.putAll(requests); + } else { + return; + } + } try { send(request); } catch (IllegalStateException e) { @@ -273,7 +285,13 @@ private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest checkNotNull(pendingRequest.computationId()); final ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { - pending.put(id, pendingRequest); + synchronized (shutdownLock) { + if (!isShutdown()) { + pending.put(id, pendingRequest); + } else { + return; + } + } for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 04ac533b585b..3b0bea713466 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -56,7 +56,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.VerifyException; import org.joda.time.Instant; import org.slf4j.Logger; @@ -79,7 +78,6 @@ final class GrpcGetDataStream // newer ComputationHeartbeatRequests. private final boolean sendKeyedGetDataRequests; private final Consumer> processHeartbeatResponses; - private final Object shutdownLock = new Object(); private GrpcGetDataStream( String backendWorkerToken, @@ -298,16 +296,14 @@ public void sendHealthCheck() { protected void shutdownInternal() { // Stream has been explicitly closed. Drain pending input streams and request batches. // Future calls to send RPCs will fail. - synchronized (shutdownLock) { - pending.values().forEach(AppendableInputStream::cancel); - pending.clear(); - batches.forEach( - batch -> { - batch.markFinalized(); - batch.notifyFailed(); - }); - batches.clear(); - } + pending.values().forEach(AppendableInputStream::cancel); + pending.clear(); + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); } @Override @@ -363,6 +359,9 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn requests) { // Synchronization of pending inserts is necessary with send to ensure duplicates are not // sent on stream reconnect. synchronized (shutdownLock) { + // shutdown() clears pending, once the stream is shutdown, prevent values from being added + // to it. + if (isShutdown()) { + throw new WindmillStreamShutdownException( + "Stream was closed when attempting to send " + requests.size() + " requests."); + } for (QueuedRequest request : requests) { - Preconditions.checkState(!isShutdown(), "Cannot send on shutdown stream."); // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. verify(pending.put(request.id(), request.getResponseStream()) == null); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java index fd2b30238836..8a37958700c9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -127,26 +127,29 @@ private GrpcDirectGetWorkStream createGetWorkStream( ThrottleTimer throttleTimer, WorkItemScheduler workItemScheduler) { serviceRegistry.addService(testStub); - return (GrpcDirectGetWorkStream) - GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) - .build() - .createDirectGetWorkStream( - WindmillConnection.builder() - .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) - .build(), - Windmill.GetWorkRequest.newBuilder() - .setClientId(TEST_JOB_HEADER.getClientId()) - .setJobId(TEST_JOB_HEADER.getJobId()) - .setProjectId(TEST_JOB_HEADER.getProjectId()) - .setWorkerId(TEST_JOB_HEADER.getWorkerId()) - .setMaxItems(initialGetWorkBudget.items()) - .setMaxBytes(initialGetWorkBudget.bytes()) - .build(), - throttleTimer, - mock(HeartbeatSender.class), - mock(GetDataClient.class), - mock(WorkCommitter.class), - workItemScheduler); + GrpcDirectGetWorkStream getWorkStream = + (GrpcDirectGetWorkStream) + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) + .build() + .createDirectGetWorkStream( + WindmillConnection.builder() + .setStub(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel)) + .build(), + Windmill.GetWorkRequest.newBuilder() + .setClientId(TEST_JOB_HEADER.getClientId()) + .setJobId(TEST_JOB_HEADER.getJobId()) + .setProjectId(TEST_JOB_HEADER.getProjectId()) + .setWorkerId(TEST_JOB_HEADER.getWorkerId()) + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build(), + throttleTimer, + mock(HeartbeatSender.class), + mock(GetDataClient.class), + mock(WorkCommitter.class), + workItemScheduler); + getWorkStream.start(); + return getWorkStream; } private Windmill.StreamingGetWorkResponseChunk createResponse(Windmill.WorkItem workItem) { From 9679382e488d6b952436bb7edd7735193d7e5fc3 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 22 Oct 2024 11:59:05 -0700 Subject: [PATCH 08/23] address PR comments around deadlocking, move WindmillStreamShutdownException to its own top level class --- .../client/AbstractWindmillStream.java | 31 ++---- .../WindmillStreamShutdownException.java | 25 +++++ .../client/getdata/StreamGetDataClient.java | 6 +- .../client/grpc/GrpcCommitWorkStream.java | 95 ++++++++++++------- .../client/grpc/GrpcGetDataStream.java | 51 ++++++---- .../refresh/FixedStreamHeartbeatSender.java | 4 +- .../client/grpc/GrpcCommitWorkStreamTest.java | 4 +- 7 files changed, 135 insertions(+), 81 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 57331d749267..63d02104b136 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -79,7 +79,8 @@ public abstract class AbstractWindmillStream implements Win /** * Used to guard {@link #start()} and {@link #shutdown()} behavior. * - * @implNote Should not be held when performing IO. + * @implNote Do not hold when performing IO. If also locking on {@code this} in the same context, + * should acquire shutdownLock first to prevent deadlocks. */ protected final Object shutdownLock = new Object(); @@ -184,15 +185,6 @@ protected boolean isShutdown() { return isShutdown; } - private StreamObserver requestObserver() { - if (requestObserver == null) { - throw new NullPointerException( - "requestObserver cannot be null. Missing a call to start() to initialize stream."); - } - - return requestObserver; - } - /** Send a request to the server. */ protected final void send(RequestT request) { synchronized (this) { @@ -221,14 +213,17 @@ protected final void send(RequestT request) { @Override public final void start() { + boolean shouldStartStream = false; synchronized (shutdownLock) { if (!isShutdown && !started) { - // start() should only be executed once during the lifetime of the stream for idempotency - // and when shutdown() has not been called. - startStream(); started = true; + shouldStartStream = true; } } + + if (shouldStartStream) { + startStream(); + } } /** Starts the underlying stream. */ @@ -366,8 +361,8 @@ public final void shutdown() { if (!isShutdown) { isShutdown = true; shutdownTime.set(DateTime.now()); - requestObserver() - .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + requestObserver.onError( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); shutdownInternal(); } } @@ -380,12 +375,6 @@ private void recordRestartReason(String error) { protected abstract void shutdownInternal(); - public static class WindmillStreamShutdownException extends RuntimeException { - public WindmillStreamShutdownException(String message) { - super(message); - } - } - /** * Request observer that allows resetting its internal delegate using the given {@link * #requestObserverSupplier}. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java new file mode 100644 index 000000000000..5f4387d6111f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +/** Thrown when operations are requested on a {@link WindmillStream} has been shutdown/closed. */ +public final class WindmillStreamShutdownException extends RuntimeException { + public WindmillStreamShutdownException(String message) { + super(message); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index c8e058e7e230..ab12946ad18b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -21,8 +21,8 @@ import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; /** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */ @@ -61,7 +61,7 @@ public Windmill.KeyedGetDataResponse getStateData( String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { return getDataStream.requestKeyedData(computationId, request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(request.getShardingKey()); } catch (Exception e) { throw new GetDataException( @@ -86,7 +86,7 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) sideInputGetDataStreamFactory.apply(request.getDataId().getTag()); try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { throw new WorkItemCancelledException(e); } catch (Exception e) { throw new GetDataException( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 826fb13b4db9..e4c8947b38fd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -233,19 +233,16 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { .setShardingKey(pendingRequest.shardingKey()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); - synchronized (this) { - synchronized (shutdownLock) { - if (!isShutdown()) { - pending.put(id, pendingRequest); - } else { - return; - } - } - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } + if (shouldCancelRequest(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + + try { + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } } @@ -265,33 +262,28 @@ private void issueBatchedRequest(Map requests) { .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); - synchronized (this) { - synchronized (shutdownLock) { - if (!isShutdown()) { - pending.putAll(requests); - } else { - return; - } - } - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } + + if (shouldCancelRequest(requests)) { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; + } + + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. } } - private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { + private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { checkNotNull(pendingRequest.computationId()); - final ByteString serializedCommit = pendingRequest.serializedCommit(); + ByteString serializedCommit = pendingRequest.serializedCommit(); + if (shouldCancelRequest(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + synchronized (this) { - synchronized (shutdownLock) { - if (!isShutdown()) { - pending.put(id, pendingRequest); - } else { - return; - } - } for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { @@ -321,6 +313,32 @@ private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest } } + private boolean shouldCancelRequest(long id, PendingRequest request) { + synchronized (shutdownLock) { + synchronized (this) { + if (!isShutdown()) { + pending.put(id, request); + return false; + } + + return true; + } + } + } + + private boolean shouldCancelRequest(Map requests) { + synchronized (shutdownLock) { + synchronized (this) { + if (!isShutdown()) { + pending.putAll(requests); + return false; + } + + return true; + } + } + } + @AutoValue abstract static class PendingRequest { @@ -402,6 +420,11 @@ private Batcher() { @Override public boolean commitWorkItem( String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { + if (isShutdown()) { + onDone.accept(CommitStatus.ABORTED); + return false; + } + if (!canAccept(commitRequest.getSerializedSize() + computation.length()) || isShutdown()) { return false; } @@ -418,7 +441,7 @@ public void flush() { if (!isShutdown()) { flushInternal(queue); } else { - queue.forEach((ignored, request) -> request.onDone().accept(CommitStatus.ABORTED)); + queue.forEach((ignored, request) -> request.abort()); } } finally { queuedBytes = 0; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 3b0bea713466..c2aa965de7c3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -50,6 +50,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -198,7 +199,8 @@ public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataReq return issueRequest( QueuedRequest.forComputation(uniqueId(), computation, request), KeyedGetDataResponse::parseFrom); - } catch (WindmillStreamShutdownException e) { + } catch ( + org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) { throw new WorkItemCancelledException(request.getShardingKey()); } } @@ -207,7 +209,8 @@ public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataReq public GlobalData requestGlobalData(GlobalDataRequest request) { try { return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); - } catch (WindmillStreamShutdownException e) { + } catch ( + org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) { throw new WorkItemCancelledException( "SideInput fetch failed for request due to stream shutdown: " + request, e); } @@ -216,7 +219,8 @@ public GlobalData requestGlobalData(GlobalDataRequest request) { @Override public void refreshActiveWork(Map> heartbeats) { if (isShutdown()) { - throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + throw new org.apache.beam.runners.dataflow.worker.windmill.client + .WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); } StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); @@ -354,18 +358,24 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn requests) { + if (requests.isEmpty()) { + return; + } + StreamingGetDataRequest batchedRequest = flushToBatch(requests); - synchronized (this) { + synchronized (shutdownLock) { // Synchronization of pending inserts is necessary with send to ensure duplicates are not // sent on stream reconnect. - synchronized (shutdownLock) { + synchronized (this) { // shutdown() clears pending, once the stream is shutdown, prevent values from being added // to it. if (isShutdown()) { @@ -448,12 +462,13 @@ private void sendBatch(List requests) { verify(pending.put(request.id(), request.getResponseStream()) == null); } } - try { - send(batchedRequest); - } catch (IllegalStateException e) { - // The stream broke before this call went through; onNewStream will retry the fetch. - LOG.warn("GetData stream broke before call started.", e); - } + } + + try { + send(batchedRequest); + } catch (IllegalStateException e) { + // The stream broke before this call went through; onNewStream will retry the fetch. + LOG.warn("GetData stream broke before call started.", e); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java index 33a55d1927f8..ed5f2db7f480 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java @@ -20,8 +20,8 @@ import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.sdk.annotations.Internal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public void sendHeartbeats(Heartbeats heartbeats) { Thread.currentThread().setName(originalThreadName + "-" + backendWorkerToken); } getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); - } catch (AbstractWindmillStream.WindmillStreamShutdownException e) { + } catch (WindmillStreamShutdownException e) { LOG.warn( "Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker." + " heartbeats", diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 5a281a06dc28..1ee16d774b0d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -158,10 +158,12 @@ public void testCommitWorkItem_afterShutdownFalse() { } commitWorkStream.shutdown(); + Set commitStatuses = new HashSet<>(); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { assertFalse( - batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), ignored -> {})); + batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add)); + assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED); } } } From 6e5ba0c686e216b16e326b97e9970eaa2d59401e Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 23 Oct 2024 16:43:25 -0700 Subject: [PATCH 09/23] address PR comments --- .../harness/GlobalDataStreamSender.java | 2 +- .../harness/WindmillStreamSender.java | 14 +-- .../client/AbstractWindmillStream.java | 16 +++- .../client/grpc/GrpcCommitWorkStream.java | 41 +++++---- .../client/grpc/GrpcGetDataStream.java | 91 ++++++++----------- .../grpc/GrpcGetDataStreamRequests.java | 62 +++++++++---- .../client/grpc/GrpcWindmillServer.java | 12 +++ .../grpc/GrpcWindmillStreamFactory.java | 23 ++--- .../grpc/WindmillStreamShutdownException.java | 35 ------- .../harness/WindmillStreamSenderTest.java | 56 ++++++++++++ .../StreamingEngineWorkCommitterTest.java | 26 ------ .../client/grpc/GrpcCommitWorkStreamTest.java | 15 ++- .../client/grpc/GrpcGetDataStreamTest.java | 6 +- .../grpc/GrpcGetWorkerMetadataStreamTest.java | 21 ++--- .../client/grpc/GrpcWindmillServerTest.java | 21 +++-- 15 files changed, 244 insertions(+), 197 deletions(-) delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java index 635482345807..d590e69c17d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/GlobalDataStreamSender.java @@ -48,8 +48,8 @@ GetDataStream stream() { private synchronized void startStream() { // Check started again after we acquire the lock. if (!started) { - started = true; delegate.start(); + started = true; } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 2a363885c7c5..89aaa0d8b640 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -125,6 +127,7 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB synchronized void start() { if (!started.get()) { + checkState(!streamStarter.isShutdown(), "WindmillStreamSender has already been shutdown."); // Start these 3 streams in parallel since they each may perform blocking IO. CompletableFuture.allOf( CompletableFuture.runAsync(getWorkStream::start, streamStarter), @@ -138,12 +141,11 @@ synchronized void start() { @Override public synchronized void close() { - if (started.get()) { - getWorkStream.shutdown(); - getDataStream.shutdown(); - workCommitter.stop(); - commitWorkStream.shutdown(); - } + streamStarter.shutdownNow(); + getWorkStream.shutdown(); + getDataStream.shutdown(); + workCommitter.stop(); + commitWorkStream.shutdown(); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 63d02104b136..ca54ade05f0f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -65,6 +65,9 @@ * block. This is generally not a problem since streams are used in a single-threaded manner. * However, some accessors used for status page and other debugging need to take care not to require * synchronizing on this. + * + *

{@link #start()} and {@link #shutdown()} are called once in the lifetime of the stream. Once + * {@link #shutdown()}, a stream in considered invalid and cannot be restarted/reused. */ public abstract class AbstractWindmillStream implements WindmillStream { @@ -361,9 +364,13 @@ public final void shutdown() { if (!isShutdown) { isShutdown = true; shutdownTime.set(DateTime.now()); - requestObserver.onError( - new WindmillStreamShutdownException("Explicit call to shutdown stream.")); - shutdownInternal(); + if (started) { + // requestObserver is not set until the first startStream() is called. If the stream was + // never started there is nothing to clean up internally. + requestObserver.onError( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + shutdownInternal(); + } } } } @@ -378,6 +385,9 @@ private void recordRestartReason(String error) { /** * Request observer that allows resetting its internal delegate using the given {@link * #requestObserverSupplier}. + * + * @implNote {@link StreamObserver}s generated by {@link * #requestObserverSupplier} are expected + * to be {@link ThreadSafe}. */ @ThreadSafe private static class ResettableRequestObserver implements StreamObserver { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index e4c8947b38fd..f49da5a4f001 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -225,6 +225,11 @@ private void flushInternal(Map requests) { } private void issueSingleRequest(long id, PendingRequest pendingRequest) { + if (isPrepareForSendFailed(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() @@ -233,11 +238,6 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { .setShardingKey(pendingRequest.shardingKey()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); - if (shouldCancelRequest(id, pendingRequest)) { - pendingRequest.abort(); - return; - } - try { send(chunk); } catch (IllegalStateException e) { @@ -247,6 +247,11 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { } private void issueBatchedRequest(Map requests) { + if (isPrepareForSendFailed(requests)) { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; + } + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); String lastComputation = null; for (Map.Entry entry : requests.entrySet()) { @@ -262,12 +267,6 @@ private void issueBatchedRequest(Map requests) { .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); - - if (shouldCancelRequest(requests)) { - requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); - return; - } - try { send(request); } catch (IllegalStateException e) { @@ -276,13 +275,13 @@ private void issueBatchedRequest(Map requests) { } private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { - checkNotNull(pendingRequest.computationId()); - ByteString serializedCommit = pendingRequest.serializedCommit(); - if (shouldCancelRequest(id, pendingRequest)) { + if (isPrepareForSendFailed(id, pendingRequest)) { pendingRequest.abort(); return; } + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { for (int i = 0; i < serializedCommit.size(); @@ -313,7 +312,11 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { } } - private boolean shouldCancelRequest(long id, PendingRequest request) { + /** + * Returns true if the request should be failed due to stream shutdown, else tracks the request to + * be sent and returns false. + */ + private boolean isPrepareForSendFailed(long id, PendingRequest request) { synchronized (shutdownLock) { synchronized (this) { if (!isShutdown()) { @@ -326,7 +329,11 @@ private boolean shouldCancelRequest(long id, PendingRequest request) { } } - private boolean shouldCancelRequest(Map requests) { + /** + * Returns true if the request should be failed due to stream shutdown, else tracks the requests + * to be sent and returns false. + */ + private boolean isPrepareForSendFailed(Map requests) { synchronized (shutdownLock) { synchronized (this) { if (!isShutdown()) { @@ -425,7 +432,7 @@ public boolean commitWorkItem( return false; } - if (!canAccept(commitRequest.getSerializedSize() + computation.length()) || isShutdown()) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index c2aa965de7c3..a840d4dfbcfd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; import java.io.IOException; import java.io.InputStream; @@ -35,7 +34,7 @@ import java.util.function.Consumer; import java.util.function.Function; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatRequest; @@ -57,11 +56,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.VerifyException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +@ThreadSafe final class GrpcGetDataStream extends AbstractWindmillStream implements GetDataStream { @@ -153,7 +153,7 @@ protected synchronized void onNewStream() { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. - verify(!hasPendingRequests()); + verify(!hasPendingRequests(), "Pending requests not expected on stream restart."); } else { for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); @@ -167,7 +167,6 @@ protected boolean hasPendingRequests() { } @Override - @SuppressWarnings("dereference.of.nullable") protected void onResponse(StreamingGetDataResponse chunk) { checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); checkArgument(chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); @@ -195,32 +194,20 @@ private long uniqueId() { @Override public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { - try { - return issueRequest( - QueuedRequest.forComputation(uniqueId(), computation, request), - KeyedGetDataResponse::parseFrom); - } catch ( - org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) { - throw new WorkItemCancelledException(request.getShardingKey()); - } + return issueRequest( + QueuedRequest.forComputation(uniqueId(), computation, request), + KeyedGetDataResponse::parseFrom); } @Override public GlobalData requestGlobalData(GlobalDataRequest request) { - try { - return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); - } catch ( - org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException e) { - throw new WorkItemCancelledException( - "SideInput fetch failed for request due to stream shutdown: " + request, e); - } + return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); } @Override public void refreshActiveWork(Map> heartbeats) { if (isShutdown()) { - throw new org.apache.beam.runners.dataflow.worker.windmill.client - .WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); } StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); @@ -302,12 +289,14 @@ protected void shutdownInternal() { // Future calls to send RPCs will fail. pending.values().forEach(AppendableInputStream::cancel); pending.clear(); - batches.forEach( - batch -> { - batch.markFinalized(); - batch.notifyFailed(); - }); - batches.clear(); + synchronized (batches) { + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); + } } @Override @@ -340,9 +329,7 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn ResponseT issueRequest(QueuedRequest request, ParseFn= streamingRpcBatchLimit + || batch.requestsCount() >= streamingRpcBatchLimit || batch.byteSize() + request.byteSize() > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { if (batch != null) { prevBatch = batch; @@ -411,7 +389,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for its completion. synchronized (batches) { - verify(batch == batches.peekFirst()); + verify(batch == batches.peekFirst(), "GetDataStream request batch removed before send()."); batch.markFinalized(); } trySendBatch(batch); @@ -423,9 +401,11 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept void trySendBatch(QueuedBatch batch) { try { - sendBatch(batch.requests()); + sendBatch(batch.sortedRequestsReadOnly()); synchronized (batches) { - verify(batch == batches.pollFirst()); + verify( + batch == batches.pollFirst(), + "Sent GetDataStream request batch removed before send() was complete."); } // Notify all waiters with requests in this batch as well as the sender // of the next batch (if one exists). @@ -439,7 +419,6 @@ void trySendBatch(QueuedBatch batch) { } } - @SuppressWarnings("NullableProblems") private void sendBatch(List requests) { if (requests.isEmpty()) { return; @@ -459,7 +438,9 @@ private void sendBatch(List requests) { for (QueuedRequest request : requests) { // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. - verify(pending.put(request.id(), request.getResponseStream()) == null); + verify( + pending.put(request.id(), request.getResponseStream()) == null, + "Request already sent."); } } } @@ -472,11 +453,7 @@ private void sendBatch(List requests) { } } - @SuppressWarnings("argument") private StreamingGetDataRequest flushToBatch(List requests) { - // Put all global data requests first because there is only a single repeated field for - // request ids and the initial ids correspond to global data requests if they are present. - requests.sort(QueuedRequest.globalRequestsFirst()); StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); for (QueuedRequest request : requests) { request.addToStreamingGetDataRequest(builder); @@ -484,6 +461,10 @@ private StreamingGetDataRequest flushToBatch(List requests) { return builder.build(); } + private void verify(boolean condition, String message) { + Verify.verify(condition || isShutdown(), message); + } + @FunctionalInterface private interface ParseFn { ResponseT parse(InputStream input) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index 53dd4a5b5294..27c960448a23 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -21,6 +21,7 @@ import com.google.auto.value.AutoOneOf; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -29,6 +30,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,10 @@ final class GrpcGetDataStreamRequests { private GrpcGetDataStreamRequests() {} + private static String debugFormat(long value) { + return String.format("%016x", value); + } + static class QueuedRequest { private final long id; private final ComputationOrGlobalDataRequest dataRequest; @@ -104,6 +110,7 @@ void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder build } } + /** Represents a batch of queued requests. Methods are thread-safe unless commented otherwise. */ static class QueuedBatch { private final List requests = new ArrayList<>(); private final CountDownLatch sent = new CountDownLatch(1); @@ -111,8 +118,18 @@ static class QueuedBatch { private volatile boolean finalized = false; private volatile boolean failed = false; - List requests() { - return requests; + /** + * Returns a read-only view of requests sorted with {@link QueuedRequest#globalRequestsFirst()}. + */ + List sortedRequestsReadOnly() { + // Put all global data requests first because there is only a single repeated field for + // request ids and the initial ids correspond to global data requests if they are present. + requests.sort(QueuedRequest.globalRequestsFirst()); + return Collections.unmodifiableList(requests); + } + + int requestsCount() { + return requests.size(); } long byteSize() { @@ -127,42 +144,51 @@ void markFinalized() { finalized = true; } + /** + * Adds a request to the batch. + * + * @implNote Requires external synchronization to be thread safe. + */ void addRequest(QueuedRequest request) { requests.add(request); byteSize += request.byteSize(); } /** Let waiting for threads know that the request has been successfully sent. */ - synchronized void notifySent() { + void notifySent() { sent.countDown(); } /** Let waiting for threads know that a failure occurred. */ - synchronized void notifyFailed() { + void notifyFailed() { failed = true; sent.countDown(); } /** * Block until notified of a successful send via {@link #notifySent()} or a non-retryable - * failure via {@link #notifyFailed()}. On failure, throw an exception to on calling threads. + * failure via {@link #notifyFailed()}. On failure, throw an exception for waiters. */ void waitForSendOrFailNotification() throws InterruptedException { sent.await(); if (failed) { - ImmutableList cancelledRequests = createStreamCancelledErrorMessage(); - LOG.error("Requests failed for the following batches: {}", cancelledRequests); - throw new WindmillStreamShutdownException( - "Requests failed for batch containing " - + String.join(", ", cancelledRequests) - + " ... requests. This is most likely due to the stream being explicitly closed" - + " which happens when the work is marked as invalid on the streaming" - + " backend when key ranges shuffle around. This is transient and corresponding" - + " work will eventually be retried."); + ImmutableList cancelledRequests = createStreamCancelledErrorMessages(); + if (!cancelledRequests.isEmpty()) { + LOG.error("Requests failed for the following batches: {}", cancelledRequests); + throw new WindmillStreamShutdownException( + "Requests failed for batch containing " + + String.join(", ", cancelledRequests) + + " ... requests. This is most likely due to the stream being explicitly closed" + + " which happens when the work is marked as invalid on the streaming" + + " backend when key ranges shuffle around. This is transient and corresponding" + + " work will eventually be retried."); + } + + throw new WindmillStreamShutdownException("Stream was shutdown while waiting for send."); } } - ImmutableList createStreamCancelledErrorMessage() { + ImmutableList createStreamCancelledErrorMessages() { return requests.stream() .flatMap( request -> { @@ -175,11 +201,11 @@ ImmutableList createStreamCancelledErrorMessage() { keyedRequest -> "KeyedGetState=[" + "shardingKey=" - + keyedRequest.getShardingKey() + + debugFormat(keyedRequest.getShardingKey()) + "cacheToken=" - + keyedRequest.getCacheToken() + + debugFormat(keyedRequest.getCacheToken()) + "workToken" - + keyedRequest.getWorkToken() + + debugFormat(keyedRequest.getWorkToken()) + "]"); default: // Will never happen switch is exhaustive. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 1a105185b8a9..770cd616a5ac 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -330,6 +330,10 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); } + /** + * @implNote Returns a {@link GetWorkStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { GetWorkStream getWorkStream = @@ -346,6 +350,10 @@ public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver rece return getWorkStream; } + /** + * @implNote Returns a {@link GetDataStream} in the started state (w/ the initial header already + * sent). + */ @Override public GetDataStream getDataStream() { GetDataStream getDataStream = @@ -355,6 +363,10 @@ public GetDataStream getDataStream() { return getDataStream; } + /** + * @implNote Returns a {@link CommitWorkStream} in the started state (w/ the initial header + * already sent). + */ @Override public CommitWorkStream commitWorkStream() { CommitWorkStream commitWorkStream = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index 02c238fa3825..af096d5441b4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -57,7 +57,9 @@ import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.AbstractStub; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.joda.time.Duration; import org.joda.time.Instant; @@ -98,8 +100,7 @@ private GrpcWindmillStreamFactory( int windmillMessagesBetweenIsReadyChecks, boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, - Supplier maxBackOffSupplier, - Set> streamRegistry) { + Supplier maxBackOffSupplier) { this.jobHeader = jobHeader; this.logEveryNStreamFailures = logEveryNStreamFailures; this.streamingRpcBatchLimit = streamingRpcBatchLimit; @@ -112,7 +113,7 @@ private GrpcWindmillStreamFactory( .withInitialBackoff(MIN_BACKOFF) .withMaxBackoff(maxBackOffSupplier.get()) .backoff()); - this.streamRegistry = streamRegistry; + this.streamRegistry = ConcurrentHashMap.newKeySet(); this.sendKeyedGetDataRequests = sendKeyedGetDataRequests; this.processHeartbeatResponses = processHeartbeatResponses; this.streamIdGenerator = new AtomicLong(); @@ -127,8 +128,7 @@ static GrpcWindmillStreamFactory create( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses, Supplier maxBackOffSupplier, - int healthCheckIntervalMillis, - Set> streamRegistry) { + int healthCheckIntervalMillis) { GrpcWindmillStreamFactory streamFactory = new GrpcWindmillStreamFactory( jobHeader, @@ -137,8 +137,7 @@ static GrpcWindmillStreamFactory create( windmillMessagesBetweenIsReadyChecks, sendKeyedGetDataRequests, processHeartbeatResponses, - maxBackOffSupplier, - streamRegistry); + maxBackOffSupplier); if (healthCheckIntervalMillis >= 0) { // Health checks are run on background daemon thread, which will only be cleaned up on JVM @@ -175,8 +174,7 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) { .setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT) .setHealthCheckIntervalMillis(NO_HEALTH_CHECKS) .setSendKeyedGetDataRequests(true) - .setProcessHeartbeatResponses(ignored -> {}) - .setStreamRegistry(ConcurrentHashMap.newKeySet()); + .setProcessHeartbeatResponses(ignored -> {}); } private static > T withDefaultDeadline(T stub) { @@ -333,6 +331,11 @@ public void appendSummaryHtml(PrintWriter writer) { .forEach((workerToken, streams) -> printSummaryHtmlForWorker(workerToken, streams, writer)); } + @VisibleForTesting + ImmutableSet> streamRegistry() { + return ImmutableSet.copyOf(streamRegistry); + } + @Internal @AutoBuilder(callMethod = "create") public interface Builder { @@ -353,8 +356,6 @@ Builder setProcessHeartbeatResponses( Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis); - Builder setStreamRegistry(Set> streamRegistry); - GrpcWindmillStreamFactory build(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java deleted file mode 100644 index 0c146e298be0..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/WindmillStreamShutdownException.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; - -import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; -import org.apache.beam.sdk.annotations.Internal; - -/** - * Indicates that a {@link WindmillStream#shutdown()} was called while waiting for some internal - * operation to complete. Most common use of this exception should be conversion to a {@link - * org.apache.beam.runners.dataflow.worker.WorkItemCancelledException} as the {@link - * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} being processed by {@link - * WindmillStream}. - */ -@Internal -final class WindmillStreamShutdownException extends RuntimeException { - WindmillStreamShutdownException(String message) { - super(message); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java index 647d45adebb4..aa2767d5472d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSenderTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.streaming.harness; +import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -226,6 +227,61 @@ public void testCloseAllStreams_closesAllStreams() { verify(mockCommitWorkStream).shutdown(); } + @Test + public void testCloseAllStreams_doesNotStartStreamsAfterClose() { + long itemBudget = 1L; + long byteBudget = 1L; + GetWorkRequest getWorkRequestWithBudget = + GET_WORK_REQUEST.toBuilder().setMaxItems(itemBudget).setMaxBytes(byteBudget).build(); + GrpcWindmillStreamFactory mockStreamFactory = mock(GrpcWindmillStreamFactory.class); + GetWorkStream mockGetWorkStream = mock(GetWorkStream.class); + GetDataStream mockGetDataStream = mock(GetDataStream.class); + CommitWorkStream mockCommitWorkStream = mock(CommitWorkStream.class); + + when(mockStreamFactory.createDirectGetWorkStream( + eq(connection), + eq(getWorkRequestWithBudget), + any(ThrottleTimer.class), + any(), + any(), + any(), + eq(workItemScheduler))) + .thenReturn(mockGetWorkStream); + + when(mockStreamFactory.createDirectGetDataStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockGetDataStream); + when(mockStreamFactory.createDirectCommitWorkStream(eq(connection), any(ThrottleTimer.class))) + .thenReturn(mockCommitWorkStream); + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build(), + mockStreamFactory); + + windmillStreamSender.close(); + + verify(mockGetWorkStream, times(0)).start(); + verify(mockGetDataStream, times(0)).start(); + verify(mockCommitWorkStream, times(0)).start(); + + verify(mockGetWorkStream).shutdown(); + verify(mockGetDataStream).shutdown(); + verify(mockCommitWorkStream).shutdown(); + } + + @Test + public void testStartStream_afterCloseThrows() { + long itemBudget = 1L; + long byteBudget = 1L; + + WindmillStreamSender windmillStreamSender = + newWindmillStreamSender( + GetWorkBudget.builder().setBytes(byteBudget).setItems(itemBudget).build()); + + windmillStreamSender.close(); + assertThrows(IllegalStateException.class, windmillStreamSender::start); + } + private WindmillStreamSender newWindmillStreamSender(GetWorkBudget budget) { return newWindmillStreamSender(budget, streamFactory); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 6bd08706cd50..8bb057156d20 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -53,16 +53,10 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessChannelBuilder; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.inprocess.InProcessServerBuilder; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -73,15 +67,12 @@ @RunWith(JUnit4.class) public class StreamingEngineWorkCommitterTest { - private static final String FAKE_SERVER_NAME = "Fake server for StreamingEngineWorkCommitterTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public ErrorCollector errorCollector = new ErrorCollector(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; - private ManagedChannel inProcessChannel; private static Work createMockWork(long workToken) { return Work.create( @@ -129,23 +120,6 @@ public void setUp() throws IOException { WindmillStreamPool.create( 1, Duration.standardMinutes(1), fakeWindmillServer::commitWorkStream) ::getCloseableStream; - Server server = - InProcessServerBuilder.forName(FAKE_SERVER_NAME) - .fallbackHandlerRegistry(serviceRegistry) - .directExecutor() - .build() - .start(); - - inProcessChannel = - grpcCleanup.register( - InProcessChannelBuilder.forName(FAKE_SERVER_NAME).directExecutor().build()); - grpcCleanup.register(server); - grpcCleanup.register(inProcessChannel); - } - - @After - public void cleanUp() { - inProcessChannel.shutdownNow(); } private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 1ee16d774b0d..3baa31585a09 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -50,6 +51,7 @@ import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.InOrder; @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { @@ -158,9 +160,9 @@ public void testCommitWorkItem_afterShutdownFalse() { } commitWorkStream.shutdown(); - Set commitStatuses = new HashSet<>(); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { + Set commitStatuses = new HashSet<>(); assertFalse( batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add)); assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED); @@ -175,9 +177,10 @@ public void testSend_notCalledAfterShutdown() { TestCommitWorkStreamRequestObserver requestObserver = spy(new TestCommitWorkStreamRequestObserver()); + InOrder requestObserverVerifier = inOrder(requestObserver); + CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); - try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { assertTrue( @@ -186,12 +189,18 @@ public void testSend_notCalledAfterShutdown() { workItemCommitRequest(i), commitStatus -> commitProcessed.countDown())); } + // Shutdown the stream before we exit the try-with-resources block which will try to send() + // the batched request. commitWorkStream.shutdown(); } // send() uses the requestObserver to send requests. We expect 1 send since startStream() sends // the header, which happens before we shutdown. - verify(requestObserver, times(1)).onNext(any(Windmill.StreamingCommitWorkRequest.class)); + requestObserverVerifier + .verify(requestObserver) + .onNext(any(Windmill.StreamingCommitWorkRequest.class)); + requestObserverVerifier.verify(requestObserver).onError(any()); + requestObserverVerifier.verifyNoMoreInteractions(); } private static class TestCommitWorkStreamRequestObserver diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index e4570f5afc9a..f71627160846 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -31,9 +31,9 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; -import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; @@ -103,7 +103,7 @@ private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { } @Test - public void testRequestKeyedData_sendOnShutdownStreamThrowsWorkItemCancelledException() { + public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() { GetDataStreamTestStub testStub = new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); GrpcGetDataStream getDataStream = createGetDataStream(testStub); @@ -154,7 +154,7 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWorkItemCancelledExce if (i % 2 == 0) { assertTrue(sendFuture.isCompletedExceptionally()); ExecutionException e = assertThrows(ExecutionException.class, sendFuture::get); - assertThat(e).hasCauseThat().isInstanceOf(WorkItemCancelledException.class); + assertThat(e).hasCauseThat().isInstanceOf(WindmillStreamShutdownException.class); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index a8b828905e5e..40a63ee90d31 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -29,7 +29,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -38,7 +37,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; -import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; @@ -80,10 +78,11 @@ public class GrpcGetWorkerMetadataStreamTest { private static final String FAKE_SERVER_NAME = "Fake server for GrpcGetWorkerMetadataStreamTest"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + private final GrpcWindmillStreamFactory streamFactory = + GrpcWindmillStreamFactory.of(TEST_JOB_HEADER).build(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); private ManagedChannel inProcessChannel; private GrpcGetWorkerMetadataStream stream; - private Set> streamRegistry; private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( GetWorkerMetadataTestStub getWorkerMetadataTestStub, @@ -91,13 +90,10 @@ private GrpcGetWorkerMetadataStream getWorkerMetadataTestStream( serviceRegistry.addService(getWorkerMetadataTestStub); GrpcGetWorkerMetadataStream getWorkerMetadataStream = (GrpcGetWorkerMetadataStream) - GrpcWindmillStreamFactory.of(TEST_JOB_HEADER) - .setStreamRegistry(streamRegistry) - .build() - .createGetWorkerMetadataStream( - () -> CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), - new ThrottleTimer(), - endpointsConsumer); + streamFactory.createGetWorkerMetadataStream( + () -> CloudWindmillMetadataServiceV1Alpha1Grpc.newStub(inProcessChannel), + new ThrottleTimer(), + endpointsConsumer); getWorkerMetadataStream.start(); return getWorkerMetadataStream; } @@ -122,7 +118,6 @@ public void setUp() throws IOException { .setDirectEndpoint(IPV6_ADDRESS_1) .setBackendWorkerToken("worker_token") .build()); - streamRegistry = ConcurrentHashMap.newKeySet(); } @After @@ -254,9 +249,9 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { .putAllGlobalDataEndpoints(GLOBAL_DATA_ENDPOINTS) .build()); - assertTrue(streamRegistry.contains(stream)); + assertTrue(streamFactory.streamRegistry().contains(stream)); stream.halfClose(); - assertFalse(streamRegistry.contains(stream)); + assertFalse(streamFactory.streamRegistry().contains(stream)); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 239e3979a3b7..c3f38a571b76 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -28,6 +28,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -115,6 +116,7 @@ public class GrpcWindmillServerTest { private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServerTest.class); private static final int STREAM_CHUNK_SIZE = 2 << 20; private final long clientId = 10L; + private final Set openedChannels = new HashSet<>(); private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Rule public GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); @@ -131,16 +133,18 @@ public void setUp() throws Exception { @After public void tearDown() throws Exception { server.shutdownNow(); + openedChannels.forEach(ManagedChannel::shutdownNow); } private void startServerAndClient(List experiments) throws Exception { String name = "Fake server for " + getClass(); this.server = - InProcessServerBuilder.forName(name) - .fallbackHandlerRegistry(serviceRegistry) - .executor(Executors.newFixedThreadPool(1)) - .build() - .start(); + grpcCleanup.register( + InProcessServerBuilder.forName(name) + .fallbackHandlerRegistry(serviceRegistry) + .executor(Executors.newFixedThreadPool(1)) + .build() + .start()); this.client = GrpcWindmillServer.newTestInstance( @@ -149,7 +153,12 @@ private void startServerAndClient(List experiments) throws Exception { clientId, new FakeWindmillStubFactoryFactory( new FakeWindmillStubFactory( - () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name))))); + () -> { + ManagedChannel channel = + grpcCleanup.register(WindmillChannelFactory.inProcessChannel(name)); + openedChannels.add(channel); + return channel; + }))); } private void maybeInjectError(Stream stream) { From 99f0078759e0ad8097be9ecf52062c8aed0ce273 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 23 Oct 2024 16:52:40 -0700 Subject: [PATCH 10/23] address PR comments about DirectStreamObserver --- .../grpc/observers/DirectStreamObserver.java | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 4d798e8d18ea..fa9e9e15b440 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -76,6 +76,12 @@ public void onNext(T value) { while (true) { try { synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) return; + // If we awaited previously and timed out, wait for the same phase. Otherwise we're // careful to observe the phase before observing isReady. if (awaitPhase < 0) { @@ -115,15 +121,16 @@ public void onNext(T value) { } synchronized (lock) { + int currentPhase = isReadyNotifier.getPhase(); + // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted + // are synchronized after terminating the phaser if we observe that the phaser is not + // terminated the onNext calls below are guaranteed to not be called on a closed observer. + if (currentPhase < 0) return; messagesSinceReady = 0; outboundObserver.onNext(value); return; } } catch (TimeoutException e) { - if (isReadyNotifier.isTerminated()) { - return; - } - totalSecondsWaited += waitSeconds; if (totalSecondsWaited > deadlineSeconds) { String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); From e92786ee6d55dd74af26807b73d4f1f4b210b7b3 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 23 Oct 2024 18:41:44 -0700 Subject: [PATCH 11/23] address PR comments --- .../client/grpc/GrpcCommitWorkStream.java | 164 ++++++++---------- .../client/grpc/GrpcGetDataStream.java | 22 +-- .../grpc/GrpcGetDataStreamRequests.java | 26 ++- 3 files changed, 104 insertions(+), 108 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index f49da5a4f001..6472b717f8ef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -225,123 +225,111 @@ private void flushInternal(Map requests) { } private void issueSingleRequest(long id, PendingRequest pendingRequest) { - if (isPrepareForSendFailed(id, pendingRequest)) { + if (prepareForSend(id, pendingRequest)) { + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + requestBuilder + .addCommitChunkBuilder() + .setComputationId(pendingRequest.computationId()) + .setRequestId(id) + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); + StreamingCommitWorkRequest chunk = requestBuilder.build(); + try { + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } + } else { pendingRequest.abort(); - return; - } - - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - requestBuilder - .addCommitChunkBuilder() - .setComputationId(pendingRequest.computationId()) - .setRequestId(id) - .setShardingKey(pendingRequest.shardingKey()) - .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); - StreamingCommitWorkRequest chunk = requestBuilder.build(); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } } private void issueBatchedRequest(Map requests) { - if (isPrepareForSendFailed(requests)) { - requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); - return; - } - - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - String lastComputation = null; - for (Map.Entry entry : requests.entrySet()) { - PendingRequest request = entry.getValue(); - StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computationId())) { - chunkBuilder.setComputationId(request.computationId()); - lastComputation = request.computationId(); + if (prepareForSend(requests)) { + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + String lastComputation = null; + for (Map.Entry entry : requests.entrySet()) { + PendingRequest request = entry.getValue(); + StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); + } + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); } - chunkBuilder - .setRequestId(entry.getKey()) - .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); - } - StreamingCommitWorkRequest request = requestBuilder.build(); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + StreamingCommitWorkRequest request = requestBuilder.build(); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } + } else { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); } } private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { - if (isPrepareForSendFailed(id, pendingRequest)) { - pendingRequest.abort(); - return; - } - - checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); - ByteString serializedCommit = pendingRequest.serializedCommit(); - synchronized (this) { - for (int i = 0; - i < serializedCommit.size(); - i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { - int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; - ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); - int remaining = serializedCommit.size() - end; - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } - - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - break; + if (prepareForSend(id, pendingRequest)) { + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); + synchronized (this) { + for (int i = 0; + i < serializedCommit.size(); + i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; + ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); + + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); + int remaining = serializedCommit.size() - end; + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + try { + send(requestChunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + break; + } } } + } else { + pendingRequest.abort(); } } - /** - * Returns true if the request should be failed due to stream shutdown, else tracks the request to - * be sent and returns false. - */ - private boolean isPrepareForSendFailed(long id, PendingRequest request) { + /** Returns true if prepare for send succeeded. */ + private boolean prepareForSend(long id, PendingRequest request) { synchronized (shutdownLock) { synchronized (this) { if (!isShutdown()) { pending.put(id, request); - return false; + return true; } - - return true; + return false; } } } - /** - * Returns true if the request should be failed due to stream shutdown, else tracks the requests - * to be sent and returns false. - */ - private boolean isPrepareForSendFailed(Map requests) { + /** Returns true if prepare for send succeeded. */ + private boolean prepareForSend(Map requests) { synchronized (shutdownLock) { synchronized (this) { if (!isShutdown()) { pending.putAll(requests); - return false; + return true; } - - return true; + return false; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index a840d4dfbcfd..5209cf25e8e4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -69,7 +69,9 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); + /** @implNote insertion and removal is guarded by {@link #shutdownLock} */ private final Deque batches; + private final Map pending; private final AtomicLong idGenerator; private final ThrottleTimer getDataThrottleTimer; @@ -289,14 +291,12 @@ protected void shutdownInternal() { // Future calls to send RPCs will fail. pending.values().forEach(AppendableInputStream::cancel); pending.clear(); - synchronized (batches) { - batches.forEach( - batch -> { - batch.markFinalized(); - batch.notifyFailed(); - }); - batches.clear(); - } + batches.forEach( + batch -> { + batch.markFinalized(); + batch.notifyFailed(); + }); + batches.clear(); } @Override @@ -363,7 +363,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept QueuedBatch batch; boolean responsibleForSend = false; @Nullable QueuedBatch prevBatch = null; - synchronized (batches) { + synchronized (shutdownLock) { batch = batches.isEmpty() ? null : batches.getLast(); if (batch == null || batch.isFinalized() @@ -388,7 +388,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept } // Finalize the batch so that no additional requests will be added. Leave the batch in the // queue so that a subsequent batch will wait for its completion. - synchronized (batches) { + synchronized (shutdownLock) { verify(batch == batches.peekFirst(), "GetDataStream request batch removed before send()."); batch.markFinalized(); } @@ -402,7 +402,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept void trySendBatch(QueuedBatch batch) { try { sendBatch(batch.sortedRequestsReadOnly()); - synchronized (batches) { + synchronized (shutdownLock) { verify( batch == batches.pollFirst(), "Sent GetDataStream request batch removed before send() was complete."); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index 27c960448a23..e3c327aa4389 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -110,7 +110,9 @@ void addToStreamingGetDataRequest(Windmill.StreamingGetDataRequest.Builder build } } - /** Represents a batch of queued requests. Methods are thread-safe unless commented otherwise. */ + /** + * Represents a batch of queued requests. Methods are not thread-safe unless commented otherwise. + */ static class QueuedBatch { private final List requests = new ArrayList<>(); private final CountDownLatch sent = new CountDownLatch(1); @@ -144,22 +146,26 @@ void markFinalized() { finalized = true; } - /** - * Adds a request to the batch. - * - * @implNote Requires external synchronization to be thread safe. - */ + /** Adds a request to the batch. */ void addRequest(QueuedRequest request) { requests.add(request); byteSize += request.byteSize(); } - /** Let waiting for threads know that the request has been successfully sent. */ + /** + * Let waiting for threads know that the request has been successfully sent. + * + * @implNote Thread safe. + */ void notifySent() { sent.countDown(); } - /** Let waiting for threads know that a failure occurred. */ + /** + * Let waiting for threads know that a failure occurred. + * + * @implNote Thread safe. + */ void notifyFailed() { failed = true; sent.countDown(); @@ -168,6 +174,8 @@ void notifyFailed() { /** * Block until notified of a successful send via {@link #notifySent()} or a non-retryable * failure via {@link #notifyFailed()}. On failure, throw an exception for waiters. + * + * @implNote Thread safe. */ void waitForSendOrFailNotification() throws InterruptedException { sent.await(); @@ -188,7 +196,7 @@ void waitForSendOrFailNotification() throws InterruptedException { } } - ImmutableList createStreamCancelledErrorMessages() { + private ImmutableList createStreamCancelledErrorMessages() { return requests.stream() .flatMap( request -> { From 6f052de7dee26cfbe78a5c3da25524b7509ce7f3 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 23 Oct 2024 23:48:01 -0700 Subject: [PATCH 12/23] add unit tests for GrpcGetDataStreamRequests utilities, moved converting QueuedBatch into StreamingGetDataReequest into QueueBatch --- .../client/grpc/GrpcGetDataStream.java | 21 +-- .../grpc/GrpcGetDataStreamRequests.java | 28 +++- .../grpc/GrpcGetDataStreamRequestsTest.java | 149 ++++++++++++++++++ 3 files changed, 178 insertions(+), 20 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 5209cf25e8e4..4aff1c83fc62 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -401,7 +401,7 @@ private void queueRequestAndWait(QueuedRequest request) throws InterruptedExcept void trySendBatch(QueuedBatch batch) { try { - sendBatch(batch.sortedRequestsReadOnly()); + sendBatch(batch); synchronized (shutdownLock) { verify( batch == batches.pollFirst(), @@ -419,12 +419,12 @@ void trySendBatch(QueuedBatch batch) { } } - private void sendBatch(List requests) { - if (requests.isEmpty()) { + private void sendBatch(QueuedBatch batch) { + if (batch.isEmpty()) { return; } - StreamingGetDataRequest batchedRequest = flushToBatch(requests); + StreamingGetDataRequest batchedRequest = batch.asGetDataRequest(); synchronized (shutdownLock) { // Synchronization of pending inserts is necessary with send to ensure duplicates are not // sent on stream reconnect. @@ -433,9 +433,10 @@ private void sendBatch(List requests) { // to it. if (isShutdown()) { throw new WindmillStreamShutdownException( - "Stream was closed when attempting to send " + requests.size() + " requests."); + "Stream was closed when attempting to send " + batch.requestsCount() + " requests."); } - for (QueuedRequest request : requests) { + + for (QueuedRequest request : batch.requestsReadOnly()) { // Map#put returns null if there was no previous mapping for the key, meaning we have not // seen it before. verify( @@ -453,14 +454,6 @@ private void sendBatch(List requests) { } } - private StreamingGetDataRequest flushToBatch(List requests) { - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); - for (QueuedRequest request : requests) { - request.addToStreamingGetDataRequest(builder); - } - return builder.build(); - } - private void verify(boolean condition, String message) { Verify.verify(condition || isShutdown(), message); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java index e3c327aa4389..56d796e488c8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequests.java @@ -120,14 +120,30 @@ static class QueuedBatch { private volatile boolean finalized = false; private volatile boolean failed = false; + /** Returns a read-only view of requests. */ + List requestsReadOnly() { + return Collections.unmodifiableList(requests); + } + /** - * Returns a read-only view of requests sorted with {@link QueuedRequest#globalRequestsFirst()}. + * Converts the batch to a {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest}. */ - List sortedRequestsReadOnly() { - // Put all global data requests first because there is only a single repeated field for - // request ids and the initial ids correspond to global data requests if they are present. - requests.sort(QueuedRequest.globalRequestsFirst()); - return Collections.unmodifiableList(requests); + Windmill.StreamingGetDataRequest asGetDataRequest() { + Windmill.StreamingGetDataRequest.Builder builder = + Windmill.StreamingGetDataRequest.newBuilder(); + + requests.stream() + // Put all global data requests first because there is only a single repeated field for + // request ids and the initial ids correspond to global data requests if they are present. + .sorted(QueuedRequest.globalRequestsFirst()) + .forEach(request -> request.addToStreamingGetDataRequest(builder)); + + return builder.build(); + } + + boolean isEmpty() { + return requests.isEmpty(); } int requestsCount() { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java new file mode 100644 index 000000000000..d8b787fe1020 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcGetDataStreamRequestsTest { + + @Test + public void testQueuedRequest_globalRequestsFirstComparator() { + List requests = new ArrayList<>(); + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + requests.add( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + requests.add(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + requests.sort(GrpcGetDataStreamRequests.QueuedRequest.globalRequestsFirst()); + + // First one should be the global request. + assertTrue(requests.get(0).getDataRequest().isGlobal()); + } + + @Test + public void testQueuedBatch_asGetDataRequest() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + + Windmill.KeyedGetDataRequest keyedGetDataRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(1L) + .setShardingKey(1L) + .setWorkToken(1L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 1, "computation1", keyedGetDataRequest1)); + + Windmill.KeyedGetDataRequest keyedGetDataRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setCacheToken(2L) + .setShardingKey(2L) + .setWorkToken(2L) + .setMaxBytes(Long.MAX_VALUE) + .build(); + queuedBatch.addRequest( + GrpcGetDataStreamRequests.QueuedRequest.forComputation( + 2, "computation2", keyedGetDataRequest2)); + + Windmill.GlobalDataRequest globalDataRequest = + Windmill.GlobalDataRequest.newBuilder() + .setDataId( + Windmill.GlobalDataId.newBuilder() + .setTag("globalData") + .setVersion(ByteString.EMPTY) + .build()) + .setComputationId("computation1") + .build(); + queuedBatch.addRequest(GrpcGetDataStreamRequests.QueuedRequest.global(3, globalDataRequest)); + + Windmill.StreamingGetDataRequest getDataRequest = queuedBatch.asGetDataRequest(); + + assertThat(getDataRequest.getRequestIdCount()).isEqualTo(3); + assertThat(getDataRequest.getGlobalDataRequestList()).containsExactly(globalDataRequest); + assertThat(getDataRequest.getStateRequestList()) + .containsExactly( + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation1") + .addRequests(keyedGetDataRequest1) + .build(), + Windmill.ComputationGetDataRequest.newBuilder() + .setComputationId("computation2") + .addRequests(keyedGetDataRequest2) + .build()); + } + + @Test + public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOnWaiters() { + GrpcGetDataStreamRequests.QueuedBatch queuedBatch = new GrpcGetDataStreamRequests.QueuedBatch(); + CompletableFuture waitFuture = + CompletableFuture.supplyAsync( + () -> + assertThrows( + WindmillStreamShutdownException.class, + queuedBatch::waitForSendOrFailNotification)); + + queuedBatch.notifyFailed(); + waitFuture.join(); + } +} From fb8573a5f9cda44c2f0bb876956cbca74faedaaf Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Thu, 24 Oct 2024 13:53:22 -0700 Subject: [PATCH 13/23] address PR comments --- .../client/grpc/GrpcCommitWorkStream.java | 139 +++++++++--------- .../client/grpc/GrpcGetDataStream.java | 29 +++- .../grpc/GrpcGetDataStreamRequests.java | 5 + 3 files changed, 102 insertions(+), 71 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 6472b717f8ef..be834bf03bbd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -225,86 +225,89 @@ private void flushInternal(Map requests) { } private void issueSingleRequest(long id, PendingRequest pendingRequest) { - if (prepareForSend(id, pendingRequest)) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - requestBuilder - .addCommitChunkBuilder() - .setComputationId(pendingRequest.computationId()) - .setRequestId(id) - .setShardingKey(pendingRequest.shardingKey()) - .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); - StreamingCommitWorkRequest chunk = requestBuilder.build(); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } else { + if (!prepareForSend(id, pendingRequest)) { pendingRequest.abort(); + return; + } + + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + requestBuilder + .addCommitChunkBuilder() + .setComputationId(pendingRequest.computationId()) + .setRequestId(id) + .setShardingKey(pendingRequest.shardingKey()) + .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); + StreamingCommitWorkRequest chunk = requestBuilder.build(); + try { + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. } } private void issueBatchedRequest(Map requests) { - if (prepareForSend(requests)) { - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); - String lastComputation = null; - for (Map.Entry entry : requests.entrySet()) { - PendingRequest request = entry.getValue(); - StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); - if (lastComputation == null || !lastComputation.equals(request.computationId())) { - chunkBuilder.setComputationId(request.computationId()); - lastComputation = request.computationId(); - } - chunkBuilder - .setRequestId(entry.getKey()) - .setShardingKey(request.shardingKey()) - .setSerializedWorkItemCommit(request.serializedCommit()); - } - StreamingCommitWorkRequest request = requestBuilder.build(); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - } - } else { + if (!prepareForSend(requests)) { requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; + } + + StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); + String lastComputation = null; + for (Map.Entry entry : requests.entrySet()) { + PendingRequest request = entry.getValue(); + StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); + if (lastComputation == null || !lastComputation.equals(request.computationId())) { + chunkBuilder.setComputationId(request.computationId()); + lastComputation = request.computationId(); + } + chunkBuilder + .setRequestId(entry.getKey()) + .setShardingKey(request.shardingKey()) + .setSerializedWorkItemCommit(request.serializedCommit()); + } + StreamingCommitWorkRequest request = requestBuilder.build(); + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. } } private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { - if (prepareForSend(id, pendingRequest)) { - checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); - ByteString serializedCommit = pendingRequest.serializedCommit(); - synchronized (this) { - for (int i = 0; - i < serializedCommit.size(); - i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { - int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; - ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - - StreamingCommitRequestChunk.Builder chunkBuilder = - StreamingCommitRequestChunk.newBuilder() - .setRequestId(id) - .setSerializedWorkItemCommit(chunk) - .setComputationId(pendingRequest.computationId()) - .setShardingKey(pendingRequest.shardingKey()); - int remaining = serializedCommit.size() - end; - if (remaining > 0) { - chunkBuilder.setRemainingBytesForWorkItem(remaining); - } - - StreamingCommitWorkRequest requestChunk = - StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. - break; - } + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + + checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); + ByteString serializedCommit = pendingRequest.serializedCommit(); + synchronized (this) { + for (int i = 0; + i < serializedCommit.size(); + i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { + int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; + ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); + + StreamingCommitRequestChunk.Builder chunkBuilder = + StreamingCommitRequestChunk.newBuilder() + .setRequestId(id) + .setSerializedWorkItemCommit(chunk) + .setComputationId(pendingRequest.computationId()) + .setShardingKey(pendingRequest.shardingKey()); + int remaining = serializedCommit.size() - end; + if (remaining > 0) { + chunkBuilder.setRemainingBytesForWorkItem(remaining); + } + + StreamingCommitWorkRequest requestChunk = + StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); + try { + send(requestChunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + break; } } - } else { - pendingRequest.abort(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 4aff1c83fc62..cda246065ab9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -69,7 +69,7 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** @implNote insertion and removal is guarded by {@link #shutdownLock} */ + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@link #shutdownLock} */ private final Deque batches; private final Map pending; @@ -349,21 +349,36 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn ResponseT issueRequest(QueuedRequest request, ParseFn Date: Mon, 28 Oct 2024 23:53:21 -0700 Subject: [PATCH 15/23] address PR comments --- .../client/AbstractWindmillStream.java | 185 +++++------------- .../client/ResettableStreamObserver.java | 98 ++++++++++ .../windmill/client/StreamDebugMetrics.java | 135 +++++++++++++ .../client/grpc/GrpcGetDataStream.java | 2 +- .../client/ResettableStreamObserverTest.java | 90 +++++++++ 5 files changed, 369 insertions(+), 141 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index ca54ade05f0f..37476546b6cc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -25,24 +25,15 @@ import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; -import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.com.google.api.client.util.Sleeper; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -74,9 +65,9 @@ public abstract class AbstractWindmillStream implements Win // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce // per-chunk overhead, and small enough that we can still perform granular flow-control. protected static final int RPC_STREAM_CHUNK_SIZE = 2 << 20; + // Indicates that the logical stream has been half-closed and is waiting for clean server + // shutdown. private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); - - protected final AtomicBoolean clientClosed; protected final Sleeper sleeper; /** @@ -87,30 +78,25 @@ public abstract class AbstractWindmillStream implements Win */ protected final Object shutdownLock = new Object(); - private final AtomicLong lastSendTimeMs; + private final Logger logger; private final ExecutorService executor; private final BackOff backoff; - private final AtomicLong startTimeMs; - private final AtomicLong lastResponseTimeMs; - private final AtomicInteger restartCount; - private final AtomicInteger errorCount; - private final AtomicReference lastRestartReason; - private final AtomicReference lastRestartTime; - private final AtomicLong sleepUntil; private final CountDownLatch finishLatch; private final Set> streamRegistry; private final int logEveryNStreamFailures; private final String backendWorkerToken; - private final ResettableRequestObserver requestObserver; - private final AtomicReference shutdownTime; + private final ResettableStreamObserver requestObserver; + private final StreamDebugMetrics debugMetrics; + protected volatile boolean clientClosed; /** - * Indicates if the current {@link ResettableRequestObserver} was closed by calling {@link - * #halfClose()}. + * Indicates if the current {@link ResettableStreamObserver} was closed by calling {@link + * #halfClose()}. Separate from {@link #clientClosed} as this is specific to the requestObserver + * and is initially false on retry. */ - private final AtomicBoolean streamClosed; + @GuardedBy("this") + private boolean streamClosed; - private final Logger logger; private volatile boolean isShutdown; private volatile boolean started; @@ -133,28 +119,20 @@ protected AbstractWindmillStream( this.backoff = backoff; this.streamRegistry = streamRegistry; this.logEveryNStreamFailures = logEveryNStreamFailures; - this.clientClosed = new AtomicBoolean(); + this.clientClosed = false; this.isShutdown = false; this.started = false; - this.streamClosed = new AtomicBoolean(false); - this.startTimeMs = new AtomicLong(); - this.lastSendTimeMs = new AtomicLong(); - this.lastResponseTimeMs = new AtomicLong(); - this.restartCount = new AtomicInteger(); - this.errorCount = new AtomicInteger(); - this.lastRestartReason = new AtomicReference<>(); - this.lastRestartTime = new AtomicReference<>(); - this.sleepUntil = new AtomicLong(); + this.streamClosed = false; this.finishLatch = new CountDownLatch(1); this.requestObserver = - new ResettableRequestObserver<>( + new ResettableStreamObserver<>( () -> streamObserverFactory.from( clientFactory, new AbstractWindmillStream.ResponseObserver())); this.sleeper = Sleeper.DEFAULT; this.logger = logger; - this.shutdownTime = new AtomicReference<>(); + this.debugMetrics = new StreamDebugMetrics(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -163,10 +141,6 @@ private static String createThreadName(String streamType, String backendWorkerTo : String.format("%s-WindmillStream-thread", streamType); } - private static long debugDuration(long nowMs, long startMs) { - return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); - } - /** Called on each response from the server. */ protected abstract void onResponse(ResponseT response); @@ -195,13 +169,13 @@ protected final void send(RequestT request) { return; } - if (streamClosed.get()) { + if (streamClosed) { // TODO(m-trieu): throw a more specific exception here (i.e StreamClosedException) throw new IllegalStateException("Send called on a client closed stream."); } try { - lastSendTimeMs.set(Instant.now().getMillis()); + debugMetrics.recordSend(); requestObserver.onNext(request); } catch (StreamObserverCancelledException e) { if (isShutdown) { @@ -239,21 +213,22 @@ private void startStream() { if (isShutdown) { break; } - startTimeMs.set(Instant.now().getMillis()); - lastResponseTimeMs.set(0); - streamClosed.set(false); + debugMetrics.recordStart(); + streamClosed = false; requestObserver.reset(); onNewStream(); - if (clientClosed.get()) { + if (clientClosed) { halfClose(); } return; } + } catch (WindmillStreamShutdownException e) { + logger.debug("Stream was shutdown waiting to start.", e); } catch (Exception e) { logger.error("Failed to create new stream, retrying: ", e); try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); + debugMetrics.recordSleep(sleep); sleeper.sleep(sleep); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); @@ -285,7 +260,7 @@ protected final void executeSafely(Runnable runnable) { } public final void maybeSendHealthCheck(Instant lastSendThreshold) { - if (!clientClosed.get() && lastSendTimeMs.get() < lastSendThreshold.getMillis()) { + if (!clientClosed && debugMetrics.lastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); } catch (RuntimeException e) { @@ -303,28 +278,19 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { */ public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - if (restartCount.get() > 0) { - writer.format( - ", %d restarts, last restart reason [ %s ] at [%s], %d errors", - restartCount.get(), lastRestartReason.get(), lastRestartTime.get(), errorCount.get()); - } - if (clientClosed.get()) { + debugMetrics.printRestartsHtml(writer); + if (clientClosed) { writer.write(", client closed"); } long nowMs = Instant.now().getMillis(); - long sleepLeft = sleepUntil.get() - nowMs; + long sleepLeft = debugMetrics.sleepLeft(); if (sleepLeft > 0) { writer.format(", %dms backoff remaining", sleepLeft); } + debugMetrics.printSummaryHtml(writer, nowMs); writer.format( - ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " - + "isShutdown: %s, shutdown time: %s", - debugDuration(nowMs, startTimeMs.get()), - debugDuration(nowMs, lastSendTimeMs.get()), - debugDuration(nowMs, lastResponseTimeMs.get()), - streamClosed.get(), - isShutdown, - shutdownTime.get()); + ", closed: %s, " + "isShutdown: %s, shutdown time: %s", + streamClosed, isShutdown, debugMetrics.shutdownTime()); } /** @@ -336,9 +302,9 @@ public final void appendSummaryHtml(PrintWriter writer) { @Override public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. - clientClosed.set(true); + clientClosed = true; requestObserver.onCompleted(); - streamClosed.set(true); + streamClosed = true; } @Override @@ -348,7 +314,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(startTimeMs.get()); + return new Instant(debugMetrics.startTimeMs()); } @Override @@ -363,71 +329,15 @@ public final void shutdown() { synchronized (shutdownLock) { if (!isShutdown) { isShutdown = true; - shutdownTime.set(DateTime.now()); - if (started) { - // requestObserver is not set until the first startStream() is called. If the stream was - // never started there is nothing to clean up internally. - requestObserver.onError( - new WindmillStreamShutdownException("Explicit call to shutdown stream.")); - shutdownInternal(); - } + debugMetrics.recordShutdown(); + requestObserver.poison(); + shutdownInternal(); } } } - private void recordRestartReason(String error) { - lastRestartReason.set(error); - lastRestartTime.set(DateTime.now()); - } - protected abstract void shutdownInternal(); - /** - * Request observer that allows resetting its internal delegate using the given {@link - * #requestObserverSupplier}. - * - * @implNote {@link StreamObserver}s generated by {@link * #requestObserverSupplier} are expected - * to be {@link ThreadSafe}. - */ - @ThreadSafe - private static class ResettableRequestObserver implements StreamObserver { - - private final Supplier> requestObserverSupplier; - - @GuardedBy("this") - private @Nullable StreamObserver delegateRequestObserver; - - private ResettableRequestObserver(Supplier> requestObserverSupplier) { - this.requestObserverSupplier = requestObserverSupplier; - this.delegateRequestObserver = null; - } - - private synchronized StreamObserver delegate() { - return Preconditions.checkNotNull( - delegateRequestObserver, - "requestObserver cannot be null. Missing a call to startStream() to initialize."); - } - - private synchronized void reset() { - delegateRequestObserver = requestObserverSupplier.get(); - } - - @Override - public void onNext(RequestT requestT) { - delegate().onNext(requestT); - } - - @Override - public void onError(Throwable throwable) { - delegate().onError(throwable); - } - - @Override - public void onCompleted() { - delegate().onCompleted(); - } - } - private class ResponseObserver implements StreamObserver { @Override @@ -437,7 +347,7 @@ public void onNext(ResponseT response) { } catch (IOException e) { // Ignore. } - lastResponseTimeMs.set(Instant.now().getMillis()); + debugMetrics.recordResponse(); onResponse(response); } @@ -451,7 +361,7 @@ public void onError(Throwable t) { try { long sleep = backoff.nextBackOffMillis(); - sleepUntil.set(Instant.now().getMillis() + sleep); + debugMetrics.recordSleep(sleep); sleeper.sleep(sleep); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -473,16 +383,16 @@ public void onCompleted() { } private void recordStreamStatus(Status status) { - int currentRestartCount = restartCount.incrementAndGet(); + int currentRestartCount = debugMetrics.incrementAndGetRestarts(); if (status.isOk()) { String restartReason = "Stream completed successfully but did not complete requested operations, " + "recreating"; logger.warn(restartReason); - recordRestartReason(restartReason); + debugMetrics.recordRestartReason(restartReason); } else { - int currentErrorCount = errorCount.incrementAndGet(); - recordRestartReason(status.toString()); + int currentErrorCount = debugMetrics.incrementAndGetErrors(); + debugMetrics.recordRestartReason(status.toString()); Throwable t = status.getCause(); if (t instanceof StreamObserverCancelledException) { logger.error( @@ -494,11 +404,6 @@ private void recordStreamStatus(Status status) { } else if (currentRestartCount % logEveryNStreamFailures == 0) { // Don't log every restart since it will get noisy, and many errors transient. long nowMillis = Instant.now().getMillis(); - String responseDebug = - lastResponseTimeMs.get() == 0 - ? "never received response" - : "received response " + (nowMillis - lastResponseTimeMs.get()) + "ms ago"; - logger.debug( "{} has been restarted {} times. Streaming Windmill RPC Error Count: {}; last was: {}" + " with status: {}. created {}ms ago; {}. This is normal with autoscaling.", @@ -507,8 +412,8 @@ private void recordStreamStatus(Status status) { currentErrorCount, t, status, - nowMillis - startTimeMs.get(), - responseDebug); + nowMillis - debugMetrics.startTimeMs(), + debugMetrics.responseDebugString(nowMillis)); } // If the stream was stopped due to a resource exhausted error then we are throttled. @@ -520,7 +425,7 @@ private void recordStreamStatus(Status status) { /** Returns true if the stream was torn down and should not be restarted internally. */ private synchronized boolean maybeTeardownStream() { - if (isShutdown || (clientClosed.get() && !hasPendingRequests())) { + if (isShutdown || (clientClosed && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java new file mode 100644 index 000000000000..cbe8d5122ee7 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; + +/** + * Request observer that allows resetting its internal delegate using the given {@link + * #streamObserverFactory}. + * + * @implNote {@link StreamObserver}s generated by {@link #streamObserverFactory} are expected to be + * {@link ThreadSafe}. + */ +@ThreadSafe +@Internal +final class ResettableStreamObserver implements StreamObserver { + private final Supplier> streamObserverFactory; + + @GuardedBy("this") + private @Nullable StreamObserver delegateStreamObserver; + + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ + @GuardedBy("this") + private boolean isPoisoned; + + ResettableStreamObserver(Supplier> streamObserverFactory) { + this.streamObserverFactory = streamObserverFactory; + this.delegateStreamObserver = null; + this.isPoisoned = false; + } + + private synchronized StreamObserver delegate() { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); + } + + return Preconditions.checkNotNull( + delegateStreamObserver, + "requestObserver cannot be null. Missing a call to startStream() to initialize."); + } + + synchronized void reset() { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); + } + + delegateStreamObserver = streamObserverFactory.get(); + } + + synchronized void poison() { + if (!isPoisoned) { + isPoisoned = true; + if (delegateStreamObserver != null) { + delegateStreamObserver.onError( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + } + } + } + + @Override + public void onNext(T t) { + delegate().onNext(t); + } + + @Override + public void onError(Throwable throwable) { + delegate().onError(throwable); + } + + @Override + public void onCompleted() { + delegate().onCompleted(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java new file mode 100644 index 000000000000..9f5da3b417ff --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.joda.time.DateTime; +import org.joda.time.Instant; + +/** Records stream metrics for debugging. */ +@ThreadSafe +final class StreamDebugMetrics { + private final AtomicInteger restartCount = new AtomicInteger(); + private final AtomicInteger errorCount = new AtomicInteger(); + + @GuardedBy("this") + private long sleepUntil = 0; + + @GuardedBy("this") + private String lastRestartReason = ""; + + @GuardedBy("this") + private DateTime lastRestartTime = null; + + @GuardedBy("this") + private long lastResponseTimeMs = 0; + + @GuardedBy("this") + private long lastSendTimeMs = 0; + + @GuardedBy("this") + private long startTimeMs = 0; + + @GuardedBy("this") + private DateTime shutdownTime = null; + + private static long debugDuration(long nowMs, long startMs) { + return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); + } + + private static long nowMs() { + return Instant.now().getMillis(); + } + + synchronized void recordSend() { + lastSendTimeMs = nowMs(); + } + + synchronized void recordStart() { + startTimeMs = nowMs(); + lastResponseTimeMs = 0; + } + + synchronized void recordResponse() { + lastResponseTimeMs = nowMs(); + } + + synchronized void recordRestartReason(String error) { + lastRestartReason = error; + lastRestartTime = DateTime.now(); + } + + synchronized long startTimeMs() { + return startTimeMs; + } + + synchronized long lastSendTimeMs() { + return lastSendTimeMs; + } + + synchronized void recordSleep(long sleepMs) { + sleepUntil = nowMs() + sleepMs; + } + + synchronized long sleepLeft() { + return sleepUntil - nowMs(); + } + + int incrementAndGetRestarts() { + return restartCount.incrementAndGet(); + } + + int incrementAndGetErrors() { + return errorCount.incrementAndGet(); + } + + synchronized void recordShutdown() { + shutdownTime = DateTime.now(); + } + + synchronized String responseDebugString(long nowMillis) { + return lastResponseTimeMs == 0 + ? "never received response" + : "received response " + (nowMillis - lastResponseTimeMs) + "ms ago"; + } + + void printRestartsHtml(PrintWriter writer) { + if (restartCount.get() > 0) { + synchronized (this) { + writer.format( + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + restartCount.get(), lastRestartReason, lastRestartTime, errorCount.get()); + } + } + } + + synchronized DateTime shutdownTime() { + return shutdownTime; + } + + synchronized void printSummaryHtml(PrintWriter writer, long nowMs) { + writer.format( + ", current stream is %dms old, last send %dms, last response %dms", + debugDuration(nowMs, startTimeMs), + debugDuration(nowMs, lastSendTimeMs), + debugDuration(nowMs, lastResponseTimeMs)); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index a5d5b1882fd7..6a809712bd9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -161,7 +161,7 @@ protected synchronized void onNewStream() { } send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed.get() && !isShutdown()) { + if (clientClosed && !isShutdown()) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java new file mode 100644 index 000000000000..538da9607f8b --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ResettableStreamObserverTest { + private final StreamObserver delegate = + spy( + new StreamObserver() { + @Override + public void onNext(Integer integer) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }); + + @Test + public void testPoison_beforeDelegateSet() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + verifyNoInteractions(delegate); + } + + @Test + public void testPoison_afterDelegateSet() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.reset(); + observer.poison(); + verify(delegate).onError(isA(WindmillStreamShutdownException.class)); + } + + @Test + public void testReset_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::reset); + } + + @Test + public void onNext_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); + } + + @Test + public void onError_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows( + WindmillStreamShutdownException.class, + () -> observer.onError(new RuntimeException("something bad happened."))); + } + + @Test + public void onCompleted_afterPoisonedThrows() { + ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + observer.poison(); + assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); + } +} From 5a68e0ff791c08438a316df22c7354c91638cf5d Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 29 Oct 2024 10:28:09 -0700 Subject: [PATCH 16/23] don't block on rendering status pages --- .../client/AbstractWindmillStream.java | 33 ++++++-- .../windmill/client/StreamDebugMetrics.java | 79 +++++++++++++++---- .../client/ResettableStreamObserverTest.java | 54 ++++++++++--- 3 files changed, 129 insertions(+), 37 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 37476546b6cc..b9648fe4ab47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -276,21 +276,38 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { * information. Blocking sends are made beneath this stream object's lock which could block * status page rendering. */ + @SuppressWarnings("GuardedBy") public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); - debugMetrics.printRestartsHtml(writer); + StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); + summaryMetrics + .restartMetrics() + .ifPresent( + metrics -> + writer.format( + ", %d restarts, last restart reason [ %s ] at [%s], %d errors", + metrics.restartCount(), + metrics.lastRestartReason(), + metrics.lastRestartTime(), + metrics.errorCount())); + if (clientClosed) { writer.write(", client closed"); } - long nowMs = Instant.now().getMillis(); - long sleepLeft = debugMetrics.sleepLeft(); - if (sleepLeft > 0) { - writer.format(", %dms backoff remaining", sleepLeft); + + if (summaryMetrics.sleepLeft() > 0) { + writer.format(", %dms backoff remaining", summaryMetrics.sleepLeft()); } - debugMetrics.printSummaryHtml(writer, nowMs); + writer.format( - ", closed: %s, " + "isShutdown: %s, shutdown time: %s", - streamClosed, isShutdown, debugMetrics.shutdownTime()); + ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " + + "isShutdown: %s, shutdown time: %s", + summaryMetrics.streamAge(), + summaryMetrics.timeSinceLastSend(), + summaryMetrics.timeSinceLastResponse(), + streamClosed, + isShutdown, + summaryMetrics.shutdownTime().orElse(null)); } /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java index 9f5da3b417ff..48aaf9c5bf4b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -17,8 +17,10 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; -import java.io.PrintWriter; +import com.google.auto.value.AutoValue; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; import org.joda.time.DateTime; @@ -89,10 +91,6 @@ synchronized void recordSleep(long sleepMs) { sleepUntil = nowMs() + sleepMs; } - synchronized long sleepLeft() { - return sleepUntil - nowMs(); - } - int incrementAndGetRestarts() { return restartCount.incrementAndGet(); } @@ -111,25 +109,74 @@ synchronized String responseDebugString(long nowMillis) { : "received response " + (nowMillis - lastResponseTimeMs) + "ms ago"; } - void printRestartsHtml(PrintWriter writer) { + private Optional getRestartMetrics() { if (restartCount.get() > 0) { synchronized (this) { - writer.format( - ", %d restarts, last restart reason [ %s ] at [%s], %d errors", - restartCount.get(), lastRestartReason, lastRestartTime, errorCount.get()); + return Optional.of( + RestartMetrics.create( + restartCount.get(), lastRestartReason, lastRestartTime, errorCount.get())); } } - } - synchronized DateTime shutdownTime() { - return shutdownTime; + return Optional.empty(); } - synchronized void printSummaryHtml(PrintWriter writer, long nowMs) { - writer.format( - ", current stream is %dms old, last send %dms, last response %dms", + synchronized Snapshot getSummaryMetrics() { + long nowMs = Instant.now().getMillis(); + return Snapshot.create( debugDuration(nowMs, startTimeMs), debugDuration(nowMs, lastSendTimeMs), - debugDuration(nowMs, lastResponseTimeMs)); + debugDuration(nowMs, lastResponseTimeMs), + getRestartMetrics(), + sleepUntil - nowMs(), + shutdownTime); + } + + @AutoValue + abstract static class Snapshot { + private static Snapshot create( + long streamAge, + long timeSinceLastSend, + long timeSinceLastResponse, + Optional restartMetrics, + long sleepLeft, + @Nullable DateTime shutdownTime) { + return new AutoValue_StreamDebugMetrics_Snapshot( + streamAge, + timeSinceLastSend, + timeSinceLastResponse, + restartMetrics, + sleepLeft, + Optional.ofNullable(shutdownTime)); + } + + abstract long streamAge(); + + abstract long timeSinceLastSend(); + + abstract long timeSinceLastResponse(); + + abstract Optional restartMetrics(); + + abstract long sleepLeft(); + + abstract Optional shutdownTime(); + } + + @AutoValue + abstract static class RestartMetrics { + private static RestartMetrics create( + int restartCount, String restartReason, DateTime lastRestartTime, int errorCount) { + return new AutoValue_StreamDebugMetrics_RestartMetrics( + restartCount, restartReason, lastRestartTime, errorCount); + } + + abstract int restartCount(); + + abstract String lastRestartReason(); + + abstract DateTime lastRestartTime(); + + abstract int errorCount(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java index 538da9607f8b..189a244c822e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java @@ -18,11 +18,14 @@ package org.apache.beam.runners.dataflow.worker.windmill.client; import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import java.util.ArrayList; +import java.util.List; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,18 +33,21 @@ @RunWith(JUnit4.class) public class ResettableStreamObserverTest { - private final StreamObserver delegate = - spy( - new StreamObserver() { - @Override - public void onNext(Integer integer) {} + private final StreamObserver delegate = newDelegate(); - @Override - public void onError(Throwable throwable) {} + private static StreamObserver newDelegate() { + return spy( + new StreamObserver() { + @Override + public void onNext(Integer integer) {} - @Override - public void onCompleted() {} - }); + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }); + } @Test public void testPoison_beforeDelegateSet() { @@ -66,14 +72,14 @@ public void testReset_afterPoisonedThrows() { } @Test - public void onNext_afterPoisonedThrows() { + public void testOnNext_afterPoisonedThrows() { ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); observer.poison(); assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); } @Test - public void onError_afterPoisonedThrows() { + public void testOnError_afterPoisonedThrows() { ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); observer.poison(); assertThrows( @@ -82,9 +88,31 @@ public void onError_afterPoisonedThrows() { } @Test - public void onCompleted_afterPoisonedThrows() { + public void testOnCompleted_afterPoisonedThrows() { ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); observer.poison(); assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); } + + @Test + public void testReset_usesNewDelegate() { + List> delegates = new ArrayList<>(); + ResettableStreamObserver observer = + new ResettableStreamObserver<>( + () -> { + StreamObserver delegate = newDelegate(); + delegates.add(delegate); + return delegate; + }); + observer.reset(); + observer.onNext(1); + observer.reset(); + observer.onNext(2); + + StreamObserver firstObserver = delegates.get(0); + StreamObserver secondObserver = delegates.get(1); + + verify(firstObserver).onNext(eq(1)); + verify(secondObserver).onNext(eq(2)); + } } From 3fb76ac2830080c7382f0f20ae235295fa5bb011 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Thu, 31 Oct 2024 17:55:12 -0700 Subject: [PATCH 17/23] address PR comments --- .../client/AbstractWindmillStream.java | 47 ++++--- .../client/ResettableStreamObserver.java | 9 +- .../windmill/client/StreamDebugMetrics.java | 64 ++++++---- .../client/grpc/GrpcCommitWorkStream.java | 96 +++++++------- .../client/grpc/GrpcDirectGetWorkStream.java | 2 +- .../client/grpc/GrpcGetDataStream.java | 54 ++++---- .../client/StreamDebugMetricsTest.java | 117 ++++++++++++++++++ .../client/grpc/GrpcCommitWorkStreamTest.java | 2 +- .../grpc/GrpcDirectGetWorkStreamTest.java | 2 +- 9 files changed, 262 insertions(+), 131 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index b9648fe4ab47..9bc7c4e7d97a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; + import java.io.IOException; import java.io.PrintWriter; import java.util.Set; @@ -68,13 +70,14 @@ public abstract class AbstractWindmillStream implements Win // Indicates that the logical stream has been half-closed and is waiting for clean server // shutdown. private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); + private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; protected final Sleeper sleeper; /** * Used to guard {@link #start()} and {@link #shutdown()} behavior. * - * @implNote Do not hold when performing IO. If also locking on {@code this} in the same context, - * should acquire shutdownLock first to prevent deadlocks. + * @implNote Do NOT hold when performing IO. If also locking on {@code this} in the same context, + * should acquire shutdownLock after {@code this} to prevent deadlocks. */ protected final Object shutdownLock = new Object(); @@ -94,11 +97,13 @@ public abstract class AbstractWindmillStream implements Win * #halfClose()}. Separate from {@link #clientClosed} as this is specific to the requestObserver * and is initially false on retry. */ - @GuardedBy("this") - private boolean streamClosed; + private volatile boolean streamClosed; + + @GuardedBy("shutdownLock") + private boolean isShutdown; - private volatile boolean isShutdown; - private volatile boolean started; + @GuardedBy("shutdownLock") + private boolean started; protected AbstractWindmillStream( Logger logger, @@ -132,7 +137,7 @@ protected AbstractWindmillStream( new AbstractWindmillStream.ResponseObserver())); this.sleeper = Sleeper.DEFAULT; this.logger = logger; - this.debugMetrics = new StreamDebugMetrics(); + this.debugMetrics = StreamDebugMetrics.create(); } private static String createThreadName(String streamType, String backendWorkerToken) { @@ -158,14 +163,16 @@ private static String createThreadName(String streamType, String backendWorkerTo protected abstract void startThrottleTimer(); /** Reflects that {@link #shutdown()} was explicitly called. */ - protected boolean isShutdown() { - return isShutdown; + protected boolean hasReceivedShutdownSignal() { + synchronized (shutdownLock) { + return isShutdown; + } } /** Send a request to the server. */ protected final void send(RequestT request) { synchronized (this) { - if (isShutdown) { + if (hasReceivedShutdownSignal()) { return; } @@ -175,10 +182,11 @@ protected final void send(RequestT request) { } try { + verify(!Thread.holdsLock(shutdownLock), "shutdownLock should not be held during send."); debugMetrics.recordSend(); requestObserver.onNext(request); } catch (StreamObserverCancelledException e) { - if (isShutdown) { + if (hasReceivedShutdownSignal()) { logger.debug("Stream was shutdown during send.", e); return; } @@ -210,7 +218,7 @@ private void startStream() { while (true) { try { synchronized (this) { - if (isShutdown) { + if (hasReceivedShutdownSignal()) { break; } debugMetrics.recordStart(); @@ -260,7 +268,7 @@ protected final void executeSafely(Runnable runnable) { } public final void maybeSendHealthCheck(Instant lastSendThreshold) { - if (!clientClosed && debugMetrics.lastSendTimeMs() < lastSendThreshold.getMillis()) { + if (!clientClosed && debugMetrics.getLastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); } catch (RuntimeException e) { @@ -276,7 +284,6 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { * information. Blocking sends are made beneath this stream object's lock which could block * status page rendering. */ - @SuppressWarnings("GuardedBy") public final void appendSummaryHtml(PrintWriter writer) { appendSpecificHtml(writer); StreamDebugMetrics.Snapshot summaryMetrics = debugMetrics.getSummaryMetrics(); @@ -306,7 +313,7 @@ public final void appendSummaryHtml(PrintWriter writer) { summaryMetrics.timeSinceLastSend(), summaryMetrics.timeSinceLastResponse(), streamClosed, - isShutdown, + hasReceivedShutdownSignal(), summaryMetrics.shutdownTime().orElse(null)); } @@ -331,7 +338,7 @@ public final boolean awaitTermination(int time, TimeUnit unit) throws Interrupte @Override public final Instant startTime() { - return new Instant(debugMetrics.startTimeMs()); + return new Instant(debugMetrics.getStartTimeMs()); } @Override @@ -429,8 +436,10 @@ private void recordStreamStatus(Status status) { currentErrorCount, t, status, - nowMillis - debugMetrics.startTimeMs(), - debugMetrics.responseDebugString(nowMillis)); + nowMillis - debugMetrics.getStartTimeMs(), + debugMetrics + .responseDebugString(nowMillis) + .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); } // If the stream was stopped due to a resource exhausted error then we are throttled. @@ -442,7 +451,7 @@ private void recordStreamStatus(Status status) { /** Returns true if the stream was torn down and should not be restarted internally. */ private synchronized boolean maybeTeardownStream() { - if (isShutdown || (clientClosed && !hasPendingRequests())) { + if (hasReceivedShutdownSignal() || (clientClosed && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java index cbe8d5122ee7..e5b99f2b6ae6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java @@ -40,10 +40,6 @@ final class ResettableStreamObserver implements StreamObserver { @GuardedBy("this") private @Nullable StreamObserver delegateStreamObserver; - /** - * Indicates that the request observer should no longer be used. Attempts to perform operations on - * the request observer will throw an {@link WindmillStreamShutdownException}. - */ @GuardedBy("this") private boolean isPoisoned; @@ -63,6 +59,7 @@ private synchronized StreamObserver delegate() { "requestObserver cannot be null. Missing a call to startStream() to initialize."); } + /** Creates a new delegate to use for future {@link StreamObserver} methods. */ synchronized void reset() { if (isPoisoned) { throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); @@ -71,6 +68,10 @@ synchronized void reset() { delegateStreamObserver = streamObserverFactory.get(); } + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ synchronized void poison() { if (!isPoisoned) { isPoisoned = true; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java index 48aaf9c5bf4b..d813b6d7c9c2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -19,18 +19,24 @@ import com.google.auto.value.AutoValue; import java.util.Optional; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.joda.time.DateTime; import org.joda.time.Instant; /** Records stream metrics for debugging. */ @ThreadSafe final class StreamDebugMetrics { - private final AtomicInteger restartCount = new AtomicInteger(); - private final AtomicInteger errorCount = new AtomicInteger(); + private final Supplier clock; + + @GuardedBy("this") + private int errorCount = 0; + + @GuardedBy("this") + private int restartCount = 0; @GuardedBy("this") private long sleepUntil = 0; @@ -53,12 +59,25 @@ final class StreamDebugMetrics { @GuardedBy("this") private DateTime shutdownTime = null; + private StreamDebugMetrics(Supplier clock) { + this.clock = clock; + } + + static StreamDebugMetrics create() { + return new StreamDebugMetrics(Instant::now); + } + + @VisibleForTesting + static StreamDebugMetrics forTesting(Supplier fakeClock) { + return new StreamDebugMetrics(fakeClock); + } + private static long debugDuration(long nowMs, long startMs) { return startMs <= 0 ? -1 : Math.max(0, nowMs - startMs); } - private static long nowMs() { - return Instant.now().getMillis(); + private long nowMs() { + return clock.get().getMillis(); } synchronized void recordSend() { @@ -76,14 +95,14 @@ synchronized void recordResponse() { synchronized void recordRestartReason(String error) { lastRestartReason = error; - lastRestartTime = DateTime.now(); + lastRestartTime = clock.get().toDateTime(); } - synchronized long startTimeMs() { + synchronized long getStartTimeMs() { return startTimeMs; } - synchronized long lastSendTimeMs() { + synchronized long getLastSendTimeMs() { return lastSendTimeMs; } @@ -91,38 +110,35 @@ synchronized void recordSleep(long sleepMs) { sleepUntil = nowMs() + sleepMs; } - int incrementAndGetRestarts() { - return restartCount.incrementAndGet(); + synchronized int incrementAndGetRestarts() { + return restartCount++; } - int incrementAndGetErrors() { - return errorCount.incrementAndGet(); + synchronized int incrementAndGetErrors() { + return errorCount++; } synchronized void recordShutdown() { - shutdownTime = DateTime.now(); + shutdownTime = clock.get().toDateTime(); } - synchronized String responseDebugString(long nowMillis) { + synchronized Optional responseDebugString(long nowMillis) { return lastResponseTimeMs == 0 - ? "never received response" - : "received response " + (nowMillis - lastResponseTimeMs) + "ms ago"; + ? Optional.empty() + : Optional.of("received response " + (nowMillis - lastResponseTimeMs) + "ms ago"); } - private Optional getRestartMetrics() { - if (restartCount.get() > 0) { - synchronized (this) { - return Optional.of( - RestartMetrics.create( - restartCount.get(), lastRestartReason, lastRestartTime, errorCount.get())); - } + private synchronized Optional getRestartMetrics() { + if (restartCount > 0) { + return Optional.of( + RestartMetrics.create(restartCount, lastRestartReason, lastRestartTime, errorCount)); } return Optional.empty(); } synchronized Snapshot getSummaryMetrics() { - long nowMs = Instant.now().getMillis(); + long nowMs = clock.get().getMillis(); return Snapshot.create( debugDuration(nowMs, startTimeMs), debugDuration(nowMs, lastSendTimeMs), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index be834bf03bbd..8711bf5850d0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -170,7 +170,7 @@ protected void onResponse(StreamingCommitResponse response) { CommitStatus commitStatus = i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; if (pendingRequest == null) { - if (!isShutdown()) { + if (!hasReceivedShutdownSignal()) { // Skip responses when the stream is shutdown since they are now invalid. LOG.error("Got unknown commit request ID: {}", requestId); } @@ -225,11 +225,6 @@ private void flushInternal(Map requests) { } private void issueSingleRequest(long id, PendingRequest pendingRequest) { - if (!prepareForSend(id, pendingRequest)) { - pendingRequest.abort(); - return; - } - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() @@ -238,19 +233,20 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { .setShardingKey(pendingRequest.shardingKey()) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); - try { - send(chunk); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + synchronized (this) { + try { + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + send(chunk); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } } } private void issueBatchedRequest(Map requests) { - if (!prepareForSend(requests)) { - requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); - return; - } - StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); String lastComputation = null; for (Map.Entry entry : requests.entrySet()) { @@ -266,28 +262,33 @@ private void issueBatchedRequest(Map requests) { .setSerializedWorkItemCommit(request.serializedCommit()); } StreamingCommitWorkRequest request = requestBuilder.build(); - try { - send(request); - } catch (IllegalStateException e) { - // Stream was broken, request will be retried when stream is reopened. + synchronized (this) { + if (!prepareForSend(requests)) { + requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); + return; + } + try { + send(request); + } catch (IllegalStateException e) { + // Stream was broken, request will be retried when stream is reopened. + } } } private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { - if (!prepareForSend(id, pendingRequest)) { - pendingRequest.abort(); - return; - } - checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; + } + for (int i = 0; i < serializedCommit.size(); i += AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE) { int end = i + AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE; ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); - StreamingCommitRequestChunk.Builder chunkBuilder = StreamingCommitRequestChunk.newBuilder() .setRequestId(id) @@ -298,7 +299,6 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { if (remaining > 0) { chunkBuilder.setRemainingBytesForWorkItem(remaining); } - StreamingCommitWorkRequest requestChunk = StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); try { @@ -312,28 +312,24 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { } /** Returns true if prepare for send succeeded. */ - private boolean prepareForSend(long id, PendingRequest request) { + private synchronized boolean prepareForSend(long id, PendingRequest request) { synchronized (shutdownLock) { - synchronized (this) { - if (!isShutdown()) { - pending.put(id, request); - return true; - } - return false; + if (!hasReceivedShutdownSignal()) { + pending.put(id, request); + return true; } + return false; } } /** Returns true if prepare for send succeeded. */ - private boolean prepareForSend(Map requests) { + private synchronized boolean prepareForSend(Map requests) { synchronized (shutdownLock) { - synchronized (this) { - if (!isShutdown()) { - pending.putAll(requests); - return true; - } - return false; + if (!hasReceivedShutdownSignal()) { + pending.putAll(requests); + return true; } + return false; } } @@ -418,12 +414,8 @@ private Batcher() { @Override public boolean commitWorkItem( String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { - if (isShutdown()) { - onDone.accept(CommitStatus.ABORTED); - return false; - } - - if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length()) + || hasReceivedShutdownSignal()) { return false; } @@ -436,7 +428,7 @@ public boolean commitWorkItem( @Override public void flush() { try { - if (!isShutdown()) { + if (!hasReceivedShutdownSignal()) { flushInternal(queue); } else { queue.forEach((ignored, request) -> request.abort()); @@ -448,13 +440,9 @@ public void flush() { } void add(long id, PendingRequest request) { - if (isShutdown()) { - request.abort(); - } else { - Preconditions.checkState(canAccept(request.getBytes())); - queuedBytes += request.getBytes(); - queue.put(id, request); - } + Preconditions.checkState(canAccept(request.getBytes())); + queuedBytes += request.getBytes(); + queue.put(id, request); } private boolean canAccept(long requestBytes) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index be9d6c6d06d6..ec26dfacc255 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -198,7 +198,7 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { @Override protected synchronized void onNewStream() { workItemAssemblers.clear(); - if (!isShutdown()) { + if (!hasReceivedShutdownSignal()) { budgetTracker.reset(); GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); StreamingGetWorkRequest request = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 6a809712bd9f..afc40ebabc17 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -156,12 +156,12 @@ private static WindmillStreamShutdownException shutdownException(QueuedRequest r @Override protected synchronized void onNewStream() { - if (isShutdown()) { + if (hasReceivedShutdownSignal()) { return; } send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed && !isShutdown()) { + if (clientClosed && !hasReceivedShutdownSignal()) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. @@ -218,7 +218,7 @@ public GlobalData requestGlobalData(GlobalDataRequest request) { @Override public void refreshActiveWork(Map> heartbeats) { - if (isShutdown()) { + if (hasReceivedShutdownSignal()) { throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); } @@ -334,7 +334,7 @@ public void appendSpecificHtml(PrintWriter writer) { } private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { - while (!isShutdown()) { + while (!hasReceivedShutdownSignal()) { request.resetResponseStream(); try { queueRequestAndWait(request); @@ -360,7 +360,7 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + streamDebugMetrics.recordSleep(sleepMs); + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertEquals(sleepMs, metricsSnapshot.sleepLeft()); + } + + @Test + public void testSummaryMetrics_withRestarts() { + String restartReason = "something bad happened"; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.incrementAndGetErrors(); + streamDebugMetrics.incrementAndGetRestarts(); + streamDebugMetrics.recordRestartReason(restartReason); + + StreamDebugMetrics.Snapshot metricsSnapshot = streamDebugMetrics.getSummaryMetrics(); + assertTrue(metricsSnapshot.restartMetrics().isPresent()); + StreamDebugMetrics.RestartMetrics restartMetrics = metricsSnapshot.restartMetrics().get(); + assertThat(restartMetrics.lastRestartReason()).isEqualTo(restartReason); + assertThat(restartMetrics.restartCount()).isEqualTo(1); + assertThat(restartMetrics.errorCount()).isEqualTo(1); + assertThat(restartMetrics.lastRestartTime()).isLessThan(DateTime.now()); + assertThat(restartMetrics.lastRestartTime().toInstant()).isGreaterThan(Instant.EPOCH); + } + + @Test + public void testResponseDebugString_neverReceivedResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + assertFalse(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testResponseDebugString_withResponse() { + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(Instant::now); + streamDebugMetrics.recordResponse(); + assertTrue(streamDebugMetrics.responseDebugString(Instant.now().getMillis()).isPresent()); + } + + @Test + public void testGetStartTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getStartTimeMs()); + streamDebugMetrics.recordStart(); + assertThat(streamDebugMetrics.getStartTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } + + @Test + public void testGetLastSendTime() { + Instant aLongTimeAgo = Instant.parse("1998-09-04T00:00:00Z"); + Supplier fakeClock = () -> aLongTimeAgo; + StreamDebugMetrics streamDebugMetrics = StreamDebugMetrics.forTesting(fakeClock); + assertEquals(0, streamDebugMetrics.getLastSendTimeMs()); + streamDebugMetrics.recordSend(); + assertThat(streamDebugMetrics.getLastSendTimeMs()).isEqualTo(aLongTimeAgo.getMillis()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 3baa31585a09..df6946ea8763 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -165,7 +165,7 @@ public void testCommitWorkItem_afterShutdownFalse() { Set commitStatuses = new HashSet<>(); assertFalse( batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add)); - assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED); + assertThat(commitStatuses).isEmpty(); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java index 8a37958700c9..6584ed1c5ae6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStreamTest.java @@ -117,8 +117,8 @@ public void setUp() throws IOException { @After public void cleanUp() { - inProcessChannel.shutdownNow(); checkNotNull(stream).shutdown(); + inProcessChannel.shutdownNow(); } private GrpcDirectGetWorkStream createGetWorkStream( From 3fafefa560e146363a97ef6353eb0613786e74e1 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Tue, 5 Nov 2024 15:53:27 -0800 Subject: [PATCH 18/23] address PR comments --- .../client/AbstractWindmillStream.java | 98 ++++-------- .../client/ResettableStreamObserver.java | 99 ------------ .../ResettableThrowingStreamObserver.java | 148 ++++++++++++++++++ .../client/StreamClosedException.java | 25 +++ .../windmill/client/WindmillStream.java | 9 +- .../WindmillStreamShutdownException.java | 2 +- .../client/grpc/GrpcCommitWorkStream.java | 47 +++--- .../client/grpc/GrpcDirectGetWorkStream.java | 39 ++--- .../client/grpc/GrpcGetDataStream.java | 77 +++++---- .../grpc/GrpcGetDataStreamRequests.java | 3 +- .../client/grpc/GrpcGetWorkStream.java | 9 +- .../grpc/GrpcGetWorkerMetadataStream.java | 6 +- .../grpc/observers/DirectStreamObserver.java | 20 ++- .../grpc/observers/StreamObserverFactory.java | 4 +- .../observers/TerminatingStreamObserver.java | 28 ++++ .../windmill/work/refresh/Heartbeats.java | 4 +- ...ResettableThrowingStreamObserverTest.java} | 44 ++++-- .../client/grpc/GrpcGetDataStreamTest.java | 26 +-- .../grpc/GrpcGetWorkerMetadataStreamTest.java | 4 +- .../client/grpc/GrpcWindmillServerTest.java | 21 ++- 20 files changed, 421 insertions(+), 292 deletions(-) delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java rename runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/{ResettableStreamObserverTest.java => ResettableThrowingStreamObserverTest.java} (63%) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 9bc7c4e7d97a..c791a36b6ec8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; - import java.io.IOException; import java.io.PrintWriter; import java.util.Set; @@ -73,14 +71,6 @@ public abstract class AbstractWindmillStream implements Win private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; protected final Sleeper sleeper; - /** - * Used to guard {@link #start()} and {@link #shutdown()} behavior. - * - * @implNote Do NOT hold when performing IO. If also locking on {@code this} in the same context, - * should acquire shutdownLock after {@code this} to prevent deadlocks. - */ - protected final Object shutdownLock = new Object(); - private final Logger logger; private final ExecutorService executor; private final BackOff backoff; @@ -88,21 +78,14 @@ public abstract class AbstractWindmillStream implements Win private final Set> streamRegistry; private final int logEveryNStreamFailures; private final String backendWorkerToken; - private final ResettableStreamObserver requestObserver; + private final ResettableThrowingStreamObserver requestObserver; private final StreamDebugMetrics debugMetrics; protected volatile boolean clientClosed; - /** - * Indicates if the current {@link ResettableStreamObserver} was closed by calling {@link - * #halfClose()}. Separate from {@link #clientClosed} as this is specific to the requestObserver - * and is initially false on retry. - */ - private volatile boolean streamClosed; - - @GuardedBy("shutdownLock") + @GuardedBy("this") private boolean isShutdown; - @GuardedBy("shutdownLock") + @GuardedBy("this") private boolean started; protected AbstractWindmillStream( @@ -127,16 +110,16 @@ protected AbstractWindmillStream( this.clientClosed = false; this.isShutdown = false; this.started = false; - this.streamClosed = false; this.finishLatch = new CountDownLatch(1); + this.logger = logger; this.requestObserver = - new ResettableStreamObserver<>( + new ResettableThrowingStreamObserver<>( () -> streamObserverFactory.from( clientFactory, - new AbstractWindmillStream.ResponseObserver())); + new AbstractWindmillStream.ResponseObserver()), + logger); this.sleeper = Sleeper.DEFAULT; - this.logger = logger; this.debugMetrics = StreamDebugMetrics.create(); } @@ -150,7 +133,8 @@ private static String createThreadName(String streamType, String backendWorkerTo protected abstract void onResponse(ResponseT response); /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream(); + protected abstract void onNewStream() + throws StreamClosedException, WindmillStreamShutdownException; /** Returns whether there are any pending requests that should be retried on a stream break. */ protected abstract boolean hasPendingRequests(); @@ -163,43 +147,21 @@ private static String createThreadName(String streamType, String backendWorkerTo protected abstract void startThrottleTimer(); /** Reflects that {@link #shutdown()} was explicitly called. */ - protected boolean hasReceivedShutdownSignal() { - synchronized (shutdownLock) { - return isShutdown; - } + protected synchronized boolean hasReceivedShutdownSignal() { + return isShutdown; } /** Send a request to the server. */ - protected final void send(RequestT request) { - synchronized (this) { - if (hasReceivedShutdownSignal()) { - return; - } - - if (streamClosed) { - // TODO(m-trieu): throw a more specific exception here (i.e StreamClosedException) - throw new IllegalStateException("Send called on a client closed stream."); - } - - try { - verify(!Thread.holdsLock(shutdownLock), "shutdownLock should not be held during send."); - debugMetrics.recordSend(); - requestObserver.onNext(request); - } catch (StreamObserverCancelledException e) { - if (hasReceivedShutdownSignal()) { - logger.debug("Stream was shutdown during send.", e); - return; - } - - requestObserver.onError(e); - } - } + protected final synchronized void send(RequestT request) + throws StreamClosedException, WindmillStreamShutdownException { + debugMetrics.recordSend(); + requestObserver.onNext(request); } @Override public final void start() { boolean shouldStartStream = false; - synchronized (shutdownLock) { + synchronized (this) { if (!isShutdown && !started) { started = true; shouldStartStream = true; @@ -218,11 +180,7 @@ private void startStream() { while (true) { try { synchronized (this) { - if (hasReceivedShutdownSignal()) { - break; - } debugMetrics.recordStart(); - streamClosed = false; requestObserver.reset(); onNewStream(); if (clientClosed) { @@ -231,7 +189,7 @@ private void startStream() { return; } } catch (WindmillStreamShutdownException e) { - logger.debug("Stream was shutdown waiting to start.", e); + logger.debug("Stream was shutdown while creating new stream.", e); } catch (Exception e) { logger.error("Failed to create new stream, retrying: ", e); try { @@ -271,13 +229,14 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { if (!clientClosed && debugMetrics.getLastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); - } catch (RuntimeException e) { + } catch (Exception e) { logger.debug("Received exception sending health check.", e); } } } - protected abstract void sendHealthCheck(); + protected abstract void sendHealthCheck() + throws WindmillStreamShutdownException, StreamClosedException; /** * @implNote Care is taken that synchronization on this is unnecessary for all status page @@ -312,7 +271,7 @@ public final void appendSummaryHtml(PrintWriter writer) { summaryMetrics.streamAge(), summaryMetrics.timeSinceLastSend(), summaryMetrics.timeSinceLastResponse(), - streamClosed, + requestObserver.isClosed(), hasReceivedShutdownSignal(), summaryMetrics.shutdownTime().orElse(null)); } @@ -327,8 +286,11 @@ public final void appendSummaryHtml(PrintWriter writer) { public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. clientClosed = true; - requestObserver.onCompleted(); - streamClosed = true; + try { + requestObserver.onCompleted(); + } catch (StreamClosedException | WindmillStreamShutdownException e) { + logger.warn("Stream was previously closed or shutdown."); + } } @Override @@ -348,13 +310,13 @@ public String backendWorkerToken() { @Override public final void shutdown() { - // Don't lock on "this" as isShutdown checks are used in the stream to free blocked - // threads or as exit conditions to loops. - synchronized (shutdownLock) { + // Don't lock on "this" before poisoning the request observer as allow IO to block shutdown. + requestObserver.poison(); + synchronized (this) { if (!isShutdown) { isShutdown = true; debugMetrics.recordShutdown(); - requestObserver.poison(); + shutdownInternal(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java deleted file mode 100644 index e5b99f2b6ae6..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserver.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.client; - -import java.util.function.Supplier; -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy; -import javax.annotation.concurrent.ThreadSafe; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; - -/** - * Request observer that allows resetting its internal delegate using the given {@link - * #streamObserverFactory}. - * - * @implNote {@link StreamObserver}s generated by {@link #streamObserverFactory} are expected to be - * {@link ThreadSafe}. - */ -@ThreadSafe -@Internal -final class ResettableStreamObserver implements StreamObserver { - private final Supplier> streamObserverFactory; - - @GuardedBy("this") - private @Nullable StreamObserver delegateStreamObserver; - - @GuardedBy("this") - private boolean isPoisoned; - - ResettableStreamObserver(Supplier> streamObserverFactory) { - this.streamObserverFactory = streamObserverFactory; - this.delegateStreamObserver = null; - this.isPoisoned = false; - } - - private synchronized StreamObserver delegate() { - if (isPoisoned) { - throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); - } - - return Preconditions.checkNotNull( - delegateStreamObserver, - "requestObserver cannot be null. Missing a call to startStream() to initialize."); - } - - /** Creates a new delegate to use for future {@link StreamObserver} methods. */ - synchronized void reset() { - if (isPoisoned) { - throw new WindmillStreamShutdownException("Explicit call to shutdown stream."); - } - - delegateStreamObserver = streamObserverFactory.get(); - } - - /** - * Indicates that the request observer should no longer be used. Attempts to perform operations on - * the request observer will throw an {@link WindmillStreamShutdownException}. - */ - synchronized void poison() { - if (!isPoisoned) { - isPoisoned = true; - if (delegateStreamObserver != null) { - delegateStreamObserver.onError( - new WindmillStreamShutdownException("Explicit call to shutdown stream.")); - } - } - } - - @Override - public void onNext(T t) { - delegate().onNext(t); - } - - @Override - public void onError(Throwable throwable) { - delegate().onError(throwable); - } - - @Override - public void onCompleted() { - delegate().onCompleted(); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java new file mode 100644 index 000000000000..d0f1236f36e1 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import java.util.function.Supplier; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.slf4j.Logger; + +/** + * Request observer that allows resetting its internal delegate using the given {@link + * #streamObserverFactory}. + * + * @implNote {@link StreamObserver}s generated by {@link #streamObserverFactory} are expected to be + * {@link ThreadSafe}. Has same methods declared in {@link StreamObserver}, but they throw + * {@link StreamClosedException} and {@link WindmillStreamShutdownException}, which much be + * handled by callers. + */ +@ThreadSafe +@Internal +final class ResettableThrowingStreamObserver { + private final Supplier> streamObserverFactory; + private final Logger logger; + + @GuardedBy("this") + private @Nullable TerminatingStreamObserver delegateStreamObserver; + + @GuardedBy("this") + private boolean isPoisoned = false; + + /** + * Indicates that the current delegate is closed via {@link #poison() or {@link #onCompleted()}}. + * If not poisoned, a call to {@link #reset()} is required to perform future operations on the + * StreamObserver. + */ + @GuardedBy("this") + private boolean isCurrentStreamClosed = false; + + ResettableThrowingStreamObserver( + Supplier> streamObserverFactory, Logger logger) { + this.streamObserverFactory = streamObserverFactory; + this.logger = logger; + this.delegateStreamObserver = null; + } + + private synchronized StreamObserver delegate() + throws WindmillStreamShutdownException, StreamClosedException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + if (isCurrentStreamClosed) { + throw new StreamClosedException( + "Current stream is closed, requires reset for future stream operations."); + } + + return Preconditions.checkNotNull( + delegateStreamObserver, + "requestObserver cannot be null. Missing a call to startStream() to initialize."); + } + + /** Creates a new delegate to use for future {@link StreamObserver} methods. */ + synchronized void reset() throws WindmillStreamShutdownException { + if (isPoisoned) { + throw new WindmillStreamShutdownException("Stream is already shutdown."); + } + + delegateStreamObserver = streamObserverFactory.get(); + isCurrentStreamClosed = false; + } + + /** + * Indicates that the request observer should no longer be used. Attempts to perform operations on + * the request observer will throw an {@link WindmillStreamShutdownException}. + */ + synchronized void poison() { + if (!isPoisoned) { + isPoisoned = true; + if (delegateStreamObserver != null) { + delegateStreamObserver.terminate( + new WindmillStreamShutdownException("Explicit call to shutdown stream.")); + delegateStreamObserver = null; + isCurrentStreamClosed = true; + } + } + } + + public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException { + // Make sure onNext and onError below to be called on the same StreamObserver instance. + StreamObserver delegate = delegate(); + try { + // Do NOT lock while sending message over the stream as this will block other StreamObserver + // operations. + delegate.onNext(t); + } catch (StreamObserverCancelledException e) { + synchronized (this) { + if (isPoisoned) { + logger.debug("Stream was shutdown during send.", e); + return; + } + } + + try { + delegate.onError(e); + } catch (RuntimeException ignored) { + // If the delegate above was already terminated via onError or onComplete from another + // thread. + logger.warn("StreamObserver was previously cancelled.", e); + } + } + } + + public void onError(Throwable throwable) + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onError(throwable); + } + + public synchronized void onCompleted() + throws StreamClosedException, WindmillStreamShutdownException { + delegate().onCompleted(); + isCurrentStreamClosed = true; + } + + synchronized boolean isClosed() { + return isCurrentStreamClosed; + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java new file mode 100644 index 000000000000..a5f7ef0b8312 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +/** Indicates that the current stream was closed. */ +public final class StreamClosedException extends Exception { + StreamClosedException(String s) { + super(s); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index 361531ce4f2d..51bc03e8e0e7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -75,13 +75,16 @@ default void setBudget(long newItems, long newBytes) { interface GetDataStream extends WindmillStream { /** Issues a keyed GetData fetch, blocking until the result is ready. */ Windmill.KeyedGetDataResponse requestKeyedData( - String computation, Windmill.KeyedGetDataRequest request); + String computation, Windmill.KeyedGetDataRequest request) + throws WindmillStreamShutdownException; /** Issues a global GetData fetch, blocking until the result is ready. */ - Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request); + Windmill.GlobalData requestGlobalData(Windmill.GlobalDataRequest request) + throws WindmillStreamShutdownException; /** Tells windmill processing is ongoing for the given keys. */ - void refreshActiveWork(Map> heartbeats); + void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException; void onHeartbeatResponse(List responses); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 5f4387d6111f..5d4e5de5eda3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -18,7 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client; /** Thrown when operations are requested on a {@link WindmillStream} has been shutdown/closed. */ -public final class WindmillStreamShutdownException extends RuntimeException { +public final class WindmillStreamShutdownException extends Exception { public WindmillStreamShutdownException(String message) { super(message); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 8711bf5850d0..40d4cf0750de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -38,7 +38,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; @@ -121,7 +123,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() + throws StreamClosedException, WindmillStreamShutdownException { send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); try (Batcher resendBatcher = new Batcher()) { for (Map.Entry entry : pending.entrySet()) { @@ -148,7 +151,7 @@ protected boolean hasPendingRequests() { } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { if (hasPendingRequests()) { StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); @@ -206,7 +209,8 @@ protected void startThrottleTimer() { commitWorkThrottleTimer.start(); } - private void flushInternal(Map requests) { + private void flushInternal(Map requests) + throws WindmillStreamShutdownException { if (requests.isEmpty()) { return; } @@ -224,7 +228,8 @@ private void flushInternal(Map requests) { } } - private void issueSingleRequest(long id, PendingRequest pendingRequest) { + private void issueSingleRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); requestBuilder .addCommitChunkBuilder() @@ -240,13 +245,14 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) { return; } send(chunk); - } catch (IllegalStateException e) { + } catch (StreamClosedException e) { // Stream was broken, request will be retried when stream is reopened. } } } - private void issueBatchedRequest(Map requests) { + private void issueBatchedRequest(Map requests) + throws WindmillStreamShutdownException { StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); String lastComputation = null; for (Map.Entry entry : requests.entrySet()) { @@ -269,13 +275,14 @@ private void issueBatchedRequest(Map requests) { } try { send(request); - } catch (IllegalStateException e) { + } catch (StreamClosedException e) { // Stream was broken, request will be retried when stream is reopened. } } } - private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { + private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) + throws WindmillStreamShutdownException { checkNotNull(pendingRequest.computationId(), "Cannot commit WorkItem w/o a computationId."); ByteString serializedCommit = pendingRequest.serializedCommit(); synchronized (this) { @@ -303,7 +310,7 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); try { send(requestChunk); - } catch (IllegalStateException e) { + } catch (StreamClosedException e) { // Stream was broken, request will be retried when stream is reopened. break; } @@ -313,24 +320,20 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) { /** Returns true if prepare for send succeeded. */ private synchronized boolean prepareForSend(long id, PendingRequest request) { - synchronized (shutdownLock) { - if (!hasReceivedShutdownSignal()) { - pending.put(id, request); - return true; - } - return false; + if (!hasReceivedShutdownSignal()) { + pending.put(id, request); + return true; } + return false; } /** Returns true if prepare for send succeeded. */ private synchronized boolean prepareForSend(Map requests) { - synchronized (shutdownLock) { - if (!hasReceivedShutdownSignal()) { - pending.putAll(requests); - return true; - } - return false; + if (!hasReceivedShutdownSignal()) { + pending.putAll(requests); + return true; } + return false; } @AutoValue @@ -433,6 +436,8 @@ public void flush() { } else { queue.forEach((ignored, request) -> request.abort()); } + } catch (WindmillStreamShutdownException e) { + queue.forEach((ignored, request) -> request.abort()); } finally { queuedBytes = 0; queue.clear(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index ec26dfacc255..750d61ac2f29 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -33,7 +33,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; @@ -188,7 +190,7 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { budgetTracker.recordBudgetRequested(extension); try { send(request); - } catch (IllegalStateException e) { + } catch (StreamClosedException | WindmillStreamShutdownException e) { // Stream was closed. } }); @@ -196,24 +198,23 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() + throws WindmillStreamShutdownException, StreamClosedException { workItemAssemblers.clear(); - if (!hasReceivedShutdownSignal()) { - budgetTracker.reset(); - GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); - StreamingGetWorkRequest request = - StreamingGetWorkRequest.newBuilder() - .setRequest( - requestHeader - .toBuilder() - .setMaxItems(initialGetWorkBudget.items()) - .setMaxBytes(initialGetWorkBudget.bytes()) - .build()) - .build(); - lastRequest.set(request); - budgetTracker.recordBudgetRequested(initialGetWorkBudget); - send(request); - } + budgetTracker.reset(); + GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); + StreamingGetWorkRequest request = + StreamingGetWorkRequest.newBuilder() + .setRequest( + requestHeader + .toBuilder() + .setMaxItems(initialGetWorkBudget.items()) + .setMaxBytes(initialGetWorkBudget.bytes()) + .build()) + .build(); + lastRequest.set(request); + budgetTracker.recordBudgetRequested(initialGetWorkBudget); + send(request); } @Override @@ -231,7 +232,7 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws WindmillStreamShutdownException, StreamClosedException { send(HEALTH_CHECK_REQUEST); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index afc40ebabc17..ee403b100fa5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -48,6 +48,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch; @@ -69,7 +70,7 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@link #shutdownLock} */ + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ private final Deque batches; private final Map pending; @@ -144,22 +145,28 @@ static GrpcGetDataStream create( processHeartbeatResponses); } - private static WindmillStreamShutdownException shutdownException(QueuedBatch batch) { + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedBatch batch) { return new WindmillStreamShutdownException( "Stream was closed when attempting to send " + batch.requestsCount() + " requests."); } - private static WindmillStreamShutdownException shutdownException(QueuedRequest request) { + private static WindmillStreamShutdownException shutdownExceptionFor(QueuedRequest request) { return new WindmillStreamShutdownException( "Cannot send request=[" + request + "] on closed stream."); } - @Override - protected synchronized void onNewStream() { - if (hasReceivedShutdownSignal()) { - return; + private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest) + throws WindmillStreamShutdownException { + try { + send(getDataRequest); + } catch (StreamClosedException e) { + // Stream was closed on send, will be retried on stream restart. } + } + @Override + protected synchronized void onNewStream() + throws StreamClosedException, WindmillStreamShutdownException { send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); if (clientClosed && !hasReceivedShutdownSignal()) { // We rely on close only occurring after all methods on the stream have returned. @@ -205,19 +212,22 @@ private long uniqueId() { } @Override - public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { + public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) + throws WindmillStreamShutdownException { return issueRequest( QueuedRequest.forComputation(uniqueId(), computation, request), KeyedGetDataResponse::parseFrom); } @Override - public GlobalData requestGlobalData(GlobalDataRequest request) { + public GlobalData requestGlobalData(GlobalDataRequest request) + throws WindmillStreamShutdownException { return issueRequest(QueuedRequest.global(uniqueId(), request), GlobalData::parseFrom); } @Override - public void refreshActiveWork(Map> heartbeats) { + public void refreshActiveWork(Map> heartbeats) + throws WindmillStreamShutdownException { if (hasReceivedShutdownSignal()) { throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); } @@ -232,7 +242,7 @@ public void refreshActiveWork(Map> heartbea if (builderBytes > 0 && (builderBytes + bytes > AbstractWindmillStream.RPC_STREAM_CHUNK_SIZE || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); } @@ -251,7 +261,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } else { // No translation necessary, but we must still respect `RPC_STREAM_CHUNK_SIZE`. @@ -266,7 +276,7 @@ public void refreshActiveWork(Map> heartbea if (computationHeartbeatBuilder.getHeartbeatRequestsCount() > 0) { builder.addComputationHeartbeatRequest(computationHeartbeatBuilder.build()); } - send(builder.build()); + sendIgnoringClosed(builder.build()); builderBytes = 0; builder.clear(); computationHeartbeatBuilder.clear().setComputationId(entry.getKey()); @@ -278,7 +288,7 @@ public void refreshActiveWork(Map> heartbea } if (builderBytes > 0) { - send(builder.build()); + sendIgnoringClosed(builder.build()); } } } @@ -289,7 +299,7 @@ public void onHeartbeatResponse(List resp } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { if (hasPendingRequests()) { send(HEALTH_CHECK_REQUEST); } @@ -333,7 +343,8 @@ public void appendSpecificHtml(PrintWriter writer) { writer.append("]"); } - private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) { + private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) + throws WindmillStreamShutdownException { while (!hasReceivedShutdownSignal()) { request.resetResponseStream(); try { @@ -359,21 +370,23 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn cancelledRequests = createStreamCancelledErrorMessages(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 511b4a6b07bd..75f0d0065153 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -28,7 +28,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GetWorkResponseChunkAssembler.AssembledWorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; @@ -120,14 +122,15 @@ private void sendRequestExtension(long moreItems, long moreBytes) { () -> { try { send(extension); - } catch (IllegalStateException e) { + } catch (StreamClosedException | WindmillStreamShutdownException e) { // Stream was closed. } }); } @Override - protected synchronized void onNewStream() { + protected synchronized void onNewStream() + throws StreamClosedException, WindmillStreamShutdownException { workItemAssemblers.clear(); inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); @@ -151,7 +154,7 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() { + public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { send( StreamingGetWorkRequest.newBuilder() .setRequestExtension( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 7fd1f011a4bb..a076b4f58258 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -28,7 +28,9 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; @@ -132,7 +134,7 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected void onNewStream() { + protected void onNewStream() throws StreamClosedException, WindmillStreamShutdownException { send(workerMetadataRequest); } @@ -150,7 +152,7 @@ protected void startThrottleTimer() { } @Override - protected void sendHealthCheck() { + protected void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { send(HEALTH_CHECK_REQUEST); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index fa9e9e15b440..e072ece0fac2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -37,7 +37,7 @@ * becomes ready. */ @ThreadSafe -public final class DirectStreamObserver implements StreamObserver { +public final class DirectStreamObserver implements TerminatingStreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; @@ -69,7 +69,7 @@ public DirectStreamObserver( } @Override - public void onNext(T value) { + public void onNext(T value) throws StreamObserverCancelledException { int awaitPhase = -1; long totalSecondsWaited = 0; long waitSeconds = 1; @@ -155,8 +155,6 @@ public void onNext(T value) { @Override public void onError(Throwable t) { - // Free the blocked threads in onNext(). - isReadyNotifier.forceTermination(); synchronized (lock) { outboundObserver.onError(t); } @@ -164,13 +162,23 @@ public void onError(Throwable t) { @Override public void onCompleted() { - // Free the blocked threads in onNext(). - isReadyNotifier.forceTermination(); synchronized (lock) { outboundObserver.onCompleted(); } } + @Override + public void terminate(Throwable terminationException) { + // Free the blocked threads in onNext(). + isReadyNotifier.forceTermination(); + try { + onError(terminationException); + } catch (RuntimeException e) { + // If onError or onComplete was previously called, this will throw. + LOG.warn("StreamObserver was already terminated."); + } + } + private String constructStreamCancelledErrorMessage(long totalSecondsWaited) { return deadlineSeconds > 0 ? "Waited " diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java index cb4415bdab18..01e854492bf9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverFactory.java @@ -33,7 +33,7 @@ public static StreamObserverFactory direct( return new Direct(deadlineSeconds, messagesBetweenIsReadyChecks); } - public abstract StreamObserver from( + public abstract TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver responseObserver); @@ -47,7 +47,7 @@ private static class Direct extends StreamObserverFactory { } @Override - public StreamObserver from( + public TerminatingStreamObserver from( Function, StreamObserver> clientFactory, StreamObserver inboundObserver) { AdvancingPhaser phaser = new AdvancingPhaser(1); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java new file mode 100644 index 000000000000..fb2555c8454f --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Internal; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; + +@Internal +public interface TerminatingStreamObserver extends StreamObserver { + + /** Terminates the StreamObserver. */ + void terminate(Throwable terminationException); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java index 071bf7fa3d43..6d768e8a972c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java @@ -21,12 +21,14 @@ import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; /** Heartbeat requests and the work that was used to generate the heartbeat requests. */ +@Internal @AutoValue -abstract class Heartbeats { +public abstract class Heartbeats { static Heartbeats.Builder builder() { return new AutoValue_Heartbeats.Builder(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java similarity index 63% rename from runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java rename to runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java index 189a244c822e..2f1a45abd7a5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java @@ -26,18 +26,21 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.TerminatingStreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; @RunWith(JUnit4.class) -public class ResettableStreamObserverTest { - private final StreamObserver delegate = newDelegate(); +public class ResettableThrowingStreamObserverTest { + private final TerminatingStreamObserver delegate = newDelegate(); - private static StreamObserver newDelegate() { + private static TerminatingStreamObserver newDelegate() { return spy( - new StreamObserver() { + new TerminatingStreamObserver() { @Override public void onNext(Integer integer) {} @@ -46,41 +49,44 @@ public void onError(Throwable throwable) {} @Override public void onCompleted() {} + + @Override + public void terminate(Throwable terminationException) {} }); } @Test public void testPoison_beforeDelegateSet() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.poison(); verifyNoInteractions(delegate); } @Test - public void testPoison_afterDelegateSet() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + public void testPoison_afterDelegateSet() throws WindmillStreamShutdownException { + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.reset(); observer.poison(); - verify(delegate).onError(isA(WindmillStreamShutdownException.class)); + verify(delegate).terminate(isA(WindmillStreamShutdownException.class)); } @Test public void testReset_afterPoisonedThrows() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.poison(); assertThrows(WindmillStreamShutdownException.class, observer::reset); } @Test public void testOnNext_afterPoisonedThrows() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.poison(); assertThrows(WindmillStreamShutdownException.class, () -> observer.onNext(1)); } @Test public void testOnError_afterPoisonedThrows() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.poison(); assertThrows( WindmillStreamShutdownException.class, @@ -89,18 +95,19 @@ public void testOnError_afterPoisonedThrows() { @Test public void testOnCompleted_afterPoisonedThrows() { - ResettableStreamObserver observer = new ResettableStreamObserver<>(() -> delegate); + ResettableThrowingStreamObserver observer = newStreamObserver(() -> delegate); observer.poison(); assertThrows(WindmillStreamShutdownException.class, observer::onCompleted); } @Test - public void testReset_usesNewDelegate() { + public void testReset_usesNewDelegate() + throws WindmillStreamShutdownException, StreamClosedException { List> delegates = new ArrayList<>(); - ResettableStreamObserver observer = - new ResettableStreamObserver<>( + ResettableThrowingStreamObserver observer = + newStreamObserver( () -> { - StreamObserver delegate = newDelegate(); + TerminatingStreamObserver delegate = newDelegate(); delegates.add(delegate); return delegate; }); @@ -115,4 +122,9 @@ public void testReset_usesNewDelegate() { verify(firstObserver).onNext(eq(1)); verify(secondObserver).onNext(eq(2)); } + + private ResettableThrowingStreamObserver newStreamObserver( + Supplier> delegate) { + return new ResettableThrowingStreamObserver<>(delegate, LoggerFactory.getLogger(getClass())); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index f71627160846..0ce455ac1270 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -126,14 +126,19 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow throw new RuntimeException(e); } } - getDataStream.requestKeyedData( - "computationId", - Windmill.KeyedGetDataRequest.newBuilder() - .setKey(ByteString.EMPTY) - .setShardingKey(i) - .setCacheToken(i) - .setWorkToken(i) - .build()); + try { + + getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(i) + .setCacheToken(i) + .setWorkToken(i) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } }) // Run the code above on multiple threads. .map(runnable -> CompletableFuture.runAsync(runnable, getDataStreamSenders)) @@ -154,7 +159,10 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow if (i % 2 == 0) { assertTrue(sendFuture.isCompletedExceptionally()); ExecutionException e = assertThrows(ExecutionException.class, sendFuture::get); - assertThat(e).hasCauseThat().isInstanceOf(WindmillStreamShutdownException.class); + assertThat(e) + .hasCauseThat() + .hasCauseThat() + .isInstanceOf(WindmillStreamShutdownException.class); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 40a63ee90d31..1239c90dc0c6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -37,6 +37,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; +import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Server; @@ -255,7 +257,7 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { } @Test - public void testSendHealthCheck() { + public void testSendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { TestGetWorkMetadataRequestObserver requestObserver = Mockito.spy(new TestGetWorkMetadataRequestObserver()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index c3f38a571b76..23ad89345bca 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -72,6 +72,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.WindmillChannelFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.testing.FakeWindmillStubFactoryFactory; @@ -489,16 +490,24 @@ private void flushResponse() { final String s = i % 5 == 0 ? largeString(i) : "tag"; executor.submit( () -> { - errorCollector.checkThat( - stream.requestKeyedData("computation", makeGetDataRequest(key, s)), - Matchers.equalTo(makeGetDataResponse(s))); + try { + errorCollector.checkThat( + stream.requestKeyedData("computation", makeGetDataRequest(key, s)), + Matchers.equalTo(makeGetDataResponse(s))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); executor.execute( () -> { - errorCollector.checkThat( - stream.requestGlobalData(makeGlobalDataRequest(key)), - Matchers.equalTo(makeGlobalDataResponse(key))); + try { + errorCollector.checkThat( + stream.requestGlobalData(makeGlobalDataRequest(key)), + Matchers.equalTo(makeGlobalDataResponse(key))); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } done.countDown(); }); } From ab34387b03fdf836edb9b52d2ec641c42ef39b3f Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 6 Nov 2024 22:45:10 -0800 Subject: [PATCH 19/23] address PR comments --- .../client/AbstractWindmillStream.java | 66 +++--- .../ResettableThrowingStreamObserver.java | 10 + .../client/StreamClosedException.java | 25 --- .../windmill/client/StreamDebugMetrics.java | 18 +- .../WindmillStreamShutdownException.java | 33 +-- .../client/grpc/GrpcCommitWorkStream.java | 71 +++---- .../client/grpc/GrpcDirectGetWorkStream.java | 19 +- .../client/grpc/GrpcGetDataStream.java | 59 +++--- .../client/grpc/GrpcGetWorkStream.java | 14 +- .../grpc/GrpcGetWorkerMetadataStream.java | 9 +- .../grpc/observers/DirectStreamObserver.java | 17 +- .../StreamObserverCancelledException.java | 8 +- .../client/AbstractWindmillStreamTest.java | 158 ++++++++++++++ .../ResettableThrowingStreamObserverTest.java | 2 +- .../client/grpc/GrpcCommitWorkStreamTest.java | 10 +- .../grpc/GrpcGetWorkerMetadataStreamTest.java | 3 +- .../client/grpc/GrpcWindmillServerTest.java | 3 +- .../observers/DirectStreamObserverTest.java | 198 ++++++++++++++++++ 18 files changed, 540 insertions(+), 183 deletions(-) delete mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java create mode 100644 runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index c791a36b6ec8..c8a00d7caea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.io.IOException; import java.io.PrintWriter; import java.util.Set; @@ -27,6 +28,7 @@ import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.runners.dataflow.worker.windmill.client.ResettableThrowingStreamObserver.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; @@ -48,7 +50,7 @@ * and {@link #onNewStream()} to perform any work that must be done when a new stream is created, * such as sending headers or retrying requests. * - *

{@link #send(RequestT)} and {@link #startStream()} should not be called from {@link + *

{@link #trySend(RequestT)} and {@link #startStream()} should not be called from {@link * #onResponse(ResponseT)}; use {@link #executeSafely(Runnable)} instead. * *

Synchronization on this is used to synchronize the gRpc stream state and internal data @@ -80,10 +82,12 @@ public abstract class AbstractWindmillStream implements Win private final String backendWorkerToken; private final ResettableThrowingStreamObserver requestObserver; private final StreamDebugMetrics debugMetrics; - protected volatile boolean clientClosed; @GuardedBy("this") - private boolean isShutdown; + protected boolean clientClosed; + + @GuardedBy("this") + protected boolean isShutdown; @GuardedBy("this") private boolean started; @@ -133,8 +137,7 @@ private static String createThreadName(String streamType, String backendWorkerTo protected abstract void onResponse(ResponseT response); /** Called when a new underlying stream to the server has been opened. */ - protected abstract void onNewStream() - throws StreamClosedException, WindmillStreamShutdownException; + protected abstract void onNewStream() throws WindmillStreamShutdownException; /** Returns whether there are any pending requests that should be retried on a stream break. */ protected abstract boolean hasPendingRequests(); @@ -146,16 +149,19 @@ protected abstract void onNewStream() */ protected abstract void startThrottleTimer(); - /** Reflects that {@link #shutdown()} was explicitly called. */ - protected synchronized boolean hasReceivedShutdownSignal() { - return isShutdown; - } - - /** Send a request to the server. */ - protected final synchronized void send(RequestT request) - throws StreamClosedException, WindmillStreamShutdownException { + /** Try to send a request to the server. Returns true if the request was successfully sent. */ + @CanIgnoreReturnValue + protected final synchronized boolean trySend(RequestT request) + throws WindmillStreamShutdownException { debugMetrics.recordSend(); - requestObserver.onNext(request); + try { + requestObserver.onNext(request); + return true; + } catch (StreamClosedException e) { + // Stream was broken, requests may be retried when stream is reopened. + } + + return false; } @Override @@ -189,6 +195,7 @@ private void startStream() { return; } } catch (WindmillStreamShutdownException e) { + // shutdown() is responsible for cleaning up pending requests. logger.debug("Stream was shutdown while creating new stream.", e); } catch (Exception e) { logger.error("Failed to create new stream, retrying: ", e); @@ -201,6 +208,8 @@ private void startStream() { logger.info( "Interrupted during {} creation backoff. The stream will not be created.", getClass()); + // Shutdown the stream to clean up any dangling resources and pending requests. + shutdown(); break; } catch (IOException ioe) { // Keep trying to create the stream. @@ -225,7 +234,7 @@ protected final void executeSafely(Runnable runnable) { } } - public final void maybeSendHealthCheck(Instant lastSendThreshold) { + public final synchronized void maybeSendHealthCheck(Instant lastSendThreshold) { if (!clientClosed && debugMetrics.getLastSendTimeMs() < lastSendThreshold.getMillis()) { try { sendHealthCheck(); @@ -235,8 +244,7 @@ public final void maybeSendHealthCheck(Instant lastSendThreshold) { } } - protected abstract void sendHealthCheck() - throws WindmillStreamShutdownException, StreamClosedException; + protected abstract void sendHealthCheck() throws WindmillStreamShutdownException; /** * @implNote Care is taken that synchronization on this is unnecessary for all status page @@ -257,7 +265,7 @@ public final void appendSummaryHtml(PrintWriter writer) { metrics.lastRestartTime(), metrics.errorCount())); - if (clientClosed) { + if (summaryMetrics.isClientClosed()) { writer.write(", client closed"); } @@ -272,7 +280,7 @@ public final void appendSummaryHtml(PrintWriter writer) { summaryMetrics.timeSinceLastSend(), summaryMetrics.timeSinceLastResponse(), requestObserver.isClosed(), - hasReceivedShutdownSignal(), + summaryMetrics.shutdownTime().isPresent(), summaryMetrics.shutdownTime().orElse(null)); } @@ -285,6 +293,7 @@ public final void appendSummaryHtml(PrintWriter writer) { @Override public final synchronized void halfClose() { // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. + debugMetrics.recordHalfClose(); clientClosed = true; try { requestObserver.onCompleted(); @@ -316,7 +325,6 @@ public final void shutdown() { if (!isShutdown) { isShutdown = true; debugMetrics.recordShutdown(); - shutdownInternal(); } } @@ -412,15 +420,17 @@ private void recordStreamStatus(Status status) { } /** Returns true if the stream was torn down and should not be restarted internally. */ - private synchronized boolean maybeTeardownStream() { - if (hasReceivedShutdownSignal() || (clientClosed && !hasPendingRequests())) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - executor.shutdownNow(); - return true; - } + private boolean maybeTeardownStream() { + synchronized (AbstractWindmillStream.this) { + if (isShutdown || (clientClosed && !hasPendingRequests())) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); + return true; + } - return false; + return false; + } } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index d0f1236f36e1..783f09687a37 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -145,4 +145,14 @@ public synchronized void onCompleted() synchronized boolean isClosed() { return isCurrentStreamClosed; } + + /** + * Indicates that the current stream was closed and the {@link StreamObserver} has finished via + * {@link StreamObserver#onCompleted()}. The stream may perform + */ + static final class StreamClosedException extends Exception { + private StreamClosedException(String s) { + super(s); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java deleted file mode 100644 index a5f7ef0b8312..000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamClosedException.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.client; - -/** Indicates that the current stream was closed. */ -public final class StreamClosedException extends Exception { - StreamClosedException(String s) { - super(s); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java index d813b6d7c9c2..fb3bf0323c47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -59,6 +59,9 @@ final class StreamDebugMetrics { @GuardedBy("this") private DateTime shutdownTime = null; + @GuardedBy("this") + private boolean clientClosed = false; + private StreamDebugMetrics(Supplier clock) { this.clock = clock; } @@ -122,6 +125,10 @@ synchronized void recordShutdown() { shutdownTime = clock.get().toDateTime(); } + synchronized void recordHalfClose() { + clientClosed = true; + } + synchronized Optional responseDebugString(long nowMillis) { return lastResponseTimeMs == 0 ? Optional.empty() @@ -145,7 +152,8 @@ synchronized Snapshot getSummaryMetrics() { debugDuration(nowMs, lastResponseTimeMs), getRestartMetrics(), sleepUntil - nowMs(), - shutdownTime); + shutdownTime, + clientClosed); } @AutoValue @@ -156,14 +164,16 @@ private static Snapshot create( long timeSinceLastResponse, Optional restartMetrics, long sleepLeft, - @Nullable DateTime shutdownTime) { + @Nullable DateTime shutdownTime, + boolean isClientClosed) { return new AutoValue_StreamDebugMetrics_Snapshot( streamAge, timeSinceLastSend, timeSinceLastResponse, restartMetrics, sleepLeft, - Optional.ofNullable(shutdownTime)); + Optional.ofNullable(shutdownTime), + isClientClosed); } abstract long streamAge(); @@ -177,6 +187,8 @@ private static Snapshot create( abstract long sleepLeft(); abstract Optional shutdownTime(); + + abstract boolean isClientClosed(); } @AutoValue diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 5d4e5de5eda3..8e401e4d2921 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -1,23 +1,28 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + package org.apache.beam.runners.dataflow.worker.windmill.client; -/** Thrown when operations are requested on a {@link WindmillStream} has been shutdown/closed. */ +/** + * Thrown when operations are requested on a {@link WindmillStream} has been shutdown. Future + * operations on the stream are not allowed and will throw an {@link + * WindmillStreamShutdownException}. + */ public final class WindmillStreamShutdownException extends Exception { public WindmillStreamShutdownException(String message) { super(message); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 40d4cf0750de..6951a6cdf772 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -38,7 +38,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -123,9 +122,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - protected synchronized void onNewStream() - throws StreamClosedException, WindmillStreamShutdownException { - send(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingCommitWorkRequest.newBuilder().setHeader(jobHeader).build()); try (Batcher resendBatcher = new Batcher()) { for (Map.Entry entry : pending.entrySet()) { if (!resendBatcher.canAccept(entry.getValue().getBytes())) { @@ -151,11 +149,11 @@ protected boolean hasPendingRequests() { } @Override - public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { StreamingCommitWorkRequest.Builder builder = StreamingCommitWorkRequest.newBuilder(); builder.addCommitChunkBuilder().setRequestId(HEARTBEAT_REQUEST_ID); - send(builder.build()); + trySend(builder.build()); } } @@ -173,9 +171,11 @@ protected void onResponse(StreamingCommitResponse response) { CommitStatus commitStatus = i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; if (pendingRequest == null) { - if (!hasReceivedShutdownSignal()) { - // Skip responses when the stream is shutdown since they are now invalid. - LOG.error("Got unknown commit request ID: {}", requestId); + synchronized (this) { + if (!isShutdown) { + // Missing responses are expected after shutdown() because it removes them. + LOG.error("Got unknown commit request ID: {}", requestId); + } } } else { try { @@ -185,13 +185,12 @@ protected void onResponse(StreamingCommitResponse response) { // other commits from being processed. Aggregate all the failures to throw after // processing the response if they exist. LOG.warn("Exception while processing commit response.", e); - failures.recordError(commitStatus, e); + failures.addError(commitStatus, e); } } } - if (failures.hasErrors()) { - throw failures; - } + + failures.throwIfNonEmpty(); } @Override @@ -239,15 +238,11 @@ private void issueSingleRequest(long id, PendingRequest pendingRequest) .setSerializedWorkItemCommit(pendingRequest.serializedCommit()); StreamingCommitWorkRequest chunk = requestBuilder.build(); synchronized (this) { - try { - if (!prepareForSend(id, pendingRequest)) { - pendingRequest.abort(); - return; - } - send(chunk); - } catch (StreamClosedException e) { - // Stream was broken, request will be retried when stream is reopened. + if (!prepareForSend(id, pendingRequest)) { + pendingRequest.abort(); + return; } + trySend(chunk); } } @@ -273,11 +268,7 @@ private void issueBatchedRequest(Map requests) requests.forEach((ignored, pendingRequest) -> pendingRequest.abort()); return; } - try { - send(request); - } catch (StreamClosedException e) { - // Stream was broken, request will be retried when stream is reopened. - } + trySend(request); } } @@ -308,10 +299,9 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) } StreamingCommitWorkRequest requestChunk = StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); - try { - send(requestChunk); - } catch (StreamClosedException e) { - // Stream was broken, request will be retried when stream is reopened. + + if (!trySend(requestChunk)) { + // The stream broke, don't try to send the rest of the chunks here. break; } } @@ -320,7 +310,7 @@ private void issueMultiChunkRequest(long id, PendingRequest pendingRequest) /** Returns true if prepare for send succeeded. */ private synchronized boolean prepareForSend(long id, PendingRequest request) { - if (!hasReceivedShutdownSignal()) { + if (!isShutdown) { pending.put(id, request); return true; } @@ -329,7 +319,7 @@ private synchronized boolean prepareForSend(long id, PendingRequest request) { /** Returns true if prepare for send succeeded. */ private synchronized boolean prepareForSend(Map requests) { - if (!hasReceivedShutdownSignal()) { + if (!isShutdown) { pending.putAll(requests); return true; } @@ -382,15 +372,17 @@ private CommitCompletionException() { this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); } - private void recordError(CommitStatus commitStatus, Throwable error) { + private void addError(CommitStatus commitStatus, Throwable error) { errorCounter.compute( Pair.of(commitStatus, error.getClass()), (ignored, current) -> current == null ? 1 : current + 1); detailedErrors.add(error); } - private boolean hasErrors() { - return !errorCounter.isEmpty(); + private void throwIfNonEmpty() { + if (!errorCounter.isEmpty()) { + throw this; + } } @Override @@ -417,8 +409,7 @@ private Batcher() { @Override public boolean commitWorkItem( String computation, WorkItemCommitRequest commitRequest, Consumer onDone) { - if (!canAccept(commitRequest.getSerializedSize() + computation.length()) - || hasReceivedShutdownSignal()) { + if (!canAccept(commitRequest.getSerializedSize() + computation.length())) { return false; } @@ -431,11 +422,7 @@ public boolean commitWorkItem( @Override public void flush() { try { - if (!hasReceivedShutdownSignal()) { - flushInternal(queue); - } else { - queue.forEach((ignored, request) -> request.abort()); - } + flushInternal(queue); } catch (WindmillStreamShutdownException e) { queue.forEach((ignored, request) -> request.abort()); } finally { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 750d61ac2f29..86f635f07ac2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -33,7 +33,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; @@ -173,7 +172,7 @@ private static Watermarks createWatermarks( /** * @implNote Do not lock/synchronize here due to this running on grpc serial executor for message * which can deadlock since we send on the stream beneath the synchronization. {@link - * AbstractWindmillStream#send(Object)} is synchronized so the sends are already guarded. + * AbstractWindmillStream#trySend(Object)} is synchronized so the sends are already guarded. */ private void maybeSendRequestExtension(GetWorkBudget extension) { if (extension.items() > 0 || extension.bytes() > 0) { @@ -189,8 +188,8 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { lastRequest.set(request); budgetTracker.recordBudgetRequested(extension); try { - send(request); - } catch (StreamClosedException | WindmillStreamShutdownException e) { + trySend(request); + } catch (WindmillStreamShutdownException e) { // Stream was closed. } }); @@ -198,23 +197,21 @@ private void maybeSendRequestExtension(GetWorkBudget extension) { } @Override - protected synchronized void onNewStream() - throws WindmillStreamShutdownException, StreamClosedException { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); budgetTracker.reset(); GetWorkBudget initialGetWorkBudget = budgetTracker.computeBudgetExtension(); StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( - requestHeader - .toBuilder() + requestHeader.toBuilder() .setMaxItems(initialGetWorkBudget.items()) .setMaxBytes(initialGetWorkBudget.bytes()) .build()) .build(); lastRequest.set(request); budgetTracker.recordBudgetRequested(initialGetWorkBudget); - send(request); + trySend(request); } @Override @@ -232,8 +229,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() throws WindmillStreamShutdownException, StreamClosedException { - send(HEALTH_CHECK_REQUEST); + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index ee403b100fa5..d354b2292ce2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -48,7 +48,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcGetDataStreamRequests.QueuedBatch; @@ -70,7 +69,9 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ + /** + * @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} + */ private final Deque batches; private final Map pending; @@ -157,18 +158,13 @@ private static WindmillStreamShutdownException shutdownExceptionFor(QueuedReques private void sendIgnoringClosed(StreamingGetDataRequest getDataRequest) throws WindmillStreamShutdownException { - try { - send(getDataRequest); - } catch (StreamClosedException e) { - // Stream was closed on send, will be retried on stream restart. - } + trySend(getDataRequest); } @Override - protected synchronized void onNewStream() - throws StreamClosedException, WindmillStreamShutdownException { - send(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); - if (clientClosed && !hasReceivedShutdownSignal()) { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { + trySend(StreamingGetDataRequest.newBuilder().setHeader(jobHeader).build()); + if (clientClosed) { // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. @@ -228,8 +224,10 @@ public GlobalData requestGlobalData(GlobalDataRequest request) @Override public void refreshActiveWork(Map> heartbeats) throws WindmillStreamShutdownException { - if (hasReceivedShutdownSignal()) { - throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + synchronized (this) { + if (isShutdown) { + throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + } } StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); @@ -299,9 +297,9 @@ public void onHeartbeatResponse(List resp } @Override - public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { + public void sendHealthCheck() throws WindmillStreamShutdownException { if (hasPendingRequests()) { - send(HEALTH_CHECK_REQUEST); + trySend(HEALTH_CHECK_REQUEST); } } @@ -345,7 +343,7 @@ public void appendSpecificHtml(PrintWriter writer) { private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) throws WindmillStreamShutdownException { - while (!hasReceivedShutdownSignal()) { + while (!isShutdownLocked()) { request.resetResponseStream(); try { queueRequestAndWait(request); @@ -366,13 +364,12 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn { try { - send(extension); - } catch (StreamClosedException | WindmillStreamShutdownException e) { + trySend(extension); + } catch (WindmillStreamShutdownException e) { // Stream was closed. } }); } @Override - protected synchronized void onNewStream() - throws StreamClosedException, WindmillStreamShutdownException { + protected synchronized void onNewStream() throws WindmillStreamShutdownException { workItemAssemblers.clear(); inflightMessages.set(request.getMaxItems()); inflightBytes.set(request.getMaxBytes()); - send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); + trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); } @Override @@ -154,8 +152,8 @@ public void appendSpecificHtml(PrintWriter writer) { } @Override - public void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { - send( + public void sendHealthCheck() throws WindmillStreamShutdownException { + trySend( StreamingGetWorkRequest.newBuilder() .setRequestExtension( StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index a076b4f58258..4ce2f651f0b7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -28,7 +28,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; import org.apache.beam.runners.dataflow.worker.windmill.client.AbstractWindmillStream; -import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; @@ -134,8 +133,8 @@ private Optional extractWindmillEndpointsFrom( } @Override - protected void onNewStream() throws StreamClosedException, WindmillStreamShutdownException { - send(workerMetadataRequest); + protected void onNewStream() throws WindmillStreamShutdownException { + trySend(workerMetadataRequest); } @Override @@ -152,8 +151,8 @@ protected void startThrottleTimer() { } @Override - protected void sendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { - send(HEALTH_CHECK_REQUEST); + protected void sendHealthCheck() throws WindmillStreamShutdownException { + trySend(HEALTH_CHECK_REQUEST); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index e072ece0fac2..98cede8fbd83 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -51,6 +51,9 @@ public final class DirectStreamObserver implements TerminatingStreamObserver< private final long deadlineSeconds; private final int messagesBetweenIsReadyChecks; + @GuardedBy("lock") + private boolean isClosed = false; + @GuardedBy("lock") private int messagesSinceReady = 0; @@ -117,7 +120,7 @@ public void onNext(T value) throws StreamObserverCancelledException { isReadyNotifier.awaitAdvanceInterruptibly(awaitPhase, waitSeconds, TimeUnit.SECONDS); // If nextPhase is a value less than 0, the phaser has been terminated. if (nextPhase < 0) { - return; + throw new StreamObserverCancelledException("StreamObserver was terminated."); } synchronized (lock) { @@ -155,7 +158,9 @@ public void onNext(T value) throws StreamObserverCancelledException { @Override public void onError(Throwable t) { + isReadyNotifier.forceTermination(); synchronized (lock) { + isClosed = true; outboundObserver.onError(t); } } @@ -163,6 +168,7 @@ public void onError(Throwable t) { @Override public void onCompleted() { synchronized (lock) { + isClosed = true; outboundObserver.onCompleted(); } } @@ -171,11 +177,10 @@ public void onCompleted() { public void terminate(Throwable terminationException) { // Free the blocked threads in onNext(). isReadyNotifier.forceTermination(); - try { - onError(terminationException); - } catch (RuntimeException e) { - // If onError or onComplete was previously called, this will throw. - LOG.warn("StreamObserver was already terminated."); + synchronized (lock) { + if (!isClosed) { + onError(terminationException); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java index 4ea209f31b1d..70fd3497a37f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/StreamObserverCancelledException.java @@ -21,11 +21,15 @@ @Internal public final class StreamObserverCancelledException extends RuntimeException { - public StreamObserverCancelledException(Throwable cause) { + StreamObserverCancelledException(Throwable cause) { super(cause); } - public StreamObserverCancelledException(String message, Throwable cause) { + StreamObserverCancelledException(String message, Throwable cause) { super(message, cause); } + + StreamObserverCancelledException(String message) { + super(message); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java new file mode 100644 index 000000000000..c9700c7a8ac3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import java.io.PrintWriter; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.LoggerFactory; + +@RunWith(JUnit4.class) +public class AbstractWindmillStreamTest { + private static final long DEADLINE_SECONDS = 10; + private final Set> streamRegistry = ConcurrentHashMap.newKeySet(); + private final StreamObserverFactory streamObserverFactory = + StreamObserverFactory.direct(DEADLINE_SECONDS, 1); + + @Before + public void setUp() { + streamRegistry.clear(); + } + + private TestStream newStream( + Function, StreamObserver> clientFactory) { + return new TestStream(clientFactory, streamRegistry, streamObserverFactory); + } + + @Test + public void testShutdown_notBlockedBySend() throws InterruptedException, ExecutionException { + CountDownLatch sendBlocker = new CountDownLatch(1); + Function, StreamObserver> clientFactory = + ignored -> + new CallStreamObserver() { + @Override + public void onNext(Integer integer) { + try { + sendBlocker.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + }; + + TestStream testStream = newStream(clientFactory); + testStream.start(); + ExecutorService sendExecutor = Executors.newSingleThreadExecutor(); + Future sendFuture = + sendExecutor.submit( + () -> + assertThrows(WindmillStreamShutdownException.class, () -> testStream.testSend(1))); + testStream.shutdown(); + sendBlocker.countDown(); + assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); + } + + private static class TestStream extends AbstractWindmillStream { + private final AtomicInteger numStarts = new AtomicInteger(); + + private TestStream( + Function, StreamObserver> clientFactory, + Set> streamRegistry, + StreamObserverFactory streamObserverFactory) { + super( + LoggerFactory.getLogger(AbstractWindmillStreamTest.class), + "Test", + clientFactory, + FluentBackoff.DEFAULT.backoff(), + streamObserverFactory, + streamRegistry, + 1, + "Test"); + } + + @Override + protected void onResponse(Integer response) {} + + @Override + protected void onNewStream() { + numStarts.incrementAndGet(); + } + + @Override + protected boolean hasPendingRequests() { + return false; + } + + @Override + protected void startThrottleTimer() {} + + public void testSend(Integer i) throws ResettableThrowingStreamObserver.StreamClosedException, WindmillStreamShutdownException { + trySend(i); + } + + @Override + protected void sendHealthCheck() {} + + @Override + protected void appendSpecificHtml(PrintWriter writer) {} + + @Override + protected void shutdownInternal() {} + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java index 2f1a45abd7a5..5cd612d14bcf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java @@ -102,7 +102,7 @@ public void testOnCompleted_afterPoisonedThrows() { @Test public void testReset_usesNewDelegate() - throws WindmillStreamShutdownException, StreamClosedException { + throws WindmillStreamShutdownException, ResettableThrowingStreamObserver.StreamClosedException { List> delegates = new ArrayList<>(); ResettableThrowingStreamObserver observer = newStreamObserver( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index df6946ea8763..b6454d319f9f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.inOrder; @@ -146,7 +145,7 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { } @Test - public void testCommitWorkItem_afterShutdownFalse() { + public void testCommitWorkItem_afterShutdown() { int numCommits = 5; CommitWorkStreamTestStub testStub = @@ -160,14 +159,15 @@ public void testCommitWorkItem_afterShutdownFalse() { } commitWorkStream.shutdown(); + Set commitStatuses = new HashSet<>(); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { - Set commitStatuses = new HashSet<>(); - assertFalse( + assertTrue( batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add)); - assertThat(commitStatuses).isEmpty(); } } + + assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java index 1239c90dc0c6..d74735ee3052 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStreamTest.java @@ -37,7 +37,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataResponse; import org.apache.beam.runners.dataflow.worker.windmill.WindmillEndpoints; -import org.apache.beam.runners.dataflow.worker.windmill.client.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; @@ -257,7 +256,7 @@ public void testGetWorkerMetadata_correctlyAddsAndRemovesStreamFromRegistry() { } @Test - public void testSendHealthCheck() throws StreamClosedException, WindmillStreamShutdownException { + public void testSendHealthCheck() throws WindmillStreamShutdownException { TestGetWorkMetadataRequestObserver requestObserver = Mockito.spy(new TestGetWorkMetadataRequestObserver()); GetWorkerMetadataTestStub testStub = new GetWorkerMetadataTestStub(requestObserver); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 23ad89345bca..7a5056b7e157 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -470,8 +470,9 @@ private void flushResponse() { "Sending batched response of {} ids", responseBuilder.getRequestIdCount()); try { responseObserver.onNext(responseBuilder.build()); - } catch (IllegalStateException e) { + } catch (Exception e) { // Stream is already closed. + System.out.println("trieu: " + e); } responseBuilder.clear(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java new file mode 100644 index 000000000000..ee0a9280610a --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.sdk.fn.stream.AdvancingPhaser; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; +import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class DirectStreamObserverTest { + + @Test + public void testTerminate_waitingForReady() throws ExecutionException, InterruptedException { + CountDownLatch sendBlocker = new CountDownLatch(1); + TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + blockLatch.countDown(); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + } catch (Throwable e) { + return e; + } + + return new VerifyException(); + }); + RuntimeException terminationException = new RuntimeException("terminated"); + + assertTrue(blockLatch.await(5, TimeUnit.SECONDS)); + streamObserver.terminate(terminationException); + assertThat(onNextFuture.get()).isInstanceOf(StreamObserverCancelledException.class); + verify(delegate).onError(same(terminationException)); + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_interruption() throws ExecutionException, InterruptedException { + CountDownLatch sendBlocker = new CountDownLatch(1); + TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + StreamObserverCancelledException e = + assertThrows( + StreamObserverCancelledException.class, () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + onNextExecutor.shutdownNow(); + assertTrue(streamObserverExitLatch.await(5, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(InterruptedException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + @Test + public void testOnNext_timeOut() throws ExecutionException, InterruptedException { + CountDownLatch sendBlocker = new CountDownLatch(1); + TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch streamObserverExitLatch = new CountDownLatch(1); + Future onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + // We will check isReady on the next message, will block here. + StreamObserverCancelledException e = + assertThrows( + StreamObserverCancelledException.class, () -> streamObserver.onNext(1)); + streamObserverExitLatch.countDown(); + return e; + }); + + // Assert that onNextFuture is blocked. + assertFalse(onNextFuture.isDone()); + assertThat(streamObserverExitLatch.getCount()).isEqualTo(1); + + assertTrue(streamObserverExitLatch.await(10, TimeUnit.SECONDS)); + assertThat(onNextFuture.get()).hasCauseThat().isInstanceOf(TimeoutException.class); + + // onNext should only have been called once. + verify(delegate, times(1)).onNext(any()); + } + + private static class TestStreamObserver extends CallStreamObserver { + private final CountDownLatch sendBlocker; + private final int blockAfter; + private final AtomicInteger seen = new AtomicInteger(0); + private volatile boolean isReady = false; + + private TestStreamObserver(CountDownLatch sendBlocker, int blockAfter) { + this.blockAfter = blockAfter; + this.sendBlocker = sendBlocker; + } + + @Override + public void onNext(Integer integer) { + try { + if (seen.incrementAndGet() == blockAfter) { + sendBlocker.await(); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + + @Override + public boolean isReady() { + return isReady; + } + + private void setIsReady(boolean isReadyOverride) { + isReady = isReadyOverride; + } + + @Override + public void setOnReadyHandler(Runnable runnable) {} + + @Override + public void disableAutoInboundFlowControl() {} + + @Override + public void request(int i) {} + + @Override + public void setMessageCompression(boolean b) {} + } +} From 4a9a86311a112e4857c0ade7736df16410bcff8c Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Mon, 11 Nov 2024 23:59:03 -0800 Subject: [PATCH 20/23] address PR comments part 1 --- .../client/AbstractWindmillStream.java | 59 ++++++++++--------- .../ResettableThrowingStreamObserver.java | 11 ++-- .../windmill/client/StreamDebugMetrics.java | 13 ++-- .../WindmillStreamShutdownException.java | 27 ++++----- .../client/grpc/GrpcDirectGetWorkStream.java | 3 +- .../client/grpc/GrpcGetDataStream.java | 4 +- .../grpc/observers/DirectStreamObserver.java | 1 + .../client/AbstractWindmillStreamTest.java | 4 +- .../ResettableThrowingStreamObserverTest.java | 3 +- .../client/StreamDebugMetricsTest.java | 5 +- .../client/grpc/GrpcWindmillServerTest.java | 8 ++- 11 files changed, 76 insertions(+), 62 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index c8a00d7caea2..98566e0a9d39 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -36,6 +36,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.Status; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.joda.time.DateTime; import org.joda.time.Instant; import org.slf4j.Logger; @@ -71,6 +72,7 @@ public abstract class AbstractWindmillStream implements Win // shutdown. private static final Status OK_STATUS = Status.fromCode(Status.Code.OK); private static final String NEVER_RECEIVED_RESPONSE_LOG_STRING = "never received response"; + private static final String NOT_SHUTDOWN = "not shutdown"; protected final Sleeper sleeper; private final Logger logger; @@ -262,7 +264,7 @@ public final void appendSummaryHtml(PrintWriter writer) { ", %d restarts, last restart reason [ %s ] at [%s], %d errors", metrics.restartCount(), metrics.lastRestartReason(), - metrics.lastRestartTime(), + metrics.lastRestartTime().orElse(null), metrics.errorCount())); if (summaryMetrics.isClientClosed()) { @@ -275,13 +277,12 @@ public final void appendSummaryHtml(PrintWriter writer) { writer.format( ", current stream is %dms old, last send %dms, last response %dms, closed: %s, " - + "isShutdown: %s, shutdown time: %s", + + "shutdown time: %s", summaryMetrics.streamAge(), summaryMetrics.timeSinceLastSend(), summaryMetrics.timeSinceLastResponse(), requestObserver.isClosed(), - summaryMetrics.shutdownTime().isPresent(), - summaryMetrics.shutdownTime().orElse(null)); + summaryMetrics.shutdownTime().map(DateTime::toString).orElse(NOT_SHUTDOWN)); } /** @@ -297,8 +298,10 @@ public final synchronized void halfClose() { clientClosed = true; try { requestObserver.onCompleted(); - } catch (StreamClosedException | WindmillStreamShutdownException e) { - logger.warn("Stream was previously closed or shutdown."); + } catch (StreamClosedException e) { + logger.warn("Stream was previously closed."); + } catch (WindmillStreamShutdownException e) { + logger.warn("Stream was previously shutdown."); } } @@ -317,10 +320,13 @@ public String backendWorkerToken() { return backendWorkerToken; } + @SuppressWarnings("GuardedBy") @Override public final void shutdown() { - // Don't lock on "this" before poisoning the request observer as allow IO to block shutdown. + // Don't lock on "this" before poisoning the request observer since otherwise the observer may + // be blocking in send(). requestObserver.poison(); + isShutdown = true; synchronized (this) { if (!isShutdown) { isShutdown = true; @@ -332,6 +338,18 @@ public final void shutdown() { protected abstract void shutdownInternal(); + /** Returns true if the stream was torn down and should not be restarted internally. */ + private synchronized boolean maybeTeardownStream() { + if (isShutdown || (clientClosed && !hasPendingRequests())) { + streamRegistry.remove(AbstractWindmillStream.this); + finishLatch.countDown(); + executor.shutdownNow(); + return true; + } + + return false; + } + private class ResponseObserver implements StreamObserver { @Override @@ -351,7 +369,13 @@ public void onError(Throwable t) { return; } - recordStreamStatus(Status.fromThrowable(t)); + Status errorStatus = Status.fromThrowable(t); + recordStreamStatus(errorStatus); + + // If the stream was stopped due to a resource exhausted error then we are throttled. + if (errorStatus.getCode() == Status.Code.RESOURCE_EXHAUSTED) { + startThrottleTimer(); + } try { long sleep = backoff.nextBackOffMillis(); @@ -411,25 +435,6 @@ private void recordStreamStatus(Status status) { .responseDebugString(nowMillis) .orElse(NEVER_RECEIVED_RESPONSE_LOG_STRING)); } - - // If the stream was stopped due to a resource exhausted error then we are throttled. - if (status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { - startThrottleTimer(); - } - } - } - - /** Returns true if the stream was torn down and should not be restarted internally. */ - private boolean maybeTeardownStream() { - synchronized (AbstractWindmillStream.this) { - if (isShutdown || (clientClosed && !hasPendingRequests())) { - streamRegistry.remove(AbstractWindmillStream.this); - finishLatch.countDown(); - executor.shutdownNow(); - return true; - } - - return false; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index 783f09687a37..17f65f56c984 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -55,7 +55,7 @@ final class ResettableThrowingStreamObserver { * StreamObserver. */ @GuardedBy("this") - private boolean isCurrentStreamClosed = false; + private boolean isCurrentStreamClosed = true; ResettableThrowingStreamObserver( Supplier> streamObserverFactory, Logger logger) { @@ -72,12 +72,10 @@ private synchronized StreamObserver delegate() if (isCurrentStreamClosed) { throw new StreamClosedException( - "Current stream is closed, requires reset for future stream operations."); + "Current stream is closed, requires reset() for future stream operations."); } - return Preconditions.checkNotNull( - delegateStreamObserver, - "requestObserver cannot be null. Missing a call to startStream() to initialize."); + return Preconditions.checkNotNull(delegateStreamObserver, "requestObserver cannot be null."); } /** Creates a new delegate to use for future {@link StreamObserver} methods. */ @@ -131,9 +129,10 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce } } - public void onError(Throwable throwable) + public synchronized void onError(Throwable throwable) throws StreamClosedException, WindmillStreamShutdownException { delegate().onError(throwable); + isCurrentStreamClosed = true; } public synchronized void onCompleted() diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java index fb3bf0323c47..4cda12a85ea2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetrics.java @@ -45,7 +45,7 @@ final class StreamDebugMetrics { private String lastRestartReason = ""; @GuardedBy("this") - private DateTime lastRestartTime = null; + private @Nullable DateTime lastRestartTime = null; @GuardedBy("this") private long lastResponseTimeMs = 0; @@ -57,7 +57,7 @@ final class StreamDebugMetrics { private long startTimeMs = 0; @GuardedBy("this") - private DateTime shutdownTime = null; + private @Nullable DateTime shutdownTime = null; @GuardedBy("this") private boolean clientClosed = false; @@ -194,16 +194,19 @@ private static Snapshot create( @AutoValue abstract static class RestartMetrics { private static RestartMetrics create( - int restartCount, String restartReason, DateTime lastRestartTime, int errorCount) { + int restartCount, + String restartReason, + @Nullable DateTime lastRestartTime, + int errorCount) { return new AutoValue_StreamDebugMetrics_RestartMetrics( - restartCount, restartReason, lastRestartTime, errorCount); + restartCount, restartReason, Optional.ofNullable(lastRestartTime), errorCount); } abstract int restartCount(); abstract String lastRestartReason(); - abstract DateTime lastRestartTime(); + abstract Optional lastRestartTime(); abstract int errorCount(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 8e401e4d2921..566c15c58036 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -1,21 +1,20 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ - package org.apache.beam.runners.dataflow.worker.windmill.client; /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 86f635f07ac2..f848489f51c6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -204,7 +204,8 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException StreamingGetWorkRequest request = StreamingGetWorkRequest.newBuilder() .setRequest( - requestHeader.toBuilder() + requestHeader + .toBuilder() .setMaxItems(initialGetWorkBudget.items()) .setMaxBytes(initialGetWorkBudget.bytes()) .build()) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index d354b2292ce2..26a87e1c805b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -69,9 +69,7 @@ final class GrpcGetDataStream private static final StreamingGetDataRequest HEALTH_CHECK_REQUEST = StreamingGetDataRequest.newBuilder().build(); - /** - * @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} - */ + /** @implNote {@link QueuedBatch} objects in the queue are is guarded by {@code this} */ private final Deque batches; private final Map pending; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 98cede8fbd83..582acce17aae 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -146,6 +146,7 @@ public void onNext(T value) throws StreamObserverCancelledException { "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, Thread.currentThread().getName()); + Thread.dumpStack(); } waitSeconds = waitSeconds * 2; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index c9700c7a8ac3..aada071416de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -142,7 +142,9 @@ protected boolean hasPendingRequests() { @Override protected void startThrottleTimer() {} - public void testSend(Integer i) throws ResettableThrowingStreamObserver.StreamClosedException, WindmillStreamShutdownException { + public void testSend(Integer i) + throws ResettableThrowingStreamObserver.StreamClosedException, + WindmillStreamShutdownException { trySend(i); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java index 5cd612d14bcf..790c155d94d6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserverTest.java @@ -102,7 +102,8 @@ public void testOnCompleted_afterPoisonedThrows() { @Test public void testReset_usesNewDelegate() - throws WindmillStreamShutdownException, ResettableThrowingStreamObserver.StreamClosedException { + throws WindmillStreamShutdownException, + ResettableThrowingStreamObserver.StreamClosedException { List> delegates = new ArrayList<>(); ResettableThrowingStreamObserver observer = newStreamObserver( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java index e079be9b5663..564b2e664505 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/StreamDebugMetricsTest.java @@ -78,8 +78,9 @@ public void testSummaryMetrics_withRestarts() { assertThat(restartMetrics.lastRestartReason()).isEqualTo(restartReason); assertThat(restartMetrics.restartCount()).isEqualTo(1); assertThat(restartMetrics.errorCount()).isEqualTo(1); - assertThat(restartMetrics.lastRestartTime()).isLessThan(DateTime.now()); - assertThat(restartMetrics.lastRestartTime().toInstant()).isGreaterThan(Instant.EPOCH); + assertTrue(restartMetrics.lastRestartTime().isPresent()); + assertThat(restartMetrics.lastRestartTime().get()).isLessThan(DateTime.now()); + assertThat(restartMetrics.lastRestartTime().get().toInstant()).isGreaterThan(Instant.EPOCH); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 7a5056b7e157..bc978a72a8d3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -22,6 +22,7 @@ import java.io.InputStream; import java.io.SequenceInputStream; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -472,7 +473,8 @@ private void flushResponse() { responseObserver.onNext(responseBuilder.build()); } catch (Exception e) { // Stream is already closed. - System.out.println("trieu: " + e); + LOG.warn("trieu: ", e); + LOG.warn(Arrays.toString(e.getStackTrace())); } responseBuilder.clear(); } @@ -512,7 +514,9 @@ private void flushResponse() { done.countDown(); }); } - done.await(); + while (done.await(5, TimeUnit.SECONDS)) { + LOG.info("trieu: {}", done.getCount()); + } stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); executor.shutdown(); From 95a19f45dadc05d17caa1dd4bbd7719f3847466a Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 13 Nov 2024 02:25:44 -0800 Subject: [PATCH 21/23] address PR comments --- .../harness/SingleSourceWorkerHarness.java | 4 +- .../worker/windmill/WindmillServerStub.java | 12 +++-- .../client/AbstractWindmillStream.java | 1 - .../ResettableThrowingStreamObserver.java | 2 +- .../client/grpc/GrpcCommitWorkStream.java | 38 +++++++------- .../client/grpc/GrpcDirectGetWorkStream.java | 4 +- .../client/grpc/GrpcGetDataStream.java | 34 ++++--------- .../client/grpc/GrpcWindmillServer.java | 14 +++--- .../grpc/observers/DirectStreamObserver.java | 31 +++++++----- .../observers/TerminatingStreamObserver.java | 9 +++- .../client/AbstractWindmillStreamTest.java | 6 +++ .../client/grpc/GrpcCommitWorkStreamTest.java | 25 ++++++++-- .../grpc/GrpcGetDataStreamRequestsTest.java | 5 +- .../client/grpc/GrpcGetDataStreamTest.java | 49 +++++++++++++++++++ .../client/grpc/GrpcWindmillServerTest.java | 5 +- .../observers/DirectStreamObserverTest.java | 8 +-- 16 files changed, 166 insertions(+), 81 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java index bc93e6d89c41..22fba91e170a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/SingleSourceWorkerHarness.java @@ -33,7 +33,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.RpcException; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; @@ -199,7 +199,7 @@ private void applianceDispatchLoop(Supplier getWorkFn) if (workResponse.getWorkCount() > 0) { break; } - } catch (RpcException e) { + } catch (WindmillRpcException e) { LOG.warn("GetWork failed, retrying:", e); } sleepUninterruptibly(backoff, TimeUnit.MILLISECONDS); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java index cd753cb8ec91..2ae97087fec7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillServerStub.java @@ -30,10 +30,16 @@ public abstract class WindmillServerStub @Override public void appendSummaryHtml(PrintWriter writer) {} - /** Generic Exception type for implementors to use to represent errors while making RPCs. */ - public static final class RpcException extends RuntimeException { - public RpcException(Throwable cause) { + /** + * Generic Exception type for implementors to use to represent errors while making Windmill RPCs. + */ + public static final class WindmillRpcException extends RuntimeException { + public WindmillRpcException(Throwable cause) { super(cause); } + + public WindmillRpcException(String message, Throwable cause) { + super(message, cause); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 98566e0a9d39..df34797b647a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -326,7 +326,6 @@ public final void shutdown() { // Don't lock on "this" before poisoning the request observer since otherwise the observer may // be blocking in send(). requestObserver.poison(); - isShutdown = true; synchronized (this) { if (!isShutdown) { isShutdown = true; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index 17f65f56c984..5eb691cbf55a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -121,7 +121,7 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce try { delegate.onError(e); - } catch (RuntimeException ignored) { + } catch (IllegalStateException ignored) { // If the delegate above was already terminated via onError or onComplete from another // thread. logger.warn("StreamObserver was previously cancelled.", e); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 6951a6cdf772..2dd069b9c443 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -30,6 +30,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; @@ -161,15 +162,19 @@ public void sendHealthCheck() throws WindmillStreamShutdownException { protected void onResponse(StreamingCommitResponse response) { commitWorkThrottleTimer.stop(); - CommitCompletionException failures = new CommitCompletionException(); + CommitCompletionFailureHandler failureHandler = new CommitCompletionFailureHandler(); for (int i = 0; i < response.getRequestIdCount(); ++i) { long requestId = response.getRequestId(i); if (requestId == HEARTBEAT_REQUEST_ID) { continue; } - PendingRequest pendingRequest = pending.remove(requestId); + + // From windmill.proto: Indices must line up with the request_id field, but trailing OKs may + // be omitted. CommitStatus commitStatus = i < response.getStatusCount() ? response.getStatus(i) : CommitStatus.OK; + + @Nullable PendingRequest pendingRequest = pending.remove(requestId); if (pendingRequest == null) { synchronized (this) { if (!isShutdown) { @@ -185,12 +190,12 @@ protected void onResponse(StreamingCommitResponse response) { // other commits from being processed. Aggregate all the failures to throw after // processing the response if they exist. LOG.warn("Exception while processing commit response.", e); - failures.addError(commitStatus, e); + failureHandler.addError(commitStatus, e); } } } - failures.throwIfNonEmpty(); + failureHandler.throwIfNonEmpty(); } @Override @@ -362,12 +367,17 @@ private void abort() { } private static class CommitCompletionException extends RuntimeException { + private CommitCompletionException(String message) { + super(message); + } + } + + private static class CommitCompletionFailureHandler { private static final int MAX_PRINTABLE_ERRORS = 10; private final Map>, Integer> errorCounter; private final EvictingQueue detailedErrors; - private CommitCompletionException() { - super("Exception while processing commit response."); + private CommitCompletionFailureHandler() { this.errorCounter = new HashMap<>(); this.detailedErrors = EvictingQueue.create(MAX_PRINTABLE_ERRORS); } @@ -381,19 +391,13 @@ private void addError(CommitStatus commitStatus, Throwable error) { private void throwIfNonEmpty() { if (!errorCounter.isEmpty()) { - throw this; + String errorMessage = + String.format( + "Exception while processing commit response. ErrorCounter: %s; Details: %s", + errorCounter, detailedErrors); + throw new CommitCompletionException(errorMessage); } } - - @Override - public final String getMessage() { - return "CommitCompletionException{" - + "errorCounter=" - + errorCounter - + ", detailedErrors=" - + detailedErrors - + '}'; - } } private class Batcher implements CommitWorkStream.RequestBatcher { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index f848489f51c6..27f457900e6c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -235,9 +235,7 @@ public void sendHealthCheck() throws WindmillStreamShutdownException { } @Override - protected void shutdownInternal() { - workItemAssemblers.clear(); - } + protected void shutdownInternal() {} @Override protected void onResponse(StreamingGetWorkResponseChunk chunk) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 26a87e1c805b..19eb6dd4915a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify.verify; import java.io.IOException; import java.io.InputStream; @@ -56,7 +57,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Verify; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -166,7 +166,7 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException // We rely on close only occurring after all methods on the stream have returned. // Since the requestKeyedData and requestGlobalData methods are blocking this // means there should be no pending requests. - verify(!hasPendingRequests(), "Pending requests not expected on stream restart."); + verify(!hasPendingRequests(), "Pending requests not expected if we've half-closed."); } else { for (AppendableInputStream responseStream : pending.values()) { responseStream.cancel(); @@ -188,7 +188,9 @@ protected void onResponse(StreamingGetDataResponse chunk) { for (int i = 0; i < chunk.getRequestIdCount(); ++i) { AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - verify(responseStream != null, "No pending response stream"); + synchronized (this) { + verify(responseStream != null || isShutdown, "No pending response stream"); + } responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { responseStream.complete(); @@ -222,12 +224,6 @@ public GlobalData requestGlobalData(GlobalDataRequest request) @Override public void refreshActiveWork(Map> heartbeats) throws WindmillStreamShutdownException { - synchronized (this) { - if (isShutdown) { - throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); - } - } - StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); if (sendKeyedGetDataRequests) { long builderBytes = 0; @@ -302,7 +298,7 @@ public void sendHealthCheck() throws WindmillStreamShutdownException { } @Override - protected void shutdownInternal() { + protected synchronized void shutdownInternal() { // Stream has been explicitly closed. Drain pending input streams and request batches. // Future calls to send RPCs will fail. pending.values().forEach(AppendableInputStream::cancel); @@ -341,13 +337,13 @@ public void appendSpecificHtml(PrintWriter writer) { private ResponseT issueRequest(QueuedRequest request, ParseFn parseFn) throws WindmillStreamShutdownException { - while (!isShutdownLocked()) { + while (true) { request.resetResponseStream(); try { queueRequestAndWait(request); return parseFn.parse(request.getResponseStream()); } catch (AppendableInputStream.InvalidInputStreamStateException | CancellationException e) { - handleShutdown(request, e); + throwIfShutdown(request, e); if (!(e instanceof CancellationException)) { throw e; } @@ -355,17 +351,15 @@ private ResponseT issueRequest(QueuedRequest request, ParseFn { ResponseT parse(InputStream input) throws IOException; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java index 770cd616a5ac..f35b9b23d091 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServer.java @@ -290,13 +290,13 @@ private ResponseT callWithBackoff(Supplier function) { e.getStatus()); } if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { - throw new RpcException(e); + throw new WindmillRpcException(e); } } catch (IOException | InterruptedException i) { if (i instanceof InterruptedException) { Thread.currentThread().interrupt(); } - RpcException rpcException = new RpcException(e); + WindmillRpcException rpcException = new WindmillRpcException(e); rpcException.addSuppressed(i); throw rpcException; } @@ -310,7 +310,7 @@ public GetWorkResponse getWork(GetWorkRequest request) { return callWithBackoff(() -> syncApplianceStub.getWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetWork")); } @Override @@ -319,7 +319,7 @@ public GetDataResponse getData(GetDataRequest request) { return callWithBackoff(() -> syncApplianceStub.getData(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("GetData")); } @Override @@ -327,7 +327,7 @@ public CommitWorkResponse commitWork(CommitWorkRequest request) { if (syncApplianceStub != null) { return callWithBackoff(() -> syncApplianceStub.commitWork(request)); } - throw new RpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); + throw new WindmillRpcException(unsupportedUnaryRequestInStreamingEngineException("CommitWork")); } /** @@ -382,7 +382,7 @@ public GetConfigResponse getConfig(GetConfigRequest request) { return callWithBackoff(() -> syncApplianceStub.getConfig(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("GetConfig not supported in Streaming Engine.")); } @@ -392,7 +392,7 @@ public ReportStatsResponse reportStats(ReportStatsRequest request) { return callWithBackoff(() -> syncApplianceStub.reportStats(request)); } - throw new RpcException( + throw new WindmillRpcException( new UnsupportedOperationException("ReportStats not supported in Streaming Engine.")); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 582acce17aae..3a289e4dd48b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -22,8 +22,10 @@ import java.util.concurrent.TimeoutException; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub.WindmillRpcException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -83,16 +85,14 @@ public void onNext(T value) throws StreamObserverCancelledException { // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted // are synchronized after terminating the phaser if we observe that the phaser is not // terminated the onNext calls below are guaranteed to not be called on a closed observer. - if (currentPhase < 0) return; + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } // If we awaited previously and timed out, wait for the same phase. Otherwise we're // careful to observe the phase before observing isReady. if (awaitPhase < 0) { - awaitPhase = isReadyNotifier.getPhase(); - // If getPhase() returns a value less than 0, the phaser has been terminated. - if (awaitPhase < 0) { - return; - } + awaitPhase = currentPhase; } // We only check isReady periodically to effectively allow for increasing the outbound @@ -128,7 +128,9 @@ public void onNext(T value) throws StreamObserverCancelledException { // Phaser is terminated so don't use the outboundObserver. Since onError and onCompleted // are synchronized after terminating the phaser if we observe that the phaser is not // terminated the onNext calls below are guaranteed to not be called on a closed observer. - if (currentPhase < 0) return; + if (currentPhase < 0) { + throw new StreamObserverCancelledException("StreamObserver was terminated."); + } messagesSinceReady = 0; outboundObserver.onNext(value); return; @@ -138,7 +140,7 @@ public void onNext(T value) throws StreamObserverCancelledException { if (totalSecondsWaited > deadlineSeconds) { String errorMessage = constructStreamCancelledErrorMessage(totalSecondsWaited); LOG.error(errorMessage); - throw new StreamObserverCancelledException(errorMessage, e); + throw new WindmillRpcException(errorMessage, e); } if (totalSecondsWaited > OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS) { @@ -146,7 +148,6 @@ public void onNext(T value) throws StreamObserverCancelledException { "Output channel stalled for {}s, outbound thread {}.", totalSecondsWaited, Thread.currentThread().getName()); - Thread.dumpStack(); } waitSeconds = waitSeconds * 2; @@ -161,19 +162,27 @@ public void onNext(T value) throws StreamObserverCancelledException { public void onError(Throwable t) { isReadyNotifier.forceTermination(); synchronized (lock) { - isClosed = true; + markClosedOrThrow(); outboundObserver.onError(t); } } @Override public void onCompleted() { + isReadyNotifier.forceTermination(); synchronized (lock) { - isClosed = true; + markClosedOrThrow(); outboundObserver.onCompleted(); } } + private void markClosedOrThrow() { + synchronized (lock) { + Preconditions.checkState(!isClosed); + isClosed = true; + } + } + @Override public void terminate(Throwable terminationException) { // Free the blocked threads in onNext(). diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java index fb2555c8454f..5fb4f95e3e1e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/TerminatingStreamObserver.java @@ -23,6 +23,13 @@ @Internal public interface TerminatingStreamObserver extends StreamObserver { - /** Terminates the StreamObserver. */ + /** + * Terminates the StreamObserver. + * + * @implSpec Different then {@link #onError(Throwable)} and {@link #onCompleted()} which can only + * be called once during the lifetime of each {@link StreamObserver}, terminate() + * implementations are meant to be idempotent and can be called multiple times as well as + * being interleaved with other stream operations. + */ void terminate(Throwable terminationException); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index aada071416de..036158f5289e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -28,12 +28,14 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -104,6 +106,10 @@ public void setMessageCompression(boolean b) {} () -> assertThrows(WindmillStreamShutdownException.class, () -> testStream.testSend(1))); testStream.shutdown(); + + // Sleep a bit to give sendExecutor time to execute the send(). + Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + sendBlocker.countDown(); assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index b6454d319f9f..316ff76eb929 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -20,6 +20,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -33,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; +import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.ManagedChannel; @@ -51,9 +53,12 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.InOrder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { + private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStreamTest.class); private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() @@ -131,13 +136,27 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { commitProcessed.countDown(); }); } + } catch (StreamObserverCancelledException ignored) { } // Verify that we sent the commits above in a request + the initial header. - verify(requestObserver, times(2)).onNext(any(Windmill.StreamingCommitWorkRequest.class)); + verify(requestObserver, times(2)) + .onNext( + argThat( + request -> { + if (request.getHeader().equals(TEST_JOB_HEADER)) { + LOG.info("Header received."); + return true; + } else if (!request.getCommitChunkList().isEmpty()) { + LOG.info("Chunk received."); + return true; + } else { + LOG.error("Incorrect request."); + return false; + } + })); // We won't get responses so we will have some pending requests. assertTrue(commitWorkStream.hasPendingRequests()); - commitWorkStream.shutdown(); commitProcessed.await(); @@ -198,7 +217,7 @@ public void testSend_notCalledAfterShutdown() { // the header, which happens before we shutdown. requestObserverVerifier .verify(requestObserver) - .onNext(any(Windmill.StreamingCommitWorkRequest.class)); + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); requestObserverVerifier.verify(requestObserver).onError(any()); requestObserverVerifier.verifyNoMoreInteractions(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java index d8b787fe1020..a6120f4052b3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -24,9 +24,11 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamShutdownException; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -142,7 +144,8 @@ public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOn assertThrows( WindmillStreamShutdownException.class, queuedBatch::waitForSendOrFailNotification)); - + // Wait a few seconds for the above future to get scheduled and run. + Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); queuedBatch.notifyFailed(); waitFuture.join(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index 0ce455ac1270..252a73c92319 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.grpc; import static com.google.common.truth.Truth.assertThat; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; @@ -28,6 +29,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; import javax.annotation.Nullable; @@ -44,6 +46,7 @@ import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.testing.GrpcCleanupRule; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.util.MutableHandlerRegistry; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -102,6 +105,48 @@ private GrpcGetDataStream createGetDataStream(GetDataStreamTestStub testStub) { return getDataStream; } + @Test + public void testRequestKeyedData() { + GetDataStreamTestStub testStub = + new GetDataStreamTestStub(new TestGetDataStreamRequestObserver()); + GrpcGetDataStream getDataStream = createGetDataStream(testStub); + // These will block until they are successfully sent. + CompletableFuture sendFuture = + CompletableFuture.supplyAsync( + () -> { + try { + return getDataStream.requestKeyedData( + "computationId", + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setCacheToken(1) + .setWorkToken(1) + .build()); + } catch (WindmillStreamShutdownException e) { + throw new RuntimeException(e); + } + }); + + // Sleep a bit to allow future to run. + Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + + Windmill.KeyedGetDataResponse response = + Windmill.KeyedGetDataResponse.newBuilder() + .setShardingKey(1) + .setKey(ByteString.EMPTY) + .build(); + + testStub.injectResponse( + Windmill.StreamingGetDataResponse.newBuilder() + .addRequestId(1) + .addSerializedResponse(response.toByteString()) + .setRemainingBytesForResponse(0) + .build()); + + assertThat(sendFuture.join()).isEqualTo(response); + } + @Test public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdownException() { GetDataStreamTestStub testStub = @@ -206,5 +251,9 @@ public StreamObserver getDataStream( return requestObserver; } + + private void injectResponse(Windmill.StreamingGetDataResponse getDataResponse) { + checkNotNull(responseObserver).onNext(getDataResponse); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index bc978a72a8d3..4f0552959ee1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -473,7 +473,6 @@ private void flushResponse() { responseObserver.onNext(responseBuilder.build()); } catch (Exception e) { // Stream is already closed. - LOG.warn("trieu: ", e); LOG.warn(Arrays.toString(e.getStackTrace())); } responseBuilder.clear(); @@ -514,9 +513,7 @@ private void flushResponse() { done.countDown(); }); } - while (done.await(5, TimeUnit.SECONDS)) { - LOG.info("trieu: {}", done.getCount()); - } + while (done.await(5, TimeUnit.SECONDS)) {} stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); executor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java index ee0a9280610a..374c5aec3b5b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -35,6 +35,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.runners.dataflow.worker.windmill.WindmillServerStub; import org.apache.beam.sdk.fn.stream.AdvancingPhaser; import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; @@ -122,15 +123,16 @@ public void testOnNext_timeOut() throws ExecutionException, InterruptedException new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); CountDownLatch streamObserverExitLatch = new CountDownLatch(1); - Future onNextFuture = + Future onNextFuture = onNextExecutor.submit( () -> { // Won't block on the first one. streamObserver.onNext(1); // We will check isReady on the next message, will block here. - StreamObserverCancelledException e = + WindmillServerStub.WindmillRpcException e = assertThrows( - StreamObserverCancelledException.class, () -> streamObserver.onNext(1)); + WindmillServerStub.WindmillRpcException.class, + () -> streamObserver.onNext(1)); streamObserverExitLatch.countDown(); return e; }); From f75df0fdd35c8c3d05875327809d89de5fba4341 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Wed, 13 Nov 2024 13:23:57 -0800 Subject: [PATCH 22/23] address PR comments --- .../FanOutStreamingEngineWorkerHarness.java | 49 +++++-- .../harness/WindmillStreamSender.java | 1 + .../worker/windmill/WindmillEndpoints.java | 4 + .../client/AbstractWindmillStream.java | 10 +- .../ResettableThrowingStreamObserver.java | 6 + .../WindmillStreamShutdownException.java | 11 ++ .../client/grpc/GrpcGetDataStream.java | 11 +- .../grpc/GrpcWindmillStreamFactory.java | 2 +- .../grpc/observers/DirectStreamObserver.java | 44 +++--- ...anOutStreamingEngineWorkerHarnessTest.java | 68 +++++++-- .../client/AbstractWindmillStreamTest.java | 2 +- .../client/grpc/GrpcCommitWorkStreamTest.java | 29 ++-- .../grpc/GrpcGetDataStreamRequestsTest.java | 2 +- .../client/grpc/GrpcGetDataStreamTest.java | 2 +- .../client/grpc/GrpcWindmillServerTest.java | 2 +- .../observers/DirectStreamObserverTest.java | 137 +++++++++++++++++- 16 files changed, 298 insertions(+), 82 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index b42223844b26..f2fe49688bb5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -20,7 +20,9 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet.toImmutableSet; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; +import java.util.List; import java.util.Map.Entry; import java.util.NoSuchElementException; import java.util.Optional; @@ -62,6 +64,7 @@ import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; @@ -235,7 +238,8 @@ public synchronized void shutdown() { Preconditions.checkState(started, "FanOutStreamingEngineWorkerHarness never started."); Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); workerMetadataConsumer.shutdownNow(); - closeStreamsNotIn(WindmillEndpoints.none()); + // Close all the streams blocking until this completes to not leak resources. + closeStreamsNotIn(WindmillEndpoints.none()).forEach(CompletableFuture::join); channelCachingStubFactory.shutdown(); try { @@ -299,24 +303,39 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi } /** Close the streams that are no longer valid asynchronously. */ - private void closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { + @CanIgnoreReturnValue + private ImmutableList> closeStreamsNotIn( + WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - currentBackends.windmillStreams().entrySet().stream() - .filter( - connectionAndStream -> - !newWindmillEndpoints.windmillEndpoints().contains(connectionAndStream.getKey())) - .forEach( - entry -> - windmillStreamManager.execute( - () -> closeStreamSender(entry.getKey(), entry.getValue()))); + List> closeStreamFutures = + currentBackends.windmillStreams().entrySet().stream() + .filter( + connectionAndStream -> + !newWindmillEndpoints + .windmillEndpoints() + .contains(connectionAndStream.getKey())) + .map( + entry -> + CompletableFuture.runAsync( + () -> closeStreamSender(entry.getKey(), entry.getValue()), + windmillStreamManager)) + .collect(Collectors.toList()); Set newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); - currentBackends.globalDataStreams().values().stream() - .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) - .forEach( - sender -> - windmillStreamManager.execute(() -> closeStreamSender(sender.endpoint(), sender))); + List> closeGlobalDataStreamFutures = + currentBackends.globalDataStreams().values().stream() + .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) + .map( + sender -> + CompletableFuture.runAsync( + () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)) + .collect(Collectors.toList()); + + return ImmutableList.>builder() + .addAll(closeStreamFutures) + .addAll(closeGlobalDataStreamFutures) + .build(); } private void closeStreamSender(Endpoint endpoint, StreamSender sender) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java index 89aaa0d8b640..2a2f49dff846 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/WindmillStreamSender.java @@ -128,6 +128,7 @@ private static GetWorkRequest withRequestBudget(GetWorkRequest request, GetWorkB synchronized void start() { if (!started.get()) { checkState(!streamStarter.isShutdown(), "WindmillStreamSender has already been shutdown."); + // Start these 3 streams in parallel since they each may perform blocking IO. CompletableFuture.allOf( CompletableFuture.runAsync(getWorkStream::start, streamStarter), diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index eb269eef848f..14e71d0d7c45 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -132,6 +132,10 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } + public final boolean isEmpty() { + return equals(none()); + } + /** Version of the endpoints which increases with every modification. */ public abstract long version(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index df34797b647a..939ea1ea75be 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -338,8 +338,10 @@ public final void shutdown() { protected abstract void shutdownInternal(); /** Returns true if the stream was torn down and should not be restarted internally. */ - private synchronized boolean maybeTeardownStream() { - if (isShutdown || (clientClosed && !hasPendingRequests())) { + private synchronized boolean maybeTearDownStream() { + if (requestObserver.hasReceivedPoisonPill() + || isShutdown + || (clientClosed && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); @@ -364,7 +366,7 @@ public void onNext(ResponseT response) { @Override public void onError(Throwable t) { - if (maybeTeardownStream()) { + if (maybeTearDownStream()) { return; } @@ -392,7 +394,7 @@ public void onError(Throwable t) { @Override public void onCompleted() { - if (maybeTeardownStream()) { + if (maybeTearDownStream()) { return; } recordStreamStatus(OK_STATUS); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index 5eb691cbf55a..0ea7680c2821 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -104,6 +104,10 @@ synchronized void poison() { } } + synchronized boolean hasReceivedPoisonPill() { + return isPoisoned; + } + public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException { // Make sure onNext and onError below to be called on the same StreamObserver instance. StreamObserver delegate = delegate(); @@ -125,6 +129,8 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce // If the delegate above was already terminated via onError or onComplete from another // thread. logger.warn("StreamObserver was previously cancelled.", e); + } catch (RuntimeException ignored) { + logger.warn("StreamObserver was unexpectedly cancelled.", e); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 566c15c58036..65c1a48ae982 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -26,4 +26,15 @@ public final class WindmillStreamShutdownException extends Exception { public WindmillStreamShutdownException(String message) { super(message); } + + /** Returns whether an exception was caused by a {@link WindmillStreamShutdownException}. */ + public static boolean isCauseOf(Throwable t) { + while (t != null) { + if (t instanceof WindmillStreamShutdownException) { + return true; + } + t = t.getCause(); + } + return false; + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 19eb6dd4915a..b5b49c8ee976 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -187,9 +187,14 @@ protected void onResponse(StreamingGetDataResponse chunk) { onHeartbeatResponse(chunk.getComputationHeartbeatResponseList()); for (int i = 0; i < chunk.getRequestIdCount(); ++i) { - AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); - synchronized (this) { - verify(responseStream != null || isShutdown, "No pending response stream"); + @Nullable AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); + if (responseStream == null) { + synchronized (this) { + // shutdown()/shutdownInternal() cleans up pending, else we expect a pending + // responseStream for every response. + verify(isShutdown, "No pending response stream"); + } + continue; } responseStream.append(chunk.getSerializedResponse(i).newInput()); if (chunk.getRemainingBytesForResponse() == 0) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java index af096d5441b4..df69af207899 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillStreamFactory.java @@ -332,7 +332,7 @@ public void appendSummaryHtml(PrintWriter writer) { } @VisibleForTesting - ImmutableSet> streamRegistry() { + final ImmutableSet> streamRegistry() { return ImmutableSet.copyOf(streamRegistry); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 3a289e4dd48b..25221c901444 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -39,27 +39,28 @@ * becomes ready. */ @ThreadSafe -public final class DirectStreamObserver implements TerminatingStreamObserver { +final class DirectStreamObserver implements TerminatingStreamObserver { private static final Logger LOG = LoggerFactory.getLogger(DirectStreamObserver.class); private static final long OUTPUT_CHANNEL_CONSIDERED_STALLED_SECONDS = 30; private final Phaser isReadyNotifier; - + private final long deadlineSeconds; + private final int messagesBetweenIsReadyChecks; private final Object lock = new Object(); @GuardedBy("lock") private final CallStreamObserver outboundObserver; - private final long deadlineSeconds; - private final int messagesBetweenIsReadyChecks; - @GuardedBy("lock") private boolean isClosed = false; + @GuardedBy("lock") + private boolean isUserClosed = false; + @GuardedBy("lock") private int messagesSinceReady = 0; - public DirectStreamObserver( + DirectStreamObserver( Phaser isReadyNotifier, CallStreamObserver outboundObserver, long deadlineSeconds, @@ -89,6 +90,9 @@ public void onNext(T value) throws StreamObserverCancelledException { throw new StreamObserverCancelledException("StreamObserver was terminated."); } + // We close under "lock", so this should never happen. + assert !isClosed; + // If we awaited previously and timed out, wait for the same phase. Otherwise we're // careful to observe the phase before observing isReady. if (awaitPhase < 0) { @@ -131,6 +135,10 @@ public void onNext(T value) throws StreamObserverCancelledException { if (currentPhase < 0) { throw new StreamObserverCancelledException("StreamObserver was terminated."); } + + // We close under "lock", so this should never happen. + assert !isClosed; + messagesSinceReady = 0; outboundObserver.onNext(value); return; @@ -162,8 +170,11 @@ public void onNext(T value) throws StreamObserverCancelledException { public void onError(Throwable t) { isReadyNotifier.forceTermination(); synchronized (lock) { - markClosedOrThrow(); - outboundObserver.onError(t); + if (!isClosed) { + Preconditions.checkState(!isUserClosed); + outboundObserver.onError(t); + isClosed = true; + } } } @@ -171,15 +182,11 @@ public void onError(Throwable t) { public void onCompleted() { isReadyNotifier.forceTermination(); synchronized (lock) { - markClosedOrThrow(); - outboundObserver.onCompleted(); - } - } - - private void markClosedOrThrow() { - synchronized (lock) { - Preconditions.checkState(!isClosed); - isClosed = true; + if (!isClosed) { + Preconditions.checkState(!isUserClosed); + outboundObserver.onCompleted(); + isClosed = true; + } } } @@ -188,8 +195,9 @@ public void terminate(Throwable terminationException) { // Free the blocked threads in onNext(). isReadyNotifier.forceTermination(); synchronized (lock) { - if (!isClosed) { + if (!isUserClosed) { onError(terminationException); + isUserClosed = true; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java index 7d0658534080..bba6cad5529a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarnessTest.java @@ -40,6 +40,8 @@ import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillMetadataServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkerMetadataRequest; @@ -79,6 +81,7 @@ @RunWith(JUnit4.class) public class FanOutStreamingEngineWorkerHarnessTest { + private static final String CHANNEL_NAME = "FanOutStreamingEngineWorkerHarnessTest"; private static final WindmillServiceAddress DEFAULT_WINDMILL_SERVICE_ADDRESS = WindmillServiceAddress.create(HostAndPort.fromParts(WindmillChannelFactory.LOCALHOST, 443)); private static final ImmutableMap DEFAULT = @@ -105,9 +108,7 @@ public class FanOutStreamingEngineWorkerHarnessTest { spy(GrpcWindmillStreamFactory.of(JOB_HEADER).build()); private final ChannelCachingStubFactory stubFactory = new FakeWindmillStubFactory( - () -> - grpcCleanup.register( - WindmillChannelFactory.inProcessChannel("StreamingEngineClientTest"))); + () -> grpcCleanup.register(WindmillChannelFactory.inProcessChannel(CHANNEL_NAME))); private final GrpcDispatcherClient dispatcherClient = GrpcDispatcherClient.forTesting( PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class), @@ -148,7 +149,7 @@ public void setUp() throws IOException { stubFactory.shutdown(); fakeStreamingEngineServer = grpcCleanup.register( - InProcessServerBuilder.forName("StreamingEngineClientTest") + InProcessServerBuilder.forName(CHANNEL_NAME) .fallbackHandlerRegistry(serviceRegistry) .executor(Executors.newFixedThreadPool(1)) .build()); @@ -156,18 +157,18 @@ public void setUp() throws IOException { fakeStreamingEngineServer.start(); dispatcherClient.consumeWindmillDispatcherEndpoints( ImmutableSet.of( - HostAndPort.fromString( - new InProcessSocketAddress("StreamingEngineClientTest").toString()))); + HostAndPort.fromString(new InProcessSocketAddress(CHANNEL_NAME).toString()))); getWorkerMetadataReady = new CountDownLatch(1); fakeGetWorkerMetadataStub = new GetWorkerMetadataTestStub(getWorkerMetadataReady); serviceRegistry.addService(fakeGetWorkerMetadataStub); + serviceRegistry.addService(new WindmillServiceFakeStub()); } @After public void cleanUp() { Preconditions.checkNotNull(fanOutStreamingEngineWorkProvider).shutdown(); - fakeStreamingEngineServer.shutdownNow(); stubFactory.shutdown(); + fakeStreamingEngineServer.shutdownNow(); } private FanOutStreamingEngineWorkerHarness newFanOutStreamingEngineWorkerHarness( @@ -248,9 +249,8 @@ public void testStreamsStartCorrectly() throws InterruptedException { @Test public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() throws InterruptedException { - int metadataCount = 2; TestGetWorkBudgetDistributor getWorkBudgetDistributor = - spy(new TestGetWorkBudgetDistributor(metadataCount)); + spy(new TestGetWorkBudgetDistributor(1)); fanOutStreamingEngineWorkProvider = newFanOutStreamingEngineWorkerHarness( GetWorkBudget.builder().setItems(1).setBytes(1).build(), @@ -285,6 +285,8 @@ public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers() getWorkerMetadataReady.await(); fakeGetWorkerMetadataStub.injectWorkerMetadata(firstWorkerMetadata); + assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); + getWorkBudgetDistributor.expectNumDistributions(1); fakeGetWorkerMetadataStub.injectWorkerMetadata(secondWorkerMetadata); assertTrue(getWorkBudgetDistributor.waitForBudgetDistribution()); StreamingEngineBackends currentBackends = fanOutStreamingEngineWorkProvider.currentBackends(); @@ -342,6 +344,54 @@ public void testOnNewWorkerMetadata_redistributesBudget() throws InterruptedExce verify(getWorkBudgetDistributor, times(2)).distributeBudget(any(), any()); } + private static class WindmillServiceFakeStub + extends CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1ImplBase { + @Override + public StreamObserver getDataStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetDataRequest getDataRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + + @Override + public StreamObserver getWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingGetWorkRequest getWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + + @Override + public StreamObserver commitWorkStream( + StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(Windmill.StreamingCommitWorkRequest streamingCommitWorkRequest) {} + + @Override + public void onError(Throwable throwable) {} + + @Override + public void onCompleted() {} + }; + } + } + private static class GetWorkerMetadataTestStub extends CloudWindmillMetadataServiceV1Alpha1Grpc .CloudWindmillMetadataServiceV1Alpha1ImplBase { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java index 036158f5289e..05fbc6f969df 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStreamTest.java @@ -108,7 +108,7 @@ public void setMessageCompression(boolean b) {} testStream.shutdown(); // Sleep a bit to give sendExecutor time to execute the send(). - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); sendBlocker.countDown(); assertThat(sendFuture.get()).isInstanceOf(WindmillStreamShutdownException.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index 316ff76eb929..b7e8f50f9249 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -23,8 +23,6 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import java.io.IOException; import java.util.HashSet; @@ -53,12 +51,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.InOrder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; @RunWith(JUnit4.class) public class GrpcCommitWorkStreamTest { - private static final Logger LOG = LoggerFactory.getLogger(GrpcCommitWorkStreamTest.class); private static final String FAKE_SERVER_NAME = "Fake server for GrpcCommitWorkStreamTest"; private static final Windmill.JobHeader TEST_JOB_HEADER = Windmill.JobHeader.newBuilder() @@ -126,6 +121,7 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { spy(new TestCommitWorkStreamRequestObserver()); CommitWorkStreamTestStub testStub = new CommitWorkStreamTestStub(requestObserver); GrpcCommitWorkStream commitWorkStream = createCommitWorkStream(testStub); + InOrder requestObserverVerifier = inOrder(requestObserver); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { batcher.commitWorkItem( @@ -140,21 +136,14 @@ public void testShutdown_abortsQueuedCommits() throws InterruptedException { } // Verify that we sent the commits above in a request + the initial header. - verify(requestObserver, times(2)) - .onNext( - argThat( - request -> { - if (request.getHeader().equals(TEST_JOB_HEADER)) { - LOG.info("Header received."); - return true; - } else if (!request.getCommitChunkList().isEmpty()) { - LOG.info("Chunk received."); - return true; - } else { - LOG.error("Incorrect request."); - return false; - } - })); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> request.getHeader().equals(TEST_JOB_HEADER))); + requestObserverVerifier + .verify(requestObserver) + .onNext(argThat(request -> !request.getCommitChunkList().isEmpty())); + requestObserverVerifier.verifyNoMoreInteractions(); + // We won't get responses so we will have some pending requests. assertTrue(commitWorkStream.hasPendingRequests()); commitWorkStream.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java index a6120f4052b3..dc2dce7807a9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamRequestsTest.java @@ -145,7 +145,7 @@ public void testQueuedBatch_notifyFailed_throwsWindmillStreamShutdownExceptionOn WindmillStreamShutdownException.class, queuedBatch::waitForSendOrFailNotification)); // Wait a few seconds for the above future to get scheduled and run. - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); queuedBatch.notifyFailed(); waitFuture.join(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index 252a73c92319..e5e77e16abef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -129,7 +129,7 @@ public void testRequestKeyedData() { }); // Sleep a bit to allow future to run. - Uninterruptibles.sleepUninterruptibly(5, TimeUnit.SECONDS); + Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS); Windmill.KeyedGetDataResponse response = Windmill.KeyedGetDataResponse.newBuilder() diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java index 4f0552959ee1..a595524ca582 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcWindmillServerTest.java @@ -513,7 +513,7 @@ private void flushResponse() { done.countDown(); }); } - while (done.await(5, TimeUnit.SECONDS)) {} + done.await(); stream.halfClose(); assertTrue(stream.awaitTermination(60, TimeUnit.SECONDS)); executor.shutdown(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java index 374c5aec3b5b..6a51ddc07d1a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -19,9 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -39,6 +41,8 @@ import org.apache.beam.sdk.fn.stream.AdvancingPhaser; import org.apache.beam.vendor.grpc.v1p60p1.com.google.common.base.VerifyException; import org.apache.beam.vendor.grpc.v1p60p1.io.grpc.stub.CallStreamObserver; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,10 +50,123 @@ @RunWith(JUnit4.class) public class DirectStreamObserverTest { + @Test + public void testOnNext_onCompleted() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnNext_onError() throws ExecutionException, InterruptedException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>( + new AdvancingPhaser(1), delegate, Long.MAX_VALUE, Integer.MAX_VALUE); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + Future onNextFuture = + onNextExecutor.submit( + () -> { + streamObserver.onNext(1); + streamObserver.onNext(1); + streamObserver.onNext(1); + }); + + // Wait for all of the onNext's to run. + onNextFuture.get(); + + verify(delegate, times(3)).onNext(eq(1)); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnCompleted_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + streamObserver.onCompleted(); + streamObserver.onCompleted(); + streamObserver.onCompleted(); + + verify(delegate, times(1)).onCompleted(); + } + + @Test + public void testOnError_executedOnce() { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + + RuntimeException error = new RuntimeException(); + streamObserver.onError(error); + streamObserver.onError(error); + streamObserver.onError(error); + + verify(delegate, times(1)).onError(same(error)); + } + + @Test + public void testOnNext_waitForReady() throws InterruptedException, ExecutionException { + TestStreamObserver delegate = spy(new TestStreamObserver(Integer.MAX_VALUE)); + delegate.setIsReady(false); + DirectStreamObserver streamObserver = + new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); + ExecutorService onNextExecutor = Executors.newSingleThreadExecutor(); + CountDownLatch blockLatch = new CountDownLatch(1); + Future<@Nullable Object> onNextFuture = + onNextExecutor.submit( + () -> { + // Won't block on the first one. + streamObserver.onNext(1); + try { + // We will check isReady on the next message, will block here. + streamObserver.onNext(1); + streamObserver.onNext(1); + blockLatch.countDown(); + return null; + } catch (Throwable e) { + return e; + } + }); + + while (delegate.getNumIsReadyChecks() <= 1) { + // Wait for isReady check to block. + Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); + } + + delegate.setIsReady(true); + blockLatch.await(); + verify(delegate, times(3)).onNext(eq(1)); + assertNull(onNextFuture.get()); + + streamObserver.onCompleted(); + verify(delegate, times(1)).onCompleted(); + } + @Test public void testTerminate_waitingForReady() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); @@ -82,8 +199,7 @@ public void testTerminate_waitingForReady() throws ExecutionException, Interrupt @Test public void testOnNext_interruption() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); @@ -116,8 +232,7 @@ public void testOnNext_interruption() throws ExecutionException, InterruptedExce @Test public void testOnNext_timeOut() throws ExecutionException, InterruptedException { - CountDownLatch sendBlocker = new CountDownLatch(1); - TestStreamObserver delegate = spy(new TestStreamObserver(sendBlocker, 2)); + TestStreamObserver delegate = spy(new TestStreamObserver(2)); delegate.setIsReady(false); DirectStreamObserver streamObserver = new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, 1, 1); @@ -152,11 +267,12 @@ private static class TestStreamObserver extends CallStreamObserver { private final CountDownLatch sendBlocker; private final int blockAfter; private final AtomicInteger seen = new AtomicInteger(0); + private final AtomicInteger numIsReadyChecks = new AtomicInteger(0); private volatile boolean isReady = false; - private TestStreamObserver(CountDownLatch sendBlocker, int blockAfter) { + private TestStreamObserver(int blockAfter) { this.blockAfter = blockAfter; - this.sendBlocker = sendBlocker; + this.sendBlocker = new CountDownLatch(1); } @Override @@ -178,9 +294,14 @@ public void onCompleted() {} @Override public boolean isReady() { + numIsReadyChecks.incrementAndGet(); return isReady; } + public int getNumIsReadyChecks() { + return numIsReadyChecks.get(); + } + private void setIsReady(boolean isReadyOverride) { isReady = isReadyOverride; } From 74e503da1558a4c4f5296c3e10ccd894edbd5f04 Mon Sep 17 00:00:00 2001 From: Martin Trieu Date: Fri, 15 Nov 2024 12:01:24 -0800 Subject: [PATCH 23/23] address PR comments --- .../FanOutStreamingEngineWorkerHarness.java | 26 +++++------ .../worker/windmill/WindmillEndpoints.java | 16 +++---- .../client/AbstractWindmillStream.java | 13 +++--- .../ResettableThrowingStreamObserver.java | 26 ++++++----- .../WindmillStreamShutdownException.java | 11 ----- .../grpc/observers/DirectStreamObserver.java | 44 ++++++++++++------- .../client/grpc/GrpcCommitWorkStreamTest.java | 7 +-- .../client/grpc/GrpcGetDataStreamTest.java | 1 - .../observers/DirectStreamObserverTest.java | 9 +--- 9 files changed, 73 insertions(+), 80 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java index f2fe49688bb5..115142f98b9c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/harness/FanOutStreamingEngineWorkerHarness.java @@ -22,7 +22,6 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; -import java.util.List; import java.util.Map.Entry; import java.util.NoSuchElementException; import java.util.Optional; @@ -35,6 +34,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.CheckReturnValue; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -64,9 +64,9 @@ import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.net.HostAndPort; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.slf4j.Logger; @@ -239,7 +239,7 @@ public synchronized void shutdown() { Preconditions.checkNotNull(getWorkerMetadataStream).shutdown(); workerMetadataConsumer.shutdownNow(); // Close all the streams blocking until this completes to not leak resources. - closeStreamsNotIn(WindmillEndpoints.none()).forEach(CompletableFuture::join); + closeStreamsNotIn(WindmillEndpoints.none()).join(); channelCachingStubFactory.shutdown(); try { @@ -304,10 +304,9 @@ private synchronized void consumeWindmillWorkerEndpoints(WindmillEndpoints newWi /** Close the streams that are no longer valid asynchronously. */ @CanIgnoreReturnValue - private ImmutableList> closeStreamsNotIn( - WindmillEndpoints newWindmillEndpoints) { + private CompletableFuture closeStreamsNotIn(WindmillEndpoints newWindmillEndpoints) { StreamingEngineBackends currentBackends = backends.get(); - List> closeStreamFutures = + Stream> closeStreamFutures = currentBackends.windmillStreams().entrySet().stream() .filter( connectionAndStream -> @@ -318,24 +317,21 @@ private ImmutableList> closeStreamsNotIn( entry -> CompletableFuture.runAsync( () -> closeStreamSender(entry.getKey(), entry.getValue()), - windmillStreamManager)) - .collect(Collectors.toList()); + windmillStreamManager)); Set newGlobalDataEndpoints = new HashSet<>(newWindmillEndpoints.globalDataEndpoints().values()); - List> closeGlobalDataStreamFutures = + Stream> closeGlobalDataStreamFutures = currentBackends.globalDataStreams().values().stream() .filter(sender -> !newGlobalDataEndpoints.contains(sender.endpoint())) .map( sender -> CompletableFuture.runAsync( - () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)) - .collect(Collectors.toList()); + () -> closeStreamSender(sender.endpoint(), sender), windmillStreamManager)); - return ImmutableList.>builder() - .addAll(closeStreamFutures) - .addAll(closeGlobalDataStreamFutures) - .build(); + return CompletableFuture.allOf( + Streams.concat(closeStreamFutures, closeGlobalDataStreamFutures) + .toArray(CompletableFuture[]::new)); } private void closeStreamSender(Endpoint endpoint, StreamSender sender) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java index 14e71d0d7c45..13b3ea954198 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/WindmillEndpoints.java @@ -40,13 +40,15 @@ @AutoValue public abstract class WindmillEndpoints { private static final Logger LOG = LoggerFactory.getLogger(WindmillEndpoints.class); + private static final WindmillEndpoints NO_ENDPOINTS = + WindmillEndpoints.builder() + .setVersion(Long.MAX_VALUE) + .setWindmillEndpoints(ImmutableSet.of()) + .setGlobalDataEndpoints(ImmutableMap.of()) + .build(); public static WindmillEndpoints none() { - return WindmillEndpoints.builder() - .setVersion(Long.MAX_VALUE) - .setWindmillEndpoints(ImmutableSet.of()) - .setGlobalDataEndpoints(ImmutableMap.of()) - .build(); + return NO_ENDPOINTS; } public static WindmillEndpoints from( @@ -132,10 +134,6 @@ private static Optional tryParseDirectEndpointIntoIpV6Address( directEndpointAddress.getHostAddress(), (int) endpointProto.getPort())); } - public final boolean isEmpty() { - return equals(none()); - } - /** Version of the endpoints which increases with every modification. */ public abstract long version(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index 939ea1ea75be..8b48459eba94 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -28,7 +28,6 @@ import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.concurrent.GuardedBy; -import org.apache.beam.runners.dataflow.worker.windmill.client.ResettableThrowingStreamObserver.StreamClosedException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.observers.StreamObserverFactory; import org.apache.beam.sdk.util.BackOff; @@ -159,7 +158,7 @@ protected final synchronized boolean trySend(RequestT request) try { requestObserver.onNext(request); return true; - } catch (StreamClosedException e) { + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { // Stream was broken, requests may be retried when stream is reopened. } @@ -199,6 +198,7 @@ private void startStream() { } catch (WindmillStreamShutdownException e) { // shutdown() is responsible for cleaning up pending requests. logger.debug("Stream was shutdown while creating new stream.", e); + break; } catch (Exception e) { logger.error("Failed to create new stream, retrying: ", e); try { @@ -298,10 +298,12 @@ public final synchronized void halfClose() { clientClosed = true; try { requestObserver.onCompleted(); - } catch (StreamClosedException e) { + } catch (ResettableThrowingStreamObserver.StreamClosedException e) { logger.warn("Stream was previously closed."); } catch (WindmillStreamShutdownException e) { logger.warn("Stream was previously shutdown."); + } catch (IllegalStateException e) { + logger.warn("Unexpected error when trying to close stream", e); } } @@ -320,7 +322,6 @@ public String backendWorkerToken() { return backendWorkerToken; } - @SuppressWarnings("GuardedBy") @Override public final void shutdown() { // Don't lock on "this" before poisoning the request observer since otherwise the observer may @@ -339,9 +340,7 @@ public final void shutdown() { /** Returns true if the stream was torn down and should not be restarted internally. */ private synchronized boolean maybeTearDownStream() { - if (requestObserver.hasReceivedPoisonPill() - || isShutdown - || (clientClosed && !hasPendingRequests())) { + if (isShutdown || (clientClosed && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); executor.shutdownNow(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java index 0ea7680c2821..1db6d8de791d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/ResettableThrowingStreamObserver.java @@ -104,10 +104,6 @@ synchronized void poison() { } } - synchronized boolean hasReceivedPoisonPill() { - return isPoisoned; - } - public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownException { // Make sure onNext and onError below to be called on the same StreamObserver instance. StreamObserver delegate = delegate(); @@ -115,22 +111,28 @@ public void onNext(T t) throws StreamClosedException, WindmillStreamShutdownExce // Do NOT lock while sending message over the stream as this will block other StreamObserver // operations. delegate.onNext(t); - } catch (StreamObserverCancelledException e) { + } catch (StreamObserverCancelledException cancellationException) { synchronized (this) { if (isPoisoned) { - logger.debug("Stream was shutdown during send.", e); + logger.debug("Stream was shutdown during send.", cancellationException); return; } } try { - delegate.onError(e); - } catch (IllegalStateException ignored) { + delegate.onError(cancellationException); + } catch (IllegalStateException onErrorException) { // If the delegate above was already terminated via onError or onComplete from another // thread. - logger.warn("StreamObserver was previously cancelled.", e); - } catch (RuntimeException ignored) { - logger.warn("StreamObserver was unexpectedly cancelled.", e); + logger.warn( + "StreamObserver was already cancelled {} due to error.", + onErrorException, + cancellationException); + } catch (RuntimeException onErrorException) { + logger.warn( + "Encountered unexpected error {} when cancelling due to error.", + onErrorException, + cancellationException); } } } @@ -156,7 +158,7 @@ synchronized boolean isClosed() { * {@link StreamObserver#onCompleted()}. The stream may perform */ static final class StreamClosedException extends Exception { - private StreamClosedException(String s) { + StreamClosedException(String s) { super(s); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java index 65c1a48ae982..566c15c58036 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamShutdownException.java @@ -26,15 +26,4 @@ public final class WindmillStreamShutdownException extends Exception { public WindmillStreamShutdownException(String message) { super(message); } - - /** Returns whether an exception was caused by a {@link WindmillStreamShutdownException}. */ - public static boolean isCauseOf(Throwable t) { - while (t != null) { - if (t instanceof WindmillStreamShutdownException) { - return true; - } - t = t.getCause(); - } - return false; - } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java index 25221c901444..8710d66d2c80 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserver.java @@ -52,7 +52,7 @@ final class DirectStreamObserver implements TerminatingStreamObserver { private final CallStreamObserver outboundObserver; @GuardedBy("lock") - private boolean isClosed = false; + private boolean isOutboundObserverClosed = false; @GuardedBy("lock") private boolean isUserClosed = false; @@ -74,8 +74,14 @@ final class DirectStreamObserver implements TerminatingStreamObserver { this.messagesBetweenIsReadyChecks = Math.max(1, messagesBetweenIsReadyChecks); } + /** + * @throws StreamObserverCancelledException if the StreamObserver was closed via {@link + * #onError(Throwable)}, {@link #onCompleted()}, or {@link #terminate(Throwable)} while + * waiting for {@code outboundObserver#isReady}. + * @throws WindmillRpcException if we time out for waiting for {@code outboundObserver#isReady}. + */ @Override - public void onNext(T value) throws StreamObserverCancelledException { + public void onNext(T value) { int awaitPhase = -1; long totalSecondsWaited = 0; long waitSeconds = 1; @@ -90,8 +96,9 @@ public void onNext(T value) throws StreamObserverCancelledException { throw new StreamObserverCancelledException("StreamObserver was terminated."); } - // We close under "lock", so this should never happen. - assert !isClosed; + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; // If we awaited previously and timed out, wait for the same phase. Otherwise we're // careful to observe the phase before observing isReady. @@ -136,8 +143,9 @@ public void onNext(T value) throws StreamObserverCancelledException { throw new StreamObserverCancelledException("StreamObserver was terminated."); } - // We close under "lock", so this should never happen. - assert !isClosed; + // Closing is performed under "lock" after terminating, so if termination was not observed + // above, the observer should not be closed. + assert !isOutboundObserverClosed; messagesSinceReady = 0; outboundObserver.onNext(value); @@ -166,26 +174,32 @@ public void onNext(T value) throws StreamObserverCancelledException { } } + /** @throws IllegalStateException if called multiple times or after {@link #onCompleted()}. */ @Override public void onError(Throwable t) { isReadyNotifier.forceTermination(); synchronized (lock) { - if (!isClosed) { - Preconditions.checkState(!isUserClosed); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { outboundObserver.onError(t); - isClosed = true; + isOutboundObserverClosed = true; } } } + /** + * @throws IllegalStateException if called multiple times or after {@link #onError(Throwable)}. + */ @Override public void onCompleted() { isReadyNotifier.forceTermination(); synchronized (lock) { - if (!isClosed) { - Preconditions.checkState(!isUserClosed); + Preconditions.checkState(!isUserClosed); + isUserClosed = true; + if (!isOutboundObserverClosed) { outboundObserver.onCompleted(); - isClosed = true; + isOutboundObserverClosed = true; } } } @@ -195,9 +209,9 @@ public void terminate(Throwable terminationException) { // Free the blocked threads in onNext(). isReadyNotifier.forceTermination(); synchronized (lock) { - if (!isUserClosed) { - onError(terminationException); - isUserClosed = true; + if (!isOutboundObserverClosed) { + outboundObserver.onError(terminationException); + isOutboundObserverClosed = true; } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java index b7e8f50f9249..7de824b86fd2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStreamTest.java @@ -28,6 +28,7 @@ import java.util.HashSet; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.CloudWindmillServiceV1Alpha1Grpc; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -167,15 +168,15 @@ public void testCommitWorkItem_afterShutdown() { } commitWorkStream.shutdown(); - Set commitStatuses = new HashSet<>(); + AtomicReference commitStatus = new AtomicReference<>(); try (WindmillStream.CommitWorkStream.RequestBatcher batcher = commitWorkStream.batcher()) { for (int i = 0; i < numCommits; i++) { assertTrue( - batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatuses::add)); + batcher.commitWorkItem(COMPUTATION_ID, workItemCommitRequest(i), commitStatus::set)); } } - assertThat(commitStatuses).containsExactly(Windmill.CommitStatus.ABORTED); + assertThat(commitStatus.get()).isEqualTo(Windmill.CommitStatus.ABORTED); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java index e5e77e16abef..3125def64b32 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStreamTest.java @@ -172,7 +172,6 @@ public void testRequestKeyedData_sendOnShutdownStreamThrowsWindmillStreamShutdow } } try { - getDataStream.requestKeyedData( "computationId", Windmill.KeyedGetDataRequest.newBuilder() diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java index 6a51ddc07d1a..6bc713aa7747 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/observers/DirectStreamObserverTest.java @@ -106,10 +106,7 @@ public void testOnCompleted_executedOnce() { new DirectStreamObserver<>(new AdvancingPhaser(1), delegate, Long.MAX_VALUE, 1); streamObserver.onCompleted(); - streamObserver.onCompleted(); - streamObserver.onCompleted(); - - verify(delegate, times(1)).onCompleted(); + assertThrows(IllegalStateException.class, streamObserver::onCompleted); } @Test @@ -120,9 +117,7 @@ public void testOnError_executedOnce() { RuntimeException error = new RuntimeException(); streamObserver.onError(error); - streamObserver.onError(error); - streamObserver.onError(error); - + assertThrows(IllegalStateException.class, () -> streamObserver.onError(error)); verify(delegate, times(1)).onError(same(error)); }