Skip to content

Commit

Permalink
Support keepAlive in WebSocket client
Browse files Browse the repository at this point in the history
See gh-608
  • Loading branch information
toby200 authored and rstoyanchev committed Apr 11, 2024
1 parent d31b94e commit 3f5fc1a
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 10 deletions.
7 changes: 7 additions & 0 deletions spring-graphql-docs/modules/ROOT/pages/client.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ existing `WebSocketGraphQlClient` to create a new instance with customized setti
----

If you'd like the client to send regular graphql ping messages to the server, you can add these by adding `keepalive(long seconds)` to the builder
[source,java,indent=0,subs="verbatim,quotes"]
----
WebSocketGraphQlClient graphQlClient = WebSocketGraphQlClient.builder(url, client)
.keepalive(30)
.build();
----

[[client.websocketgraphqlclient.interceptor]]
==== Interceptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ final class DefaultWebSocketGraphQlClientBuilder

private final CodecConfigurer codecConfigurer;

private long keepalive;

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}.
Expand All @@ -57,13 +58,28 @@ final class DefaultWebSocketGraphQlClientBuilder
this(toURI(url), client);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) {
this(toURI(url), client, keepalive);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) {
this(url, client, 0);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) {
this.url = url;
this.webSocketClient = client;
this.codecConfigurer = ClientCodecConfigurer.create();
this.keepalive = keepalive;
}

/**
Expand All @@ -75,6 +91,7 @@ final class DefaultWebSocketGraphQlClientBuilder
this.headers.putAll(transport.getHeaders());
this.webSocketClient = transport.getWebSocketClient();
this.codecConfigurer = transport.getCodecConfigurer();
this.keepalive = transport.getKeepAlive();
}


Expand Down Expand Up @@ -119,12 +136,18 @@ public WebSocketGraphQlClient build() {
CodecDelegate.findJsonDecoder(this.codecConfigurer));

WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor());
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive);

GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
}

@Override
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepalive(long keepalive) {
this.keepalive = keepalive;
return this;
}

private WebSocketGraphQlClientInterceptor getInterceptor() {

List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) {
return builder(url, webSocketClient).build();
}

/**
* Create a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) {
return builder(url, webSocketClient).keepalive(keepalive).build();
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -73,6 +83,16 @@ static Builder<?> builder(String url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(String url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -82,6 +102,16 @@ static Builder<?> builder(URI url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(URI url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}


/**
* Builder for a GraphQL over WebSocket client.
Expand All @@ -95,6 +125,8 @@ interface Builder<B extends Builder<B>> extends WebGraphQlClient.Builder<B> {
@Override
WebSocketGraphQlClient build();

Builder<B> keepalive(long keepalive);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.graphql.client;

import java.net.URI;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -67,10 +68,12 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {

private final Mono<GraphQlSession> graphQlSessionMono;

private final long keepalive;


WebSocketGraphQlTransport(
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
WebSocketGraphQlClientInterceptor interceptor) {
WebSocketGraphQlClientInterceptor interceptor, long keepalive) {

Assert.notNull(url, "URI is required");
Assert.notNull(client, "WebSocketClient is required");
Expand All @@ -80,8 +83,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
this.url = url;
this.headers.putAll((headers != null) ? headers : HttpHeaders.EMPTY);
this.webSocketClient = client;
this.keepalive = keepalive;

this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor);
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive);

this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
Expand Down Expand Up @@ -162,6 +166,10 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
return this.graphQlSessionMono.flatMapMany((session) -> session.executeSubscription(request));
}

public long getKeepAlive() {
return keepalive;
}


/**
* Client {@code WebSocketHandler} for GraphQL that deals with WebSocket
Expand All @@ -183,11 +191,15 @@ private static class GraphQlSessionHandler implements WebSocketHandler {

private final AtomicBoolean stopped = new AtomicBoolean();

private final long keepalive;

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) {

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
long keepalive) {
this.codecDelegate = new CodecDelegate(codecConfigurer);
this.interceptor = interceptor;
this.graphQlSessionSink = Sinks.unsafe().one();
this.keepalive = keepalive;
}


Expand Down Expand Up @@ -245,7 +257,7 @@ public Mono<Void> handle(WebSocketSession session) {
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
.map((message) -> this.codecDelegate.encode(session, message)));

Mono<Void> receiveCompletion = session.receive()
Flux<Void> receiveCompletion = session.receive()
.flatMap((webSocketMessage) -> {
if (sessionNotInitialized()) {
try {
Expand Down Expand Up @@ -276,6 +288,7 @@ public Mono<Void> handle(WebSocketSession session) {
switch (message.resolvedType()) {
case NEXT -> graphQlSession.handleNext(message);
case PING -> graphQlSession.sendPong(null);
case PONG -> { }
case ERROR -> graphQlSession.handleError(message);
case COMPLETE -> graphQlSession.handleComplete(message);
default -> throw new IllegalStateException(
Expand All @@ -290,10 +303,21 @@ public Mono<Void> handle(WebSocketSession session) {
}
}
return Mono.empty();
})
.then();
});

if (keepalive > 0) {
Duration keepAliveDuration = Duration.ofSeconds(keepalive);
receiveCompletion = receiveCompletion
.mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration)
.flatMap(i -> {
graphQlSession.sendPing(null);
return Mono.empty();
})
);
}


return Mono.zip(sendCompletion, receiveCompletion).then();
return Mono.zip(sendCompletion, receiveCompletion.then()).then();
}

private boolean sessionNotInitialized() {
Expand Down Expand Up @@ -459,6 +483,11 @@ void sendPong(@Nullable Map<String, Object> payload) {
this.requestSink.sendRequest(message);
}

public void sendPing(@Nullable Map<String, Object> payload) {
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload);
this.requestSink.sendRequest(message);
}


// Inbound messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ private Publisher<GraphQlWebSocketMessage> handleMessage(GraphQlWebSocketMessage
GraphQlWebSocketMessage.error(id, Collections.singletonList(request.getError())) :
GraphQlWebSocketMessage.complete(id));
}
case PING -> {
return Mono.just(GraphQlWebSocketMessage.pong(null));
}
case COMPLETE -> {
return Flux.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class WebSocketGraphQlTransportTests {
private static final Duration TIMEOUT = Duration.ofSeconds(5);

private static final CodecDelegate CODEC_DELEGATE = new CodecDelegate(ClientCodecConfigurer.create());
public static final int KEEPALIVE = 1;


private final MockGraphQlWebSocketServer mockServer = new MockGraphQlWebSocketServer();
Expand Down Expand Up @@ -185,6 +186,22 @@ void pingHandling() {
GraphQlWebSocketMessage.subscribe("1", new DefaultGraphQlRequest("{Query1}")));
}

@Test
void pingSending() throws InterruptedException {

GraphQlRequest request = this.mockServer.expectOperation("{Sub1}").andStream(Flux.just(this.response1, response2));

StepVerifier.create(this.transport.executeSubscription(request))
.expectNext(this.response1, response2).expectComplete()
.verify(TIMEOUT);
Thread.sleep(KEEPALIVE*1000 + 50); // wait for ping

assertActualClientMessages(
GraphQlWebSocketMessage.connectionInit(null),
GraphQlWebSocketMessage.subscribe("1", request),
GraphQlWebSocketMessage.ping(null));
}

@Test
void start() {
MockGraphQlWebSocketServer handler = new MockGraphQlWebSocketServer();
Expand All @@ -210,7 +227,7 @@ public Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {


WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor);
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor, KEEPALIVE);

transport.start().block(TIMEOUT);

Expand Down Expand Up @@ -324,7 +341,7 @@ void errorDuringResponseHandling() {
private static WebSocketGraphQlTransport createTransport(WebSocketClient client) {
return new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
new WebSocketGraphQlClientInterceptor() { });
new WebSocketGraphQlClientInterceptor() { }, KEEPALIVE);
}

private void assertActualClientMessages(GraphQlWebSocketMessage... expectedMessages) {
Expand Down

0 comments on commit 3f5fc1a

Please sign in to comment.