diff --git a/spring-graphql-docs/src/docs/asciidoc/includes/client.adoc b/spring-graphql-docs/src/docs/asciidoc/includes/client.adoc index 088f99c27..ebb319cdf 100644 --- a/spring-graphql-docs/src/docs/asciidoc/includes/client.adoc +++ b/spring-graphql-docs/src/docs/asciidoc/includes/client.adoc @@ -120,6 +120,13 @@ existing `WebSocketGraphQlClient` to create a new instance with customized setti ---- +If you'd like the client to send regular graphql ping messages to the server, you can add these by adding `keepalive(long seconds)` to the builder +[source,java,indent=0,subs="verbatim,quotes"] +---- + WebSocketGraphQlClient graphQlClient = WebSocketGraphQlClient.builder(url, client) + .keepalive(30) + .build(); +---- [[client.websocketgraphqlclient.interceptor]] ==== Interceptor diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClientBuilder.java b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClientBuilder.java index fc519cacb..bb708e802 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClientBuilder.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/DefaultWebSocketGraphQlClientBuilder.java @@ -16,20 +16,18 @@ package org.springframework.graphql.client; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.function.Consumer; -import java.util.stream.Collectors; - -import reactor.core.publisher.Mono; - import org.springframework.http.HttpHeaders; import org.springframework.http.codec.ClientCodecConfigurer; import org.springframework.http.codec.CodecConfigurer; import org.springframework.util.Assert; import org.springframework.web.reactive.socket.client.WebSocketClient; import org.springframework.web.util.DefaultUriBuilderFactory; +import reactor.core.publisher.Mono; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.function.Consumer; /** @@ -51,6 +49,7 @@ final class DefaultWebSocketGraphQlClientBuilder private final CodecConfigurer codecConfigurer; + private long keepalive; /** * Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient)}. @@ -59,13 +58,28 @@ final class DefaultWebSocketGraphQlClientBuilder this(toURI(url), client); } + /** + * Constructor to start via {@link WebSocketGraphQlClient#builder(String, WebSocketClient, long)}. + */ + DefaultWebSocketGraphQlClientBuilder(String url, WebSocketClient client, long keepalive) { + this(toURI(url), client, keepalive); + } + /** * Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient)}. */ DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client) { + this(url, client, 0); + } + + /** + * Constructor to start via {@link WebSocketGraphQlClient#builder(URI, WebSocketClient, long)}. + */ + DefaultWebSocketGraphQlClientBuilder(URI url, WebSocketClient client, long keepalive) { this.url = url; this.webSocketClient = client; this.codecConfigurer = ClientCodecConfigurer.create(); + this.keepalive = keepalive; } /** @@ -77,6 +91,7 @@ final class DefaultWebSocketGraphQlClientBuilder this.headers.putAll(transport.getHeaders()); this.webSocketClient = transport.getWebSocketClient(); this.codecConfigurer = transport.getCodecConfigurer(); + this.keepalive = transport.getKeepAlive(); } @@ -121,18 +136,24 @@ public WebSocketGraphQlClient build() { CodecDelegate.findJsonDecoder(this.codecConfigurer)); WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport( - this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor()); + this.url, this.headers, this.webSocketClient, this.codecConfigurer, getInterceptor(), this.keepalive); GraphQlClient graphQlClient = super.buildGraphQlClient(transport); return new DefaultWebSocketGraphQlClient(graphQlClient, transport, getBuilderInitializer()); } + @Override + public WebSocketGraphQlClient.Builder keepalive(long keepalive) { + this.keepalive = keepalive; + return this; + } + private WebSocketGraphQlClientInterceptor getInterceptor() { List interceptors = getInterceptors().stream() .filter(interceptor -> interceptor instanceof WebSocketGraphQlClientInterceptor) .map(interceptor -> (WebSocketGraphQlClientInterceptor) interceptor) - .collect(Collectors.toList()); + .toList(); Assert.state(interceptors.size() <= 1, "Only a single interceptor of type WebSocketGraphQlClientInterceptor may be configured"); diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlClient.java b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlClient.java index 2350ec0ac..73b0f7bf7 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlClient.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlClient.java @@ -64,6 +64,16 @@ static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient) { return builder(url, webSocketClient).build(); } + /** + * Create a {@link WebSocketGraphQlClient}. + * @param url the GraphQL endpoint URL + * @param webSocketClient the underlying transport client to use + * @param keepalive the delay in seconds between sending ping messages, or 0 to disable + */ + static WebSocketGraphQlClient create(URI url, WebSocketClient webSocketClient, long keepalive) { + return builder(url, webSocketClient).keepalive(keepalive).build(); + } + /** * Return a builder for a {@link WebSocketGraphQlClient}. * @param url the GraphQL endpoint URL @@ -73,6 +83,16 @@ static Builder builder(String url, WebSocketClient webSocketClient) { return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient); } + /** + * Return a builder for a {@link WebSocketGraphQlClient}. + * @param url the GraphQL endpoint URL + * @param webSocketClient the underlying transport client to use + * @param keepalive the delay in seconds between sending ping messages, or 0 to disable + */ + static Builder builder(String url, WebSocketClient webSocketClient, long keepalive) { + return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive); + } + /** * Return a builder for a {@link WebSocketGraphQlClient}. * @param url the GraphQL endpoint URL @@ -82,6 +102,16 @@ static Builder builder(URI url, WebSocketClient webSocketClient) { return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient); } + /** + * Return a builder for a {@link WebSocketGraphQlClient}. + * @param url the GraphQL endpoint URL + * @param webSocketClient the underlying transport client to use + * @param keepalive the delay in seconds between sending ping messages, or 0 to disable + */ + static Builder builder(URI url, WebSocketClient webSocketClient, long keepalive) { + return new DefaultWebSocketGraphQlClientBuilder(url, webSocketClient, keepalive); + } + /** * Builder for a GraphQL over WebSocket client. @@ -94,6 +124,8 @@ interface Builder> extends WebGraphQlClient.Builder { @Override WebSocketGraphQlClient build(); + Builder keepalive(long keepalive); + } } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java index f89a00417..6dd004fbc 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/client/WebSocketGraphQlTransport.java @@ -16,6 +16,7 @@ package org.springframework.graphql.client; import java.net.URI; +import java.time.Duration; import java.util.Collections; import java.util.List; import java.util.Map; @@ -67,10 +68,12 @@ final class WebSocketGraphQlTransport implements GraphQlTransport { private final Mono graphQlSessionMono; + private final long keepalive; + WebSocketGraphQlTransport( URI url, @Nullable HttpHeaders headers, WebSocketClient client, CodecConfigurer codecConfigurer, - WebSocketGraphQlClientInterceptor interceptor) { + WebSocketGraphQlClientInterceptor interceptor, long keepalive) { Assert.notNull(url, "URI is required"); Assert.notNull(client, "WebSocketClient is required"); @@ -80,8 +83,9 @@ final class WebSocketGraphQlTransport implements GraphQlTransport { this.url = url; this.headers.putAll(headers != null ? headers : HttpHeaders.EMPTY); this.webSocketClient = client; + this.keepalive = keepalive; - this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor); + this.graphQlSessionHandler = new GraphQlSessionHandler(codecConfigurer, interceptor, keepalive); this.graphQlSessionMono = initGraphQlSession(this.url, this.headers, client, this.graphQlSessionHandler) .cacheInvalidateWhen(GraphQlSession::notifyWhenClosed); @@ -154,6 +158,10 @@ public Flux executeSubscription(GraphQlRequest request) { return this.graphQlSessionMono.flatMapMany(session -> session.executeSubscription(request)); } + public long getKeepAlive() { + return keepalive; + } + /** * Client {@code WebSocketHandler} for GraphQL that deals with WebSocket @@ -175,11 +183,15 @@ private static class GraphQlSessionHandler implements WebSocketHandler { private final AtomicBoolean stopped = new AtomicBoolean(); + private final long keepalive; - GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor) { + + GraphQlSessionHandler(CodecConfigurer codecConfigurer, WebSocketGraphQlClientInterceptor interceptor, + long keepalive) { this.codecDelegate = new CodecDelegate(codecConfigurer); this.interceptor = interceptor; this.graphQlSessionSink = Sinks.unsafe().one(); + this.keepalive = keepalive; } @@ -236,7 +248,7 @@ public Mono handle(WebSocketSession session) { session.send(connectionInitMono.concatWith(graphQlSession.getRequestFlux()) .map(message -> this.codecDelegate.encode(session, message))); - Mono receiveCompletion = session.receive() + Flux receiveCompletion = session.receive() .flatMap(webSocketMessage -> { if (sessionNotInitialized()) { try { @@ -277,6 +289,8 @@ public Mono handle(WebSocketSession session) { case COMPLETE: graphQlSession.handleComplete(message); break; + case PONG: + break; default: throw new IllegalStateException( "Unexpected message type: '" + message.getType() + "'"); @@ -290,10 +304,21 @@ public Mono handle(WebSocketSession session) { } } return Mono.empty(); - }) - .then(); + }); + + if (keepalive > 0) { + Duration keepAliveDuration = Duration.ofSeconds(keepalive); + receiveCompletion = receiveCompletion + .mergeWith(Flux.interval(keepAliveDuration, keepAliveDuration) + .flatMap(i -> { + graphQlSession.sendPing(null); + return Mono.empty(); + }) + ); + } + - return Mono.zip(sendCompletion, receiveCompletion).then(); + return Mono.zip(sendCompletion, receiveCompletion.then()).then(); } private boolean sessionNotInitialized() { @@ -454,6 +479,11 @@ public void sendPong(@Nullable Map payload) { this.requestSink.sendRequest(message); } + public void sendPing(@Nullable Map payload) { + GraphQlWebSocketMessage message = GraphQlWebSocketMessage.ping(payload); + this.requestSink.sendRequest(message); + } + // Inbound messages diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java b/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java index 712a3b3bd..1e093c976 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/MockGraphQlWebSocketServer.java @@ -110,6 +110,8 @@ private Publisher handleMessage(GraphQlWebSocketMessage GraphQlWebSocketMessage.complete(id)); case COMPLETE: return Flux.empty(); + case PING: + return Mono.just(GraphQlWebSocketMessage.pong(null)); default: return Flux.error(new IllegalStateException("Unexpected message: " + message)); } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/client/WebSocketGraphQlTransportTests.java b/spring-graphql/src/test/java/org/springframework/graphql/client/WebSocketGraphQlTransportTests.java index bd6371bc7..93818489f 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/client/WebSocketGraphQlTransportTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/client/WebSocketGraphQlTransportTests.java @@ -60,6 +60,7 @@ public class WebSocketGraphQlTransportTests { private final static Duration TIMEOUT = Duration.ofSeconds(5); private static final CodecDelegate CODEC_DELEGATE = new CodecDelegate(ClientCodecConfigurer.create()); + public static final int KEEPALIVE = 1; private final MockGraphQlWebSocketServer mockServer = new MockGraphQlWebSocketServer(); @@ -185,6 +186,22 @@ void pingHandling() { GraphQlWebSocketMessage.subscribe("1", new DefaultGraphQlRequest("{Query1}"))); } + @Test + void pingSending() throws InterruptedException { + + GraphQlRequest request = this.mockServer.expectOperation("{Sub1}").andStream(Flux.just(this.response1, response2)); + + StepVerifier.create(this.transport.executeSubscription(request)) + .expectNext(this.response1, response2).expectComplete() + .verify(TIMEOUT); + Thread.sleep(KEEPALIVE*1000 + 50); // wait for ping + + assertActualClientMessages( + GraphQlWebSocketMessage.connectionInit(null), + GraphQlWebSocketMessage.subscribe("1", request), + GraphQlWebSocketMessage.ping(null)); + } + @Test void start() { MockGraphQlWebSocketServer handler = new MockGraphQlWebSocketServer(); @@ -210,7 +227,7 @@ public Mono handleConnectionAck(Map ackPayload) { WebSocketGraphQlTransport transport = new WebSocketGraphQlTransport( - URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor); + URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), interceptor, KEEPALIVE); transport.start().block(TIMEOUT); @@ -324,7 +341,7 @@ void errorDuringResponseHandling() { private static WebSocketGraphQlTransport createTransport(WebSocketClient client) { return new WebSocketGraphQlTransport( URI.create("/"), HttpHeaders.EMPTY, client, ClientCodecConfigurer.create(), - new WebSocketGraphQlClientInterceptor() {}); + new WebSocketGraphQlClientInterceptor() {}, KEEPALIVE); } private void assertActualClientMessages(GraphQlWebSocketMessage... expectedMessages) {