Skip to content

Commit

Permalink
[ML][Data Frame] only complete task after state persistence (#43230)
Browse files Browse the repository at this point in the history
* [ML][Data Frame] only complete task after state persistence

There is a race condition where the task could be completed, but there
is still a pending document write. This change moves
the task cancellation into the actionlistener of the state persistence.

intermediate commit

intermediate commit

* removing unused import

* removing unused const

* refreshing internal index after waiting for task to complete

* adjusting test data generation
  • Loading branch information
benwtrent authored Jun 17, 2019
1 parent 37e4008 commit 551353d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@

abstract class DataFrameIntegTestCase extends ESRestTestCase {

protected static final String REVIEWS_INDEX_NAME = "data_frame_reviews";

private Map<String, DataFrameTransformConfig> transformConfigs = new HashMap<>();

protected void cleanUp() throws IOException {
Expand Down Expand Up @@ -213,8 +211,7 @@ protected DataFrameTransformConfig createTransformConfig(String id,
.build();
}

protected void createReviewsIndex() throws Exception {
final int numDocs = 1000;
protected void createReviewsIndex(String indexName, int numDocs) throws Exception {
RestHighLevelClient restClient = new TestRestHighLevelClient();

// create mapping
Expand All @@ -241,12 +238,12 @@ protected void createReviewsIndex() throws Exception {
}
builder.endObject();
CreateIndexResponse response =
restClient.indices().create(new CreateIndexRequest(REVIEWS_INDEX_NAME).mapping(builder), RequestOptions.DEFAULT);
restClient.indices().create(new CreateIndexRequest(indexName).mapping(builder), RequestOptions.DEFAULT);
assertThat(response.isAcknowledged(), is(true));
}

// create index
BulkRequest bulk = new BulkRequest(REVIEWS_INDEX_NAME);
BulkRequest bulk = new BulkRequest(indexName);
int day = 10;
for (int i = 0; i < numDocs; i++) {
long user = i % 28;
Expand All @@ -256,7 +253,7 @@ protected void createReviewsIndex() throws Exception {
int min = 10 + (i % 49);
int sec = 10 + (i % 49);

String date_string = "2017-01-" + day + "T" + hour + ":" + min + ":" + sec + "Z";
String date_string = "2017-01-" + (day < 10 ? "0" + day : day) + "T" + hour + ":" + min + ":" + sec + "Z";

StringBuilder sourceBuilder = new StringBuilder();
sourceBuilder.append("{\"user_id\":\"")
Expand All @@ -277,13 +274,13 @@ protected void createReviewsIndex() throws Exception {
if (i % 50 == 0) {
BulkResponse response = restClient.bulk(bulk, RequestOptions.DEFAULT);
assertThat(response.buildFailureMessage(), response.hasFailures(), is(false));
bulk = new BulkRequest(REVIEWS_INDEX_NAME);
day += 1;
bulk = new BulkRequest(indexName);
day = (day + 1) % 28;
}
}
BulkResponse response = restClient.bulk(bulk, RequestOptions.DEFAULT);
assertThat(response.buildFailureMessage(), response.hasFailures(), is(false));
restClient.indices().refresh(new RefreshRequest(REVIEWS_INDEX_NAME), RequestOptions.DEFAULT);
restClient.indices().refresh(new RefreshRequest(indexName), RequestOptions.DEFAULT);
}

protected Map<String, Object> toLazy(ToXContent parsedObject) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public void cleanTransforms() throws IOException {
}

public void testDataFrameTransformCrud() throws Exception {
createReviewsIndex();
String indexName = "basic-crud-reviews";
createReviewsIndex(indexName, 100);

Map<String, SingleGroupSource> groups = new HashMap<>();
groups.put("by-day", createDateHistogramGroupSourceWithCalendarInterval("timestamp", DateHistogramInterval.DAY, null, null));
Expand All @@ -45,7 +46,7 @@ public void testDataFrameTransformCrud() throws Exception {
groups,
aggs,
"reviews-by-user-business-day",
REVIEWS_INDEX_NAME);
indexName);

assertTrue(putDataFrameTransform(config, RequestOptions.DEFAULT).isAcknowledged());
assertTrue(startDataFrameTransform(config.getId(), RequestOptions.DEFAULT).isAcknowledged());
Expand All @@ -56,7 +57,8 @@ public void testDataFrameTransformCrud() throws Exception {
assertBusy(() ->
assertThat(getDataFrameTransformStats(config.getId()).getTransformsStateAndStats().get(0).getTransformState().getIndexerState(),
equalTo(IndexerState.STOPPED)));
stopDataFrameTransform(config.getId());
deleteDataFrameTransform(config.getId());
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
*/
package org.elasticsearch.xpack.dataframe.action;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
Expand All @@ -26,6 +29,7 @@
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.dataframe.action.StopDataFrameTransformAction;
import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformTaskState;
import org.elasticsearch.xpack.dataframe.persistence.DataFrameInternalIndex;
import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager;
import org.elasticsearch.xpack.dataframe.transforms.DataFrameTransformTask;

Expand All @@ -38,20 +42,25 @@ public class TransportStopDataFrameTransformAction extends
TransportTasksAction<DataFrameTransformTask, StopDataFrameTransformAction.Request,
StopDataFrameTransformAction.Response, StopDataFrameTransformAction.Response> {

private static final Logger logger = LogManager.getLogger(TransportStopDataFrameTransformAction.class);

private final ThreadPool threadPool;
private final DataFrameTransformsConfigManager dataFrameTransformsConfigManager;
private final PersistentTasksService persistentTasksService;
private final Client client;

@Inject
public TransportStopDataFrameTransformAction(TransportService transportService, ActionFilters actionFilters,
ClusterService clusterService, ThreadPool threadPool,
PersistentTasksService persistentTasksService,
DataFrameTransformsConfigManager dataFrameTransformsConfigManager) {
DataFrameTransformsConfigManager dataFrameTransformsConfigManager,
Client client) {
super(StopDataFrameTransformAction.NAME, clusterService, transportService, actionFilters, StopDataFrameTransformAction.Request::new,
StopDataFrameTransformAction.Response::new, StopDataFrameTransformAction.Response::new, ThreadPool.Names.SAME);
this.threadPool = threadPool;
this.dataFrameTransformsConfigManager = dataFrameTransformsConfigManager;
this.persistentTasksService = persistentTasksService;
this.client = client;
}

@Override
Expand Down Expand Up @@ -132,12 +141,26 @@ protected StopDataFrameTransformAction.Response newResponse(StopDataFrameTransfo
waitForStopListener(StopDataFrameTransformAction.Request request,
ActionListener<StopDataFrameTransformAction.Response> listener) {

ActionListener<StopDataFrameTransformAction.Response> onStopListener = ActionListener.wrap(
waitResponse ->
client.admin()
.indices()
.prepareRefresh(DataFrameInternalIndex.INDEX_NAME)
.execute(ActionListener.wrap(
r -> listener.onResponse(waitResponse),
e -> {
logger.info("Failed to refresh internal index after delete", e);
listener.onResponse(waitResponse);
})
),
listener::onFailure
);
return ActionListener.wrap(
response -> {
// Wait until the persistent task is stopped
// Switch over to Generic threadpool so we don't block the network thread
threadPool.generic().execute(() ->
waitForDataFrameStopped(request.getExpandedIds(), request.getTimeout(), listener));
waitForDataFrameStopped(request.getExpandedIds(), request.getTimeout(), onStopListener));
},
listener::onFailure
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public synchronized void stop() {

IndexerState state = getIndexer().stop();
if (state == IndexerState.STOPPED) {
//doSaveState calls `onStop` when the task state is `STOPPED`
getIndexer().onStop();
getIndexer().doSaveState(state, getIndexer().getPosition(), () -> {});
}
}
Expand Down Expand Up @@ -610,7 +610,7 @@ protected void doSaveState(IndexerState indexerState, Map<String, Object> positi
r -> {
// for auto stop shutdown the task
if (state.getTaskState().equals(DataFrameTransformTaskState.STOPPED)) {
onStop();
transformTask.shutdown();
}
next.run();
},
Expand All @@ -620,7 +620,7 @@ protected void doSaveState(IndexerState indexerState, Map<String, Object> positi
"Failure updating stats of transform: " + statsExc.getMessage());
// for auto stop shutdown the task
if (state.getTaskState().equals(DataFrameTransformTaskState.STOPPED)) {
onStop();
transformTask.shutdown();
}
next.run();
}
Expand Down Expand Up @@ -666,7 +666,6 @@ protected void onFinish(ActionListener<Void> listener) {
protected void onStop() {
auditor.info(transformConfig.getId(), "Data frame transform has stopped.");
logger.info("Data frame transform [{}] indexer has stopped", transformConfig.getId());
transformTask.shutdown();
}

@Override
Expand Down

0 comments on commit 551353d

Please sign in to comment.