فهرست منبع

AbstractSearchAsyncAction shouldn't try to fork the final execution (#69807)

When search requests are throttled, the last pending execution is a noop.
However this noop operation can someone use the search thread pool
(if the previous execution didn't leave the current thread). This change ensures
that we use the thread pool for actual operations. This is mainly needed for unit tests
that assumes that all async operations are done when onNextPhase is called.

Closes #69730
Jim Ferenczi 4 سال پیش
والد
کامیت
093e78dc9d

+ 5 - 6
server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

@@ -719,18 +719,17 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
     protected abstract SearchPhase getNextPhase(SearchPhaseResults<Result> results, SearchPhaseContext context);
 
     private void executeNext(PendingExecutions pendingExecutions, Thread originalThread) {
-        executeNext(pendingExecutions == null ? null : pendingExecutions::finishAndRunNext, originalThread);
+        executeNext(pendingExecutions == null ? null : pendingExecutions.finishAndGetNext(), originalThread);
     }
 
     void executeNext(Runnable runnable, Thread originalThread) {
-        if (throttleConcurrentRequests) {
+        if (runnable != null) {
+            assert throttleConcurrentRequests;
             if (originalThread == Thread.currentThread()) {
                 fork(runnable);
             } else {
                 runnable.run();
             }
-        } else {
-            assert runnable == null;
         }
     }
 
@@ -744,12 +743,12 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
             this.permits = permits;
         }
 
-        void finishAndRunNext() {
+        Runnable finishAndGetNext() {
             synchronized (this) {
                 permitsTaken--;
                 assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken;
             }
-            tryRun(null);
+            return tryQueue(null);
         }
 
         void tryRun(Runnable runnable) {

+ 16 - 38
server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java

@@ -77,7 +77,7 @@ public class SearchAsyncActionTests extends ESTestCase {
                 numSkipped++;
             }
         }
-        CountDownLatch latch = new CountDownLatch(numShards - numSkipped);
+        CountDownLatch latch = new CountDownLatch(1);
         AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false);
 
         SearchTransportService transportService = new SearchTransportService(null, null, null);
@@ -132,15 +132,10 @@ public class SearchAsyncActionTests extends ESTestCase {
                         @Override
                         public void run() {
                             assertTrue(searchPhaseDidRun.compareAndSet(false, true));
+                            latch.countDown();
                         }
                     };
                 }
-
-                @Override
-                protected void executeNext(Runnable runnable, Thread originalThread) {
-                    super.executeNext(runnable, originalThread);
-                    latch.countDown();
-                }
             };
         asyncAction.start();
         latch.await();
@@ -159,7 +154,6 @@ public class SearchAsyncActionTests extends ESTestCase {
         request.setMaxConcurrentShardRequests(numConcurrent);
         boolean doReplicas = randomBoolean();
         int numShards = randomIntBetween(5, 10);
-        int numShardAttempts = numShards;
         Boolean[] shardFailures = new Boolean[numShards];
         // at least one response otherwise the entire request fails
         shardFailures[randomIntBetween(0, shardFailures.length - 1)] = false;
@@ -167,12 +161,9 @@ public class SearchAsyncActionTests extends ESTestCase {
             if (shardFailures[i] == null) {
                 boolean failure = randomBoolean();
                 shardFailures[i] = failure;
-                if (failure && doReplicas) {
-                    numShardAttempts++;
-                }
             }
         }
-        CountDownLatch latch = new CountDownLatch(numShardAttempts);
+        CountDownLatch latch = new CountDownLatch(1);
         AtomicBoolean searchPhaseDidRun = new AtomicBoolean(false);
         ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
             (e) -> { throw new AssertionError("unexpected", e);});
@@ -244,15 +235,10 @@ public class SearchAsyncActionTests extends ESTestCase {
                         @Override
                         public void run() {
                             assertTrue(searchPhaseDidRun.compareAndSet(false, true));
+                            latch.countDown();
                         }
                     };
                 }
-
-                @Override
-                protected void executeNext(Runnable runnable, Thread originalThread) {
-                    super.executeNext(runnable, originalThread);
-                    latch.countDown();
-                }
             };
         asyncAction.start();
         assertEquals(numConcurrent, numRequests.get());
@@ -293,7 +279,8 @@ public class SearchAsyncActionTests extends ESTestCase {
         lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
         Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
         ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
-        final CountDownLatch latch = new CountDownLatch(numShards);
+        final CountDownLatch latch = new CountDownLatch(1);
+        final AtomicBoolean latchTriggered = new AtomicBoolean();
         AbstractSearchAsyncAction<TestSearchPhaseResult> asyncAction =
                 new AbstractSearchAsyncAction<TestSearchPhaseResult>(
                         "test",
@@ -343,15 +330,13 @@ public class SearchAsyncActionTests extends ESTestCase {
                             sendReleaseSearchContext(result.getContextId(), new MockConnection(result.node), OriginalIndices.NONE);
                         }
                         responseListener.onResponse(response);
+                        if (latchTriggered.compareAndSet(false, true) == false) {
+                            throw new AssertionError("latch triggered twice");
+                        }
+                        latch.countDown();
                     }
                 };
             }
-
-            @Override
-            protected void executeNext(Runnable runnable, Thread originalThread) {
-                super.executeNext(runnable, originalThread);
-                latch.countDown();
-            }
         };
         asyncAction.start();
         latch.await();
@@ -364,7 +349,8 @@ public class SearchAsyncActionTests extends ESTestCase {
         } else {
             assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty());
         }
-        executor.shutdown();
+        final List<Runnable> runnables = executor.shutdownNow();
+        assertThat(runnables, equalTo(Collections.emptyList()));
     }
 
     public void testFanOutAndFail() throws InterruptedException {
@@ -469,7 +455,8 @@ public class SearchAsyncActionTests extends ESTestCase {
         } else {
             assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty());
         }
-        executor.shutdown();
+        final List<Runnable> runnables = executor.shutdownNow();
+        assertThat(runnables, equalTo(Collections.emptyList()));
     }
 
     public void testAllowPartialResults() throws InterruptedException {
@@ -489,11 +476,7 @@ public class SearchAsyncActionTests extends ESTestCase {
         GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
             new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
             numShards, true, primaryNode, replicaNode);
-        int numShardAttempts = 0;
-        for (SearchShardIterator it : shardsIter) {
-            numShardAttempts += it.remaining();
-        }
-        CountDownLatch latch = new CountDownLatch(numShardAttempts);
+        CountDownLatch latch = new CountDownLatch(1);
 
         SearchTransportService transportService = new SearchTransportService(null, null, null);
         Map<String, Transport.Connection> lookup = new HashMap<>();
@@ -550,15 +533,10 @@ public class SearchAsyncActionTests extends ESTestCase {
                         @Override
                         public void run() {
                             assertTrue(searchPhaseDidRun.compareAndSet(false, true));
+                            latch.countDown();
                         }
                     };
                 }
-
-                @Override
-                protected void executeNext(Runnable runnable, Thread originalThread) {
-                    super.executeNext(runnable, originalThread);
-                    latch.countDown();
-                }
             };
         asyncAction.start();
         latch.await();