diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java new file mode 100644 index 00000000..f41159c3 --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicates.java @@ -0,0 +1,167 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webflux; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.server.PathContainer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.web.cors.reactive.CorsUtils; +import org.springframework.web.reactive.function.server.RequestPredicate; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * {@link RequestPredicate} implementations tailored for GraphQL reactive endpoints. + * + * @author Brian Clozel + * @since 1.3.0 + */ +public class GraphQlRequestPredicates { + + private static final Log logger = LogFactory.getLog(GraphQlRequestPredicates.class); + + /** + * Create a {@link RequestPredicate predicate} that matches GraphQL HTTP requests for the configured path. + * + * @param path the path on which the GraphQL HTTP endpoint is mapped + * @see GraphQlHttpHandler + */ + public static RequestPredicate graphQlHttp(String path) { + return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE); + } + + /** + * Create a {@link RequestPredicate predicate} that matches GraphQL SSE over HTTP requests for the configured path. + * + * @param path the path on which the GraphQL SSE endpoint is mapped + * @see GraphQlSseHandler + */ + public static RequestPredicate graphQlSse(String path) { + return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM); + } + + private static class GraphQlHttpRequestPredicate implements RequestPredicate { + + private final PathPattern pattern; + + private final List acceptedMediaTypes; + + + GraphQlHttpRequestPredicate(String path, MediaType... accepted) { + Assert.notNull(path, "'path' must not be null"); + Assert.notEmpty(accepted, "'accepted' must not be empty"); + PathPatternParser parser = PathPatternParser.defaultInstance; + path = parser.initFullPathPattern(path); + this.pattern = parser.parse(path); + this.acceptedMediaTypes = Arrays.asList(accepted); + } + + @Override + public boolean test(ServerRequest request) { + return methodMatch(request, HttpMethod.POST) + && contentTypeMatch(request, MediaType.APPLICATION_JSON) + && acceptMatch(request, this.acceptedMediaTypes) + && pathMatch(request, this.pattern); + } + } + + private static boolean methodMatch(ServerRequest request, HttpMethod expected) { + HttpMethod actual = resolveMethod(request); + boolean methodMatch = expected.equals(actual); + traceMatch("Method", expected, actual, methodMatch); + return methodMatch; + } + + private static HttpMethod resolveMethod(ServerRequest request) { + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + String accessControlRequestMethod = + request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + if (accessControlRequestMethod != null) { + return HttpMethod.valueOf(accessControlRequestMethod); + } + } + return request.method(); + } + + private static boolean contentTypeMatch(ServerRequest request, MediaType expected) { + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + return true; + } + ServerRequest.Headers headers = request.headers(); + MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean contentTypeMatch = expected.includes(actual); + traceMatch("Content-Type", expected, actual, contentTypeMatch); + return contentTypeMatch; + } + + private static boolean acceptMatch(ServerRequest request, List expected) { + if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) { + return true; + } + ServerRequest.Headers headers = request.headers(); + List acceptedMediaTypes = acceptedMediaTypes(headers); + boolean match = false; + outer: + for (MediaType acceptedMediaType : acceptedMediaTypes) { + for (MediaType mediaType : expected) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break outer; + } + } + } + traceMatch("Accept", expected, acceptedMediaTypes, match); + return match; + } + + private static List acceptedMediaTypes(ServerRequest.Headers headers) { + List acceptedMediaTypes = headers.accept(); + if (acceptedMediaTypes.isEmpty()) { + acceptedMediaTypes = Collections.singletonList(MediaType.ALL); + } else { + MimeTypeUtils.sortBySpecificity(acceptedMediaTypes); + } + return acceptedMediaTypes; + } + + private static boolean pathMatch(ServerRequest request, PathPattern pattern) { + PathContainer pathContainer = request.requestPath().pathWithinApplication(); + boolean pathMatch = pattern.matches(pathContainer); + traceMatch("Pattern", pattern.getPatternString(), request.path(), pathMatch); + return pathMatch; + } + + private static void traceMatch(String prefix, Object desired, @Nullable Object actual, boolean match) { + if (logger.isTraceEnabled()) { + logger.trace(String.format("%s \"%s\" %s against value \"%s\"", + prefix, desired, match ? "matches" : "does not match", actual)); + } + } + +} diff --git a/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java new file mode 100644 index 00000000..e5f5103b --- /dev/null +++ b/spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicates.java @@ -0,0 +1,167 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webmvc; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.server.PathContainer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.MimeTypeUtils; +import org.springframework.web.cors.CorsUtils; +import org.springframework.web.servlet.function.RequestPredicate; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.util.pattern.PathPattern; +import org.springframework.web.util.pattern.PathPatternParser; + +/** + * {@link RequestPredicate} implementations tailored for GraphQL endpoints. + * + * @author Brian Clozel + * @since 1.3.0 + */ +public class GraphQlRequestPredicates { + + private static final Log logger = LogFactory.getLog(GraphQlRequestPredicates.class); + + /** + * Create a {@link RequestPredicate predicate} that matches GraphQL HTTP requests for the configured path. + * + * @param path the path on which the GraphQL HTTP endpoint is mapped + * @see GraphQlHttpHandler + */ + public static RequestPredicate graphQlHttp(String path) { + return new GraphQlHttpRequestPredicate(path, MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE); + } + + /** + * Create a {@link RequestPredicate predicate} that matches GraphQL SSE over HTTP requests for the configured path. + * + * @param path the path on which the GraphQL SSE endpoint is mapped + * @see GraphQlSseHandler + */ + public static RequestPredicate graphQlSse(String path) { + return new GraphQlHttpRequestPredicate(path, MediaType.TEXT_EVENT_STREAM); + } + + private static class GraphQlHttpRequestPredicate implements RequestPredicate { + + private final PathPattern pattern; + + private final List acceptedMediaTypes; + + + GraphQlHttpRequestPredicate(String path, MediaType... accepted) { + Assert.notNull(path, "'path' must not be null"); + Assert.notEmpty(accepted, "'accepted' must not be empty"); + PathPatternParser parser = PathPatternParser.defaultInstance; + path = parser.initFullPathPattern(path); + this.pattern = parser.parse(path); + this.acceptedMediaTypes = Arrays.asList(accepted); + } + + @Override + public boolean test(ServerRequest request) { + return methodMatch(request, HttpMethod.POST) + && contentTypeMatch(request, MediaType.APPLICATION_JSON) + && acceptMatch(request, this.acceptedMediaTypes) + && pathMatch(request, this.pattern); + } + } + + private static boolean methodMatch(ServerRequest request, HttpMethod expected) { + HttpMethod actual = resolveMethod(request); + boolean methodMatch = expected.equals(actual); + traceMatch("Method", expected, actual, methodMatch); + return methodMatch; + } + + private static HttpMethod resolveMethod(ServerRequest request) { + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + String accessControlRequestMethod = + request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD); + if (accessControlRequestMethod != null) { + return HttpMethod.valueOf(accessControlRequestMethod); + } + } + return request.method(); + } + + private static boolean contentTypeMatch(ServerRequest request, MediaType expected) { + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + return true; + } + ServerRequest.Headers headers = request.headers(); + MediaType actual = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM); + boolean contentTypeMatch = expected.includes(actual); + traceMatch("Content-Type", expected, actual, contentTypeMatch); + return contentTypeMatch; + } + + private static boolean acceptMatch(ServerRequest request, List expected) { + if (CorsUtils.isPreFlightRequest(request.servletRequest())) { + return true; + } + ServerRequest.Headers headers = request.headers(); + List acceptedMediaTypes = acceptedMediaTypes(headers); + boolean match = false; + outer: + for (MediaType acceptedMediaType : acceptedMediaTypes) { + for (MediaType mediaType : expected) { + if (acceptedMediaType.isCompatibleWith(mediaType)) { + match = true; + break outer; + } + } + } + traceMatch("Accept", expected, acceptedMediaTypes, match); + return match; + } + + private static List acceptedMediaTypes(ServerRequest.Headers headers) { + List acceptedMediaTypes = headers.accept(); + if (acceptedMediaTypes.isEmpty()) { + acceptedMediaTypes = Collections.singletonList(MediaType.ALL); + } else { + MimeTypeUtils.sortBySpecificity(acceptedMediaTypes); + } + return acceptedMediaTypes; + } + + private static boolean pathMatch(ServerRequest request, PathPattern pattern) { + PathContainer pathContainer = request.requestPath().pathWithinApplication(); + boolean pathMatch = pattern.matches(pathContainer); + traceMatch("Pattern", pattern.getPatternString(), request.path(), pathMatch); + return pathMatch; + } + + private static void traceMatch(String prefix, Object desired, @Nullable Object actual, boolean match) { + if (logger.isTraceEnabled()) { + logger.trace(String.format("%s \"%s\" %s against value \"%s\"", + prefix, desired, match ? "matches" : "does not match", actual)); + } + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java new file mode 100644 index 00000000..0aaea59d --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webflux/GraphQlRequestPredicatesTests.java @@ -0,0 +1,172 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webflux; + + +import java.util.Collections; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.reactive.function.server.RequestPredicate; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link GraphQlRequestPredicates}. + * + * @author Brian Clozel + */ +class GraphQlRequestPredicatesTests { + + @Nested + class HttpPredicatesTests { + + RequestPredicate httpPredicate = GraphQlRequestPredicates.graphQlHttp("/graphql"); + + @Test + void shouldAcceptGraphQlHttpRequest() { + ServerWebExchange exchange = createMatchingHttpExchange(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldAcceptCorsRequest() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(req -> req.method(HttpMethod.OPTIONS).header("Origin", "https://example.org") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldRejectRequestWithGetMethod() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(req -> req.method(HttpMethod.GET)).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentPath() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(req -> req.path("/invalid")).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentContentType() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(req -> req.headers(headers -> headers.setContentType(MediaType.TEXT_HTML))) + .build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithIncompatibleAccept() { + ServerWebExchange exchange = createMatchingHttpExchange() + .mutate().request(req -> req.headers(headers -> headers.setAccept(Collections.singletonList(MediaType.TEXT_HTML)))) + .build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + private MockServerWebExchange createMatchingHttpExchange() { + MockServerHttpRequest request = MockServerHttpRequest.post("/graphql") + .contentType(MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON, MediaType.APPLICATION_GRAPHQL_RESPONSE) + .build(); + return MockServerWebExchange.from(request); + } + + } + + @Nested + class SsePredicatesTests { + + RequestPredicate ssePredicate = GraphQlRequestPredicates.graphQlSse("/graphql"); + + @Test + void shouldAcceptGraphQlSseRequest() { + ServerWebExchange exchange = createMatchingSseExchange(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldAcceptCorsRequest() { + ServerWebExchange exchange = createMatchingSseExchange() + .mutate().request(req -> req.method(HttpMethod.OPTIONS).header("Origin", "https://example.org") + .header(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST")).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldRejectRequestWithGetMethod() { + ServerWebExchange exchange = createMatchingSseExchange() + .mutate().request(req -> req.method(HttpMethod.GET)).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentPath() { + ServerWebExchange exchange = createMatchingSseExchange() + .mutate().request(req -> req.path("/invalid")).build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentContentType() { + ServerWebExchange exchange = createMatchingSseExchange() + .mutate().request(req -> req.headers(headers -> headers.setContentType(MediaType.TEXT_HTML))) + .build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithIncompatibleAccept() { + ServerWebExchange exchange = createMatchingSseExchange() + .mutate().request(req -> req.headers(headers -> headers.setAccept(Collections.singletonList(MediaType.TEXT_HTML)))) + .build(); + ServerRequest serverRequest = ServerRequest.create(exchange, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + private MockServerWebExchange createMatchingSseExchange() { + MockServerHttpRequest request = MockServerHttpRequest.post("/graphql") + .contentType(MediaType.APPLICATION_JSON) + .accept(MediaType.TEXT_EVENT_STREAM) + .build(); + return MockServerWebExchange.from(request); + } + } + +} diff --git a/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java new file mode 100644 index 00000000..807f05bf --- /dev/null +++ b/spring-graphql/src/test/java/org/springframework/graphql/server/webmvc/GraphQlRequestPredicatesTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2020-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.graphql.server.webmvc; + + +import java.util.Collections; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import org.springframework.http.HttpHeaders; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.web.servlet.function.RequestPredicate; +import org.springframework.web.servlet.function.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link GraphQlRequestPredicates}. + * + * @author Brian Clozel + */ +class GraphQlRequestPredicatesTests { + + @Nested + class HttpPredicatesTests { + + RequestPredicate httpPredicate = GraphQlRequestPredicates.graphQlHttp("/graphql"); + + @Test + void shouldAcceptGraphQlHttpRequest() { + MockHttpServletRequest request = createMatchingHttpRequest(); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldAcceptCorsRequest() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.setMethod("OPTIONS"); + request.addHeader("Origin", "https://example.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldRejectRequestWithGetMethod() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.setMethod("GET"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentPath() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.setRequestURI("/invalid"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentContentType() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.setContentType("text/xml"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithIncompatibleAccept() { + MockHttpServletRequest request = createMatchingHttpRequest(); + request.removeHeader("Accept"); + request.addHeader("Accept", "text/xml"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(httpPredicate.test(serverRequest)).isFalse(); + } + + private MockHttpServletRequest createMatchingHttpRequest() { + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/graphql"); + request.setContentType("application/json"); + request.addHeader("Accept", "application/graphql-response+json"); + return request; + } + + } + + @Nested + class SsePredicatesTests { + + RequestPredicate ssePredicate = GraphQlRequestPredicates.graphQlSse("/graphql"); + + @Test + void shouldAcceptGraphQlSseRequest() { + MockHttpServletRequest request = createMatchingSseRequest(); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldAcceptCorsRequest() { + MockHttpServletRequest request = createMatchingSseRequest(); + request.setMethod("OPTIONS"); + request.addHeader("Origin", "https://example.com"); + request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "POST"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isTrue(); + } + + @Test + void shouldRejectRequestWithGetMethod() { + MockHttpServletRequest request = createMatchingSseRequest(); + request.setMethod("GET"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentPath() { + MockHttpServletRequest request = createMatchingSseRequest(); + request.setRequestURI("/invalid"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithDifferentContentType() { + MockHttpServletRequest request = createMatchingSseRequest(); + request.setContentType("text/xml"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + @Test + void shouldRejectRequestWithIncompatibleAccept() { + MockHttpServletRequest request = createMatchingSseRequest(); + request.removeHeader("Accept"); + request.addHeader("Accept", "text/xml"); + ServerRequest serverRequest = ServerRequest.create(request, Collections.emptyList()); + assertThat(ssePredicate.test(serverRequest)).isFalse(); + } + + private MockHttpServletRequest createMatchingSseRequest() { + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/graphql"); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Accept", "text/event-stream"); + return request; + } + } + +}