Skip to content

Commit

Permalink
[FLINK-31963][state] Fix rescaling bug in recovery from unaligned che…
Browse files Browse the repository at this point in the history
…ckpoints. (apache#22584)

This commit fixes problems in StateAssignmentOperation for unaligned checkpoints with stateless operators that have upstream operators with output partition state or downstream operators with input channel state.
  • Loading branch information
StefanRRichter authored May 16, 2023
1 parent 5ba3f2b commit 354c0f4
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,24 @@ public void assignStates() {

// repartition state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
if (stateAssignment.hasNonFinishedState) {
if (stateAssignment.hasNonFinishedState
// FLINK-31963: We need to run repartitioning for stateless operators that have
// upstream output or downstream input states.
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignAttemptState(stateAssignment);
}
}

// actually assign the state
for (TaskStateAssignment stateAssignment : vertexAssignments.values()) {
// If upstream has output states, even the empty task state should be assigned for the
// current task in order to notify this task that the old states will send to it which
// likely should be filtered.
// If upstream has output states or downstream has input states, even the empty task
// state should be assigned for the current task in order to notify this task that the
// old states will send to it which likely should be filtered.
if (stateAssignment.hasNonFinishedState
|| stateAssignment.isFullyFinished
|| stateAssignment.hasUpstreamOutputStates()) {
|| stateAssignment.hasUpstreamOutputStates()
|| stateAssignment.hasDownstreamInputStates()) {
assignTaskStateToExecutionJobVertices(stateAssignment);
}
}
Expand Down Expand Up @@ -345,9 +350,10 @@ public static <T extends StateObject> void reDistributePartitionableStates(
newParallelism)));
}

public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpartitionStates(
TaskStateAssignment assignment) {
if (!assignment.hasOutputState) {
public void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
// FLINK-31963: We can skip this phase if there is no output state AND downstream has no
// input states
if (!assignment.hasOutputState && !assignment.hasDownstreamInputStates()) {
return;
}

Expand Down Expand Up @@ -394,7 +400,9 @@ public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpa
}

public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
if (!stateAssignment.hasInputState) {
// FLINK-31963: We can skip this phase only if there is no input state AND upstream has no
// output states
if (!stateAssignment.hasInputState && !stateAssignment.hasUpstreamOutputStates()) {
return;
}

Expand Down Expand Up @@ -435,7 +443,7 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
: getPartitionState(
inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
final MappingBasedRepartitioner<InputChannelStateHandle> repartitioner =
new MappingBasedRepartitioner(mapping);
new MappingBasedRepartitioner<>(mapping);
final Map<OperatorInstanceID, List<InputChannelStateHandle>> repartitioned =
applyRepartitioner(
stateAssignment.inputOperatorID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class TaskStateAssignment {

@Nullable private TaskStateAssignment[] downstreamAssignments;
@Nullable private TaskStateAssignment[] upstreamAssignments;
@Nullable private Boolean hasUpstreamOutputStates;
@Nullable private Boolean hasDownstreamInputStates;

private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment;
private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
Expand Down Expand Up @@ -202,8 +204,21 @@ public OperatorSubtaskState getSubtaskState(OperatorInstanceID instanceID) {
}

public boolean hasUpstreamOutputStates() {
return Arrays.stream(getUpstreamAssignments())
.anyMatch(assignment -> assignment.hasOutputState);
if (hasUpstreamOutputStates == null) {
hasUpstreamOutputStates =
Arrays.stream(getUpstreamAssignments())
.anyMatch(assignment -> assignment.hasOutputState);
}
return hasUpstreamOutputStates;
}

public boolean hasDownstreamInputStates() {
if (hasDownstreamInputStates == null) {
hasDownstreamInputStates =
Arrays.stream(getDownstreamAssignments())
.anyMatch(assignment -> assignment.hasInputState);
}
return hasDownstreamInputStates;
}

private InflightDataGateOrPartitionRescalingDescriptor log(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
Expand Down Expand Up @@ -51,6 +52,9 @@
import org.junit.ClassRule;
import org.junit.Test;

import javax.annotation.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
Expand Down Expand Up @@ -82,6 +86,7 @@
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
Expand Down Expand Up @@ -785,6 +790,129 @@ public void testOnlyUpstreamChannelStateAssignment()
}
}

/** FLINK-31963: Tests rescaling for stateless operators and upstream result partition state. */
@Test
public void testOnlyUpstreamChannelRescaleStateAssignment()
throws JobException, JobExecutionException {
Random random = new Random();
OperatorSubtaskState upstreamOpState =
OperatorSubtaskState.builder()
.setResultSubpartitionState(
new StateObjectCollection<>(
asList(
createNewResultSubpartitionStateHandle(10, random),
createNewResultSubpartitionStateHandle(
10, random))))
.build();
testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 5, 7);
}

/** FLINK-31963: Tests rescaling for stateless operators and downstream input channel state. */
@Test
public void testOnlyDownstreamChannelRescaleStateAssignment()
throws JobException, JobExecutionException {
Random random = new Random();
OperatorSubtaskState downstreamOpState =
OperatorSubtaskState.builder()
.setInputChannelState(
new StateObjectCollection<>(
asList(
createNewInputChannelStateHandle(10, random),
createNewInputChannelStateHandle(10, random))))
.build();
testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 5, 5);
}

private void testOnlyUpstreamOrDownstreamRescalingInternal(
@Nullable OperatorSubtaskState upstreamOpState,
@Nullable OperatorSubtaskState downstreamOpState,
int expectedUpstreamCount,
int expectedDownstreamCount)
throws JobException, JobExecutionException {

checkArgument(
upstreamOpState != downstreamOpState
&& (upstreamOpState == null || downstreamOpState == null),
"Either upstream or downstream state must exist, but not both");

// Start from parallelism 5 for both operators
int upstreamParallelism = 5;
int downstreamParallelism = 5;

// Build states
List<OperatorID> operatorIds = buildOperatorIds(2);
Map<OperatorID, OperatorState> states = new HashMap<>();
OperatorState upstreamState =
new OperatorState(operatorIds.get(0), upstreamParallelism, MAX_P);
OperatorState downstreamState =
new OperatorState(operatorIds.get(1), downstreamParallelism, MAX_P);

states.put(operatorIds.get(0), upstreamState);
states.put(operatorIds.get(1), downstreamState);

if (upstreamOpState != null) {
upstreamState.putState(0, upstreamOpState);
// rescale downstream 5 -> 3
downstreamParallelism = 3;
}

if (downstreamOpState != null) {
downstreamState.putState(0, downstreamOpState);
// rescale upstream 5 -> 3
upstreamParallelism = 3;
}

List<OperatorIdWithParallelism> opIdWithParallelism = new ArrayList<>(2);
opIdWithParallelism.add(
new OperatorIdWithParallelism(operatorIds.get(0), upstreamParallelism));
opIdWithParallelism.add(
new OperatorIdWithParallelism(operatorIds.get(1), downstreamParallelism));

Map<OperatorID, ExecutionJobVertex> vertices =
buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);

// Run state assignment
new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false)
.assignStates();

// Check results
ExecutionJobVertex upstreamExecutionJobVertex = vertices.get(operatorIds.get(0));
ExecutionJobVertex downstreamExecutionJobVertex = vertices.get(operatorIds.get(1));

List<TaskStateSnapshot> upstreamTaskStateSnapshots =
getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
List<TaskStateSnapshot> downstreamTaskStateSnapshots =
getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);

checkMappings(
upstreamTaskStateSnapshots,
TaskStateSnapshot::getOutputRescalingDescriptor,
expectedUpstreamCount);

checkMappings(
downstreamTaskStateSnapshots,
TaskStateSnapshot::getInputRescalingDescriptor,
expectedDownstreamCount);
}

private void checkMappings(
List<TaskStateSnapshot> taskStateSnapshots,
Function<TaskStateSnapshot, InflightDataRescalingDescriptor> extractFun,
int expectedCount) {
Assert.assertEquals(
expectedCount,
taskStateSnapshots.stream()
.map(extractFun)
.mapToInt(
x -> {
int len = x.getOldSubtaskIndexes(0).length;
// Assert that there is a mapping.
Assert.assertTrue(len > 0);
return len;
})
.sum());
}

@Test
public void testStateWithFullyFinishedOperators() throws JobException, JobExecutionException {
List<OperatorID> operatorIds = buildOperatorIds(2);
Expand Down Expand Up @@ -949,15 +1077,50 @@ private Map<OperatorID, OperatorState> buildOperatorStates(
}));
}

private static class OperatorIdWithParallelism {
private final OperatorID operatorID;
private final int parallelism;

public OperatorID getOperatorID() {
return operatorID;
}

public int getParallelism() {
return parallelism;
}

public OperatorIdWithParallelism(OperatorID operatorID, int parallelism) {
this.operatorID = operatorID;
this.parallelism = parallelism;
}
}

private Map<OperatorID, ExecutionJobVertex> buildVertices(
List<OperatorID> operatorIds,
int parallelism,
int parallelisms,
SubtaskStateMapper downstreamRescaler,
SubtaskStateMapper upstreamRescaler)
throws JobException, JobExecutionException {
final JobVertex[] jobVertices =
List<OperatorIdWithParallelism> opIdsWithParallelism =
operatorIds.stream()
.map(id -> createJobVertex(id, id, parallelism))
.map(operatorID -> new OperatorIdWithParallelism(operatorID, parallelisms))
.collect(Collectors.toList());
return buildVertices(opIdsWithParallelism, downstreamRescaler, upstreamRescaler);
}

private Map<OperatorID, ExecutionJobVertex> buildVertices(
List<OperatorIdWithParallelism> operatorIdsAndParallelism,
SubtaskStateMapper downstreamRescaler,
SubtaskStateMapper upstreamRescaler)
throws JobException, JobExecutionException {
final JobVertex[] jobVertices =
operatorIdsAndParallelism.stream()
.map(
idWithParallelism ->
createJobVertex(
idWithParallelism.getOperatorID(),
idWithParallelism.getOperatorID(),
idWithParallelism.getParallelism()))
.toArray(JobVertex[]::new);
for (int index = 1; index < jobVertices.length; index++) {
connectVertices(
Expand Down Expand Up @@ -1029,6 +1192,15 @@ private JobVertex createJobVertex(
return jobVertex;
}

private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(
ExecutionJobVertex executionJobVertex) {
return Arrays.stream(executionJobVertex.getTaskVertices())
.map(ExecutionVertex::getCurrentExecutionAttempt)
.map(Execution::getTaskRestore)
.map(JobManagerTaskRestore::getTaskStateSnapshot)
.collect(Collectors.toList());
}

private OperatorSubtaskState getAssignedState(
ExecutionJobVertex executionJobVertex, OperatorID operatorId, int subtaskIdx) {
return executionJobVertex
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
final SingleOutputStreamOperator<Long> stream =
env.fromSource(
new LongSource(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source")
.slotSharingGroup(slotSharing ? "default" : "source")
Expand All @@ -144,7 +146,8 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
DataStream<Long> combinedSource = null;
for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
Expand All @@ -154,7 +157,8 @@ public void create(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source" + inputIndex)
.slotSharingGroup(
Expand Down Expand Up @@ -182,7 +186,8 @@ public void create(
StreamExecutionEnvironment env,
int minCheckpoints,
boolean slotSharing,
int expectedRestarts) {
int expectedRestarts,
long sourceSleepMs) {
final int parallelism = env.getParallelism();
DataStream<Tuple2<Integer, Long>> combinedSource = null;
for (int inputIndex = 0; inputIndex < NUM_SOURCES; inputIndex++) {
Expand All @@ -193,7 +198,8 @@ public void create(
minCheckpoints,
parallelism,
expectedRestarts,
env.getCheckpointInterval()),
env.getCheckpointInterval(),
sourceSleepMs),
noWatermarks(),
"source" + inputIndex)
.slotSharingGroup(
Expand Down
Loading

0 comments on commit 354c0f4

Please sign in to comment.