Browse Source

percolator: Add scoring support to the percolator query

Percolator query documents are scored based on how well they match with the document being percolated.

Closes #13827
Martijn van Groningen 9 years ago
parent
commit
ccb009e45f

+ 5 - 1
core/src/main/java/org/elasticsearch/action/percolate/TransportPercolateAction.java

@@ -38,6 +38,7 @@ import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
 import org.elasticsearch.index.query.PercolatorQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
@@ -207,7 +208,10 @@ public class TransportPercolateAction extends HandledTransportAction<PercolateRe
             boolQueryBuilder.filter(percolatorQueryBuilder);
             searchSource.field("query", boolQueryBuilder);
         } else {
-            searchSource.field("query", percolatorQueryBuilder);
+            // wrapping in a constant score query with boost 0 for bwc reason.
+            // percolator api didn't emit scores before and never included scores
+            // for how well percolator queries matched with the document being percolated
+            searchSource.field("query", new ConstantScoreQueryBuilder(percolatorQueryBuilder).boost(0f));
         }
 
         searchSource.endObject();

+ 92 - 44
core/src/main/java/org/elasticsearch/index/query/PercolatorQuery.java

@@ -28,6 +28,8 @@ import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.SimpleCollector;
+import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TwoPhaseIterator;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.util.Accountable;
@@ -36,12 +38,10 @@ import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.index.percolator.ExtractQueryTermsService;
 
 import java.io.IOException;
-import java.util.Collection;
 import java.util.Objects;
 import java.util.Set;
 
 import static org.apache.lucene.search.BooleanClause.Occur.FILTER;
-import static org.apache.lucene.search.BooleanClause.Occur.MUST;
 
 public final class PercolatorQuery extends Query implements Accountable {
 
@@ -139,7 +139,14 @@ public final class PercolatorQuery extends Query implements Accountable {
                     int result = twoPhaseIterator.approximation().advance(docId);
                     if (result == docId) {
                         if (twoPhaseIterator.matches()) {
-                            return Explanation.match(scorer.score(), "PercolatorQuery");
+                            if (needsScores) {
+                                QueryRegistry.Leaf percolatorQueries = queryRegistry.getQueries(leafReaderContext);
+                                Query query = percolatorQueries.getQuery(docId);
+                                Explanation detail = percolatorIndexSearcher.explain(query, 0);
+                                return Explanation.match(scorer.score(), "PercolatorQuery", detail);
+                            } else {
+                                return Explanation.match(scorer.score(), "PercolatorQuery");
+                            }
                         }
                     }
                 }
@@ -164,52 +171,46 @@ public final class PercolatorQuery extends Query implements Accountable {
                 }
 
                 final QueryRegistry.Leaf percolatorQueries = queryRegistry.getQueries(leafReaderContext);
-                return new Scorer(this) {
-
-                    @Override
-                    public DocIdSetIterator iterator() {
-                        return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
-                    }
-
-                    @Override
-                    public TwoPhaseIterator twoPhaseIterator() {
-                        return new TwoPhaseIterator(approximation.iterator()) {
-                            @Override
-                            public boolean matches() throws IOException {
-                                return matchDocId(approximation.docID());
+                if (needsScores) {
+                    return new BaseScorer(this, approximation, percolatorQueries, percolatorIndexSearcher) {
+
+                        float score;
+
+                        @Override
+                        boolean matchDocId(int docId) throws IOException {
+                            Query query = percolatorQueries.getQuery(docId);
+                            if (query != null) {
+                                TopDocs topDocs = percolatorIndexSearcher.search(query, 1);
+                                if (topDocs.totalHits > 0) {
+                                    score = topDocs.scoreDocs[0].score;
+                                    return true;
+                                } else {
+                                    return false;
+                                }
+                            } else {
+                                return false;
                             }
+                        }
 
-                            @Override
-                            public float matchCost() {
-                                return MATCH_COST;
-                            }
-                        };
-                    }
-
-                    @Override
-                    public float score() throws IOException {
-                        return approximation.score();
-                    }
-
-                    @Override
-                    public int freq() throws IOException {
-                        return approximation.freq();
-                    }
+                        @Override
+                        public float score() throws IOException {
+                            return score;
+                        }
+                    };
+                } else {
+                    return new BaseScorer(this, approximation, percolatorQueries, percolatorIndexSearcher) {
 
-                    @Override
-                    public int docID() {
-                        return approximation.docID();
-                    }
+                        @Override
+                        public float score() throws IOException {
+                            return 0f;
+                        }
 
-                    boolean matchDocId(int docId) throws IOException {
-                        Query query = percolatorQueries.getQuery(docId);
-                        if (query != null) {
-                            return Lucene.exists(percolatorIndexSearcher, query);
-                        } else {
-                            return false;
+                        boolean matchDocId(int docId) throws IOException {
+                            Query query = percolatorQueries.getQuery(docId);
+                            return query != null && Lucene.exists(percolatorIndexSearcher, query);
                         }
-                    }
-                };
+                    };
+                }
             }
         };
     }
@@ -276,4 +277,51 @@ public final class PercolatorQuery extends Query implements Accountable {
 
     }
 
+    static abstract class BaseScorer extends Scorer {
+
+        final Scorer approximation;
+        final QueryRegistry.Leaf percolatorQueries;
+        final IndexSearcher percolatorIndexSearcher;
+
+        BaseScorer(Weight weight, Scorer approximation, QueryRegistry.Leaf percolatorQueries, IndexSearcher percolatorIndexSearcher) {
+            super(weight);
+            this.approximation = approximation;
+            this.percolatorQueries = percolatorQueries;
+            this.percolatorIndexSearcher = percolatorIndexSearcher;
+        }
+
+        @Override
+        public final DocIdSetIterator iterator() {
+            return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator());
+        }
+
+        @Override
+        public final TwoPhaseIterator twoPhaseIterator() {
+            return new TwoPhaseIterator(approximation.iterator()) {
+                @Override
+                public boolean matches() throws IOException {
+                    return matchDocId(approximation.docID());
+                }
+
+                @Override
+                public float matchCost() {
+                    return MATCH_COST;
+                }
+            };
+        }
+
+        @Override
+        public final int freq() throws IOException {
+            return approximation.freq();
+        }
+
+        @Override
+        public final int docID() {
+            return approximation.docID();
+        }
+
+        abstract boolean matchDocId(int docId) throws IOException;
+
+    }
+
 }

+ 71 - 9
core/src/test/java/org/elasticsearch/index/query/PercolatorQueryTests.java

@@ -36,6 +36,7 @@ import org.apache.lucene.queries.BlendedTermQuery;
 import org.apache.lucene.queries.CommonTermsQuery;
 import org.apache.lucene.search.BooleanClause;
 import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.ConstantScoreQuery;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
@@ -64,6 +65,7 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.hamcrest.Matchers.arrayWithSize;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 
@@ -149,44 +151,104 @@ public class PercolatorQueryTests extends ESTestCase {
                 new MatchAllDocsQuery()
         );
         builder.extractQueryTermsQuery(EXTRACTED_TERMS_FIELD_NAME, UNKNOWN_QUERY_FIELD_NAME);
-        TopDocs topDocs = shardSearcher.search(builder.build(), 10);
+        // no scoring, wrapping it in a constant score query:
+        Query query = new ConstantScoreQuery(builder.build());
+        TopDocs topDocs = shardSearcher.search(query, 10);
         assertThat(topDocs.totalHits, equalTo(5));
         assertThat(topDocs.scoreDocs.length, equalTo(5));
         assertThat(topDocs.scoreDocs[0].doc, equalTo(0));
-        Explanation explanation = shardSearcher.explain(builder.build(), 0);
+        Explanation explanation = shardSearcher.explain(query, 0);
         assertThat(explanation.isMatch(), is(true));
         assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[0].score));
 
-        explanation = shardSearcher.explain(builder.build(), 1);
+        explanation = shardSearcher.explain(query, 1);
         assertThat(explanation.isMatch(), is(false));
 
         assertThat(topDocs.scoreDocs[1].doc, equalTo(2));
-        explanation = shardSearcher.explain(builder.build(), 2);
+        explanation = shardSearcher.explain(query, 2);
         assertThat(explanation.isMatch(), is(true));
         assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[1].score));
 
         assertThat(topDocs.scoreDocs[2].doc, equalTo(3));
-        explanation = shardSearcher.explain(builder.build(), 3);
+        explanation = shardSearcher.explain(query, 3);
         assertThat(explanation.isMatch(), is(true));
         assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[2].score));
 
-        explanation = shardSearcher.explain(builder.build(), 4);
+        explanation = shardSearcher.explain(query, 4);
         assertThat(explanation.isMatch(), is(false));
 
         assertThat(topDocs.scoreDocs[3].doc, equalTo(5));
-        explanation = shardSearcher.explain(builder.build(), 5);
+        explanation = shardSearcher.explain(query, 5);
         assertThat(explanation.isMatch(), is(true));
         assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[3].score));
 
-        explanation = shardSearcher.explain(builder.build(), 6);
+        explanation = shardSearcher.explain(query, 6);
         assertThat(explanation.isMatch(), is(false));
 
         assertThat(topDocs.scoreDocs[4].doc, equalTo(7));
-        explanation = shardSearcher.explain(builder.build(), 7);
+        explanation = shardSearcher.explain(query, 7);
         assertThat(explanation.isMatch(), is(true));
         assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[4].score));
     }
 
+    public void testVariousQueries_withScoring() throws Exception {
+        SpanNearQuery.Builder snp = new SpanNearQuery.Builder("field", true);
+        snp.addClause(new SpanTermQuery(new Term("field", "jumps")));
+        snp.addClause(new SpanTermQuery(new Term("field", "lazy")));
+        snp.addClause(new SpanTermQuery(new Term("field", "dog")));
+        snp.setSlop(2);
+        addPercolatorQuery("1", snp.build());
+        PhraseQuery.Builder pq1 = new PhraseQuery.Builder();
+        pq1.add(new Term("field", "quick"));
+        pq1.add(new Term("field", "brown"));
+        pq1.add(new Term("field", "jumps"));
+        pq1.setSlop(1);
+        addPercolatorQuery("2", pq1.build());
+        BooleanQuery.Builder bq1 = new BooleanQuery.Builder();
+        bq1.add(new TermQuery(new Term("field", "quick")), BooleanClause.Occur.MUST);
+        bq1.add(new TermQuery(new Term("field", "brown")), BooleanClause.Occur.MUST);
+        bq1.add(new TermQuery(new Term("field", "fox")), BooleanClause.Occur.MUST);
+        addPercolatorQuery("3", bq1.build());
+
+        indexWriter.close();
+        directoryReader = DirectoryReader.open(directory);
+        IndexSearcher shardSearcher = newSearcher(directoryReader);
+
+        MemoryIndex memoryIndex = new MemoryIndex();
+        memoryIndex.addField("field", "the quick brown fox jumps over the lazy dog", new WhitespaceAnalyzer());
+        IndexSearcher percolateSearcher = memoryIndex.createSearcher();
+
+        PercolatorQuery.Builder builder = new PercolatorQuery.Builder(
+                "docType",
+                queryRegistry,
+                new BytesArray("{}"),
+                percolateSearcher,
+                new MatchAllDocsQuery()
+        );
+        builder.extractQueryTermsQuery(EXTRACTED_TERMS_FIELD_NAME, UNKNOWN_QUERY_FIELD_NAME);
+        Query query = builder.build();
+        TopDocs topDocs = shardSearcher.search(query, 10);
+        assertThat(topDocs.totalHits, equalTo(3));
+
+        assertThat(topDocs.scoreDocs[0].doc, equalTo(2));
+        Explanation explanation = shardSearcher.explain(query, 2);
+        assertThat(explanation.isMatch(), is(true));
+        assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[0].score));
+        assertThat(explanation.getDetails(), arrayWithSize(1));
+
+        assertThat(topDocs.scoreDocs[1].doc, equalTo(1));
+        explanation = shardSearcher.explain(query, 1);
+        assertThat(explanation.isMatch(), is(true));
+        assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[1].score));
+        assertThat(explanation.getDetails(), arrayWithSize(1));
+
+        assertThat(topDocs.scoreDocs[2].doc, equalTo(0));
+        explanation = shardSearcher.explain(query, 0);
+        assertThat(explanation.isMatch(), is(true));
+        assertThat(explanation.getValue(), equalTo(topDocs.scoreDocs[2].score));
+        assertThat(explanation.getDetails(), arrayWithSize(1));
+    }
+
     public void testDuel() throws Exception {
         int numQueries = scaledRandomIntBetween(32, 256);
         for (int i = 0; i < numQueries; i++) {

+ 5 - 0
core/src/test/java/org/elasticsearch/search/percolator/PercolatorQuerySearchIT.java

@@ -148,12 +148,17 @@ public class PercolatorQuerySearchIT extends ESSingleNodeTestCase {
                 .endObject().bytes();
         SearchResponse response = client().prepareSearch()
                 .setQuery(percolatorQuery("type", source))
+                .addSort("_uid", SortOrder.ASC)
                 .get();
         assertHitCount(response, 4);
         assertThat(response.getHits().getAt(0).getId(), equalTo("1"));
+        assertThat(response.getHits().getAt(0).score(), equalTo(Float.NaN));
         assertThat(response.getHits().getAt(1).getId(), equalTo("2"));
+        assertThat(response.getHits().getAt(1).score(), equalTo(Float.NaN));
         assertThat(response.getHits().getAt(2).getId(), equalTo("3"));
+        assertThat(response.getHits().getAt(2).score(), equalTo(Float.NaN));
         assertThat(response.getHits().getAt(3).getId(), equalTo("4"));
+        assertThat(response.getHits().getAt(3).score(), equalTo(Float.NaN));
     }
 
     public void testPercolatorQueryWithHighlighting() throws Exception {