diff --git a/server/src/main/java/org/elasticsearch/action/ActionListener.java b/server/src/main/java/org/elasticsearch/action/ActionListener.java index 957f46e6116dc..3cb205f9847b3 100644 --- a/server/src/main/java/org/elasticsearch/action/ActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/ActionListener.java @@ -136,14 +136,50 @@ static ActionListener wrap(Runnable runnable) { * Creates a listener that wraps another listener, mapping response values via the given mapping function and passing along * exceptions to the delegate. * - * @param listener Listener to delegate to + * Notice that it is considered a bug if the listener's onResponse or onFailure fails. onResponse failures will not call onFailure. + * + * If the function fails, the listener's onFailure handler will be called. The principle is that the mapped listener will handle + * exceptions from the mapping function {@code fn} but it is the responsibility of {@code delegate} to handle its own exceptions + * inside `onResponse` and `onFailure`. + * + * @param delegate Listener to delegate to * @param fn Function to apply to listener response * @param Response type of the new listener * @param Response type of the wrapped listener * @return a listener that maps the received response and then passes it to its delegate listener */ - static ActionListener map(ActionListener listener, CheckedFunction fn) { - return wrap(r -> listener.onResponse(fn.apply(r)), listener::onFailure); + static ActionListener map(ActionListener delegate, CheckedFunction fn) { + return new ActionListener<>() { + @Override + public void onResponse(Response response) { + T mapped; + try { + mapped = fn.apply(response); + } catch (Exception e) { + onFailure(e); + return; + } + try { + delegate.onResponse(mapped); + } catch (RuntimeException e) { + assert false : new AssertionError("map: listener.onResponse failed", e); + throw e; + } + } + + @Override + public void onFailure(Exception e) { + try { + delegate.onFailure(e); + } catch (RuntimeException ex) { + if (ex != e) { + ex.addSuppressed(e); + } + assert false : new AssertionError("map: listener.onFailure failed", ex); + throw ex; + } + } + }; } /** diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index 822557696f7e5..16e8c17688906 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -55,7 +55,6 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; import java.util.function.BiFunction; @@ -306,27 +305,11 @@ public static void registerRequestHandler(TransportService transportService, Sea (in) -> TransportResponse.Empty.INSTANCE); transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, - (request, channel, task) -> { - searchService.executeDfsPhase(request, (SearchShardTask) task, new ActionListener() { - @Override - public void onResponse(SearchPhaseResult searchPhaseResult) { - try { - channel.sendResponse(searchPhaseResult); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - @Override - public void onFailure(Exception e) { - try { - channel.sendResponse(e); - } catch (IOException e1) { - throw new UncheckedIOException(e1); - } - } - }); - }); + (request, channel, task) -> + searchService.executeDfsPhase(request, (SearchShardTask) task, + new ChannelActionListener<>(channel, DFS_ACTION_NAME, request)) + ); + TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new); transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index 4f9b63fb75e6c..3577577dd9e7a 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -234,4 +234,53 @@ public void testCompleteWith() { assertThat(onFailureListener.isDone(), equalTo(true)); assertThat(expectThrows(ExecutionException.class, onFailureListener::get).getCause(), instanceOf(IOException.class)); } + + /** + * Test that map passes the output of the function to its delegate listener and that exceptions in the function are propagated to the + * onFailure handler. Also verify that exceptions from ActionListener.onResponse does not invoke onFailure, since it is the + * responsibility of the ActionListener implementation (the client of the API) to handle exceptions in onResponse and onFailure. + */ + public void testMap() { + AtomicReference exReference = new AtomicReference<>(); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(String s) { + if (s == null) { + throw new IllegalArgumentException("simulate onResponse exception"); + } + } + + @Override + public void onFailure(Exception e) { + exReference.set(e); + if (e instanceof IllegalArgumentException) { + throw (IllegalArgumentException) e; + } + } + }; + ActionListener mapped = ActionListener.map(listener, b -> { + if (b == null) { + return null; + } else if (b) { + throw new IllegalStateException("simulate map function exception"); + } else { + return b.toString(); + } + }); + + AssertionError assertionError = expectThrows(AssertionError.class, () -> mapped.onResponse(null)); + assertThat(assertionError.getCause().getCause(), instanceOf(IllegalArgumentException.class)); + assertNull(exReference.get()); + mapped.onResponse(false); + assertNull(exReference.get()); + mapped.onResponse(true); + assertThat(exReference.get(), instanceOf(IllegalStateException.class)); + + assertionError = expectThrows(AssertionError.class, () -> mapped.onFailure(new IllegalArgumentException())); + assertThat(assertionError.getCause().getCause(), instanceOf(IllegalArgumentException.class)); + assertThat(exReference.get(), instanceOf(IllegalArgumentException.class)); + mapped.onFailure(new IllegalStateException()); + assertThat(exReference.get(), instanceOf(IllegalStateException.class)); + } }