소스 검색

Implement count for wrapped Weight in ContextIndexSearcher (#88396)

Implements Weight#count() for wrapped Weights that don't change matching documents.

Relatess #88284
Nhat Nguyen 3 년 전
부모
커밋
4732fc2343

+ 2 - 1
docs/reference/search/search-your-data/search-your-data.asciidoc

@@ -503,7 +503,8 @@ be set to `true` in the response.
 }
 --------------------------------------------------
 // TESTRESPONSE[s/"took": 3/"took": $body.took/]
-
+// TESTRESPONSE[s/"value": 1/"value": $body.hits.total.value/]
+// TESTRESPONSE[s/"relation": "eq"/"relation": $body.hits.total.relation/]
 
 The `took` time in the response contains the milliseconds that this request
 took for processing, beginning quickly after the node received the query, up

+ 5 - 0
server/src/main/java/org/elasticsearch/common/lucene/search/NoRewriteMatchNoDocsQuery.java

@@ -50,6 +50,11 @@ public class NoRewriteMatchNoDocsQuery extends Query {
             public boolean isCacheable(LeafReaderContext ctx) {
                 return true;
             }
+
+            @Override
+            public int count(LeafReaderContext context) {
+                return 0;
+            }
         };
     }
 

+ 5 - 0
server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java

@@ -245,6 +245,11 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
                         return null;
                     }
                 }
+
+                @Override
+                public int count(LeafReaderContext context) throws IOException {
+                    return weight.count(context);
+                }
             };
         } else {
             return weight;

+ 25 - 8
server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java

@@ -81,9 +81,11 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.function.IntUnaryOperator;
 
 import static org.elasticsearch.search.query.TopDocsCollectorContext.hasInfMaxScore;
 import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -337,6 +339,17 @@ public class QueryPhaseTests extends IndexShardTestCase {
         }
         w.close();
         final IndexReader reader = DirectoryReader.open(dir);
+        // TotalHitCountCollector can shortcut count until the terminate_after
+        IntUnaryOperator countDocUpTo = terminateAfter -> {
+            int total = 0;
+            for (LeafReaderContext leaf : reader.leaves()) {
+                total += leaf.reader().numDocs();
+                if (total >= terminateAfter) {
+                    break;
+                }
+            }
+            return total;
+        };
         TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader));
         context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
         context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
@@ -364,7 +377,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
             context.setSize(0);
             QueryPhase.executeInternal(context);
             assertTrue(context.queryResult().terminatedEarly());
-            assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
+            assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) countDocUpTo.applyAsInt(1)));
             assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
         }
 
@@ -390,8 +403,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
             context.parsedQuery(new ParsedQuery(bq));
             QueryPhase.executeInternal(context);
             assertTrue(context.queryResult().terminatedEarly());
-            assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
             assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
+            context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
         }
         {
             context.setSize(1);
@@ -401,7 +414,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
             assertTrue(context.queryResult().terminatedEarly());
             assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
             assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
-            assertThat(collector.getTotalHits(), equalTo(1));
+            // TotalHitCountCollector counts num docs in the first leaf
+            assertThat(collector.getTotalHits(), equalTo(reader.leaves().get(0).reader().numDocs()));
             context.queryCollectors().clear();
         }
         {
@@ -410,9 +424,11 @@ public class QueryPhaseTests extends IndexShardTestCase {
             context.queryCollectors().put(TotalHitCountCollector.class, collector);
             QueryPhase.executeInternal(context);
             assertTrue(context.queryResult().terminatedEarly());
-            assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
+            // TotalHitCountCollector counts num docs in the first leaf
+            int numDocsInFirstLeaf = reader.leaves().get(0).reader().numDocs();
+            assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocsInFirstLeaf));
             assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
-            assertThat(collector.getTotalHits(), equalTo(1));
+            assertThat(collector.getTotalHits(), equalTo(numDocsInFirstLeaf));
         }
 
         // tests with trackTotalHits and terminateAfter
@@ -427,10 +443,10 @@ public class QueryPhaseTests extends IndexShardTestCase {
             if (trackTotalHits == -1) {
                 assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L));
             } else {
-                assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10)));
+                assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) countDocUpTo.applyAsInt(10)));
             }
             assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
-            assertThat(collector.getTotalHits(), equalTo(10));
+            assertThat(collector.getTotalHits(), equalTo(countDocUpTo.applyAsInt(10)));
         }
 
         context.terminateAfter(7);
@@ -769,7 +785,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
             searchContext.setSize(0);
             QueryPhase.executeInternal(searchContext);
             assertTrue(searchContext.sort().sort.getSort()[0].getOptimizeSortWithPoints());
-            assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false);
+            assertThat(searchContext.queryResult().topDocs().topDocs.scoreDocs, arrayWithSize(0));
+            assertThat(searchContext.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
         }
 
         // 7. Test that sort optimization doesn't break a case where from = 0 and size= 0