Skip to content

Commit

Permalink
Stop processing search requests when _msearch is canceled
Browse files Browse the repository at this point in the history
Prior to this fix, the _msearch API would keep running search requests
even after being canceled. With this change, we explicitly check if
the task has been canceled before kicking off subsequent requests.

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed Jan 10, 2025
1 parent 5afb92f commit b88b6d5
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.TaskId;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -193,7 +195,7 @@ private void handleResponse(final int responseSlot, final MultiSearchResponse.It
if (responseCounter.decrementAndGet() == 0) {
assert requests.isEmpty();
finish();
} else {
} else if (isCancelled(request.request.getParentTask()) == false) {
if (thread == Thread.currentThread()) {
// we are on the same thread, we need to fork to another thread to avoid recursive stack overflow on a single thread
threadPool.generic()
Expand All @@ -220,6 +222,14 @@ private long buildTookInMillis() {
});
}

private boolean isCancelled(TaskId taskId) {
if (taskId.isSet()) {
CancellableTask task = taskManager.getCancellableTask(taskId.getId());
return task != null && task.isCancelled();
}
return false;
}

/**
* Slots a search request
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskListener;
import org.opensearch.tasks.TaskManager;
import org.opensearch.telemetry.tracing.noop.NoopTracer;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -289,4 +291,111 @@ public void testDefaultMaxConcurrentSearches() {
assertThat(result, equalTo(1));
}

public void testCancellation() {
// Initialize dependencies of TransportMultiSearchAction
Settings settings = Settings.builder().put("node.name", TransportMultiSearchActionTests.class.getSimpleName()).build();
ActionFilters actionFilters = mock(ActionFilters.class);
when(actionFilters.filters()).thenReturn(new ActionFilter[0]);
ThreadPool threadPool = new ThreadPool(settings);
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
null,
Collections.emptySet(),
NoopTracer.INSTANCE
) {
@Override
public TaskManager getTaskManager() {
return taskManager;
}
};
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("test")).build());

// Keep track of the number of concurrent searches started by multi search api,
// and if there are more searches than is allowed create an error and remember that.
int maxAllowedConcurrentSearches = 1; // Allow 1 search at a time.
AtomicInteger counter = new AtomicInteger();
AtomicReference<AssertionError> errorHolder = new AtomicReference<>();
// randomize whether or not requests are executed asynchronously
ExecutorService executorService = threadPool.executor(ThreadPool.Names.GENERIC);
final Set<SearchRequest> requests = Collections.newSetFromMap(Collections.synchronizedMap(new IdentityHashMap<>()));
NodeClient client = new NodeClient(settings, threadPool) {
@Override
public void search(final SearchRequest request, final ActionListener<SearchResponse> listener) {
try {
Thread.sleep(10);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
requests.add(request);
executorService.execute(() -> {
counter.decrementAndGet();
listener.onResponse(
new SearchResponse(
InternalSearchResponse.empty(),
null,
0,
0,
0,
0L,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
)
);
});
}

@Override
public String getLocalNodeId() {
return "local_node_id";
}
};

TransportMultiSearchAction action = new TransportMultiSearchAction(
threadPool,
actionFilters,
transportService,
clusterService,
10,
System::nanoTime,
client
);

// Execute the multi search api and fail if we find an error after executing:
try {
/*
* Allow for a large number of search requests in a single batch as previous implementations could stack overflow if the number
* of requests in a single batch was large
*/
int numSearchRequests = scaledRandomIntBetween(1024, 8192);
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.maxConcurrentSearchRequests(maxAllowedConcurrentSearches);
for (int i = 0; i < numSearchRequests; i++) {
multiSearchRequest.add(new SearchRequest());
}
MultiSearchResponse[] responses = new MultiSearchResponse[1];
CancellableTask parentTask = (CancellableTask) action.execute(multiSearchRequest, new TaskListener<MultiSearchResponse>() {
@Override
public void onResponse(Task task, MultiSearchResponse items) {
responses[0] = items;
System.out.println("Got response: " + items);
}

@Override
public void onFailure(Task task, Exception e) {
e.printStackTrace();
}
});
parentTask.cancel("Giving up");

MultiSearchResponse response = responses[0];
assertNull(response);
} finally {
assertTrue(OpenSearchTestCase.terminate(threadPool));
}
}
}

0 comments on commit b88b6d5

Please sign in to comment.