|  | @@ -16,7 +16,9 @@ import org.elasticsearch.action.search.SearchAction;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.SearchPhaseExecutionException;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.SearchResponse;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.SearchScrollAction;
 | 
	
		
			
				|  |  | +import org.elasticsearch.action.search.SearchShardTask;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.SearchTask;
 | 
	
		
			
				|  |  | +import org.elasticsearch.action.search.SearchTransportService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.SearchType;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.search.ShardSearchFailure;
 | 
	
		
			
				|  |  |  import org.elasticsearch.common.Strings;
 | 
	
	
		
			
				|  | @@ -25,6 +27,7 @@ import org.elasticsearch.script.Script;
 | 
	
		
			
				|  |  |  import org.elasticsearch.script.ScriptType;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 | 
	
		
			
				|  |  |  import org.elasticsearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;
 | 
	
		
			
				|  |  | +import org.elasticsearch.search.internal.ReaderContext;
 | 
	
		
			
				|  |  |  import org.elasticsearch.tasks.Task;
 | 
	
		
			
				|  |  |  import org.elasticsearch.tasks.TaskCancelledException;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.AbstractSearchCancellationTestCase;
 | 
	
	
		
			
				|  | @@ -221,43 +224,12 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    /**
 | 
	
		
			
				|  |  | -     * The test `testCancelFailedSearchWhenPartialResultDisallowed` usually fails when concurrency is enabled unless
 | 
	
		
			
				|  |  | -     * the `cancelledLatch.await()` section is commented out. However, this approach seems prone to race conditions.
 | 
	
		
			
				|  |  | -     * Further investigation is needed to determine if this test just needs to be revised, or rather, if it is
 | 
	
		
			
				|  |  | -     * detecting a deeper issue.  For now, we will disable concurrency here.
 | 
	
		
			
				|  |  | -     */
 | 
	
		
			
				|  |  | -    @Override
 | 
	
		
			
				|  |  | -    protected boolean enableConcurrentSearch() {
 | 
	
		
			
				|  |  | -        return false;
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception {
 | 
	
		
			
				|  |  | -        final List<ScriptedBlockPlugin> plugins = initBlockFactory();
 | 
	
		
			
				|  |  |          int numberOfShards = between(2, 5);
 | 
	
		
			
				|  |  | -        AtomicBoolean failed = new AtomicBoolean();
 | 
	
		
			
				|  |  | -        CountDownLatch queryLatch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | -        CountDownLatch cancelledLatch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | -        for (ScriptedBlockPlugin plugin : plugins) {
 | 
	
		
			
				|  |  | -            plugin.disableBlock();
 | 
	
		
			
				|  |  | -            plugin.setBeforeExecution(() -> {
 | 
	
		
			
				|  |  | -                try {
 | 
	
		
			
				|  |  | -                    queryLatch.await(); // block the query until we get a search task
 | 
	
		
			
				|  |  | -                } catch (InterruptedException e) {
 | 
	
		
			
				|  |  | -                    throw new AssertionError(e);
 | 
	
		
			
				|  |  | -                }
 | 
	
		
			
				|  |  | -                if (failed.compareAndSet(false, true)) {
 | 
	
		
			
				|  |  | -                    throw new IllegalStateException("simulated");
 | 
	
		
			
				|  |  | -                }
 | 
	
		
			
				|  |  | -                try {
 | 
	
		
			
				|  |  | -                    cancelledLatch.await(); // block the query until the search is cancelled
 | 
	
		
			
				|  |  | -                } catch (InterruptedException e) {
 | 
	
		
			
				|  |  | -                    throw new AssertionError(e);
 | 
	
		
			
				|  |  | -                }
 | 
	
		
			
				|  |  | -            });
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  |          createIndex("test", numberOfShards, 0);
 | 
	
		
			
				|  |  |          indexTestData();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Define (but don't run) the search request, expecting a partial shard failure. We will run it later.
 | 
	
		
			
				|  |  |          Thread searchThread = new Thread(() -> {
 | 
	
		
			
				|  |  |              SearchPhaseExecutionException e = expectThrows(
 | 
	
		
			
				|  |  |                  SearchPhaseExecutionException.class,
 | 
	
	
		
			
				|  | @@ -270,29 +242,59 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
 | 
	
		
			
				|  |  |              );
 | 
	
		
			
				|  |  |              assertThat(e.getMessage(), containsString("Partial shards failure"));
 | 
	
		
			
				|  |  |          });
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // When the search request executes, block all shards except 1.
 | 
	
		
			
				|  |  | +        final List<SearchShardBlockingPlugin> searchShardBlockingPlugins = initSearchShardBlockingPlugin();
 | 
	
		
			
				|  |  | +        AtomicBoolean letOneShardProceed = new AtomicBoolean();
 | 
	
		
			
				|  |  | +        CountDownLatch shardTaskLatch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        for (SearchShardBlockingPlugin plugin : searchShardBlockingPlugins) {
 | 
	
		
			
				|  |  | +            plugin.setRunOnNewReaderContext((ReaderContext c) -> {
 | 
	
		
			
				|  |  | +                if (letOneShardProceed.compareAndSet(false, true)) {
 | 
	
		
			
				|  |  | +                    // Let one shard continue.
 | 
	
		
			
				|  |  | +                } else {
 | 
	
		
			
				|  |  | +                    try {
 | 
	
		
			
				|  |  | +                        shardTaskLatch.await(); // Bock the other shards.
 | 
	
		
			
				|  |  | +                    } catch (InterruptedException e) {
 | 
	
		
			
				|  |  | +                        throw new AssertionError(e);
 | 
	
		
			
				|  |  | +                    }
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            });
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // For the shard that was allowed to proceed, have a single query-execution thread throw an exception.
 | 
	
		
			
				|  |  | +        final List<ScriptedBlockPlugin> plugins = initBlockFactory();
 | 
	
		
			
				|  |  | +        AtomicBoolean oneThreadWillError = new AtomicBoolean();
 | 
	
		
			
				|  |  | +        for (ScriptedBlockPlugin plugin : plugins) {
 | 
	
		
			
				|  |  | +            plugin.disableBlock();
 | 
	
		
			
				|  |  | +            plugin.setBeforeExecution(() -> {
 | 
	
		
			
				|  |  | +                if (oneThreadWillError.compareAndSet(false, true)) {
 | 
	
		
			
				|  |  | +                    throw new IllegalStateException("This will cancel the ContextIndexSearcher.search task");
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            });
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Now run the search request.
 | 
	
		
			
				|  |  |          searchThread.start();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          try {
 | 
	
		
			
				|  |  | -            assertBusy(() -> assertThat(getSearchTasks(), hasSize(1)));
 | 
	
		
			
				|  |  | -            queryLatch.countDown();
 | 
	
		
			
				|  |  |              assertBusy(() -> {
 | 
	
		
			
				|  |  | -                final List<SearchTask> searchTasks = getSearchTasks();
 | 
	
		
			
				|  |  | -                // The search request can complete before the "cancelledLatch" is latched if the second shard request is sent
 | 
	
		
			
				|  |  | -                // after the request was cancelled (i.e., the child task is not allowed to start after the parent was cancelled).
 | 
	
		
			
				|  |  | -                if (searchTasks.isEmpty() == false) {
 | 
	
		
			
				|  |  | -                    assertThat(searchTasks, hasSize(1));
 | 
	
		
			
				|  |  | -                    assertTrue(searchTasks.get(0).isCancelled());
 | 
	
		
			
				|  |  | +                final List<SearchTask> coordinatorSearchTask = getCoordinatorSearchTasks();
 | 
	
		
			
				|  |  | +                assertThat("The Coordinator should have one SearchTask.", coordinatorSearchTask, hasSize(1));
 | 
	
		
			
				|  |  | +                assertTrue("The SearchTask should be cancelled.", coordinatorSearchTask.get(0).isCancelled());
 | 
	
		
			
				|  |  | +                for (var shardQueryTask : getShardQueryTasks()) {
 | 
	
		
			
				|  |  | +                    assertTrue("All SearchShardTasks should then be cancelled", shardQueryTask.isCancelled());
 | 
	
		
			
				|  |  |                  }
 | 
	
		
			
				|  |  |              }, 30, TimeUnit.SECONDS);
 | 
	
		
			
				|  |  | +            shardTaskLatch.countDown(); // unblock the shardTasks, allowing the test to conclude.
 | 
	
		
			
				|  |  |          } finally {
 | 
	
		
			
				|  |  | +            searchThread.join();
 | 
	
		
			
				|  |  |              for (ScriptedBlockPlugin plugin : plugins) {
 | 
	
		
			
				|  |  |                  plugin.setBeforeExecution(() -> {});
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  | -            cancelledLatch.countDown();
 | 
	
		
			
				|  |  | -            searchThread.join();
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    List<SearchTask> getSearchTasks() {
 | 
	
		
			
				|  |  | +    List<SearchTask> getCoordinatorSearchTasks() {
 | 
	
		
			
				|  |  |          List<SearchTask> tasks = new ArrayList<>();
 | 
	
		
			
				|  |  |          for (String nodeName : internalCluster().getNodeNames()) {
 | 
	
		
			
				|  |  |              TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
 | 
	
	
		
			
				|  | @@ -305,4 +307,16 @@ public class SearchCancellationIT extends AbstractSearchCancellationTestCase {
 | 
	
		
			
				|  |  |          return tasks;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    List<SearchShardTask> getShardQueryTasks() {
 | 
	
		
			
				|  |  | +        List<SearchShardTask> tasks = new ArrayList<>();
 | 
	
		
			
				|  |  | +        for (String nodeName : internalCluster().getNodeNames()) {
 | 
	
		
			
				|  |  | +            TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
 | 
	
		
			
				|  |  | +            for (Task task : transportService.getTaskManager().getCancellableTasks().values()) {
 | 
	
		
			
				|  |  | +                if (task.getAction().equals(SearchTransportService.QUERY_ACTION_NAME)) {
 | 
	
		
			
				|  |  | +                    tasks.add((SearchShardTask) task);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        return tasks;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  |  }
 |