Browse Source

Avoid double term construction in DfsPhase (#38716)

DfsPhase captures terms used for scoring a query in order to build global term statistics across
multiple shards for more accurate scoring. It currently does this by building the query's `Weight`
and calling `extractTerms` on it to collect terms, and then calling `IndexSearcher.termStatistics()`
for each collected term. This duplicates work, however, as the various `Weight` implementations 
will already have collected these statistics at construction time.

This commit replaces this round-about way of collecting stats, instead using a delegating
IndexSearcher that collects the term contexts and statistics when `IndexSearcher.termStatistics()`
is called from the Weight.

It also fixes a bug when using rescorers, where a `QueryRescorer` would calculate distributed term
statistics, but ignore field statistics.  `Rescorer.extractTerms` has been removed, and replaced with
a new method on `RescoreContext` that returns any queries used by the rescore implementation.
The delegating IndexSearcher then collects term contexts and statistics in the same way described
above for each Query.
Alan Woodward 6 years ago
parent
commit
38d29354ff

+ 0 - 6
plugins/examples/rescore/src/main/java/org/elasticsearch/example/rescore/ExampleRescoreBuilder.java

@@ -20,7 +20,6 @@
 package org.elasticsearch.example.rescore;
 
 import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.Term;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
@@ -46,7 +45,6 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.Objects;
-import java.util.Set;
 
 import static java.util.Collections.singletonList;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@@ -224,9 +222,5 @@ public class ExampleRescoreBuilder extends RescorerBuilder<ExampleRescoreBuilder
             return Explanation.match(context.factor, "test", singletonList(sourceExplanation));
         }
 
-        @Override
-        public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) {
-            // Since we don't use queries there are no terms to extract.
-        }
     }
 }

+ 37 - 88
server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

@@ -19,14 +19,12 @@
 
 package org.elasticsearch.search.dfs;
 
-import com.carrotsearch.hppc.ObjectHashSet;
 import com.carrotsearch.hppc.ObjectObjectHashMap;
-import com.carrotsearch.hppc.cursors.ObjectCursor;
-
-import org.apache.lucene.index.IndexReaderContext;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.index.TermStates;
 import org.apache.lucene.search.CollectionStatistics;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.TermStatistics;
 import org.elasticsearch.common.collect.HppcMaps;
@@ -36,9 +34,8 @@ import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.tasks.TaskCancelledException;
 
 import java.io.IOException;
-import java.util.AbstractSet;
-import java.util.Collection;
-import java.util.Iterator;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
@@ -52,42 +49,46 @@ public class DfsPhase implements SearchPhase {
 
     @Override
     public void execute(SearchContext context) {
-        final ObjectHashSet<Term> termsSet = new ObjectHashSet<>();
         try {
-            context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1f)
-                .extractTerms(new DelegateSet(termsSet));
+            ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
+            Map<Term, TermStatistics> stats = new HashMap<>();
+            IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
+                @Override
+                public TermStatistics termStatistics(Term term, TermStates states) throws IOException {
+                    if (context.isCancelled()) {
+                        throw new TaskCancelledException("cancelled");
+                    }
+                    TermStatistics ts = super.termStatistics(term, states);
+                    if (ts != null) {
+                        stats.put(term, ts);
+                    }
+                    return ts;
+                }
+
+                @Override
+                public CollectionStatistics collectionStatistics(String field) throws IOException {
+                    if (context.isCancelled()) {
+                        throw new TaskCancelledException("cancelled");
+                    }
+                    CollectionStatistics cs = super.collectionStatistics(field);
+                    if (cs != null) {
+                        fieldStatistics.put(field, cs);
+                    }
+                    return cs;
+                }
+            };
+
+            searcher.createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1);
             for (RescoreContext rescoreContext : context.rescore()) {
-                try {
-                    rescoreContext.rescorer().extractTerms(context.searcher(), rescoreContext, new DelegateSet(termsSet));
-                } catch (IOException e) {
-                    throw new IllegalStateException("Failed to extract terms", e);
+                for (Query query : rescoreContext.getQueries()) {
+                    searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
                 }
             }
 
-            Term[] terms = termsSet.toArray(Term.class);
+            Term[] terms = stats.keySet().toArray(new Term[0]);
             TermStatistics[] termStatistics = new TermStatistics[terms.length];
-            IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
             for (int i = 0; i < terms.length; i++) {
-                if(context.isCancelled()) {
-                    throw new TaskCancelledException("cancelled");
-                }
-                // LUCENE 4 UPGRADE: cache TermStates?
-                TermStates termContext = TermStates.build(indexReaderContext, terms[i], true);
-                termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
-            }
-
-            ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
-            for (Term term : terms) {
-                assert term.field() != null : "field is null";
-                if (fieldStatistics.containsKey(term.field()) == false) {
-                    final CollectionStatistics collectionStatistics = context.searcher().collectionStatistics(term.field());
-                    if (collectionStatistics != null) {
-                        fieldStatistics.put(term.field(), collectionStatistics);
-                    }
-                    if(context.isCancelled()) {
-                        throw new TaskCancelledException("cancelled");
-                    }
-                }
+                termStatistics[i] = stats.get(terms[i]);
             }
 
             context.dfsResult().termsStatistics(terms, termStatistics)
@@ -95,58 +96,6 @@ public class DfsPhase implements SearchPhase {
                     .maxDoc(context.searcher().getIndexReader().maxDoc());
         } catch (Exception e) {
             throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
-        } finally {
-            termsSet.clear(); // don't hold on to terms
-        }
-    }
-
-    // We need to bridge to JCF world, b/c of Query#extractTerms
-    private static class DelegateSet extends AbstractSet<Term> {
-
-        private final ObjectHashSet<Term> delegate;
-
-        private DelegateSet(ObjectHashSet<Term> delegate) {
-            this.delegate = delegate;
-        }
-
-        @Override
-        public boolean add(Term term) {
-            return delegate.add(term);
-        }
-
-        @Override
-        public boolean addAll(Collection<? extends Term> terms) {
-            boolean result = false;
-            for (Term term : terms) {
-                result = delegate.add(term);
-            }
-            return result;
-        }
-
-        @Override
-        public Iterator<Term> iterator() {
-            final Iterator<ObjectCursor<Term>> iterator = delegate.iterator();
-            return new Iterator<Term>() {
-                @Override
-                public boolean hasNext() {
-                    return iterator.hasNext();
-                }
-
-                @Override
-                public Term next() {
-                    return iterator.next().value;
-                }
-
-                @Override
-                public void remove() {
-                    throw new UnsupportedOperationException();
-                }
-            };
-        }
-
-        @Override
-        public int size() {
-            return delegate.size();
         }
     }
 

+ 8 - 9
server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java

@@ -19,19 +19,19 @@
 
 package org.elasticsearch.search.rescore;
 
-import org.apache.lucene.index.Term;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.ScoreDoc;
-import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.TopDocs;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Comparator;
+import java.util.List;
 import java.util.Set;
-import java.util.Collections;
+
 import static java.util.stream.Collectors.toSet;
 
 public final class QueryRescorer implements Rescorer {
@@ -170,6 +170,11 @@ public final class QueryRescorer implements Rescorer {
             this.query = query;
         }
 
+        @Override
+        public List<Query> getQueries() {
+            return Collections.singletonList(query);
+        }
+
         public Query query() {
             return query;
         }
@@ -203,10 +208,4 @@ public final class QueryRescorer implements Rescorer {
         }
     }
 
-    @Override
-    public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException {
-        Query query = ((QueryRescoreContext) rescoreContext).query();
-        searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f).extractTerms(termsSet);
-    }
-
 }

+ 14 - 3
server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java

@@ -19,6 +19,10 @@
 
 package org.elasticsearch.search.rescore;
 
+import org.apache.lucene.search.Query;
+
+import java.util.Collections;
+import java.util.List;
 import java.util.Set;
 
 /**
@@ -29,7 +33,7 @@ import java.util.Set;
 public class RescoreContext {
     private final int windowSize;
     private final Rescorer rescorer;
-    private Set<Integer> resroredDocs; //doc Ids for which rescoring was applied
+    private Set<Integer> rescoredDocs; //doc Ids for which rescoring was applied
 
     /**
      * Build the context.
@@ -55,10 +59,17 @@ public class RescoreContext {
     }
 
     public void setRescoredDocs(Set<Integer> docIds) {
-        resroredDocs = docIds;
+        rescoredDocs = docIds;
     }
 
     public boolean isRescored(int docId) {
-        return resroredDocs.contains(docId);
+        return rescoredDocs.contains(docId);
+    }
+
+    /**
+     * Returns queries associated with the rescorer
+     */
+    public List<Query> getQueries() {
+        return Collections.emptyList();
     }
 }

+ 0 - 9
server/src/main/java/org/elasticsearch/search/rescore/Rescorer.java

@@ -19,14 +19,11 @@
 
 package org.elasticsearch.search.rescore;
 
-import org.apache.lucene.index.Term;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.TopDocs;
-import org.elasticsearch.action.search.SearchType;
 
 import java.io.IOException;
-import java.util.Set;
 
 /**
  * A query rescorer interface used to re-rank the Top-K results of a previously
@@ -61,10 +58,4 @@ public interface Rescorer {
     Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext,
                         Explanation sourceExplanation) throws IOException;
 
-    /**
-     * Extracts all terms needed to execute this {@link Rescorer}. This method
-     * is executed in a distributed frequency collection roundtrip for
-     * {@link SearchType#DFS_QUERY_THEN_FETCH}
-     */
-    void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException;
 }