فهرست منبع

Adding new RankFeature phase (#107099)

In this PR we add a new search phase, in-between query and fetch, that is responsible for applying any reranking needed.

The idea is to trim down the query phase results down to rank_window_size, reach out to the shards to extract any additional feature data if needed, and then use this information to rerank the top results, trim them down to size and pass them to fetch phase.
Panagiotis Bailis 1 سال پیش
والد
کامیت
af0c9566e5

+ 13 - 9
server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

@@ -33,15 +33,21 @@ final class FetchSearchPhase extends SearchPhase {
     private final BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
     private final SearchPhaseContext context;
     private final Logger logger;
-    private final SearchPhaseResults<SearchPhaseResult> resultConsumer;
     private final SearchProgressListener progressListener;
     private final AggregatedDfs aggregatedDfs;
+    private final SearchPhaseController.ReducedQueryPhase reducedQueryPhase;
 
-    FetchSearchPhase(SearchPhaseResults<SearchPhaseResult> resultConsumer, AggregatedDfs aggregatedDfs, SearchPhaseContext context) {
+    FetchSearchPhase(
+        SearchPhaseResults<SearchPhaseResult> resultConsumer,
+        AggregatedDfs aggregatedDfs,
+        SearchPhaseContext context,
+        SearchPhaseController.ReducedQueryPhase reducedQueryPhase
+    ) {
         this(
             resultConsumer,
             aggregatedDfs,
             context,
+            reducedQueryPhase,
             (response, queryPhaseResults) -> new ExpandSearchPhase(
                 context,
                 response.hits,
@@ -54,6 +60,7 @@ final class FetchSearchPhase extends SearchPhase {
         SearchPhaseResults<SearchPhaseResult> resultConsumer,
         AggregatedDfs aggregatedDfs,
         SearchPhaseContext context,
+        SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
         BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory
     ) {
         super("fetch");
@@ -72,18 +79,16 @@ final class FetchSearchPhase extends SearchPhase {
         this.nextPhaseFactory = nextPhaseFactory;
         this.context = context;
         this.logger = context.getLogger();
-        this.resultConsumer = resultConsumer;
         this.progressListener = context.getTask().getProgressListener();
+        this.reducedQueryPhase = reducedQueryPhase;
     }
 
     @Override
     public void run() {
         context.execute(new AbstractRunnable() {
+
             @Override
-            protected void doRun() throws Exception {
-                // we do the heavy lifting in this inner run method where we reduce aggs etc. that's why we fork this phase
-                // off immediately instead of forking when we send back the response to the user since there we only need
-                // to merge together the fetched results which is a linear operation.
+            protected void doRun() {
                 innerRun();
             }
 
@@ -94,9 +99,8 @@ final class FetchSearchPhase extends SearchPhase {
         });
     }
 
-    private void innerRun() throws Exception {
+    private void innerRun() {
         final int numShards = context.getNumShards();
-        final SearchPhaseController.ReducedQueryPhase reducedQueryPhase = resultConsumer.reduce();
         // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
         // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
         final boolean queryAndFetchOptimization = queryResults.length() == 1

+ 77 - 0
server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+package org.elasticsearch.action.search;
+
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.dfs.AggregatedDfs;
+
+/**
+ * This search phase is responsible for executing any re-ranking needed for the given search request, iff that is applicable.
+ * It starts by retrieving {code num_shards * window_size} results from the query phase and reduces them to a global list of
+ * the top {@code window_size} results. It then reaches out to the shards to extract the needed feature data,
+ * and finally passes all this information to the appropriate {@code RankFeatureRankCoordinatorContext} which is responsible for reranking
+ * the results. If no rank query is specified, it proceeds directly to the next phase (FetchSearchPhase) by first reducing the results.
+ */
+public final class RankFeaturePhase extends SearchPhase {
+
+    private final SearchPhaseContext context;
+    private final SearchPhaseResults<SearchPhaseResult> queryPhaseResults;
+    private final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
+
+    private final AggregatedDfs aggregatedDfs;
+
+    RankFeaturePhase(SearchPhaseResults<SearchPhaseResult> queryPhaseResults, AggregatedDfs aggregatedDfs, SearchPhaseContext context) {
+        super("rank-feature");
+        if (context.getNumShards() != queryPhaseResults.getNumShards()) {
+            throw new IllegalStateException(
+                "number of shards must match the length of the query results but doesn't:"
+                    + context.getNumShards()
+                    + "!="
+                    + queryPhaseResults.getNumShards()
+            );
+        }
+        this.context = context;
+        this.queryPhaseResults = queryPhaseResults;
+        this.aggregatedDfs = aggregatedDfs;
+        this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards());
+        context.addReleasable(rankPhaseResults);
+    }
+
+    @Override
+    public void run() {
+        context.execute(new AbstractRunnable() {
+            @Override
+            protected void doRun() throws Exception {
+                // we need to reduce the results at this point instead of fetch phase, so we fork this process similarly to how
+                // was set up at FetchSearchPhase.
+
+                // we do the heavy lifting in this inner run method where we reduce aggs etc
+                innerRun();
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                context.onPhaseFailure(RankFeaturePhase.this, "", e);
+            }
+        });
+    }
+
+    private void innerRun() throws Exception {
+        // other than running reduce, this is currently close to a no-op
+        SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce();
+        moveToNextPhase(queryPhaseResults, reducedQueryPhase);
+    }
+
+    private void moveToNextPhase(
+        SearchPhaseResults<SearchPhaseResult> phaseResults,
+        SearchPhaseController.ReducedQueryPhase reducedQueryPhase
+    ) {
+        context.executeNextPhase(this, new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase));
+    }
+}

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

@@ -100,7 +100,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
             aggregatedDfs,
             mergedKnnResults,
             queryPhaseResultConsumer,
-            (queryResults) -> new FetchSearchPhase(queryResults, aggregatedDfs, context),
+            (queryResults) -> new RankFeaturePhase(queryResults, aggregatedDfs, context),
             context
         );
     }

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

@@ -122,7 +122,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
 
     @Override
     protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
-        return new FetchSearchPhase(results, null, this);
+        return new RankFeaturePhase(results, null, this);
     }
 
     private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {

+ 18 - 7
server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java

@@ -47,7 +47,7 @@ import static org.hamcrest.Matchers.nullValue;
 public class FetchSearchPhaseTests extends ESTestCase {
     private static final long FETCH_PROFILE_TIME = 555;
 
-    public void testShortcutQueryAndFetchOptimization() {
+    public void testShortcutQueryAndFetchOptimization() throws Exception {
         SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder());
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
         try (
@@ -99,11 +99,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
             } else {
                 numHits = 0;
             }
-
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -141,7 +142,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         }
     }
 
-    public void testFetchTwoDocument() {
+    public void testFetchTwoDocument() throws Exception {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder());
         try (
@@ -231,10 +232,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
                     }
                 }
             };
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -262,7 +265,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         }
     }
 
-    public void testFailFetchOneDoc() {
+    public void testFailFetchOneDoc() throws Exception {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder());
         try (
@@ -343,10 +346,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
                     }
                 }
             };
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -390,7 +395,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         }
     }
 
-    public void testFetchDocsConcurrently() throws InterruptedException {
+    public void testFetchDocsConcurrently() throws Exception {
         int resultSetSize = randomIntBetween(0, 100);
         // we use at least 2 hits otherwise this is subject to single shard optimization and we trip an assert...
         int numHits = randomIntBetween(2, 100); // also numshards --> 1 hit per shard
@@ -454,10 +459,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
                 }
             };
             CountDownLatch latch = new CountDownLatch(1);
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -509,7 +516,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         }
     }
 
-    public void testExceptionFailsPhase() {
+    public void testExceptionFailsPhase() throws Exception {
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder());
         try (
@@ -600,10 +607,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
                     }
                 }
             };
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {
@@ -624,7 +633,7 @@ public class FetchSearchPhaseTests extends ESTestCase {
         }
     }
 
-    public void testCleanupIrrelevantContexts() { // contexts that are not fetched should be cleaned up
+    public void testCleanupIrrelevantContexts() throws Exception { // contexts that are not fetched should be cleaned up
         MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
         SearchPhaseController controller = new SearchPhaseController((t, s) -> InternalAggregationTestCase.emptyReduceContextBuilder());
         try (
@@ -705,10 +714,12 @@ public class FetchSearchPhaseTests extends ESTestCase {
                     }
                 }
             };
+            SearchPhaseController.ReducedQueryPhase reducedQueryPhase = results.reduce();
             FetchSearchPhase phase = new FetchSearchPhase(
                 results,
                 null,
                 mockSearchPhaseContext,
+                reducedQueryPhase,
                 (searchResponse, scrollId) -> new SearchPhase("test") {
                     @Override
                     public void run() {