Browse Source

Add runAfter and notifyOnce wrapper to ActionListener (#37331)

Relates #37291
Nhat Nguyen 6 years ago
parent
commit
360c430ad7

+ 44 - 0
server/src/main/java/org/elasticsearch/action/ActionListener.java

@@ -136,4 +136,48 @@ public interface ActionListener<Response> {
         }
         ExceptionsHelper.maybeThrowRuntimeAndSuppress(exceptionList);
     }
+
+    /**
+     * Wraps a given listener and returns a new listener which executes the provided {@code runAfter}
+     * callback when the listener is notified via either {@code #onResponse} or {@code #onFailure}.
+     */
+    static <Response> ActionListener<Response> runAfter(ActionListener<Response> delegate, Runnable runAfter) {
+        return new ActionListener<Response>() {
+            @Override
+            public void onResponse(Response response) {
+                try {
+                    delegate.onResponse(response);
+                } finally {
+                    runAfter.run();
+                }
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                try {
+                    delegate.onFailure(e);
+                } finally {
+                    runAfter.run();
+                }
+            }
+        };
+    }
+
+    /**
+     * Wraps a given listener and returns a new listener which makes sure {@link #onResponse(Object)}
+     * and {@link #onFailure(Exception)} of the provided listener will be called at most once.
+     */
+    static <Response> ActionListener<Response> notifyOnce(ActionListener<Response> delegate) {
+        return new NotifyOnceListener<Response>() {
+            @Override
+            protected void innerOnResponse(Response response) {
+                delegate.onResponse(response);
+            }
+
+            @Override
+            protected void innerOnFailure(Exception e) {
+                delegate.onFailure(e);
+            }
+        };
+    }
 }

+ 53 - 0
server/src/test/java/org/elasticsearch/action/ActionListenerTests.java

@@ -23,9 +23,12 @@ import org.elasticsearch.test.ESTestCase;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class ActionListenerTests extends ESTestCase {
 
     public void testWrap() {
@@ -148,4 +151,54 @@ public class ActionListenerTests extends ESTestCase {
             assertEquals("listener index " + i, "booom", excList.get(i).get().getMessage());
         }
     }
+
+    public void testRunAfter() {
+        {
+            AtomicBoolean afterSuccess = new AtomicBoolean();
+            ActionListener<Object> listener = ActionListener.runAfter(ActionListener.wrap(r -> {}, e -> {}), () -> afterSuccess.set(true));
+            listener.onResponse(null);
+            assertThat(afterSuccess.get(), equalTo(true));
+        }
+        {
+            AtomicBoolean afterFailure = new AtomicBoolean();
+            ActionListener<Object> listener = ActionListener.runAfter(ActionListener.wrap(r -> {}, e -> {}), () -> afterFailure.set(true));
+            listener.onFailure(null);
+            assertThat(afterFailure.get(), equalTo(true));
+        }
+    }
+
+    public void testNotifyOnce() {
+        AtomicInteger onResponseTimes = new AtomicInteger();
+        AtomicInteger onFailureTimes = new AtomicInteger();
+        ActionListener<Object> listener = ActionListener.notifyOnce(new ActionListener<Object>() {
+            @Override
+            public void onResponse(Object o) {
+                onResponseTimes.getAndIncrement();
+            }
+            @Override
+            public void onFailure(Exception e) {
+                onFailureTimes.getAndIncrement();
+            }
+        });
+        boolean success = randomBoolean();
+        if (success) {
+            listener.onResponse(null);
+        } else {
+            listener.onFailure(new RuntimeException("test"));
+        }
+        for (int iters = between(0, 10), i = 0; i < iters; i++) {
+            if (randomBoolean()) {
+                listener.onResponse(null);
+            } else {
+                listener.onFailure(new RuntimeException("test"));
+            }
+        }
+        if (success) {
+            assertThat(onResponseTimes.get(), equalTo(1));
+            assertThat(onFailureTimes.get(), equalTo(0));
+        } else {
+            assertThat(onResponseTimes.get(), equalTo(0));
+            assertThat(onFailureTimes.get(), equalTo(1));
+        }
+    }
 }