Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support keep-alive pings in WebSocketGraphQlClient #608

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions spring-graphql-docs/src/docs/asciidoc/includes/client.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ existing `WebSocketGraphQlClient` to create a new instance with customized setti

----

If you'd like the client to send regular graphql ping messages to the server, you can add these by adding `keepalive(long seconds)` to the builder
[source,java,indent=0,subs="verbatim,quotes"]
----
WebSocketGraphQlClient graphQlClient = WebSocketGraphQlClient.builder(url, client)
.keepalive(30)
.build();
----

[[client.websocketgraphqlclient.interceptor]]
==== Interceptor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@

package org.springframework.graphql.client;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import reactor.core.publisher.Mono;

import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.ClientCodecConfigurer;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.util.DefaultUriBuilderFactory;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.util.Arrays;
import java.util.List;
import java.util.function.Consumer;


/**
Expand All @@ -51,6 +49,7 @@ final class DefaultWebSocketGraphQlClientBuilder

private final CodecConfigurer codecConfigurer;

private long keepalive;

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}.
Expand All @@ -59,13 +58,28 @@ final class DefaultWebSocketGraphQlClientBuilder
this(toURI(url), client);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) {
this(toURI(url), client, keepalive);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) {
this(url, client, 0);
}

/**
* Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}.
*/
DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) {
this.url = url;
this.webSocketClient = client;
this.codecConfigurer = ClientCodecConfigurer.create();
this.keepalive = keepalive;
}

/**
Expand All @@ -77,6 +91,7 @@ final class DefaultWebSocketGraphQlClientBuilder
this.headers.putAll(transport.getHeaders());
this.webSocketClient = transport.getWebSocketClient();
this.codecConfigurer = transport.getCodecConfigurer();
this.keepalive = transport.getKeepAlive();
}


Expand Down Expand Up @@ -121,18 +136,24 @@ public WebSocketGraphQlClient build() {
CodecDelegate.findJsonDecoder(this.codecConfigurer));

WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor());
this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive);

GraphQlClient graphQlClient = super.buildGraphQlClient(transport);
return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer());
}

@Override
public WebSocketGraphQlClient.Builder<DefaultWebSocketGraphQlClientBuilder> keepalive(long keepalive) {
this.keepalive = keepalive;
return this;
}

private WebSocketGraphQlClientInterceptor getInterceptor() {

List<WebSocketGraphQlClientInterceptor> interceptors = getInterceptors().stream()
.filter(interceptor -> interceptor instanceof WebSocketGraphQlClientInterceptor)
.map(interceptor -> (WebSocketGraphQlClientInterceptor) interceptor)
.collect(Collectors.toList());
.toList();

Assert.state(interceptors.size() <= 1,
"Only a single interceptor of type WebSocketGraphQlClientInterceptor may be configured");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) {
return builder(url, webSocketClient).build();
}

/**
* Create a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) {
return builder(url, webSocketClient).keepalive(keepalive).build();
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -73,6 +83,16 @@ static Builder<?> builder(String url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(String url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
Expand All @@ -82,6 +102,16 @@ static Builder<?> builder(URI url, WebSocketClient webSocketClient) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient);
}

/**
* Return a builder for a {@link WebSocketGraphQlClient}.
* @param url the GraphQL endpoint URL
* @param webSocketClient the underlying transport client to use
* @param keepalive the delay in seconds between sending ping messages, or 0 to disable
*/
static Builder<?> builder(URI url, WebSocketClient webSocketClient, long keepalive) {
return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive);
}


/**
* Builder for a GraphQL over WebSocket client.
Expand All @@ -94,6 +124,8 @@ interface Builder<B extends Builder<B>> extends WebGraphQlClient.Builder<B> {
@Override
WebSocketGraphQlClient build();

Builder<B> keepalive(long keepalive);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.graphql.client;

import java.net.URI;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -67,10 +68,12 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {

private final Mono<GraphQlSession> graphQlSessionMono;

private final long keepalive;


WebSocketGraphQlTransport(
URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer,
WebSocketGraphQlClientInterceptor interceptor) {
WebSocketGraphQlClientInterceptor interceptor, long keepalive) {

Assert.notNull(url, "URI is required");
Assert.notNull(client, "WebSocketClient is required");
Expand All @@ -80,8 +83,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport {
this.url = url;
this.headers.putAll(headers != null ? headers : HttpHeaders.EMPTY);
this.webSocketClient = client;
this.keepalive = keepalive;

this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor);
this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive);

this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler)
.cacheInvalidateWhen(GraphQlSession::notifyWhenClosed);
Expand Down Expand Up @@ -154,6 +158,10 @@ public Flux<GraphQlResponse> executeSubscription(GraphQlRequest request) {
return this.graphQlSessionMono.flatMapMany(session -> session.executeSubscription(request));
}

public long getKeepAlive() {
return keepalive;
}


/**
* Client {@code WebSocketHandler} for GraphQL that deals with WebSocket
Expand All @@ -175,11 +183,15 @@ private static class GraphQlSessionHandler implements WebSocketHandler {

private final AtomicBoolean stopped = new AtomicBoolean();

private final long keepalive;

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) {

GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor,
long keepalive) {
this.codecDelegate = new CodecDelegate(codecConfigurer);
this.interceptor = interceptor;
this.graphQlSessionSink = Sinks.unsafe().one();
this.keepalive = keepalive;
}


Expand Down Expand Up @@ -236,7 +248,7 @@ public Mono<Void> handle(WebSocketSession session) {
session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux())
.map(message -> this.codecDelegate.encode(session, message)));

Mono<Void> receiveCompletion = session.receive()
Flux<Void> receiveCompletion = session.receive()
.flatMap(webSocketMessage -> {
if (sessionNotInitialized()) {
try {
Expand Down Expand Up @@ -277,6 +289,8 @@ public Mono<Void> handle(WebSocketSession session) {
case COMPLETE:
graphQlSession.handleComplete(message);
break;
case PONG:
break;
default:
throw new IllegalStateException(
"Unexpected message type: '" + message.getType() + "'");
Expand All @@ -290,10 +304,21 @@ public Mono<Void> handle(WebSocketSession session) {
}
}
return Mono.empty();
})
.then();
});

if (keepalive > 0) {
Duration keepAliveDuration = Duration.ofSeconds(keepalive);
receiveCompletion = receiveCompletion
.mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration)
.flatMap(i -> {
graphQlSession.sendPing(null);
return Mono.empty();
})
);
}


return Mono.zip(sendCompletion, receiveCompletion).then();
return Mono.zip(sendCompletion, receiveCompletion.then()).then();
}

private boolean sessionNotInitialized() {
Expand Down Expand Up @@ -454,6 +479,11 @@ public void sendPong(@Nullable Map<String, Object> payload) {
this.requestSink.sendRequest(message);
}

public void sendPing(@Nullable Map<String, Object> payload) {
GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload);
this.requestSink.sendRequest(message);
}


// Inbound messages

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ private Publisher<GraphQlWebSocketMessage> handleMessage(GraphQlWebSocketMessage
GraphQlWebSocketMessage.complete(id));
case COMPLETE:
return Flux.empty();
case PING:
return Mono.just(GraphQlWebSocketMessage.pong(null));
default:
return Flux.error(new IllegalStateException("Unexpected message: " + message));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class WebSocketGraphQlTransportTests {
private final static Duration TIMEOUT = Duration.ofSeconds(5);

private static final CodecDelegate CODEC_DELEGATE = new CodecDelegate(ClientCodecConfigurer.create());
public static final int KEEPALIVE = 1;


private final MockGraphQlWebSocketServer mockServer = new MockGraphQlWebSocketServer();
Expand Down Expand Up @@ -185,6 +186,22 @@ void pingHandling() {
GraphQlWebSocketMessage.subscribe("1", new DefaultGraphQlRequest("{Query1}")));
}

@Test
void pingSending() throws InterruptedException {

GraphQlRequest request = this.mockServer.expectOperation("{Sub1}").andStream(Flux.just(this.response1, response2));

StepVerifier.create(this.transport.executeSubscription(request))
.expectNext(this.response1, response2).expectComplete()
.verify(TIMEOUT);
Thread.sleep(KEEPALIVE*1000 + 50); // wait for ping

assertActualClientMessages(
GraphQlWebSocketMessage.connectionInit(null),
GraphQlWebSocketMessage.subscribe("1", request),
GraphQlWebSocketMessage.ping(null));
}

@Test
void start() {
MockGraphQlWebSocketServer handler = new MockGraphQlWebSocketServer();
Expand All @@ -210,7 +227,7 @@ public Mono<Void> handleConnectionAck(Map<String, Object> ackPayload) {


WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor);
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor, KEEPALIVE);

transport.start().block(TIMEOUT);

Expand Down Expand Up @@ -324,7 +341,7 @@ void errorDuringResponseHandling() {
private static WebSocketGraphQlTransport createTransport(WebSocketClient client) {
return new WebSocketGraphQlTransport(
URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(),
new WebSocketGraphQlClientInterceptor() {});
new WebSocketGraphQlClientInterceptor() {}, KEEPALIVE);
}

private void assertActualClientMessages(GraphQlWebSocketMessage... expectedMessages) {
Expand Down