Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize failed key processing by indexing workitems by shardingkey + workid #33755

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
package org.apache.beam.runners.dataflow.worker.streaming;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap;

import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -36,14 +33,13 @@
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
Expand All @@ -63,11 +59,11 @@ public final class ActiveWorkState {
private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50;

/**
* Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link
* Queue<Work>} is actively processing.
* Map from shardingKey to {@link Work} for the key. The first item in the {@link LinkedHashMap}
* is actively processing.
*/
@GuardedBy("this")
private final Map<ShardedKey, Deque<ExecutableWork>> activeWork;
private final Map<Long /*shardingKey*/, LinkedHashMap<WorkId, ExecutableWork>> activeWork;

@GuardedBy("this")
private final WindmillStateCache.ForComputation computationStateCache;
Expand All @@ -81,8 +77,8 @@ public final class ActiveWorkState {
private GetWorkBudget activeGetWorkBudget;

private ActiveWorkState(
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
WindmillStateCache.ForComputation computationStateCache) {
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
ForComputation computationStateCache) {
this.activeWork = activeWork;
this.computationStateCache = computationStateCache;
this.activeGetWorkBudget = GetWorkBudget.noBudget();
Expand All @@ -94,7 +90,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState

@VisibleForTesting
static ActiveWorkState forTesting(
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
WindmillStateCache.ForComputation computationStateCache) {
return new ActiveWorkState(activeWork, computationStateCache);
}
Expand Down Expand Up @@ -124,28 +120,30 @@ private static String elapsedString(Instant start, Instant end) {
*/
synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork) {
ShardedKey shardedKey = executableWork.work().getShardedKey();
Deque<ExecutableWork> workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>());
long shardingKey = shardedKey.shardingKey();
LinkedHashMap<WorkId, ExecutableWork> workQueue =
activeWork.computeIfAbsent(shardingKey, (unused) -> new LinkedHashMap<>());
// This key does not have any work queued up on it. Create one, insert Work, and mark the work
// to be executed.
if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
workQueue.addLast(executableWork);
activeWork.put(shardedKey, workQueue);
if (workQueue.isEmpty()) {
workQueue.put(executableWork.id(), executableWork);
incrementActiveWorkBudget(executableWork.work());
return ActivateWorkResult.EXECUTE;
}

// Check to see if we have this work token queued.
Iterator<ExecutableWork> workIterator = workQueue.iterator();
Iterator<Entry<WorkId, ExecutableWork>> workIterator = workQueue.entrySet().iterator();
while (workIterator.hasNext()) {
ExecutableWork queuedWork = workIterator.next();
ExecutableWork queuedWork = workIterator.next().getValue();
if (queuedWork.id().equals(executableWork.id())) {
return ActivateWorkResult.DUPLICATE;
}
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()) {
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()
&& queuedWork.work().getShardedKey().equals(executableWork.work().getShardedKey())) {
if (executableWork.id().workToken() > queuedWork.id().workToken()) {
arunpandianp marked this conversation as resolved.
Show resolved Hide resolved
// Check to see if the queuedWork is active. We only want to remove it if it is NOT
// currently active.
if (!queuedWork.equals(workQueue.peek())) {
if (!queuedWork.equals(Preconditions.checkNotNull(firstEntry(workQueue)).getValue())) {
scwhittle marked this conversation as resolved.
Show resolved Hide resolved
workIterator.remove();
decrementActiveWorkBudget(queuedWork.work());
}
Expand All @@ -157,7 +155,7 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
}

// Queue the work for later processing.
workQueue.addLast(executableWork);
workQueue.put(executableWork.id(), executableWork);
incrementActiveWorkBudget(executableWork.work());
return ActivateWorkResult.QUEUED;
}
Expand All @@ -167,54 +165,28 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
*
* @param failedWork a map from sharding_key to tokens for the corresponding work.
*/
synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
// Note we can't construct a ShardedKey and look it up in activeWork directly since
// HeartbeatResponse doesn't include the user key.
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
Collection<WorkId> failedWorkIds = failedWork.get(entry.getKey().shardingKey());
for (WorkId failedWorkId : failedWorkIds) {
for (ExecutableWork queuedWork : entry.getValue()) {
WorkItem workItem = queuedWork.work().getWorkItem();
if (workItem.getWorkToken() == failedWorkId.workToken()
&& workItem.getCacheToken() == failedWorkId.cacheToken()) {
LOG.debug(
"Failing work "
+ computationStateCache.getComputation()
+ " "
+ entry.getKey().shardingKey()
+ " "
+ failedWorkId.workToken()
+ " "
+ failedWorkId.cacheToken()
+ ". The work will be retried and is not lost.");
queuedWork.work().setFailed();
break;
}
}
synchronized void failWorkForKey(ImmutableList<WorkIdWithShardingKey> failedWork) {
for (WorkIdWithShardingKey failedId : failedWork) {
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(failedId.shardingKey());
if (workQueue == null) {
// Work could complete/fail before heartbeat response arrives
continue;
}
ExecutableWork executableWork = workQueue.get(failedId.workId());
if (executableWork == null) {
continue;
}
executableWork.work().setFailed();
LOG.debug(
"Failing work {} {} The work will be retried and is not lost.",
arunpandianp marked this conversation as resolved.
Show resolved Hide resolved
computationStateCache.getComputation(),
failedId);
}
}

/**
* Returns a read only view of current active work.
*
* @implNote Do not return a reference to the underlying workQueue as iterations over it will
* cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
* structure.
*/
synchronized ImmutableListMultimap<ShardedKey, RefreshableWork> getReadOnlyActiveWork() {
return activeWork.entrySet().stream()
.collect(
flatteningToImmutableListMultimap(
Entry::getKey,
e ->
e.getValue().stream()
.map(executableWork -> (RefreshableWork) executableWork.work())));
}

synchronized ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
return activeWork.values().stream()
.flatMap(Deque::stream)
.flatMap(workMap -> workMap.values().stream())
.map(ExecutableWork::work)
.filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline))
.collect(toImmutableList());
Expand All @@ -236,7 +208,8 @@ private synchronized void decrementActiveWorkBudget(Work work) {
*/
synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
ShardedKey shardedKey, WorkId workId) {
@Nullable Queue<ExecutableWork> workQueue = activeWork.get(shardedKey);
@Nullable
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(shardedKey.shardingKey());
if (workQueue == null) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn(
Expand All @@ -251,14 +224,15 @@ synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
}

private synchronized void removeCompletedWorkFromQueue(
Queue<ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
@Nullable ExecutableWork completedWork = workQueue.peek();
if (completedWork == null) {
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
Iterator<Entry<WorkId, ExecutableWork>> completedWorkIterator = workQueue.entrySet().iterator();
if (!completedWorkIterator.hasNext()) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn("Active key {} without work, expected token {}", shardedKey, workId);
return;
}

ExecutableWork completedWork = completedWorkIterator.next().getValue();
if (!completedWork.id().equals(workId)) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn(
Expand All @@ -271,19 +245,19 @@ private synchronized void removeCompletedWorkFromQueue(
completedWork.id());
return;
}

// We consumed the matching work item.
workQueue.remove();
completedWorkIterator.remove();
decrementActiveWorkBudget(completedWork.work());
}

@SuppressWarnings("ReferenceEquality")
private synchronized Optional<ExecutableWork> getNextWork(
Queue<ExecutableWork> workQueue, ShardedKey shardedKey) {
Optional<ExecutableWork> nextWork = Optional.ofNullable(workQueue.peek());
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey) {
Optional<ExecutableWork> nextWork =
Optional.ofNullable(firstEntry(workQueue)).map(Entry::getValue);
if (!nextWork.isPresent()) {
Preconditions.checkState(workQueue == activeWork.remove(shardedKey));
Preconditions.checkState(workQueue == activeWork.remove(shardedKey.shardingKey()));
}

return nextWork;
}

Expand All @@ -302,22 +276,27 @@ synchronized void invalidateStuckCommits(
}
}

private static @Nullable Entry<WorkId, ExecutableWork> firstEntry(
arunpandianp marked this conversation as resolved.
Show resolved Hide resolved
Map<WorkId, ExecutableWork> map) {
Iterator<Entry<WorkId, ExecutableWork>> iterator = map.entrySet().iterator();
return iterator.hasNext() ? iterator.next() : null;
}

private synchronized ImmutableMap<ShardedKey, WorkId> getStuckCommitsAt(
Instant stuckCommitDeadline) {
// Determine the stuck commit keys but complete them outside the loop iterating over
// activeWork as completeWork may delete the entry from activeWork.
ImmutableMap.Builder<ShardedKey, WorkId> stuckCommits = ImmutableMap.builder();
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
ShardedKey shardedKey = entry.getKey();
@Nullable ExecutableWork executableWork = entry.getValue().peek();
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
@Nullable Entry<WorkId, ExecutableWork> executableWork = firstEntry(entry.getValue());
if (executableWork != null) {
Work work = executableWork.work();
Work work = executableWork.getValue().work();
if (work.isStuckCommittingAt(stuckCommitDeadline)) {
LOG.error(
"Detected key {} stuck in COMMITTING state since {}, completing it with error.",
shardedKey,
work.getShardedKey(),
work.getStateStartTime());
stuckCommits.put(shardedKey, work.id());
stuckCommits.put(work.getShardedKey(), work.id());
}
}
}
Expand Down Expand Up @@ -353,9 +332,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
// Use StringBuilder because we are appending in loop.
StringBuilder activeWorkStatus = new StringBuilder();
int commitsPendingCount = 0;
for (Map.Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
Queue<ExecutableWork> workQueue = Preconditions.checkNotNull(entry.getValue());
Work activeWork = Preconditions.checkNotNull(workQueue.peek()).work();
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
LinkedHashMap<WorkId, ExecutableWork> workQueue =
Preconditions.checkNotNull(entry.getValue());
Work activeWork = Preconditions.checkNotNull(firstEntry(workQueue)).getValue().work();
WorkItem workItem = activeWork.getWorkItem();
if (activeWork.isCommitPending()) {
if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -120,7 +118,7 @@ public boolean activateWork(ExecutableWork executableWork) {
}
}

public void failWork(Multimap<Long, WorkId> failedWork) {
public void failWork(ImmutableList<WorkIdWithShardingKey> failedWork) {
activeWorkState.failWorkForKey(failedWork);
}

Expand All @@ -146,10 +144,6 @@ private void forceExecute(ExecutableWork executableWork) {
executor.forceExecute(executableWork, executableWork.work().getSerializedWorkItemSize());
}

public ImmutableListMultimap<ShardedKey, RefreshableWork> currentActiveWorkReadOnly() {
return activeWorkState.getReadOnlyActiveWork();
}

public ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
return activeWorkState.getRefreshableWork(refreshDeadline);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that is invalid from
Expand All @@ -34,6 +33,7 @@
@Internal
public final class WorkHeartbeatResponseProcessor
implements Consumer<List<ComputationHeartbeatResponse>> {

/** Fetches a {@link ComputationState} for a computationId. */
private final Function<String, Optional<ComputationState>> computationStateFetcher;

Expand All @@ -46,23 +46,23 @@ public WorkHeartbeatResponseProcessor(
@Override
public void accept(List<ComputationHeartbeatResponse> responses) {
for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) {
// Maps sharding key to (work token, cache token) for work that should be marked failed.
Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
ImmutableList.Builder<WorkIdWithShardingKey> failedWorkBuilder = ImmutableList.builder();
for (HeartbeatResponse heartbeatResponse :
computationHeartbeatResponse.getHeartbeatResponsesList()) {
if (heartbeatResponse.getFailed()) {
failedWork.put(
heartbeatResponse.getShardingKey(),
WorkId workId =
WorkId.builder()
.setWorkToken(heartbeatResponse.getWorkToken())
.setCacheToken(heartbeatResponse.getCacheToken())
.build());
.build();
failedWorkBuilder.add(
WorkIdWithShardingKey.create(heartbeatResponse.getShardingKey(), workId));
}
}

computationStateFetcher
.apply(computationHeartbeatResponse.getComputationId())
.ifPresent(state -> state.failWork(failedWork));
.ifPresent(state -> state.failWork(failedWorkBuilder.build()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import com.google.auto.value.AutoValue;

@AutoValue
abstract class WorkIdWithShardingKey {

public static WorkIdWithShardingKey create(long shardingKey, WorkId workId) {
return new AutoValue_WorkIdWithShardingKey(shardingKey, workId);
}

public abstract long shardingKey();

public abstract WorkId workId();
}
Loading
Loading