From b1cb364ab4e7a42e32077e49c868856d01742534 Mon Sep 17 00:00:00 2001 From: rstoyanchev Date: Mon, 22 Apr 2024 16:02:53 +0100 Subject: [PATCH] Update GraphQlRequestPredicates to map application/graphql content Closes gh-948 --- .../webflux/GraphQlRequestPredicates.java | 43 +++++++++++-------- .../webmvc/GraphQlRequestPredicates.java | 43 +++++++++++-------- .../GraphQlRequestPredicatesTests.java | 12 ++++++ .../webmvc/GraphQlRequestPredicatesTests.java | 8 ++++ 4 files changed, 72 insertions(+), 34 deletions(-) diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java index 28fd69e4..e63f3f69 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java @@ -16,7 +16,6 @@ package org.springframework.graphql.server.webflux; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -40,6 +39,7 @@ * {@link RequestPredicate} implementations tailored for GraphQL reactive endpoints. * * @author Brian Clozel + * @author Rossen Stoyanchev * @since 1.3.0 */ public final class GraphQlRequestPredicates { @@ -56,7 +56,8 @@ private GraphQlRequestPredicates() { * @see GraphQlHttpHandler */ public static RequestPredicate graphQlHttp(String path) { - return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE); + return new GraphQlHttpRequestPredicate( + path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE)); } /** @@ -65,59 +66,67 @@ public static RequestPredicate graphQlHttp(String path) { * @see GraphQlSseHandler */ public static RequestPredicate graphQlSse(String path) { - return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM); + return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM)); } private static class GraphQlHttpRequestPredicate implements RequestPredicate { private final PathPattern pattern; + private final List contentTypes; + private final List acceptedMediaTypes; - GraphQlHttpRequestPredicate(String path, MediaType... accepted) { + GraphQlHttpRequestPredicate(String path, List accepted) { Assert.notNull(path, "'path' must not be null"); Assert.notEmpty(accepted, "'accepted' must not be empty"); PathPatternParser parser = PathPatternParser.defaultInstance; path = parser.initFullPathPattern(path); this.pattern = parser.parse(path); - this.acceptedMediaTypes = Arrays.asList(accepted); + this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql")); + this.acceptedMediaTypes = accepted; } @Override public boolean test(ServerRequest request) { - return methodMatch(request, HttpMethod.POST) - && contentTypeMatch(request, MediaType.APPLICATION_JSON) + return httpMethodMatch(request, HttpMethod.POST) + && contentTypeMatch(request, this.contentTypes) && acceptMatch(request, this.acceptedMediaTypes) && pathMatch(request, this.pattern); } - private static boolean methodMatch(ServerRequest request, HttpMethod expected) { - HttpMethod actual = resolveMethod(request); + private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) { + HttpMethod actual = resolveHttpMethod(request); boolean methodMatch = expected.equals(actual); traceMatch("Method", expected, actual, methodMatch); return methodMatch; } - private static HttpMethod resolveMethod(ServerRequest request) { + private static HttpMethod resolveHttpMethod(ServerRequest request) { if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { - String accessControlRequestMethod = - request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); - if (accessControlRequestMethod != null) { - return HttpMethod.valueOf(accessControlRequestMethod); + String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + if (httpMethod != null) { + return HttpMethod.valueOf(httpMethod); } } return request.method(); } - private static boolean contentTypeMatch(ServerRequest request, MediaType expected) { + private static boolean contentTypeMatch(ServerRequest request, List contentTypes) { if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { return true; } ServerRequest.Headers headers = request.headers(); MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); - boolean contentTypeMatch = expected.includes(actual); - traceMatch("Content-Type", expected, actual, contentTypeMatch); + boolean contentTypeMatch = false; + for (MediaType contentType : contentTypes) { + contentTypeMatch = contentType.includes(actual); + traceMatch("Content-Type", contentTypes, actual, contentTypeMatch); + if (contentTypeMatch) { + break; + } + } return contentTypeMatch; } diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java index 96e6bcbb..d9b14601 100644 --- a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java @@ -16,7 +16,6 @@ package org.springframework.graphql.server.webmvc; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -40,6 +39,7 @@ * {@link RequestPredicate} implementations tailored for GraphQL endpoints. * * @author Brian Clozel + * @author Rossen Stoyanchev * @since 1.3.0 */ public final class GraphQlRequestPredicates { @@ -56,7 +56,8 @@ private GraphQlRequestPredicates() { * @see GraphQlHttpHandler */ public static RequestPredicate graphQlHttp(String path) { - return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE); + return new GraphQlHttpRequestPredicate( + path, List.of(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE)); } /** @@ -65,59 +66,67 @@ public static RequestPredicate graphQlHttp(String path) { * @see GraphQlSseHandler */ public static RequestPredicate graphQlSse(String path) { - return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM); + return new GraphQlHttpRequestPredicate(path, List.of(MediaType.TEXT_EVENT_STREAM)); } private static class GraphQlHttpRequestPredicate implements RequestPredicate { private final PathPattern pattern; + private final List contentTypes; + private final List acceptedMediaTypes; - GraphQlHttpRequestPredicate(String path, MediaType... accepted) { + GraphQlHttpRequestPredicate(String path, List accepted) { Assert.notNull(path, "'path' must not be null"); Assert.notEmpty(accepted, "'accepted' must not be empty"); PathPatternParser parser = PathPatternParser.defaultInstance; path = parser.initFullPathPattern(path); this.pattern = parser.parse(path); - this.acceptedMediaTypes = Arrays.asList(accepted); + this.contentTypes = List.of(MediaType.APPLICATION_JSON, MediaType.parseMediaType("application/graphql")); + this.acceptedMediaTypes = accepted; } @Override public boolean test(ServerRequest request) { - return methodMatch(request, HttpMethod.POST) - && contentTypeMatch(request, MediaType.APPLICATION_JSON) + return httpMethodMatch(request, HttpMethod.POST) + && contentTypeMatch(request, this.contentTypes) && acceptMatch(request, this.acceptedMediaTypes) && pathMatch(request, this.pattern); } - private static boolean methodMatch(ServerRequest request, HttpMethod expected) { - HttpMethod actual = resolveMethod(request); + private static boolean httpMethodMatch(ServerRequest request, HttpMethod expected) { + HttpMethod actual = resolveHttpMethod(request); boolean methodMatch = expected.equals(actual); traceMatch("Method", expected, actual, methodMatch); return methodMatch; } - private static HttpMethod resolveMethod(ServerRequest request) { + private static HttpMethod resolveHttpMethod(ServerRequest request) { if (CorsUtils.isPreFlightRequest(request.servletRequest())) { - String accessControlRequestMethod = - request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); - if (accessControlRequestMethod != null) { - return HttpMethod.valueOf(accessControlRequestMethod); + String httpMethod = request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + if (httpMethod != null) { + return HttpMethod.valueOf(httpMethod); } } return request.method(); } - private static boolean contentTypeMatch(ServerRequest request, MediaType expected) { + private static boolean contentTypeMatch(ServerRequest request, List contentTypes) { if (CorsUtils.isPreFlightRequest(request.servletRequest())) { return true; } ServerRequest.Headers headers = request.headers(); MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); - boolean contentTypeMatch = expected.includes(actual); - traceMatch("Content-Type", expected, actual, contentTypeMatch); + boolean contentTypeMatch = false; + for (MediaType contentType : contentTypes) { + contentTypeMatch = contentType.includes(actual); + traceMatch("Content-Type", contentTypes, actual, contentTypeMatch); + if (contentTypeMatch) { + break; + } + } return contentTypeMatch; } diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java index 710b1ab0..132147d9 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java @@ -77,6 +77,18 @@ void shouldRejectRequestWithDifferentPath() { assertThat(httpPredicate.test(serverRequest)).isFalse(); } + @Test + void shouldMapApplicationGraphQlRequestContent() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(builder -> builder.headers(headers -> { + MediaType contentType = MediaType.parseMediaType("application/graphql"); + headers.setContentType(contentType); + })) + .build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + @Test void shouldRejectRequestWithDifferentContentType() { ServerWebExchange exchange = createMatchingHttpExchange() diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java index 38ba2e98..77a04714 100644 --- a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java @@ -74,6 +74,14 @@ void shouldRejectRequestWithDifferentPath() { assertThat(httpPredicate.test(serverRequest)).isFalse(); } + @Test + void shouldMapApplicationGraphQlRequestContent() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.setContentType("application/graphql"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + @Test void shouldRejectRequestWithDifferentContentType() { MockHttpServletRequest request = createMatchingHttpRequest();