Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zen2: Add join validation #37203

Merged
merged 5 commits into from
Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@
import org.elasticsearch.transport.TransportService;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
Expand Down Expand Up @@ -117,6 +119,7 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
private final LeaderChecker leaderChecker;
private final FollowersChecker followersChecker;
private final ClusterApplier clusterApplier;
private final Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators;
@Nullable
private Releasable electionScheduler;
@Nullable
Expand All @@ -139,13 +142,14 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
public Coordinator(String nodeName, Settings settings, ClusterSettings clusterSettings, TransportService transportService,
NamedWriteableRegistry namedWriteableRegistry, AllocationService allocationService, MasterService masterService,
Supplier<CoordinationState.PersistedState> persistedStateSupplier, UnicastHostsProvider unicastHostsProvider,
ClusterApplier clusterApplier, Random random) {
ClusterApplier clusterApplier, Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators, Random random) {
super(settings);
this.settings = settings;
this.transportService = transportService;
this.masterService = masterService;
this.onJoinValidators = JoinTaskExecutor.addBuiltInJoinValidators(onJoinValidators);
this.joinHelper = new JoinHelper(settings, allocationService, masterService, transportService,
this::getCurrentTerm, this::handleJoinRequest, this::joinLeaderInTerm);
this::getCurrentTerm, this::handleJoinRequest, this::joinLeaderInTerm, this.onJoinValidators);
this.persistedStateSupplier = persistedStateSupplier;
this.discoverySettings = new DiscoverySettings(settings, clusterSettings);
this.lastKnownLeader = Optional.empty();
Expand Down Expand Up @@ -277,6 +281,11 @@ PublishWithJoinResponse handlePublishRequest(PublishRequest publishRequest) {
+ lastKnownLeader + ", rejecting");
}

if (publishRequest.getAcceptedState().term() > coordinationState.get().getLastAcceptedState().term()) {
// only do join validation if we have not accepted state from this master yet
onJoinValidators.forEach(a -> a.accept(getLocalNode(), publishRequest.getAcceptedState()));
}

ensureTermAtLeast(sourceNode, publishRequest.getAcceptedState().term());
final PublishResponse publishResponse = coordinationState.get().handlePublishRequest(publishRequest);

Expand Down Expand Up @@ -389,6 +398,41 @@ private void handleJoinRequest(JoinRequest joinRequest, JoinHelper.JoinCallback
logger.trace("handleJoinRequest: as {}, handling {}", mode, joinRequest);
transportService.connectToNode(joinRequest.getSourceNode());

final ClusterState stateForJoinValidation = getStateForMasterService();

if (stateForJoinValidation.nodes().isLocalNodeElectedMaster()) {
onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation));
if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) {
// we do this in a couple of places including the cluster update thread. This one here is really just best effort
// to ensure we fail as fast as possible.
JoinTaskExecutor.ensureMajorVersionBarrier(joinRequest.getSourceNode().getVersion(),
stateForJoinValidation.getNodes().getMinNodeVersion());
}

// validate the join on the joining node, will throw a failure if it fails the validation
joinHelper.sendValidateJoinRequest(joinRequest.getSourceNode(), stateForJoinValidation, new ActionListener<Empty>() {
@Override
public void onResponse(Empty empty) {
try {
processJoinRequest(joinRequest, joinCallback);
} catch (Exception e) {
joinCallback.onFailure(e);
}
}

@Override
public void onFailure(Exception e) {
logger.warn(() -> new ParameterizedMessage("failed to validate incoming join request from node [{}]",
joinRequest.getSourceNode()), e);
joinCallback.onFailure(new IllegalStateException("failure when sending a validation request to node", e));
}
});
} else {
processJoinRequest(joinRequest, joinCallback);
}
}

private void processJoinRequest(JoinRequest joinRequest, JoinHelper.JoinCallback joinCallback) {
final Optional<Join> optionalJoin = joinRequest.getOptionalJoin();
synchronized (mutex) {
final CoordinationState coordState = coordinationState.get();
Expand Down Expand Up @@ -514,7 +558,7 @@ Mode getMode() {
}

// visible for testing
public DiscoveryNode getLocalNode() {
DiscoveryNode getLocalNode() {
return transportService.getLocalNode();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.ClusterStateTaskListener;
Expand All @@ -40,15 +41,18 @@
import org.elasticsearch.discovery.zen.ZenDiscovery;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPool.Names;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponse.Empty;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
Expand All @@ -64,6 +68,7 @@ public class JoinHelper {
private static final Logger logger = LogManager.getLogger(JoinHelper.class);

public static final String JOIN_ACTION_NAME = "internal:cluster/coordination/join";
public static final String VALIDATE_JOIN_ACTION_NAME = "internal:cluster/coordination/join/validate";
public static final String START_JOIN_ACTION_NAME = "internal:cluster/coordination/start_join";

// the timeout for each join attempt
Expand All @@ -80,7 +85,8 @@ public class JoinHelper {

public JoinHelper(Settings settings, AllocationService allocationService, MasterService masterService,
TransportService transportService, LongSupplier currentTermSupplier,
BiConsumer<JoinRequest, JoinCallback> joinHandler, Function<StartJoinRequest, Join> joinLeaderInTerm) {
BiConsumer<JoinRequest, JoinCallback> joinHandler, Function<StartJoinRequest, Join> joinLeaderInTerm,
Collection<BiConsumer<DiscoveryNode, ClusterState>> joinValidators) {
this.masterService = masterService;
this.transportService = transportService;
this.joinTimeout = JOIN_TIMEOUT_SETTING.get(settings);
Expand Down Expand Up @@ -123,9 +129,19 @@ public ClusterTasksResult<JoinTaskExecutor.Task> execute(ClusterState currentSta
channel.sendResponse(Empty.INSTANCE);
});

transportService.registerRequestHandler(VALIDATE_JOIN_ACTION_NAME,
MembershipAction.ValidateJoinRequest::new, ThreadPool.Names.GENERIC,
(request, channel, task) -> {
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
channel.sendResponse(Empty.INSTANCE);
});

transportService.registerRequestHandler(MembershipAction.DISCOVERY_JOIN_VALIDATE_ACTION_NAME,
() -> new MembershipAction.ValidateJoinRequest(), ThreadPool.Names.GENERIC,
(request, channel, task) -> channel.sendResponse(Empty.INSTANCE)); // TODO: implement join validation
MembershipAction.ValidateJoinRequest::new, ThreadPool.Names.GENERIC,
(request, channel, task) -> {
joinValidators.forEach(action -> action.accept(transportService.getLocalNode(), request.getState()));
channel.sendResponse(Empty.INSTANCE);
});

transportService.registerRequestHandler(
ZenDiscovery.DISCOVERY_REJOIN_ACTION_NAME, ZenDiscovery.RejoinClusterRequest::new, ThreadPool.Names.SAME,
Expand Down Expand Up @@ -244,6 +260,29 @@ public String executor() {
});
}

public void sendValidateJoinRequest(DiscoveryNode node, ClusterState state, ActionListener<TransportResponse.Empty> listener) {
final String actionName;
if (Coordinator.isZen1Node(node)) {
actionName = MembershipAction.DISCOVERY_JOIN_VALIDATE_ACTION_NAME;
} else {
actionName = VALIDATE_JOIN_ACTION_NAME;
}
transportService.sendRequest(node, actionName,
new MembershipAction.ValidateJoinRequest(state),
TransportRequestOptions.builder().withTimeout(joinTimeout).build(),
new EmptyTransportResponseHandler(ThreadPool.Names.GENERIC) {
@Override
public void handleResponse(TransportResponse.Empty response) {
listener.onResponse(response);
}

@Override
public void handleException(TransportException exp) {
listener.onFailure(exp);
}
});
}

public interface JoinCallback {
void onSuccess();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
import org.elasticsearch.cluster.routing.allocation.AllocationService;
import org.elasticsearch.discovery.DiscoverySettings;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;

import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;

Expand Down Expand Up @@ -259,4 +263,15 @@ public static void ensureMajorVersionBarrier(Version joiningNodeVersion, Version
"All nodes in the cluster are of a higher major [" + clusterMajor + "].");
}
}

public static Collection<BiConsumer<DiscoveryNode,ClusterState>> addBuiltInJoinValidators(
Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators) {
final Collection<BiConsumer<DiscoveryNode, ClusterState>> validators = new ArrayList<>();
validators.add((node, state) -> {
ensureNodesCompatibility(node.getVersion(), state.getNodes());
ensureIndexCompatibility(node.getVersion(), state.getMetaData());
});
validators.addAll(onJoinValidators);
return Collections.unmodifiableCollection(validators);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,13 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
in.setVersion(request.version());
// If true we received full cluster state - otherwise diffs
if (in.readBoolean()) {
final ClusterState incomingState = ClusterState.readFrom(in, transportService.getLocalNode());
final ClusterState incomingState;
try {
incomingState = ClusterState.readFrom(in, transportService.getLocalNode());
} catch (Exception e){
logger.warn("unexpected error while deserializing an incoming cluster state", e);
throw e;
}
fullClusterStateReceivedCount.incrementAndGet();
logger.debug("received full cluster state version [{}] with size [{}]", incomingState.version(),
request.bytes().length());
Expand All @@ -400,10 +406,20 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
final ClusterState lastSeen = lastSeenClusterState.get();
if (lastSeen == null) {
logger.debug("received diff for but don't have any local cluster state - requesting full state");
incompatibleClusterStateDiffReceivedCount.incrementAndGet();
throw new IncompatibleClusterStateVersionException("have no local cluster state");
} else {
Diff<ClusterState> diff = ClusterState.readDiffFrom(in, lastSeen.nodes().getLocalNode());
final ClusterState incomingState = diff.apply(lastSeen); // might throw IncompatibleClusterStateVersionException
final ClusterState incomingState;
try {
Diff<ClusterState> diff = ClusterState.readDiffFrom(in, lastSeen.nodes().getLocalNode());
incomingState = diff.apply(lastSeen); // might throw IncompatibleClusterStateVersionException
} catch (IncompatibleClusterStateVersionException e) {
incompatibleClusterStateDiffReceivedCount.incrementAndGet();
throw e;
} catch (Exception e){
logger.warn("unexpected error while deserializing an incoming cluster state", e);
throw e;
}
compatibleClusterStateDiffReceivedCount.incrementAndGet();
logger.debug("received diff cluster state version [{}] with uuid [{}], diff size [{}]",
incomingState.version(), incomingState.stateUUID(), request.bytes().length());
Expand All @@ -412,12 +428,6 @@ private PublishWithJoinResponse handleIncomingPublishRequest(BytesTransportReque
return response;
}
}
} catch (IncompatibleClusterStateVersionException e) {
incompatibleClusterStateDiffReceivedCount.incrementAndGet();
throw e;
} catch (Exception e) {
logger.warn("unexpected error while deserializing an incoming cluster state", e);
throw e;
} finally {
IOUtils.close(in);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ public DiscoveryModule(Settings settings, ThreadPool threadPool, TransportServic
Map<String, Supplier<Discovery>> discoveryTypes = new HashMap<>();
discoveryTypes.put(ZEN_DISCOVERY_TYPE,
() -> new ZenDiscovery(settings, threadPool, transportService, namedWriteableRegistry, masterService, clusterApplier,
clusterSettings, hostsProvider, allocationService, Collections.unmodifiableCollection(joinValidators), gatewayMetaState));
clusterSettings, hostsProvider, allocationService, joinValidators, gatewayMetaState));
discoveryTypes.put(ZEN2_DISCOVERY_TYPE, () -> new Coordinator(NODE_NAME_SETTING.get(settings), settings, clusterSettings,
transportService, namedWriteableRegistry, allocationService, masterService,
() -> gatewayMetaState.getPersistedState(settings, (ClusterApplierService) clusterApplier), hostsProvider, clusterApplier,
Randomness.get()));
joinValidators, Randomness.get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We say Collections.unmodifiableCollection(joinValidators) above. Should we also do so here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not necessary because it's wrapped again with addBuiltInJoinValidators. I will remove the unmodifiableCollection call above for uniformity.

discoveryTypes.put("single-node", () -> new SingleNodeDiscovery(settings, transportService, masterService, clusterApplier,
gatewayMetaState));
for (DiscoveryPlugin plugin : plugins) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public static class ValidateJoinRequest extends TransportRequest {

public ValidateJoinRequest() {}

ValidateJoinRequest(ClusterState state) {
public ValidateJoinRequest(ClusterState state) {
this.state = state;
}

Expand All @@ -179,6 +179,10 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.state.writeTo(out);
}

public ClusterState getState() {
return state;
}
}

static class ValidateJoinRequestRequestHandler implements TransportRequestHandler<ValidateJoinRequest> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Set;
Expand Down Expand Up @@ -163,7 +162,7 @@ public ZenDiscovery(Settings settings, ThreadPool threadPool, TransportService t
ClusterSettings clusterSettings, UnicastHostsProvider hostsProvider, AllocationService allocationService,
Collection<BiConsumer<DiscoveryNode, ClusterState>> onJoinValidators, GatewayMetaState gatewayMetaState) {
super(settings);
this.onJoinValidators = addBuiltInJoinValidators(onJoinValidators);
this.onJoinValidators = JoinTaskExecutor.addBuiltInJoinValidators(onJoinValidators);
this.masterService = masterService;
this.clusterApplier = clusterApplier;
this.transportService = transportService;
Expand Down Expand Up @@ -235,17 +234,6 @@ public ZenDiscovery(Settings settings, ThreadPool threadPool, TransportService t
}
}

static Collection<BiConsumer<DiscoveryNode,ClusterState>> addBuiltInJoinValidators(
Collection<BiConsumer<DiscoveryNode,ClusterState>> onJoinValidators) {
Collection<BiConsumer<DiscoveryNode, ClusterState>> validators = new ArrayList<>();
validators.add((node, state) -> {
JoinTaskExecutor.ensureNodesCompatibility(node.getVersion(), state.getNodes());
JoinTaskExecutor.ensureIndexCompatibility(node.getVersion(), state.getMetaData());
});
validators.addAll(onJoinValidators);
return Collections.unmodifiableCollection(validators);
}

// protected to allow overriding in tests
protected ZenPing newZenPing(Settings settings, ThreadPool threadPool, TransportService transportService,
UnicastHostsProvider hostsProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.Collections;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -96,7 +97,7 @@ public void setupTest() {
ESAllocationTestCase.createAllocationService(Settings.EMPTY),
new MasterService("local", Settings.EMPTY, threadPool),
() -> new InMemoryPersistedState(0, ClusterState.builder(new ClusterName("cluster")).build()), r -> emptyList(),
new NoOpClusterApplier(), new Random(random().nextLong()));
new NoOpClusterApplier(), Collections.emptyList(), new Random(random().nextLong()));
}

public void testHandlesNonstandardDiscoveryImplementation() throws InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ protected void onSendRequest(long requestId, String action, TransportRequest req
ESAllocationTestCase.createAllocationService(settings),
new MasterService("local", settings, threadPool),
() -> new InMemoryPersistedState(0, ClusterState.builder(new ClusterName(clusterName)).build()), r -> emptyList(),
new NoOpClusterApplier(), new Random(random().nextLong()));
new NoOpClusterApplier(), Collections.emptyList(), new Random(random().nextLong()));
}

public void testHandlesNonstandardDiscoveryImplementation() throws InterruptedException {
Expand Down
Loading