Przeglądaj źródła

Async Search: correct shards counting (#55758)

Async search allows users to retrieve partial results for a running search. For partial results, the number of successful shards does not include the skipped shards, while the response returned to users should.

Also, we recently had a bug where async search would miss tracking shard failures, which would have been caught if we had assertions in place that verified that whenever we get the last response, the number of failures included in it is the same as the failures that were tracked through the listener notifications.
Luca Cavanna 5 lat temu
rodzic
commit
9ffd006ca0

+ 9 - 1
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java

@@ -94,7 +94,8 @@ class MutableSearchResponse {
             throw new IllegalStateException("received partial response out of order: "
                 + reducePhase + " < " + this.reducePhase);
         }
-        this.successfulShards = successfulShards;
+        //when we get partial results skipped shards are not included in the provided number of successful shards
+        this.successfulShards = successfulShards + skippedShards;
         this.totalHits = totalHits;
         this.reducedAggsSource = reducedAggs;
         this.reducePhase = reducePhase;
@@ -106,6 +107,11 @@ class MutableSearchResponse {
      */
     synchronized void updateFinalResponse(SearchResponse response) {
         failIfFrozen();
+        assert response.getTotalShards() == totalShards : "received number of total shards differs from the one " +
+            "notified through onListShards";
+        assert response.getSkippedShards() == skippedShards : "received number of skipped shards differs from the one " +
+            "notified through onListShards";
+        assert response.getFailedShards() == buildShardFailures().length : "number of tracked failures differs from failed shards";
         // copy the response headers from the current context
         this.responseHeaders = threadContext.getResponseHeaders();
         this.finalResponse = response;
@@ -121,6 +127,8 @@ class MutableSearchResponse {
         failIfFrozen();
         // copy the response headers from the current context
         this.responseHeaders = threadContext.getResponseHeaders();
+        //note that when search fails, we may have gotten partial results before the failure. In that case async
+        // search will return an error plus the last partial results that were collected.
         this.isPartial = true;
         this.failure = ElasticsearchException.guessRootCauses(exc)[0];
         this.frozen = true;

+ 1 - 1
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchResponseTests.java

@@ -121,7 +121,7 @@ public class AsyncSearchResponseTests extends ESTestCase {
         long tookInMillis = randomNonNegativeLong();
         int totalShards = randomIntBetween(1, Integer.MAX_VALUE);
         int successfulShards = randomIntBetween(0, totalShards);
-        int skippedShards = totalShards - successfulShards;
+        int skippedShards = randomIntBetween(0, successfulShards);
         InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
         return new SearchResponse(internalSearchResponse, null, totalShards,
             successfulShards, skippedShards, tookInMillis, ShardSearchFailure.EMPTY_ARRAY, SearchResponse.Clusters.EMPTY);

+ 52 - 20
x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java

@@ -134,26 +134,24 @@ public class AsyncSearchTaskTests extends ESTestCase {
         for (int i = 0; i < numSkippedShards; i++) {
             skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
         }
-
-        int numShardFailures = 0;
+        int totalShards = numShards + numSkippedShards;
         task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
         for (int i = 0; i < numShards; i++) {
             task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
                 new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-            assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true);
+            assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
         }
         task.getSearchProgressActionListener().onFinalReduce(shards,
             new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-        assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numShardFailures, true);
+        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, true);
         ((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
-            newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards));
-        assertCompletionListeners(task, numShards+numSkippedShards,
-            numSkippedShards, numShardFailures, false);
+            newSearchResponse(totalShards, totalShards, numSkippedShards));
+        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, 0, false);
     }
 
     public void testWithFetchFailures() throws InterruptedException {
         AsyncSearchTask task = createAsyncSearchTask();
-        int numShards = randomIntBetween(0, 10);
+        int numShards = randomIntBetween(2, 10);
         List<SearchShard> shards = new ArrayList<>();
         for (int i = 0; i < numShards; i++) {
             shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
@@ -163,27 +161,59 @@ public class AsyncSearchTaskTests extends ESTestCase {
         for (int i = 0; i < numSkippedShards; i++) {
             skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
         }
-
+        int totalShards = numShards + numSkippedShards;
         task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
         for (int i = 0; i < numShards; i++) {
             task.getSearchProgressActionListener().onPartialReduce(shards.subList(i, i+1),
                 new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-            assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, 0, true);
+            assertCompletionListeners(task, totalShards, 1 + numSkippedShards, numSkippedShards, 0, true);
         }
         task.getSearchProgressActionListener().onFinalReduce(shards,
             new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
-        int numFetchFailures = randomIntBetween(0, numShards);
-        ShardSearchFailure[] failures = new ShardSearchFailure[numFetchFailures];
+        int numFetchFailures = randomIntBetween(1, numShards - 1);
+        ShardSearchFailure[] shardSearchFailures = new ShardSearchFailure[numFetchFailures];
         for (int i = 0; i < numFetchFailures; i++) {
-            failures[i] = new ShardSearchFailure(new IOException("boum"),
-                new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE));
-            task.getSearchProgressActionListener().onFetchFailure(i, failures[i].shard(), (Exception) failures[i].getCause());
+            IOException failure = new IOException("boum");
+            task.getSearchProgressActionListener().onFetchFailure(i,
+                new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
+                failure);
+            shardSearchFailures[i] = new ShardSearchFailure(failure);
         }
-        assertCompletionListeners(task, numShards+numSkippedShards, numSkippedShards, numFetchFailures, true);
+        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numFetchFailures, true);
         ((AsyncSearchTask.Listener)task.getProgressListener()).onResponse(
-            newSearchResponse(numShards+numSkippedShards, numShards, numSkippedShards, failures));
-        assertCompletionListeners(task, numShards+numSkippedShards,
-            numSkippedShards, numFetchFailures, false);
+            newSearchResponse(totalShards, totalShards - numFetchFailures, numSkippedShards, shardSearchFailures));
+        assertCompletionListeners(task, totalShards, totalShards - numFetchFailures, numSkippedShards, numFetchFailures, false);
+    }
+
+    public void testFatalFailureDuringFetch() throws InterruptedException {
+        AsyncSearchTask task = createAsyncSearchTask();
+        int numShards = randomIntBetween(0, 10);
+        List<SearchShard> shards = new ArrayList<>();
+        for (int i = 0; i < numShards; i++) {
+            shards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+        }
+        List<SearchShard> skippedShards = new ArrayList<>();
+        int numSkippedShards = randomIntBetween(0, 10);
+        for (int i = 0; i < numSkippedShards; i++) {
+            skippedShards.add(new SearchShard(null, new ShardId("0", "0", 1)));
+        }
+        int totalShards = numShards + numSkippedShards;
+        task.getSearchProgressActionListener().onListShards(shards, skippedShards, SearchResponse.Clusters.EMPTY, false);
+        for (int i = 0; i < numShards; i++) {
+            task.getSearchProgressActionListener().onPartialReduce(shards.subList(0, i+1),
+                new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+            assertCompletionListeners(task, totalShards, i + 1 + numSkippedShards, numSkippedShards, 0, true);
+        }
+        task.getSearchProgressActionListener().onFinalReduce(shards,
+            new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), null, 0);
+        for (int i = 0; i < numShards; i++) {
+            task.getSearchProgressActionListener().onFetchFailure(i,
+                new SearchShardTarget("0", new ShardId("0", "0", 1), null, OriginalIndices.NONE),
+                new IOException("boum"));
+        }
+        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
+        ((AsyncSearchTask.Listener)task.getProgressListener()).onFailure(new IOException("boum"));
+        assertCompletionListeners(task, totalShards, totalShards, numSkippedShards, numShards, true);
     }
 
     private static SearchResponse newSearchResponse(int totalShards, int successfulShards, int skippedShards,
@@ -194,8 +224,9 @@ public class AsyncSearchTaskTests extends ESTestCase {
             100, failures, SearchResponse.Clusters.EMPTY);
     }
 
-    private void assertCompletionListeners(AsyncSearchTask task,
+    private static void assertCompletionListeners(AsyncSearchTask task,
                                            int expectedTotalShards,
+                                           int expectedSuccessfulShards,
                                            int expectedSkippedShards,
                                            int expectedShardFailures,
                                            boolean isPartial) throws InterruptedException {
@@ -206,6 +237,7 @@ public class AsyncSearchTaskTests extends ESTestCase {
                 @Override
                 public void onResponse(AsyncSearchResponse resp) {
                     assertThat(resp.getSearchResponse().getTotalShards(), equalTo(expectedTotalShards));
+                    assertThat(resp.getSearchResponse().getSuccessfulShards(), equalTo(expectedSuccessfulShards));
                     assertThat(resp.getSearchResponse().getSkippedShards(), equalTo(expectedSkippedShards));
                     assertThat(resp.getSearchResponse().getFailedShards(), equalTo(expectedShardFailures));
                     assertThat(resp.isPartial(), equalTo(isPartial));