Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors ClientHelper to combine header logic #30620

Merged
merged 2 commits into from
May 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;

import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* Utility class to help with the execution of requests made using a {@link Client} such that they
* have the origin as a transient and listeners have the appropriate context upon invocation
*/
public final class ClientHelper {

/**
* List of headers that are related to security
*/
public static final Set<String> SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationServiceField.RUN_AS_USER_HEADER,
AuthenticationField.AUTHENTICATION_KEY);

public static final String ACTION_ORIGIN_TRANSIENT_NAME = "action.origin";
public static final String SECURITY_ORIGIN = "security";
public static final String WATCHER_ORIGIN = "watcher";
Expand Down Expand Up @@ -78,6 +90,82 @@ RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>>
}
}

/**
* Execute a client operation and return the response, try to run an action
* with least privileges, when headers exist
*
* @param headers
* Request headers, ideally including security headers
* @param origin
* The origin to fall back to if there are no security headers
* @param client
* The client used to query
* @param supplier
* The action to run
* @return An instance of the response class
*/
public static <T extends ActionResponse> T executeWithHeaders(Map<String, String> headers, String origin, Client client,
Supplier<T> supplier) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

// no security headers, we will have to use the xpack internal user for
// our execution by specifying the origin
if (filteredHeaders.isEmpty()) {
try (ThreadContext.StoredContext ignore = stashWithOrigin(client.threadPool().getThreadContext(), origin)) {
return supplier.get();
}
} else {
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashContext()) {
client.threadPool().getThreadContext().copyHeaders(filteredHeaders.entrySet());
return supplier.get();
}
}
}

/**
* Execute a client operation asynchronously, try to run an action with
* least privileges, when headers exist
*
* @param headers
* Request headers, ideally including security headers
* @param origin
* The origin to fall back to if there are no security headers
* @param action
* The action to execute
* @param request
* The request object for the action
* @param listener
* The listener to call when the action is complete
*/
public static <Request extends ActionRequest, Response extends ActionResponse,
RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>> void executeWithHeadersAsync(
Map<String, String> headers, String origin, Client client, Action<Request, Response, RequestBuilder> action, Request request,
ActionListener<Response> listener) {

Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

final ThreadContext threadContext = client.threadPool().getThreadContext();

// No headers (e.g. security not installed/in use) so execute as origin
if (filteredHeaders.isEmpty()) {
ClientHelper.executeAsyncWithOrigin(client, origin, action, request, listener);
} else {
// Otherwise stash the context and copy in the saved headers before executing
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = stashWithHeaders(threadContext, filteredHeaders)) {
client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
}
}
}

private static ThreadContext.StoredContext stashWithHeaders(ThreadContext threadContext, Map<String, String> headers) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
threadContext.copyHeaders(headers.entrySet());
return storedContext;
}

private static final class ClientWithOrigin extends FilterClient {

private final String origin;
Expand All @@ -98,5 +186,4 @@ RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>>
}
}
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.PersistentTask;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
Expand All @@ -35,8 +38,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.PersistentTask;

import java.io.IOException;
import java.util.Collection;
Expand Down Expand Up @@ -303,7 +304,7 @@ public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadCo
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
datafeedConfig = builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.MlClientHelper;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -304,7 +304,7 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadC
if (threadContext != null) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
}
Expand Down
Loading