Skip to content

Commit

Permalink
chore: simplify ChennelMonitor methods
Browse files Browse the repository at this point in the history
Signed-off-by: Bernd Warmuth <[email protected]>
  • Loading branch information
Bernd Warmuth authored and warber committed Jan 13, 2025
1 parent d2efff5 commit 6b9b10d
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import lombok.extern.slf4j.Slf4j;

Expand Down Expand Up @@ -32,65 +34,58 @@ public static void monitorChannelState(
ConnectivityState currentState = channel.getState(true);
log.info("Channel state changed to: {}", currentState);
if (currentState == ConnectivityState.READY) {
onConnectionReady.run();
if (onConnectionReady != null) {
onConnectionReady.run();
} else {
log.debug("onConnectionReady is null");
}
} else if (currentState == ConnectivityState.TRANSIENT_FAILURE
|| currentState == ConnectivityState.SHUTDOWN) {
onConnectionLost.run();
if (onConnectionLost != null) {
onConnectionLost.run();
} else {
log.debug("onConnectionLost is null");
}
}
// Re-register the state monitor to watch for the next state transition.
monitorChannelState(currentState, channel, onConnectionReady, onConnectionLost);
});
}

/**
* Waits for the channel to reach a desired state within a specified timeout period.
* Waits for the channel to reach the desired connectivity state within the specified timeout.
*
* @param channel the ManagedChannel to monitor.
* @param desiredState the ConnectivityState to wait for.
* @param connectCallback callback invoked when the desired state is reached.
* @param timeout the maximum amount of time to wait.
* @param unit the time unit of the timeout.
* @throws InterruptedException if the current thread is interrupted while waiting.
* @param desiredState the desired {@link ConnectivityState} to wait for
* @param channel the {@link ManagedChannel} to monitor
* @param connectCallback the {@link Runnable} to execute when the desired state is reached
* @param timeout the maximum time to wait
* @param unit the time unit of the timeout argument
* @throws InterruptedException if the current thread is interrupted while waiting
* @throws GeneralError if the desired state is not reached within the timeout
*/
public static void waitForDesiredState(
ManagedChannel channel,
ConnectivityState desiredState,
Runnable connectCallback,
long timeout,
TimeUnit unit)
throws InterruptedException {
waitForDesiredState(channel, desiredState, connectCallback, new CountDownLatch(1), timeout, unit);
}

private static void waitForDesiredState(
ManagedChannel channel,
ConnectivityState desiredState,
Runnable connectCallback,
CountDownLatch latch,
long timeout,
TimeUnit unit)
throws InterruptedException {
channel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> {
try {
ConnectivityState state = channel.getState(true);
log.debug("Channel state changed to: {}", state);
CountDownLatch latch = new CountDownLatch(1);

if (state == desiredState) {
connectCallback.run();
latch.countDown();
return;
}
waitForDesiredState(channel, desiredState, connectCallback, latch, timeout, unit);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("Thread interrupted while waiting for desired state", e);
} catch (Exception e) {
log.error("Error occurred while waiting for desired state", e);
Runnable waitForStateTask = () -> {
ConnectivityState currentState = channel.getState(true);
if (currentState == desiredState) {
connectCallback.run();
latch.countDown();
}
});
};

ScheduledFuture<?> scheduledFuture = Executors.newSingleThreadScheduledExecutor()
.scheduleWithFixedDelay(waitForStateTask, 0, 100, TimeUnit.MILLISECONDS);

// Await the latch or timeout for the state change
if (!latch.await(timeout, unit)) {
boolean success = latch.await(timeout, unit);
scheduledFuture.cancel(true);
if (!success) {
throw new GeneralError(String.format(
"Deadline exceeded. Condition did not complete within the %d " + "deadline", timeout));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public GrpcConnector(
public void initialize() throws Exception {
log.info("Initializing GRPC connection...");
ChannelMonitor.waitForDesiredState(
channel, ConnectivityState.READY, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS);
ConnectivityState.READY, channel, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS);
ChannelMonitor.monitorChannelState(ConnectivityState.READY, channel, this::onReady, this::onConnectionLost);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package dev.openfeature.contrib.providers.flagd.resolver.common;

import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.monitorChannelState;
import static dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor.waitForDesiredState;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import dev.openfeature.sdk.exceptions.GeneralError;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

class ChannelMonitorTest {
@Test
void testWaitForDesiredState() throws InterruptedException {
ManagedChannel channel = mock(ManagedChannel.class);
Runnable connectCallback = mock(Runnable.class);

// Set up the desired state
ConnectivityState desiredState = ConnectivityState.READY;
when(channel.getState(anyBoolean())).thenReturn(desiredState);

// Call the method
waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS);

// Verify that the callback was run
verify(connectCallback, times(1)).run();
}

@Test
void testWaitForDesiredStateTimeout() {
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
Runnable connectCallback = mock(Runnable.class);

// Set up the desired state
ConnectivityState desiredState = ConnectivityState.READY;
when(channel.getState(anyBoolean())).thenReturn(ConnectivityState.IDLE);

// Call the method and expect a timeout
assertThrows(GeneralError.class, () -> {
waitForDesiredState(desiredState, channel, connectCallback, 1, TimeUnit.SECONDS);
});
}

@ParameterizedTest
@EnumSource(ConnectivityState.class)
void testMonitorChannelState(ConnectivityState state) {
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
Runnable onConnectionReady = mock(Runnable.class);
Runnable onConnectionLost = mock(Runnable.class);

// Set up the expected state
ConnectivityState expectedState = ConnectivityState.IDLE;
when(channel.getState(anyBoolean())).thenReturn(state);

// Capture the callback
ArgumentCaptor<Runnable> callbackCaptor = ArgumentCaptor.forClass(Runnable.class);
doNothing().when(channel).notifyWhenStateChanged(eq(expectedState), callbackCaptor.capture());

// Call the method
monitorChannelState(expectedState, channel, onConnectionReady, onConnectionLost);

// Simulate state change
callbackCaptor.getValue().run();

// Verify the callbacks based on the state
if (state == ConnectivityState.READY) {
verify(onConnectionReady, times(1)).run();
verify(onConnectionLost, never()).run();
} else if (state == ConnectivityState.TRANSIENT_FAILURE || state == ConnectivityState.SHUTDOWN) {
verify(onConnectionReady, never()).run();
verify(onConnectionLost, times(1)).run();
} else {
verify(onConnectionReady, never()).run();
verify(onConnectionLost, never()).run();
}
}
}

0 comments on commit 6b9b10d

Please sign in to comment.