Skip to content

Commit

Permalink
assign splits properly to source readers
Browse files Browse the repository at this point in the history
Signed-off-by: Tilak Raj <[email protected]>
  • Loading branch information
tilakraj94 committed Jan 24, 2025
1 parent 3ea2253 commit 4593bd9
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ public class NatsSourceEnumerator implements SplitEnumerator<NatsSubjectSplit, C
private final String id;
private final SplitEnumeratorContext<NatsSubjectSplit> context;
private final Queue<NatsSubjectSplit> remainingSplits;
private List<List<NatsSubjectSplit>> precomputedSplitAssignments;

// assumes splits to be less than or equal to parallelism
private int minimumSplitsToAssign = -1;


public NatsSourceEnumerator(String sourceId,
Expand All @@ -34,60 +33,80 @@ public NatsSourceEnumerator(String sourceId,
id = generatePrefixedId(sourceId);
this.context = checkNotNull(context);
this.remainingSplits = splits == null ? new ArrayDeque<>() : new ArrayDeque<>(splits);
this.precomputedSplitAssignments = Collections.synchronizedList(new LinkedList<>());
}

@Override
public void start() {
int noOfSplits = remainingSplits.size();
int totalSplits = remainingSplits.size();
int parallelism = context.currentParallelism();

// let the splits be evenly distributed
if (noOfSplits <= parallelism) {
this.minimumSplitsToAssign = -1;
return;
}
// Calculate the minimum splits per reader and leftover splits
int minimumSplitsPerReader = totalSplits / parallelism;
int leftoverSplits = totalSplits % parallelism;

// minimum splits that needs to be assigned to reader
this.minimumSplitsToAssign = noOfSplits / parallelism;
}
// Precompute split assignments
List<List<NatsSubjectSplit>>splitAssignments = preComputeSplitsAssignments(parallelism, minimumSplitsPerReader, leftoverSplits);

@Override
public void close() {
// Store precomputed split assignments
this.precomputedSplitAssignments = splitAssignments;
LOG.debug("{} | Precomputed split assignments: {}", id, splitAssignments);
}

@Override
public void handleSplitRequest(int subtaskId, @Nullable String requesterHostname) {
if (remainingSplits.isEmpty()) {
context.signalNoMoreSplits(subtaskId);
return;
private List<List<NatsSubjectSplit>> preComputeSplitsAssignments (int parallelism, int minimumSplitsPerReader, int leftoverSplits) {
List<List<NatsSubjectSplit>> splitAssignments = new ArrayList<>(parallelism);

// Initialize lists
for (int i = 0; i < parallelism; i++) {
splitAssignments.add(new ArrayList<>());
}

List<NatsSubjectSplit> nextSplits = new ArrayList<>();
for (int i = 0; i < this.minimumSplitsToAssign; i++) {
NatsSubjectSplit nextSplit = remainingSplits.poll();
if (nextSplit == null) {
break;
// Distribute splits evenly among subtasks
for (int j = 0; j < parallelism; j++) {
List<NatsSubjectSplit> readerSplits = splitAssignments.get(j);

// Assign minimum splits to each reader
for (int i = 0; i < minimumSplitsPerReader && !remainingSplits.isEmpty(); i++) {
readerSplits.add(remainingSplits.poll());
}

nextSplits.add(nextSplit);
// Assign one leftover split if available
if (leftoverSplits > 0 && !remainingSplits.isEmpty()) {
readerSplits.add(remainingSplits.poll());
leftoverSplits--;
}
}

if (!nextSplits.isEmpty()) {
Map<Integer, List<NatsSubjectSplit>> assignedSplits = new HashMap<>();
assignedSplits.put(subtaskId, nextSplits);
return splitAssignments;
}

// assign the splits back to the source reader
context.assignSplits(new SplitsAssignment<>(assignedSplits));
LOG.debug("{} | Assigned splits to subtask: {}", id, subtaskId);
}
@Override
public void close() {
// remove precomputed split assignments if any
precomputedSplitAssignments.clear();
}

// Perform round-robin assignment for leftover splits
// Assign only one split at a time since the number of leftover splits will always be less than the parallelism.
// Each leftover split can be assigned to any reader, and the list will be exhausted quickly.
NatsSubjectSplit nextSplit = remainingSplits.poll();
if (nextSplit != null) {
context.assignSplit(nextSplit, subtaskId);
LOG.debug("{} | Assigned split in round-robin to subtask: {}", id, subtaskId);
@Override
public void handleSplitRequest(int subtaskId, @Nullable String requesterHostname) {
int size = precomputedSplitAssignments.size();

if (size == 0) {
LOG.debug("{} | No more splits available for subtask {}", id, subtaskId);
context.signalNoMoreSplits(subtaskId);
} else {
// O(1) operation with LinkedList
// Remove the first element from the list
// and assign splits to subtask
List<NatsSubjectSplit> splits = precomputedSplitAssignments.remove(0);
if (splits.isEmpty()) {
LOG.debug("{} | Empty split assignment for subtask {}", id, subtaskId);
context.signalNoMoreSplits(subtaskId);
} else {

// Assign splits to subtask
LOG.debug("{} | Assigning splits {} to subtask {}", id, splits, subtaskId);
context.assignSplits(new SplitsAssignment<>(Map.of(subtaskId, splits)));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package io.synadia.io.synadia.flink.v0;

import io.synadia.flink.v0.enumerator.NatsSourceEnumerator;
import io.synadia.flink.v0.source.split.NatsSubjectSplit;
import io.synadia.io.synadia.flink.TestBase;
import org.apache.flink.api.connector.source.SplitEnumeratorContext;
import org.apache.flink.api.connector.source.SplitsAssignment;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

class NatsSourceEnumeratorTests extends TestBase {

private SplitEnumeratorContext<NatsSubjectSplit> context;
private Queue<NatsSubjectSplit> splitsQueue;
private NatsSourceEnumerator enumerator;

@BeforeEach
@SuppressWarnings("unchecked")
void setup() {
context = mock(SplitEnumeratorContext.class);
splitsQueue = new ArrayDeque<>();
}

static Stream<TestParameters> provideTestParameters() {
return Stream.of(
new TestParameters(5, 3, generateSplits(3), "Splits < Parallelism"),
new TestParameters(3, 5, generateSplits(5), "Splits > Parallelism"),
new TestParameters(3, 3, generateSplits(3), "Splits == Parallelism"),
new TestParameters(89, 9, generateSplits(9), "High Parallelism with few splits, Splits <<<< Parallelism"),
new TestParameters(5, 100, generateSplits(100), "More Splits with less parallelism, Splits >>>>>> Parallelism")
);
}

@DisplayName("NatsSourceEnumerator Parameterized Test")
@ParameterizedTest(name = "{index} - {0}")
@MethodSource("provideTestParameters")
void testHandleSplitRequest(TestParameters params) {
// Setup splits queue
splitsQueue.addAll(params.splits.stream().map(NatsSubjectSplit::new).collect(Collectors.toList()));

when(context.currentParallelism()).thenReturn(params.parallelism);
enumerator = new NatsSourceEnumerator("test", context, splitsQueue);

// precompute split assignments
enumerator.start();

for (int subtaskId = 0; subtaskId < params.parallelism; subtaskId++) {
enumerator.handleSplitRequest(subtaskId, null);
}

// Capture arguments passed to assignSplits
ArgumentCaptor<SplitsAssignment<NatsSubjectSplit>> captor = ArgumentCaptor.forClass(SplitsAssignment.class);
verify(context, times(params.expectedInvocations)).assignSplits(captor.capture());

List<SplitsAssignment<NatsSubjectSplit>> assignments = captor.getAllValues();
assertEquals(params.expectedInvocations, assignments.size());

// Verify splits assigned
List<String> assignedSplits = new ArrayList<>();
for (SplitsAssignment<NatsSubjectSplit> assignment : assignments) {
assignment.assignment().values().forEach(splits -> {
assertTrue(splits.size() <= params.splits.size() / params.parallelism + 1, "Each subtask should get the correct number of splits");
splits.forEach(split -> assignedSplits.add(split.getSubject()));
});
}

assertEquals(params.splits, assignedSplits);
}

static List<String> generateSplits(int count) {
List<String> splits = new ArrayList<>();
for (int i = 0; i < count; i++) {
splits.add("split" + i);
}
return splits;
}

public static class TestParameters {
final int parallelism;
final int splitsCount;
final List<String> splits;
final int expectedInvocations;
final String description;

TestParameters(int parallelism, int splitsCount, List<String> splits, String description) {
this.parallelism = parallelism;
this.splitsCount = splitsCount;
this.splits = splits;
this.expectedInvocations = Math.min(parallelism, splitsCount);
this.description = description;
}

@Override
public String toString() {
return description;
}
}
}

0 comments on commit 4593bd9

Please sign in to comment.