Bläddra i källkod

Leverage scorerSupplier when applicable. (#25109)

The `scorerSupplier` API allows to give a hint to queries in order to let them
know that they will be consumed in a random-access fashion. We should use this
for aggregations, function_score and matched queries.
Adrien Grand 8 år sedan
förälder
incheckning
a8ea2f0df4

+ 6 - 3
core/src/main/java/org/elasticsearch/common/lucene/Lucene.java

@@ -49,6 +49,7 @@ import org.apache.lucene.search.LeafCollector;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.SimpleCollector;
 import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.SortedNumericSortField;
@@ -838,14 +839,16 @@ public class Lucene {
     }
 
     /**
-     * Given a {@link Scorer}, return a {@link Bits} instance that will match
+     * Given a {@link ScorerSupplier}, return a {@link Bits} instance that will match
      * all documents contained in the set. Note that the returned {@link Bits}
      * instance MUST be consumed in order.
      */
-    public static Bits asSequentialAccessBits(final int maxDoc, @Nullable Scorer scorer) throws IOException {
-        if (scorer == null) {
+    public static Bits asSequentialAccessBits(final int maxDoc, @Nullable ScorerSupplier scorerSupplier) throws IOException {
+        if (scorerSupplier == null) {
             return new Bits.MatchNoBits(maxDoc);
         }
+        // Since we want bits, we need random-access
+        final Scorer scorer = scorerSupplier.get(true); // this never returns null
         final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
         final DocIdSetIterator iterator;
         if (twoPhase == null) {

+ 3 - 3
core/src/main/java/org/elasticsearch/common/lucene/search/FilteredCollector.java

@@ -22,7 +22,7 @@ import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Collector;
 import org.apache.lucene.search.FilterLeafCollector;
 import org.apache.lucene.search.LeafCollector;
-import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.common.lucene.Lucene;
@@ -41,9 +41,9 @@ public class FilteredCollector implements Collector {
 
     @Override
     public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
-        final Scorer filterScorer = filter.scorer(context);
+        final ScorerSupplier filterScorerSupplier = filter.scorerSupplier(context);
         final LeafCollector in = collector.getLeafCollector(context);
-        final Bits bits = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorer);
+        final Bits bits = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorerSupplier);
 
         return new FilterLeafCollector(in) {
             @Override

+ 4 - 3
core/src/main/java/org/elasticsearch/common/lucene/search/function/FiltersFunctionScoreQuery.java

@@ -27,6 +27,7 @@ import org.apache.lucene.search.FilterScorer;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -174,8 +175,8 @@ public class FiltersFunctionScoreQuery extends Query {
             for (int i = 0; i < filterFunctions.length; i++) {
                 FilterFunction filterFunction = filterFunctions[i];
                 functions[i] = filterFunction.function.getLeafScoreFunction(context);
-                Scorer filterScorer = filterWeights[i].scorer(context);
-                docSets[i] = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorer);
+                ScorerSupplier filterScorerSupplier = filterWeights[i].scorerSupplier(context);
+                docSets[i] = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorerSupplier);
             }
             return new FiltersFunctionFactorScorer(this, subQueryScorer, scoreMode, filterFunctions, maxBoost, functions, docSets, combineFunction, needsScores);
         }
@@ -200,7 +201,7 @@ public class FiltersFunctionScoreQuery extends Query {
             List<Explanation> filterExplanations = new ArrayList<>();
             for (int i = 0; i < filterFunctions.length; ++i) {
                 Bits docSet = Lucene.asSequentialAccessBits(context.reader().maxDoc(),
-                        filterWeights[i].scorer(context));
+                        filterWeights[i].scorerSupplier(context));
                 if (docSet.get(doc)) {
                     FilterFunction filterFunction = filterFunctions[i];
                     Explanation functionExplanation = filterFunction.function.getLeafScoreFunction(context).explainScore(doc, expl);

+ 1 - 1
core/src/main/java/org/elasticsearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java

@@ -172,7 +172,7 @@ public class AdjacencyMatrixAggregator extends BucketsAggregator {
         // no need to provide deleted docs to the filter
         final Bits[] bits = new Bits[filters.length + totalNumIntersections];
         for (int i = 0; i < filters.length; ++i) {
-            bits[i] = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filters[i].scorer(ctx));
+            bits[i] = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filters[i].scorerSupplier(ctx));
         }
         // Add extra Bits for intersections
         int pos = filters.length;

+ 1 - 1
core/src/main/java/org/elasticsearch/search/aggregations/bucket/filter/FilterAggregator.java

@@ -56,7 +56,7 @@ public class FilterAggregator extends SingleBucketAggregator {
     public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
             final LeafBucketCollector sub) throws IOException {
         // no need to provide deleted docs to the filter
-        final Bits bits = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filter.scorer(ctx));
+        final Bits bits = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filter.scorerSupplier(ctx));
         return new LeafBucketCollectorBase(sub, null) {
             @Override
             public void collect(int doc, long bucket) throws IOException {

+ 1 - 1
core/src/main/java/org/elasticsearch/search/aggregations/bucket/filters/FiltersAggregator.java

@@ -144,7 +144,7 @@ public class FiltersAggregator extends BucketsAggregator {
         // no need to provide deleted docs to the filter
         final Bits[] bits = new Bits[filters.length];
         for (int i = 0; i < filters.length; ++i) {
-            bits[i] = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filters[i].scorer(ctx));
+            bits[i] = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), filters[i].scorerSupplier(ctx));
         }
         return new LeafBucketCollectorBase(sub, null) {
             @Override

+ 3 - 3
core/src/main/java/org/elasticsearch/search/fetch/subphase/MatchedQueriesFetchSubPhase.java

@@ -22,7 +22,7 @@ import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.ReaderUtil;
 import org.apache.lucene.search.Query;
-import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Bits;
 import org.elasticsearch.ExceptionsHelper;
@@ -78,8 +78,8 @@ public final class MatchedQueriesFetchSubPhase implements FetchSubPhase {
                         LeafReaderContext ctx = indexReader.leaves().get(readerIndex);
                         docBase = ctx.docBase;
                         // scorers can be costly to create, so reuse them across docs of the same segment
-                        Scorer scorer = weight.scorer(ctx);
-                        matchingDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), scorer);
+                        ScorerSupplier scorerSupplier = weight.scorerSupplier(ctx);
+                        matchingDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), scorerSupplier);
                     }
                     if (matchingDocs.get(hit.docId() - docBase)) {
                         matchedQueries[i].add(name);

+ 1 - 1
core/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java

@@ -377,7 +377,7 @@ public class LuceneTests extends ESTestCase {
             Weight termWeight = new TermQuery(new Term("foo", "bar")).createWeight(searcher, false, 1f);
             assertEquals(1, reader.leaves().size());
             LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0);
-            Bits bits = Lucene.asSequentialAccessBits(leafReaderContext.reader().maxDoc(), termWeight.scorer(leafReaderContext));
+            Bits bits = Lucene.asSequentialAccessBits(leafReaderContext.reader().maxDoc(), termWeight.scorerSupplier(leafReaderContext));
 
             expectThrows(IndexOutOfBoundsException.class, () -> bits.get(-1));
             expectThrows(IndexOutOfBoundsException.class, () -> bits.get(leafReaderContext.reader().maxDoc()));

+ 1 - 2
modules/parent-join/src/main/java/org/elasticsearch/join/aggregations/ParentToChildrenAggregator.java

@@ -104,8 +104,7 @@ public class ParentToChildrenAggregator extends SingleBucketAggregator {
             return LeafBucketCollector.NO_OP_COLLECTOR;
         }
         final SortedSetDocValues globalOrdinals = valuesSource.globalOrdinalsValues(ctx);
-        Scorer parentScorer = parentFilter.scorer(ctx);
-        final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), parentScorer);
+        final Bits parentDocs = Lucene.asSequentialAccessBits(ctx.reader().maxDoc(), parentFilter.scorerSupplier(ctx));
         return new LeafBucketCollector() {
 
             @Override

+ 2 - 1
modules/percolator/src/main/java/org/elasticsearch/percolator/PercolateQuery.java

@@ -27,6 +27,7 @@ import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TwoPhaseIterator;
 import org.apache.lucene.search.Weight;
@@ -139,7 +140,7 @@ final class PercolateQuery extends Query implements Accountable {
                         }
                     };
                 } else {
-                    Scorer verifiedDocsScorer = verifiedMatchesWeight.scorer(leafReaderContext);
+                    ScorerSupplier verifiedDocsScorer = verifiedMatchesWeight.scorerSupplier(leafReaderContext);
                     Bits verifiedDocsBits = Lucene.asSequentialAccessBits(leafReaderContext.reader().maxDoc(), verifiedDocsScorer);
                     return new BaseScorer(this, approximation, queries, percolatorIndexSearcher) {