浏览代码

Fail query when a sort is provided in conjunction with rescorers (#26510)

This change fixes a regression introduced in 6 that removes the skipping of the rescore phase
when a sort other than _score is used.
We now fail the request when a sort is provided in conjunction with rescore instead of just skipping the rescore phase
This commit also adds an assert that checks if the topdocs are sorted by _score after the rescoring.
This is the responsibility of the rescorer to make sure that topdocs are sorted after rescore so we
just check that this condition is met in the rescore phase.
Jim Ferenczi 8 年之前
父节点
当前提交
abe83c4fac

+ 3 - 0
core/src/main/java/org/elasticsearch/search/DefaultSearchContext.java

@@ -212,6 +212,9 @@ final class DefaultSearchContext extends SearchContext {
                             + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey() + "] index level setting.");
                             + IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey() + "] index level setting.");
         }
         }
         if (rescore != null) {
         if (rescore != null) {
+            if (sort != null) {
+                throw new QueryPhaseExecutionException(this, "Cannot use [sort] option in conjunction with [rescore].");
+            }
             int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();
             int maxWindow = indexService.getIndexSettings().getMaxRescoreWindow();
             for (RescoreContext rescoreContext: rescore) {
             for (RescoreContext rescoreContext: rescore) {
                 if (rescoreContext.getWindowSize() > maxWindow) {
                 if (rescoreContext.getWindowSize() > maxWindow) {

+ 5 - 2
core/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java

@@ -289,8 +289,11 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
         } else {
         } else {
             int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
             int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
             final boolean rescore = searchContext.rescore().isEmpty() == false;
             final boolean rescore = searchContext.rescore().isEmpty() == false;
-            for (RescoreContext rescoreContext : searchContext.rescore()) {
-                numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
+            if (rescore) {
+                assert searchContext.sort() == null;
+                for (RescoreContext rescoreContext : searchContext.rescore()) {
+                    numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
+                }
             }
             }
             return new SimpleTopDocsCollectorContext(searchContext.sort(),
             return new SimpleTopDocsCollectorContext(searchContext.sort(),
                                                      searchContext.searchAfter(),
                                                      searchContext.searchAfter(),

+ 24 - 0
core/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java

@@ -19,6 +19,9 @@
 
 
 package org.elasticsearch.search.rescore;
 package org.elasticsearch.search.rescore;
 
 
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.SortField;
 import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.component.AbstractComponent;
 import org.elasticsearch.common.component.AbstractComponent;
@@ -47,10 +50,31 @@ public class RescorePhase extends AbstractComponent implements SearchPhase {
             TopDocs topDocs = context.queryResult().topDocs();
             TopDocs topDocs = context.queryResult().topDocs();
             for (RescoreContext ctx : context.rescore()) {
             for (RescoreContext ctx : context.rescore()) {
                 topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
                 topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
+                // It is the responsibility of the rescorer to sort the resulted top docs,
+                // here we only assert that this condition is met.
+                assert context.sort() == null && topDocsSortedByScore(topDocs): "topdocs should be sorted after rescore";
             }
             }
             context.queryResult().topDocs(topDocs, context.queryResult().sortValueFormats());
             context.queryResult().topDocs(topDocs, context.queryResult().sortValueFormats());
         } catch (IOException e) {
         } catch (IOException e) {
             throw new ElasticsearchException("Rescore Phase Failed", e);
             throw new ElasticsearchException("Rescore Phase Failed", e);
         }
         }
     }
     }
+
+    /**
+     * Returns true if the provided docs are sorted by score.
+     */
+    private boolean topDocsSortedByScore(TopDocs topDocs) {
+        if (topDocs == null || topDocs.scoreDocs == null || topDocs.scoreDocs.length < 2) {
+            return true;
+        }
+        float lastScore = topDocs.scoreDocs[0].score;
+        for (int i = 1; i < topDocs.scoreDocs.length; i++) {
+            ScoreDoc doc = topDocs.scoreDocs[i];
+            if (Float.compare(doc.score, lastScore) > 0) {
+                return false;
+            }
+            lastScore = doc.score;
+        }
+        return true;
+    }
 }
 }

+ 45 - 0
core/src/test/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java

@@ -38,6 +38,9 @@ import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.rescore.QueryRescoreMode;
 import org.elasticsearch.search.rescore.QueryRescoreMode;
 import org.elasticsearch.search.rescore.QueryRescorerBuilder;
 import org.elasticsearch.search.rescore.QueryRescorerBuilder;
+import org.elasticsearch.search.sort.SortBuilder;
+import org.elasticsearch.search.sort.SortBuilders;
+import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.ESIntegTestCase;
 
 
 import java.util.Arrays;
 import java.util.Arrays;
@@ -66,6 +69,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSeco
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThirdHit;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThirdHit;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
@@ -705,4 +709,45 @@ public class QueryRescorerIT extends ESIntegTestCase {
 
 
         assertEquals(4, request.get().getHits().getHits().length);
         assertEquals(4, request.get().getHits().getHits().length);
     }
     }
+
+    public void testRescorePhaseWithInvalidSort() throws Exception {
+        assertAcked(prepareCreate("test"));
+        for(int i=0;i<5;i++) {
+            client().prepareIndex("test", "type", ""+i).setSource("number", 0).get();
+        }
+        refresh();
+
+        Exception exc = expectThrows(Exception.class,
+            () -> client().prepareSearch()
+                .addSort(SortBuilders.fieldSort("number"))
+                .setTrackScores(true)
+                .addRescorer(new QueryRescorerBuilder(matchAllQuery()), 50)
+                .get()
+        );
+        assertNotNull(exc.getCause());
+        assertThat(exc.getCause().getMessage(),
+            containsString("Cannot use [sort] option in conjunction with [rescore]."));
+
+        exc = expectThrows(Exception.class,
+            () -> client().prepareSearch()
+                .addSort(SortBuilders.fieldSort("number"))
+                .addSort(SortBuilders.scoreSort())
+                .setTrackScores(true)
+                .addRescorer(new QueryRescorerBuilder(matchAllQuery()), 50)
+                .get()
+        );
+        assertNotNull(exc.getCause());
+        assertThat(exc.getCause().getMessage(),
+            containsString("Cannot use [sort] option in conjunction with [rescore]."));
+
+        SearchResponse resp = client().prepareSearch().addSort(SortBuilders.scoreSort())
+            .setTrackScores(true)
+            .addRescorer(new QueryRescorerBuilder(matchAllQuery()).setRescoreQueryWeight(100.0f), 50)
+            .get();
+        assertThat(resp.getHits().totalHits, equalTo(5L));
+        assertThat(resp.getHits().getHits().length, equalTo(5));
+        for (SearchHit hit : resp.getHits().getHits()) {
+            assertThat(hit.getScore(), equalTo(101f));
+        }
+    }
 }
 }