浏览代码

Fix concurrency issues in testCancelFailedSearchWhenPartialResultDisallowed (#99689)

Fix concurrency issues in the test testCancelFailedSearchWhenPartialResultDisallowed while additionally ensuring that cancelation propagates to the other shards.
John Verwolf 2 年之前
父节点
当前提交
7246624c6c

+ 58 - 44
server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java

@@ -16,7 +16,9 @@ import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchPhaseExecutionException;
 import org.elasticsearch.action.search.SearchPhaseExecutionException;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.search.SearchScrollAction;
 import org.elasticsearch.action.search.SearchScrollAction;
+import org.elasticsearch.action.search.SearchShardTask;
 import org.elasticsearch.action.search.SearchTask;
 import org.elasticsearch.action.search.SearchTask;
+import org.elasticsearch.action.search.SearchTransportService;
 import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
@@ -25,6 +27,7 @@ import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
 import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
+import org.elasticsearch.search.internal.ReaderContext;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.test.AbstractSearchCancellationTestCase;
 import org.elasticsearch.test.AbstractSearchCancellationTestCase;
@@ -221,43 +224,12 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
         }
         }
     }
     }
 
 
-    /**
-     * The test `testCancelFailedSearchWhenPartialResultDisallowed` usually fails when concurrency is enabled unless
-     * the `cancelledLatch.await()` section is commented out. However, this approach seems prone to race conditions.
-     * Further investigation is needed to determine if this test just needs to be revised, or rather, if it is
-     * detecting a deeper issue.  For now, we will disable concurrency here.
-     */
-    @Override
-    protected boolean enableConcurrentSearch() {
-        return false;
-    }
-
     public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception {
     public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception {
-        final List<ScriptedBlockPlugin> plugins = initBlockFactory();
         int numberOfShards = between(2, 5);
         int numberOfShards = between(2, 5);
-        AtomicBoolean failed = new AtomicBoolean();
-        CountDownLatch queryLatch = new CountDownLatch(1);
-        CountDownLatch cancelledLatch = new CountDownLatch(1);
-        for (ScriptedBlockPlugin plugin : plugins) {
-            plugin.disableBlock();
-            plugin.setBeforeExecution(() -> {
-                try {
-                    queryLatch.await(); // block the query until we get a search task
-                } catch (InterruptedException e) {
-                    throw new AssertionError(e);
-                }
-                if (failed.compareAndSet(false, true)) {
-                    throw new IllegalStateException("simulated");
-                }
-                try {
-                    cancelledLatch.await(); // block the query until the search is cancelled
-                } catch (InterruptedException e) {
-                    throw new AssertionError(e);
-                }
-            });
-        }
         createIndex("test", numberOfShards, 0);
         createIndex("test", numberOfShards, 0);
         indexTestData();
         indexTestData();
+
+        // Define (but don't run) the search request, expecting a partial shard failure. We will run it later.
         Thread searchThread = new Thread(() -> {
         Thread searchThread = new Thread(() -> {
             SearchPhaseExecutionException e = expectThrows(
             SearchPhaseExecutionException e = expectThrows(
                 SearchPhaseExecutionException.class,
                 SearchPhaseExecutionException.class,
@@ -270,29 +242,59 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
             );
             );
             assertThat(e.getMessage(), containsString("Partial shards failure"));
             assertThat(e.getMessage(), containsString("Partial shards failure"));
         });
         });
+
+        // When the search request executes, block all shards except 1.
+        final List<SearchShardBlockingPlugin> searchShardBlockingPlugins = initSearchShardBlockingPlugin();
+        AtomicBoolean letOneShardProceed = new AtomicBoolean();
+        CountDownLatch shardTaskLatch = new CountDownLatch(1);
+        for (SearchShardBlockingPlugin plugin : searchShardBlockingPlugins) {
+            plugin.setRunOnNewReaderContext((ReaderContext c) -> {
+                if (letOneShardProceed.compareAndSet(false, true)) {
+                    // Let one shard continue.
+                } else {
+                    try {
+                        shardTaskLatch.await(); // Bock the other shards.
+                    } catch (InterruptedException e) {
+                        throw new AssertionError(e);
+                    }
+                }
+            });
+        }
+
+        // For the shard that was allowed to proceed, have a single query-execution thread throw an exception.
+        final List<ScriptedBlockPlugin> plugins = initBlockFactory();
+        AtomicBoolean oneThreadWillError = new AtomicBoolean();
+        for (ScriptedBlockPlugin plugin : plugins) {
+            plugin.disableBlock();
+            plugin.setBeforeExecution(() -> {
+                if (oneThreadWillError.compareAndSet(false, true)) {
+                    throw new IllegalStateException("This will cancel the ContextIndexSearcher.search task");
+                }
+            });
+        }
+
+        // Now run the search request.
         searchThread.start();
         searchThread.start();
+
         try {
         try {
-            assertBusy(() -> assertThat(getSearchTasks(), hasSize(1)));
-            queryLatch.countDown();
             assertBusy(() -> {
             assertBusy(() -> {
-                final List<SearchTask> searchTasks = getSearchTasks();
-                // The search request can complete before the "cancelledLatch" is latched if the second shard request is sent
-                // after the request was cancelled (i.e., the child task is not allowed to start after the parent was cancelled).
-                if (searchTasks.isEmpty() == false) {
-                    assertThat(searchTasks, hasSize(1));
-                    assertTrue(searchTasks.get(0).isCancelled());
+                final List<SearchTask> coordinatorSearchTask = getCoordinatorSearchTasks();
+                assertThat("The Coordinator should have one SearchTask.", coordinatorSearchTask, hasSize(1));
+                assertTrue("The SearchTask should be cancelled.", coordinatorSearchTask.get(0).isCancelled());
+                for (var shardQueryTask : getShardQueryTasks()) {
+                    assertTrue("All SearchShardTasks should then be cancelled", shardQueryTask.isCancelled());
                 }
                 }
             }, 30, TimeUnit.SECONDS);
             }, 30, TimeUnit.SECONDS);
+            shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude.
         } finally {
         } finally {
+            searchThread.join();
             for (ScriptedBlockPlugin plugin : plugins) {
             for (ScriptedBlockPlugin plugin : plugins) {
                 plugin.setBeforeExecution(() -> {});
                 plugin.setBeforeExecution(() -> {});
             }
             }
-            cancelledLatch.countDown();
-            searchThread.join();
         }
         }
     }
     }
 
 
-    List<SearchTask> getSearchTasks() {
+    List<SearchTask> getCoordinatorSearchTasks() {
         List<SearchTask> tasks = new ArrayList<>();
         List<SearchTask> tasks = new ArrayList<>();
         for (String nodeName : internalCluster().getNodeNames()) {
         for (String nodeName : internalCluster().getNodeNames()) {
             TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
             TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
@@ -305,4 +307,16 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
         return tasks;
         return tasks;
     }
     }
 
 
+    List<SearchShardTask> getShardQueryTasks() {
+        List<SearchShardTask> tasks = new ArrayList<>();
+        for (String nodeName : internalCluster().getNodeNames()) {
+            TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
+            for (Task task : transportService.getTaskManager().getCancellableTasks().values()) {
+                if (task.getAction().equals(SearchTransportService.QUERY_ACTION_NAME)) {
+                    tasks.add((SearchShardTask) task);
+                }
+            }
+        }
+        return tasks;
+    }
 }
 }

+ 32 - 2
test/framework/src/main/java/org/elasticsearch/test/AbstractSearchCancellationTestCase.java

@@ -19,23 +19,26 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.search.ShardSearchFailure;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexModule;
+import org.elasticsearch.index.shard.SearchOperationListener;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.PluginsService;
 import org.elasticsearch.plugins.PluginsService;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.script.MockScriptPlugin;
 import org.elasticsearch.script.MockScriptPlugin;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.SearchService;
+import org.elasticsearch.search.internal.ReaderContext;
 import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
 import org.elasticsearch.search.lookup.LeafStoredFieldsLookup;
 import org.elasticsearch.tasks.TaskInfo;
 import org.elasticsearch.tasks.TaskInfo;
 import org.junit.BeforeClass;
 import org.junit.BeforeClass;
 
 
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Function;
 
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
@@ -54,7 +57,7 @@ public class AbstractSearchCancellationTestCase extends ESIntegTestCase {
 
 
     @Override
     @Override
     protected Collection<Class<? extends Plugin>> nodePlugins() {
     protected Collection<Class<? extends Plugin>> nodePlugins() {
-        return Collections.singleton(ScriptedBlockPlugin.class);
+        return List.of(ScriptedBlockPlugin.class, SearchShardBlockingPlugin.class);
     }
     }
 
 
     @Override
     @Override
@@ -256,4 +259,31 @@ public class AbstractSearchCancellationTestCase extends ESIntegTestCase {
             return 1;
             return 1;
         }
         }
     }
     }
+
+    protected List<SearchShardBlockingPlugin> initSearchShardBlockingPlugin() {
+        List<SearchShardBlockingPlugin> plugins = new ArrayList<>();
+        for (PluginsService pluginsService : internalCluster().getInstances(PluginsService.class)) {
+            plugins.addAll(pluginsService.filterPlugins(SearchShardBlockingPlugin.class));
+        }
+        return plugins;
+    }
+
+    public static class SearchShardBlockingPlugin extends Plugin {
+        private final AtomicReference<Consumer<ReaderContext>> runOnNewReaderContext = new AtomicReference<>();
+
+        public void setRunOnNewReaderContext(Consumer<ReaderContext> consumer) {
+            runOnNewReaderContext.set(consumer);
+        }
+
+        @Override
+        public void onIndexModule(IndexModule indexModule) {
+            super.onIndexModule(indexModule);
+            indexModule.addSearchOperationListener(new SearchOperationListener() {
+                @Override
+                public void onNewReaderContext(ReaderContext c) {
+                    runOnNewReaderContext.get().accept(c);
+                }
+            });
+        }
+    }
 }
 }