Skip to content

Commit

Permalink
[grid] Purge Nodes if health check fails consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
pujagani authored Aug 5, 2021
1 parent c498dad commit 1fad80a
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 2 deletions.
36 changes: 35 additions & 1 deletion java/src/org/openqa/selenium/grid/distributor/GridModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ public class GridModel {
private static final Logger LOG = Logger.getLogger(GridModel.class.getName());
// How many times a node's heartbeat duration needs to be exceeded before the node is considered purgeable.
private static final int PURGE_TIMEOUT_MULTIPLIER = 4;
private static final int UNHEALTHY_THRESHOLD = 4;
private final ReadWriteLock lock = new ReentrantReadWriteLock(/* fair */ true);
private final Set<NodeStatus> nodes = Collections.newSetFromMap(new ConcurrentHashMap<>());
private final Map<NodeId, Instant> nodePurgeTimes = new ConcurrentHashMap<>();
private final Map<NodeId, Integer> nodeHealthCount = new ConcurrentHashMap<>();
private final EventBus events;

public GridModel(EventBus events) {
Expand Down Expand Up @@ -96,6 +98,7 @@ public void add(NodeStatus node) {
NodeStatus refreshed = rewrite(node, next.getAvailability());
nodes.add(refreshed);
nodePurgeTimes.put(refreshed.getNodeId(), Instant.now());
updateHealthCheckCount(refreshed.getNodeId(), refreshed.getAvailability());

return;
}
Expand All @@ -117,6 +120,7 @@ public void add(NodeStatus node) {
NodeStatus refreshed = rewrite(node, DOWN);
nodes.add(refreshed);
nodePurgeTimes.put(refreshed.getNodeId(), Instant.now());
updateHealthCheckCount(refreshed.getNodeId(), refreshed.getAvailability());
} finally {
writeLock.unlock();
}
Expand Down Expand Up @@ -177,6 +181,7 @@ public void remove(NodeId id) {
try {
nodes.removeIf(n -> n.getNodeId().equals(id));
nodePurgeTimes.remove(id);
nodeHealthCount.remove(id);
} finally {
writeLock.unlock();
}
Expand All @@ -190,8 +195,13 @@ public void purgeDeadNodes() {
Set<NodeStatus> toRemove = new HashSet<>();

for (NodeStatus node : nodes) {
NodeId id = node.getNodeId();
if (nodeHealthCount.getOrDefault(id, 0) > UNHEALTHY_THRESHOLD) {
toRemove.add(node);
}

Instant now = Instant.now();
Instant lastTouched = nodePurgeTimes.getOrDefault(node.getNodeId(), Instant.now());
Instant lastTouched = nodePurgeTimes.getOrDefault(id, Instant.now());
Instant lostTime = lastTouched.plus(node.getHeartbeatPeriod().multipliedBy(PURGE_TIMEOUT_MULTIPLIER / 2));
Instant deadTime = lastTouched.plus(node.getHeartbeatPeriod().multipliedBy(PURGE_TIMEOUT_MULTIPLIER));

Expand All @@ -212,6 +222,7 @@ public void purgeDeadNodes() {
toRemove.forEach(node -> {
nodes.remove(node);
nodePurgeTimes.remove(node.getNodeId());
nodeHealthCount.remove(node.getNodeId());
});
} finally {
writeLock.unlock();
Expand Down Expand Up @@ -423,6 +434,29 @@ public void setSession(SlotId slotId, Session session) {
}
}

public void updateHealthCheckCount(NodeId id, Availability availability) {
Require.nonNull("Node ID", id);
Require.nonNull("Availability", availability);

Lock writeLock = lock.writeLock();
writeLock.lock();
try {
int unhealthyCount = nodeHealthCount.getOrDefault(id, 0);

// Keep track of consecutive number of times the Node health check fails
if (availability.equals(DOWN)) {
nodeHealthCount.put(id, unhealthyCount + 1);
}

// If the Node is healthy again before crossing the threshold, then reset the count.
if (unhealthyCount <= UNHEALTHY_THRESHOLD && availability.equals(UP)) {
nodeHealthCount.put(id, 0);
}
} finally {
writeLock.unlock();
}
}

private void amend(Availability availability, NodeStatus status, Slot slot) {
Set<Slot> newSlots = new HashSet<>(status.getSlots());
newSlots.removeIf(s -> s.getId().equals(slot.getId()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ private Runnable asRunnableHealthCheck(Node node) {
getDebugLogLevel(),
String.format("Health check result for %s was %s", node.getId(), result.getAvailability()));
model.setAvailability(id, result.getAvailability());
model.updateHealthCheckCount(id, result.getAvailability());
} finally {
writeLock.unlock();
}
Expand Down
114 changes: 113 additions & 1 deletion java/test/org/openqa/selenium/grid/distributor/DistributorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Logger;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -941,6 +942,117 @@ public void shouldReturnNodesThatWereDownToPoolOfNodesOnceTheyMarkTheirHealthChe
assertThatEither(result).isRight();
}

@Test
public void shouldRemoveNodeWhoseHealthCheckFailsConsistently() {
CombinedHandler handler = new CombinedHandler();

AtomicReference<Availability> availability = new AtomicReference<>(UP);

SessionMap sessions = new LocalSessionMap(tracer, bus);
handler.addHandler(sessions);
NewSessionQueue queue = new LocalNewSessionQueue(
tracer,
bus,
new DefaultSlotMatcher(),
Duration.ofSeconds(2),
Duration.ofSeconds(2),
registrationSecret);

URI uri = createUri();
Node node = LocalNode.builder(tracer, bus, uri, uri, registrationSecret)
.add(
caps,
new TestSessionFactory((id, caps) -> new Session(id, uri, stereotype, caps, Instant.now())))
.advanced()
.healthCheck(() -> new HealthCheck.Result(availability.get(), "TL;DR"))
.build();
handler.addHandler(node);

LocalDistributor distributor = new LocalDistributor(
tracer,
bus,
new PassthroughHttpClient.Factory(handler),
sessions,
queue,
new DefaultSlotSelector(),
registrationSecret,
Duration.ofSeconds(1),
false);
handler.addHandler(distributor);
distributor.add(node);

waitToHaveCapacity(distributor);

Either<SessionNotCreatedException, CreateSessionResponse> result =
distributor.newSession(createRequest(caps));
assertThatEither(result).isRight();

availability.set(DOWN);

waitTillNodesAreRemoved(distributor);

result =
distributor.newSession(createRequest(caps));
assertThatEither(result).isLeft();
}

@Test
public void shouldNotRemoveNodeWhoseHealthCheckPassesBeforeThreshold()
throws InterruptedException {
CombinedHandler handler = new CombinedHandler();

AtomicInteger count = new AtomicInteger(0);
CountDownLatch latch = new CountDownLatch(1);

SessionMap sessions = new LocalSessionMap(tracer, bus);
handler.addHandler(sessions);
NewSessionQueue queue = new LocalNewSessionQueue(
tracer,
bus,
new DefaultSlotMatcher(),
Duration.ofSeconds(2),
Duration.ofSeconds(2),
registrationSecret);

URI uri = createUri();
Node node = LocalNode.builder(tracer, bus, uri, uri, registrationSecret)
.add(
caps,
new TestSessionFactory((id, caps) -> new Session(id, uri, stereotype, caps, Instant.now())))
.advanced()
.healthCheck(() -> {
if (count.get() <= 4) {
count.incrementAndGet();
return new HealthCheck.Result(DOWN, "Down");
}
latch.countDown();
return new HealthCheck.Result(UP, "Up");
})
.build();
handler.addHandler(node);

LocalDistributor distributor = new LocalDistributor(
tracer,
bus,
new PassthroughHttpClient.Factory(handler),
sessions,
queue,
new DefaultSlotSelector(),
registrationSecret,
Duration.ofSeconds(1),
false);
handler.addHandler(distributor);
distributor.add(node);

latch.await(60, TimeUnit.SECONDS);

waitToHaveCapacity(distributor);

Either<SessionNotCreatedException, CreateSessionResponse> result =
distributor.newSession(createRequest(caps));
assertThatEither(result).isRight();
}

private Set<Node> createNodeSet(CombinedHandler handler, Distributor distributor, int count, Capabilities...capabilities) {
Set<Node> nodeSet = new HashSet<>();
for (int i=0; i<count; i++) {
Expand Down Expand Up @@ -1253,7 +1365,7 @@ private void waitToHaveCapacity(Distributor distributor) {

private void waitTillNodesAreRemoved(Distributor distributor) {
new FluentWait<>(distributor)
.withTimeout(Duration.ofSeconds(5))
.withTimeout(Duration.ofSeconds(60))
.pollingEvery(Duration.ofMillis(100))
.until(d -> {
Set<NodeStatus> nodes = d.getStatus().getNodes();
Expand Down

0 comments on commit 1fad80a

Please sign in to comment.