Browse Source

Lucene#asSequentialBits gets the leadCost backwards. (#48335)

The comment says it needs random-access, but it passes `Long#MAX_VALUE` as the
lead cost, which forces sequential access, it should pass `0` instead. I took
advantage of this fix to improve the logic to leverage an estimation of the
number of times that `Bits#get` gets called to make better decisions.
Adrien Grand 6 years ago
parent
commit
09e9bff14c

+ 15 - 4
server/src/main/java/org/elasticsearch/common/lucene/Lucene.java

@@ -801,16 +801,27 @@ public class Lucene {
     }
 
     /**
-     * Given a {@link ScorerSupplier}, return a {@link Bits} instance that will match
-     * all documents contained in the set. Note that the returned {@link Bits}
-     * instance MUST be consumed in order.
+     * Return a {@link Bits} view of the provided scorer.
+     * <b>NOTE</b>: that the returned {@link Bits} instance MUST be consumed in order.
+     * @see #asSequentialAccessBits(int, ScorerSupplier, long)
      */
     public static Bits asSequentialAccessBits(final int maxDoc, @Nullable ScorerSupplier scorerSupplier) throws IOException {
+        return asSequentialAccessBits(maxDoc, scorerSupplier, 0L);
+    }
+
+    /**
+     * Given a {@link ScorerSupplier}, return a {@link Bits} instance that will match
+     * all documents contained in the set.
+     * <b>NOTE</b>: that the returned {@link Bits} instance MUST be consumed in order.
+     * @param estimatedGetCount an estimation of the number of times that {@link Bits#get} will get called
+     */
+    public static Bits asSequentialAccessBits(final int maxDoc, @Nullable ScorerSupplier scorerSupplier,
+            long estimatedGetCount) throws IOException {
         if (scorerSupplier == null) {
             return new Bits.MatchNoBits(maxDoc);
         }
         // Since we want bits, we need random-access
-        final Scorer scorer = scorerSupplier.get(Long.MAX_VALUE); // this never returns null
+        final Scorer scorer = scorerSupplier.get(estimatedGetCount); // this never returns null
         final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
         final DocIdSetIterator iterator;
         if (twoPhase == null) {

+ 2 - 1
server/src/main/java/org/elasticsearch/common/lucene/search/function/FunctionScoreQuery.java

@@ -269,6 +269,7 @@ public class FunctionScoreQuery extends Query {
             if (subQueryScorer == null) {
                 return null;
             }
+            final long leadCost = subQueryScorer.iterator().cost();
             final LeafScoreFunction[] leafFunctions = new LeafScoreFunction[functions.length];
             final Bits[] docSets = new Bits[functions.length];
             for (int i = 0; i < functions.length; i++) {
@@ -276,7 +277,7 @@ public class FunctionScoreQuery extends Query {
                 leafFunctions[i] = function.getLeafScoreFunction(context);
                 if (filterWeights[i] != null) {
                     ScorerSupplier filterScorerSupplier = filterWeights[i].scorerSupplier(context);
-                    docSets[i] = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorerSupplier);
+                    docSets[i] = Lucene.asSequentialAccessBits(context.reader().maxDoc(), filterScorerSupplier, leadCost);
                 } else {
                     docSets[i] = new Bits.MatchAllBits(context.reader().maxDoc());
                 }

+ 102 - 0
server/src/test/java/org/elasticsearch/common/lucene/LuceneTests.java

@@ -24,9 +24,11 @@ import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.Field.Store;
 import org.apache.lucene.document.LatLonDocValuesField;
+import org.apache.lucene.document.NumericDocValuesField;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.document.TextField;
 import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.LeafReaderContext;
@@ -36,10 +38,15 @@ import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.SegmentInfos;
 import org.apache.lucene.index.SoftDeletesRetentionMergePolicy;
 import org.apache.lucene.index.Term;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexOrDocValuesQuery;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.ScorerSupplier;
 import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.SortedNumericSortField;
 import org.apache.lucene.search.SortedSetSelector;
@@ -420,6 +427,101 @@ public class LuceneTests extends ESTestCase {
         dir.close();
     }
 
+    private static class UnsupportedQuery extends Query {
+
+        @Override
+        public String toString(String field) {
+            return "Unsupported";
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            return obj instanceof UnsupportedQuery;
+        }
+
+        @Override
+        public int hashCode() {
+            return 42;
+        }
+
+        @Override
+        public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+            return new Weight(this) {
+
+                @Override
+                public boolean isCacheable(LeafReaderContext ctx) {
+                    return true;
+                }
+
+                @Override
+                public void extractTerms(Set<Term> terms) {
+                    throw new UnsupportedOperationException();
+                }
+
+                @Override
+                public Explanation explain(LeafReaderContext context, int doc) throws IOException {
+                    throw new UnsupportedOperationException();
+                }
+
+                @Override
+                public Scorer scorer(LeafReaderContext context) throws IOException {
+                    throw new UnsupportedOperationException();
+                }
+
+                @Override
+                public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
+                    return new ScorerSupplier() {
+
+                        @Override
+                        public Scorer get(long leadCost) throws IOException {
+                            throw new UnsupportedOperationException();
+                        }
+
+                        @Override
+                        public long cost() {
+                            return context.reader().maxDoc();
+                        }
+
+                    };
+                }
+
+            };
+        }
+
+    }
+
+    public void testAsSequentialBitsUsesRandomAccess() throws IOException {
+        try (Directory dir = newDirectory()) {
+            try (IndexWriter w = new IndexWriter(dir, new IndexWriterConfig(new KeywordAnalyzer()))) {
+                Document doc = new Document();
+                doc.add(new NumericDocValuesField("foo", 5L));
+                // we need more than 8 documents because doc values are artificially penalized by IndexOrDocValuesQuery
+                for (int i = 0; i < 10; ++i) {
+                    w.addDocument(doc);
+                }
+                w.forceMerge(1);
+                try (IndexReader reader = DirectoryReader.open(w)) {
+                    IndexSearcher searcher = newSearcher(reader);
+                    searcher.setQueryCache(null);
+                    Query query = new IndexOrDocValuesQuery(
+                            new UnsupportedQuery(), NumericDocValuesField.newSlowRangeQuery("foo", 3L, 5L));
+                    Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1f);
+
+                    // Random access by default
+                    ScorerSupplier scorerSupplier = weight.scorerSupplier(reader.leaves().get(0));
+                    Bits bits = Lucene.asSequentialAccessBits(reader.maxDoc(), scorerSupplier);
+                    assertNotNull(bits);
+                    assertTrue(bits.get(0));
+
+                    // Moves to sequential access if Bits#get is called more than the number of matches
+                    ScorerSupplier scorerSupplier2 = weight.scorerSupplier(reader.leaves().get(0));
+                    expectThrows(UnsupportedOperationException.class,
+                            () -> Lucene.asSequentialAccessBits(reader.maxDoc(), scorerSupplier2, reader.maxDoc()));
+                }
+            }
+        }
+    }
+
     /**
      * Test that the "unmap hack" is detected as supported by lucene.
      * This works around the following bug: https://bugs.openjdk.java.net/browse/JDK-4724038