Browse Source

Remove empty results before merging (#126770)

We addressed the empty top docs issue with #126385 specifically for scenarios where
empty top docs don't go through the wire. Yet they may be serialized from data node
back to the coord node, in which case they will no longer be equal to Lucene#EMPTY_TOP_DOCS.

This commit expands the existing filtering of empty top docs to include also those that
did go through serialization.

Closes #126742
Luca Cavanna 6 months ago
parent
commit
f274ab7402

+ 6 - 0
docs/changelog/126770.yaml

@@ -0,0 +1,6 @@
+pr: 126770
+summary: Remove empty results before merging
+area: Search
+type: bug
+issues:
+ - 126742

+ 0 - 3
muted-tests.yml

@@ -315,9 +315,6 @@ tests:
 - class: org.elasticsearch.search.CCSDuelIT
 - class: org.elasticsearch.search.CCSDuelIT
   method: testTerminateAfter
   method: testTerminateAfter
   issue: https://github.com/elastic/elasticsearch/issues/126085
   issue: https://github.com/elastic/elasticsearch/issues/126085
-- class: org.elasticsearch.search.sort.GeoDistanceIT
-  method: testDistanceSortingWithUnmappedField
-  issue: https://github.com/elastic/elasticsearch/issues/126118
 - class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
 - class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT
   method: testSearchWithRandomDisconnects
   method: testSearchWithRandomDisconnects
   issue: https://github.com/elastic/elasticsearch/issues/122707
   issue: https://github.com/elastic/elasticsearch/issues/122707

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -225,6 +225,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_QUERY_PLANNING_DURATION = def(9_051_0_00);
     public static final TransportVersion ESQL_QUERY_PLANNING_DURATION = def(9_051_0_00);
     public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED = def(9_052_0_00);
     public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED = def(9_052_0_00);
     public static final TransportVersion BATCHED_QUERY_EXECUTION_DELAYABLE_WRITABLE = def(9_053_0_00);
     public static final TransportVersion BATCHED_QUERY_EXECUTION_DELAYABLE_WRITABLE = def(9_053_0_00);
+    public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL = def(9_054_0_00);
 
 
     /*
     /*
      * STOP! READ THIS FIRST! No, really,
      * STOP! READ THIS FIRST! No, really,

+ 10 - 4
server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

@@ -303,13 +303,19 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         Collection<DelayableWriteable<InternalAggregations>> aggsList
         Collection<DelayableWriteable<InternalAggregations>> aggsList
     ) {
     ) {
         if (topDocsList != null) {
         if (topDocsList != null) {
-            topDocsList.add(partialResult.reducedTopDocs);
+            addTopDocsToList(partialResult, topDocsList);
         }
         }
         if (aggsList != null) {
         if (aggsList != null) {
             addAggsToList(partialResult, aggsList);
             addAggsToList(partialResult, aggsList);
         }
         }
     }
     }
 
 
+    private static void addTopDocsToList(MergeResult partialResult, List<TopDocs> topDocsList) {
+        if (partialResult.reducedTopDocs != null) {
+            topDocsList.add(partialResult.reducedTopDocs);
+        }
+    }
+
     private static void addAggsToList(MergeResult partialResult, Collection<DelayableWriteable<InternalAggregations>> aggsList) {
     private static void addAggsToList(MergeResult partialResult, Collection<DelayableWriteable<InternalAggregations>> aggsList) {
         var aggs = partialResult.reducedAggs;
         var aggs = partialResult.reducedAggs;
         if (aggs != null) {
         if (aggs != null) {
@@ -340,7 +346,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
         if (hasTopDocs) {
         if (hasTopDocs) {
             topDocsList = new ArrayList<>(resultSetSize);
             topDocsList = new ArrayList<>(resultSetSize);
             if (lastMerge != null) {
             if (lastMerge != null) {
-                topDocsList.add(lastMerge.reducedTopDocs);
+                addTopDocsToList(lastMerge, topDocsList);
             }
             }
         } else {
         } else {
             topDocsList = null;
             topDocsList = null;
@@ -358,7 +364,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
                 }
                 }
             }
             }
             // we have to merge here in the same way we collect on a shard
             // we have to merge here in the same way we collect on a shard
-            newTopDocs = topDocsList == null ? Lucene.EMPTY_TOP_DOCS : mergeTopDocs(topDocsList, topNSize, 0);
+            newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0);
             newAggs = hasAggs
             newAggs = hasAggs
                 ? aggregate(
                 ? aggregate(
                     toConsume.iterator(),
                     toConsume.iterator(),
@@ -636,7 +642,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
 
 
     record MergeResult(
     record MergeResult(
         List<SearchShard> processedShards,
         List<SearchShard> processedShards,
-        TopDocs reducedTopDocs,
+        @Nullable TopDocs reducedTopDocs,
         @Nullable DelayableWriteable<InternalAggregations> reducedAggs,
         @Nullable DelayableWriteable<InternalAggregations> reducedAggs,
         long estimatedSize
         long estimatedSize
     ) implements Writeable {
     ) implements Writeable {

+ 12 - 9
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -60,6 +60,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.Objects;
 import java.util.concurrent.Executor;
 import java.util.concurrent.Executor;
 import java.util.function.BiFunction;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Consumer;
@@ -140,24 +141,26 @@ public final class SearchPhaseController {
     }
     }
 
 
     static TopDocs mergeTopDocs(List<TopDocs> results, int topN, int from) {
     static TopDocs mergeTopDocs(List<TopDocs> results, int topN, int from) {
-        if (results.isEmpty()) {
+        List<TopDocs> topDocsList = results.stream().filter(Objects::nonNull).toList();
+        if (topDocsList.isEmpty()) {
             return null;
             return null;
         }
         }
-        final TopDocs topDocs = results.getFirst();
-        final TopDocs mergedTopDocs;
-        final int numShards = results.size();
+        final TopDocs topDocs = topDocsList.getFirst();
+        final int numShards = topDocsList.size();
         if (numShards == 1 && from == 0) { // only one shard and no pagination we can just return the topDocs as we got them.
         if (numShards == 1 && from == 0) { // only one shard and no pagination we can just return the topDocs as we got them.
             return topDocs;
             return topDocs;
-        } else if (topDocs instanceof TopFieldGroups firstTopDocs) {
+        }
+        final TopDocs mergedTopDocs;
+        if (topDocs instanceof TopFieldGroups firstTopDocs) {
             final Sort sort = new Sort(firstTopDocs.fields);
             final Sort sort = new Sort(firstTopDocs.fields);
-            final TopFieldGroups[] shardTopDocs = results.stream().filter(td -> td != Lucene.EMPTY_TOP_DOCS).toArray(TopFieldGroups[]::new);
+            TopFieldGroups[] shardTopDocs = topDocsList.toArray(new TopFieldGroups[0]);
             mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false);
             mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false);
         } else if (topDocs instanceof TopFieldDocs firstTopDocs) {
         } else if (topDocs instanceof TopFieldDocs firstTopDocs) {
-            final Sort sort = checkSameSortTypes(results, firstTopDocs.fields);
-            final TopFieldDocs[] shardTopDocs = results.stream().filter((td -> td != Lucene.EMPTY_TOP_DOCS)).toArray(TopFieldDocs[]::new);
+            TopFieldDocs[] shardTopDocs = topDocsList.toArray(new TopFieldDocs[0]);
+            final Sort sort = checkSameSortTypes(topDocsList, firstTopDocs.fields);
             mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs);
             mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs);
         } else {
         } else {
-            final TopDocs[] shardTopDocs = results.toArray(new TopDocs[numShards]);
+            final TopDocs[] shardTopDocs = topDocsList.toArray(new TopDocs[0]);
             mergedTopDocs = TopDocs.merge(from, topN, shardTopDocs);
             mergedTopDocs = TopDocs.merge(from, topN, shardTopDocs);
         }
         }
         return mergedTopDocs;
         return mergedTopDocs;

+ 7 - 6
server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

@@ -27,7 +27,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.common.util.concurrent.CountDown;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.ListenableFuture;
 import org.elasticsearch.common.util.concurrent.ListenableFuture;
@@ -722,7 +721,7 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<S
 
 
         private static final QueryPhaseResultConsumer.MergeResult EMPTY_PARTIAL_MERGE_RESULT = new QueryPhaseResultConsumer.MergeResult(
         private static final QueryPhaseResultConsumer.MergeResult EMPTY_PARTIAL_MERGE_RESULT = new QueryPhaseResultConsumer.MergeResult(
             List.of(),
             List.of(),
-            Lucene.EMPTY_TOP_DOCS,
+            null,
             null,
             null,
             0L
             0L
         );
         );
@@ -782,10 +781,12 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<S
                 // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other
                 // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other
                 // indices without a roundtrip to the coordinating node
                 // indices without a roundtrip to the coordinating node
                 final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size());
                 final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size());
-                for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) {
-                    final int localIndex = scoreDoc.shardIndex;
-                    scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex;
-                    relevantShardIndices.set(localIndex);
+                if (mergeResult.reducedTopDocs() != null) {
+                    for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) {
+                        final int localIndex = scoreDoc.shardIndex;
+                        scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex;
+                        relevantShardIndices.set(localIndex);
+                    }
                 }
                 }
                 final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()];
                 final Object[] results = new Object[queryPhaseResultConsumer.getNumShards()];
                 for (int i = 0; i < results.length; i++) {
                 for (int i = 0; i < results.length; i++) {

+ 13 - 1
server/src/main/java/org/elasticsearch/common/lucene/Lucene.java

@@ -64,6 +64,7 @@ import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.Version;
 import org.apache.lucene.util.Version;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ExceptionsHelper;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -384,6 +385,14 @@ public class Lucene {
      * by shard for sorting purposes.
      * by shard for sorting purposes.
      */
      */
     public static void writeTopDocsIncludingShardIndex(StreamOutput out, TopDocs topDocs) throws IOException {
     public static void writeTopDocsIncludingShardIndex(StreamOutput out, TopDocs topDocs) throws IOException {
+        if (topDocs == null) {
+            if (out.getTransportVersion().onOrAfter(TransportVersions.SEARCH_INCREMENTAL_TOP_DOCS_NULL)) {
+                out.writeByte((byte) -1);
+                return;
+            } else {
+                topDocs = Lucene.EMPTY_TOP_DOCS;
+            }
+        }
         if (topDocs instanceof TopFieldGroups topFieldGroups) {
         if (topDocs instanceof TopFieldGroups topFieldGroups) {
             out.writeByte((byte) 2);
             out.writeByte((byte) 2);
             writeTotalHits(out, topDocs.totalHits);
             writeTotalHits(out, topDocs.totalHits);
@@ -424,7 +433,10 @@ public class Lucene {
      */
      */
     public static TopDocs readTopDocsIncludingShardIndex(StreamInput in) throws IOException {
     public static TopDocs readTopDocsIncludingShardIndex(StreamInput in) throws IOException {
         byte type = in.readByte();
         byte type = in.readByte();
-        if (type == 0) {
+        if (type == -1) {
+            assert in.getTransportVersion().onOrAfter(TransportVersions.SEARCH_INCREMENTAL_TOP_DOCS_NULL);
+            return null;
+        } else if (type == 0) {
             TotalHits totalHits = readTotalHits(in);
             TotalHits totalHits = readTotalHits(in);
 
 
             final int scoreDocCount = in.readVInt();
             final int scoreDocCount = in.readVInt();