Browse Source

Fix ActionListener.map exception handling (#50886)

ActionListener.map would call listener.onFailure for exceptions from
listener.onResponse, but this means we could double trigger some
listeners which is generally unexpected. Instead, we should assume that
a listener's onResponse (and onFailure) implementation is responsible
for its own exception handling.
Henning Andersen 5 years ago
parent
commit
73b556484b

+ 39 - 3
server/src/main/java/org/elasticsearch/action/ActionListener.java

@@ -136,14 +136,50 @@ public interface ActionListener<Response> {
      * Creates a listener that wraps another listener, mapping response values via the given mapping function and passing along
      * exceptions to the delegate.
      *
-     * @param listener Listener to delegate to
+     * Notice that it is considered a bug if the listener's onResponse or onFailure fails. onResponse failures will not call onFailure.
+     *
+     * If the function fails, the listener's onFailure handler will be called. The principle is that the mapped listener will handle
+     * exceptions from the mapping function {@code fn} but it is the responsibility of {@code delegate} to handle its own exceptions
+     * inside `onResponse` and `onFailure`.
+     *
+     * @param delegate Listener to delegate to
      * @param fn Function to apply to listener response
      * @param <Response> Response type of the new listener
      * @param <T> Response type of the wrapped listener
      * @return a listener that maps the received response and then passes it to its delegate listener
      */
-    static <T, Response> ActionListener<Response> map(ActionListener<T> listener, CheckedFunction<Response, T, Exception> fn) {
-        return wrap(r -> listener.onResponse(fn.apply(r)), listener::onFailure);
+    static <T, Response> ActionListener<Response> map(ActionListener<T> delegate, CheckedFunction<Response, T, Exception> fn) {
+        return new ActionListener<>() {
+            @Override
+            public void onResponse(Response response) {
+                T mapped;
+                try {
+                    mapped = fn.apply(response);
+                } catch (Exception e) {
+                    onFailure(e);
+                    return;
+                }
+                try {
+                    delegate.onResponse(mapped);
+                } catch (RuntimeException e) {
+                    assert false : new AssertionError("map: listener.onResponse failed", e);
+                    throw e;
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                try {
+                    delegate.onFailure(e);
+                } catch (RuntimeException ex) {
+                    if (ex != e) {
+                        ex.addSuppressed(e);
+                    }
+                    assert false : new AssertionError("map: listener.onFailure failed", ex);
+                    throw ex;
+                }
+            }
+        };
     }
 
     /**

+ 5 - 22
server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

@@ -55,7 +55,6 @@ import org.elasticsearch.transport.TransportResponse;
 import org.elasticsearch.transport.TransportService;
 
 import java.io.IOException;
-import java.io.UncheckedIOException;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.function.BiFunction;
@@ -306,27 +305,11 @@ public class SearchTransportService {
             (in) -> TransportResponse.Empty.INSTANCE);
 
         transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
-            (request, channel, task) -> {
-                searchService.executeDfsPhase(request, (SearchShardTask) task, new ActionListener<SearchPhaseResult>() {
-                    @Override
-                    public void onResponse(SearchPhaseResult searchPhaseResult) {
-                        try {
-                            channel.sendResponse(searchPhaseResult);
-                        } catch (IOException e) {
-                            throw new UncheckedIOException(e);
-                        }
-                    }
-
-                    @Override
-                    public void onFailure(Exception e) {
-                        try {
-                            channel.sendResponse(e);
-                        } catch (IOException e1) {
-                            throw new UncheckedIOException(e1);
-                        }
-                    }
-                });
-            });
+            (request, channel, task) ->
+                searchService.executeDfsPhase(request, (SearchShardTask) task,
+                    new ChannelActionListener<>(channel, DFS_ACTION_NAME, request))
+        );
+
         TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new);
 
         transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,

+ 49 - 0
server/src/test/java/org/elasticsearch/action/ActionListenerTests.java

@@ -234,4 +234,53 @@ public class ActionListenerTests extends ESTestCase {
         assertThat(onFailureListener.isDone(), equalTo(true));
         assertThat(expectThrows(ExecutionException.class, onFailureListener::get).getCause(), instanceOf(IOException.class));
     }
+
+    /**
+     * Test that map passes the output of the function to its delegate listener and that exceptions in the function are propagated to the
+     * onFailure handler. Also verify that exceptions from ActionListener.onResponse does not invoke onFailure, since it is the
+     * responsibility of the ActionListener implementation (the client of the API) to handle exceptions in onResponse and onFailure.
+     */
+    public void testMap() {
+        AtomicReference<Exception> exReference = new AtomicReference<>();
+
+        ActionListener<String> listener = new ActionListener<>() {
+            @Override
+            public void onResponse(String s) {
+                if (s == null) {
+                    throw new IllegalArgumentException("simulate onResponse exception");
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                exReference.set(e);
+                if (e instanceof IllegalArgumentException) {
+                    throw (IllegalArgumentException) e;
+                }
+            }
+        };
+        ActionListener<Boolean> mapped = ActionListener.map(listener, b -> {
+            if (b == null) {
+                return null;
+            } else if (b) {
+                throw new IllegalStateException("simulate map function exception");
+            } else {
+                return b.toString();
+            }
+        });
+
+        AssertionError assertionError = expectThrows(AssertionError.class, () -> mapped.onResponse(null));
+        assertThat(assertionError.getCause().getCause(), instanceOf(IllegalArgumentException.class));
+        assertNull(exReference.get());
+        mapped.onResponse(false);
+        assertNull(exReference.get());
+        mapped.onResponse(true);
+        assertThat(exReference.get(), instanceOf(IllegalStateException.class));
+
+        assertionError = expectThrows(AssertionError.class, () -> mapped.onFailure(new IllegalArgumentException()));
+        assertThat(assertionError.getCause().getCause(), instanceOf(IllegalArgumentException.class));
+        assertThat(exReference.get(), instanceOf(IllegalArgumentException.class));
+        mapped.onFailure(new IllegalStateException());
+        assertThat(exReference.get(), instanceOf(IllegalStateException.class));
+    }
 }