Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] [Cleanup] Remove duplication in Netty4HttpRequestHeaderVerifier #3564

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
if (API_AUTHTOKEN_SUFFIX.equals(suffix)) {
// Verficiation of SAML ASC endpoint only works with RestRequests
if (!(request instanceof OpenSearchRequest)) {
throw new SecurityRequestChannelUnsupported();
throw new SecurityRequestChannelUnsupported(
API_AUTHTOKEN_SUFFIX + " not supported for request of type " + request.getClass().getName()
);
} else {
final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request;
final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest();
Expand All @@ -200,6 +202,9 @@ public Optional<SecurityResponse> reRequestAuthentication(final SecurityRequest
new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)), "")
);
} catch (Exception e) {
if (e instanceof SecurityRequestChannelUnsupported) {
throw (SecurityRequestChannelUnsupported) e;
}
log.error("Error in reRequestAuthentication()", e);
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,29 @@
public interface SecurityRequest {

/** Collection of headers associated with the request */
public Map<String, List<String>> getHeaders();
Map<String, List<String>> getHeaders();

/** The SSLEngine associated with the request */
public SSLEngine getSSLEngine();
SSLEngine getSSLEngine();

/** The path of the request */
public String path();
String path();

/** The method type of this request */
public Method method();
Method method();

/** The remote address of the request, possible null */
public Optional<InetSocketAddress> getRemoteAddress();
Optional<InetSocketAddress> getRemoteAddress();

/** The full uri of the request */
public String uri();
String uri();

/** Finds the first value of the matching header or null */
default public String header(final String headerName) {
default String header(final String headerName) {
final Optional<Map<String, List<String>>> headersMap = Optional.ofNullable(getHeaders());
return headersMap.map(headers -> headers.get(headerName)).map(List::stream).flatMap(Stream::findFirst).orElse(null);
}

/** The parameters associated with this request */
public Map<String, String> params();
Map<String, String> params();
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

package org.opensearch.security.filter;

import org.opensearch.OpenSearchException;

/** Thrown when a security rest channel is not supported */
public class SecurityRequestChannelUnsupported extends RuntimeException {
public class SecurityRequestChannelUnsupported extends OpenSearchException {

public SecurityRequestChannelUnsupported(String msg, Object... args) {
super(msg, args);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.net.ssl.SSLPeerUnverifiedException;
Expand All @@ -48,7 +47,6 @@
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.security.auditlog.AuditLog;
import org.opensearch.security.auditlog.AuditLog.Origin;
import org.opensearch.security.auth.BackendRegistry;
Expand Down Expand Up @@ -299,9 +297,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t
return;
}

Matcher matcher = PATTERN_PATH_PREFIX.matcher(requestChannel.path());
final String suffix = matcher.matches() ? matcher.group(2) : null;
if (requestChannel.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix)) && !(WHO_AM_I_SUFFIX.equals(suffix))) {
if (!SecurityRestUtils.shouldSkipAuthentication(requestChannel)) {
if (!registry.authenticate(requestChannel)) {
// another roundtrip
org.apache.logging.log4j.ThreadContext.remove("user");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.filter;

import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX;
import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX;
import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX;

import java.util.regex.Matcher;

import org.opensearch.rest.RestRequest.Method;

public class SecurityRestUtils {
public static String path(final String uri) {
final int index = uri.indexOf('?');
Expand All @@ -9,4 +28,15 @@ public static String path(final String uri) {
return uri;
}
}

public static boolean shouldSkipAuthentication(SecurityRequestChannel request) {
Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path());
final String suffix = matcher.matches() ? matcher.group(2) : null;

boolean shouldSkipAuthentication = (request.method() == Method.OPTIONS)
|| HEALTH_SUFFIX.equals(suffix)
|| WHO_AM_I_SUFFIX.equals(suffix);

return shouldSkipAuthentication;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.util.ReferenceCountUtil;
import org.opensearch.ExceptionsHelper;
Expand All @@ -20,7 +19,6 @@
import io.netty.channel.ChannelHandlerContext;
import org.opensearch.http.netty4.Netty4HttpChannel;
import org.opensearch.http.netty4.Netty4HttpServerTransport;
import org.opensearch.rest.RestUtils;
import org.opensearch.security.filter.SecurityRequestChannel;
import org.opensearch.security.filter.SecurityRequestChannelUnsupported;
import org.opensearch.security.filter.SecurityRequestFactory;
Expand All @@ -34,12 +32,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.OpenSearchSecurityException;

import java.util.regex.Matcher;

import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.API_AUTHTOKEN_SUFFIX;
import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX;
import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX;
import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX;
import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE;
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS;
Expand Down Expand Up @@ -83,34 +75,21 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro
ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.FALSE);

final Netty4HttpChannel httpChannel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
String rawPath = SecurityRestUtils.path(msg.uri());
String path = RestUtils.decodeComponent(rawPath);
Matcher matcher = PATTERN_PATH_PREFIX.matcher(path);
final String suffix = matcher.matches() ? matcher.group(2) : null;
if (API_AUTHTOKEN_SUFFIX.equals(suffix)) {
ctx.fireChannelRead(msg);
return;
}

final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(msg, httpChannel);
ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
injectUser(msg, threadContext);

boolean shouldSkipAuthentication = HttpMethod.OPTIONS.equals(msg.method())
|| HEALTH_SUFFIX.equals(suffix)
|| WHO_AM_I_SUFFIX.equals(suffix);

if (!shouldSkipAuthentication) {
// If request channel is completed and a response is sent, then there was a failure during authentication
restFilter.checkAndAuthenticateRequest(requestChannel);
}
// If request channel is completed and a response is sent, then there was a failure during authentication
restFilter.checkAndAuthenticateRequest(requestChannel);

ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false);
ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore);

requestChannel.getQueuedResponse().ifPresent(response -> ctx.channel().attr(EARLY_RESPONSE).set(response));

boolean shouldSkipAuthentication = SecurityRestUtils.shouldSkipAuthentication(requestChannel);
boolean shouldDecompress = !shouldSkipAuthentication && requestChannel.getQueuedResponse().isEmpty();

if (requestChannel.getQueuedResponse().isEmpty() || shouldSkipAuthentication) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.opensearch.security.filter;

import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.Test;
import org.opensearch.http.netty4.Netty4HttpChannel;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

public class SecurityRestUtilsTests {

@Test
public void testShouldSkipAuthentication_positive() {
FullHttpRequest request1 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/");
NettyRequestChannel requestChannel1 = new NettyRequestChannel(request1, mock(Netty4HttpChannel.class));

assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel1));

FullHttpRequest request2 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_plugins/_security/health");
NettyRequestChannel requestChannel2 = new NettyRequestChannel(request2, mock(Netty4HttpChannel.class));

assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel2));

FullHttpRequest request3 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_plugins/_security/whoami");
NettyRequestChannel requestChannel3 = new NettyRequestChannel(request3, mock(Netty4HttpChannel.class));

assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel3));
}

@Test
public void testShouldSkipAuthentication_negative() {
FullHttpRequest request1 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
NettyRequestChannel requestChannel1 = new NettyRequestChannel(request1, mock(Netty4HttpChannel.class));

assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel1));

FullHttpRequest request2 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_cluster/health");
NettyRequestChannel requestChannel2 = new NettyRequestChannel(request2, mock(Netty4HttpChannel.class));

assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel2));

FullHttpRequest request3 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/my-index/_search");
NettyRequestChannel requestChannel3 = new NettyRequestChannel(request3, mock(Netty4HttpChannel.class));

assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel3));
}

@Test
public void testGetRawPath() {
String rawPathWithParams = "/_cluster/health?pretty";
String rawPathWithoutParams = "/my-index/search";

String path1 = SecurityRestUtils.path(rawPathWithParams);
String path2 = SecurityRestUtils.path(rawPathWithoutParams);

assertTrue("/_cluster/health".equals(path1));
assertTrue("/my-index/search".equals(path2));
}
}