浏览代码

Support rewriteAndFetch on the shard in `DfsQueryPhase` (#97152)

The following phases already support `rewriteAndFetch` on the shard:  -
The `DfsPhase` for gathering term statistics  - QueryThenFetch phase

It seems like an oversight, and honestly prevents more complex rescorers
(like the inference_rescorer) from working in the DFSQuery phase.

This commit adds rewriteAndFetch for the `DfsQuery` phase to bring it
into parity with the other document shard query phases.
Benjamin Trent 2 年之前
父节点
当前提交
35cc35dda0

+ 30 - 26
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -713,33 +713,37 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
         final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
         final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
-        runAsync(getExecutor(readerContext.indexShard()), () -> {
-            readerContext.setAggregatedDfs(request.dfs());
-            try (
-                SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.QUERY, true);
-                SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)
-            ) {
-                searchContext.searcher().setAggregatedDfs(request.dfs());
-                QueryPhase.execute(searchContext);
-                if (searchContext.queryResult().hasSearchContext() == false && readerContext.singleSession()) {
-                    // no hits, we can release the context since there will be no fetch phase
-                    freeReaderContext(readerContext.id());
+        rewriteAndFetchShardRequest(readerContext.indexShard(), shardSearchRequest, listener.delegateFailure((l, rewritten) -> {
+            // fork the execution in the search thread pool
+            runAsync(getExecutor(readerContext.indexShard()), () -> {
+                readerContext.setAggregatedDfs(request.dfs());
+                try (
+                    SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.QUERY, true);
+                    SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)
+                ) {
+                    searchContext.searcher().setAggregatedDfs(request.dfs());
+                    QueryPhase.execute(searchContext);
+                    if (searchContext.queryResult().hasSearchContext() == false && readerContext.singleSession()) {
+                        // no hits, we can release the context since there will be no fetch phase
+                        freeReaderContext(readerContext.id());
+                    }
+                    executor.success();
+                    // Pass the rescoreDocIds to the queryResult to send them the coordinating node
+                    // and receive them back in the fetch phase.
+                    // We also pass the rescoreDocIds to the LegacyReaderContext in case the search state needs to stay in the data node.
+                    final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
+                    searchContext.queryResult().setRescoreDocIds(rescoreDocIds);
+                    readerContext.setRescoreDocIds(rescoreDocIds);
+                    searchContext.queryResult().incRef();
+                    return searchContext.queryResult();
+                } catch (Exception e) {
+                    assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
+                    logger.trace("Query phase failed", e);
+                    // we handle the failure in the failure listener below
+                    throw e;
                 }
-                executor.success();
-                // Pass the rescoreDocIds to the queryResult to send them the coordinating node and receive them back in the fetch phase.
-                // We also pass the rescoreDocIds to the LegacyReaderContext in case the search state needs to stay in the data node.
-                final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
-                searchContext.queryResult().setRescoreDocIds(rescoreDocIds);
-                readerContext.setRescoreDocIds(rescoreDocIds);
-                searchContext.queryResult().incRef();
-                return searchContext.queryResult();
-            } catch (Exception e) {
-                assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
-                logger.trace("Query phase failed", e);
-                // we handle the failure in the failure listener below
-                throw e;
-            }
-        }, wrapFailureListener(listener, readerContext, markAsUsed));
+            }, wrapFailureListener(l, readerContext, markAsUsed));
+        }));
     }
 
     private Executor getExecutor(IndexShard indexShard) {

+ 119 - 0
server/src/test/java/org/elasticsearch/search/SearchServiceTests.java

@@ -11,9 +11,12 @@ import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.FilterDirectoryReader;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.search.Query;
 import org.apache.lucene.store.AlreadyClosedException;
+import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchTimeoutException;
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.OriginalIndices;
 import org.elasticsearch.action.index.IndexResponse;
@@ -39,6 +42,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingState;
 import org.elasticsearch.cluster.routing.TestShardRouting;
 import org.elasticsearch.common.UUIDs;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.TimeValue;
@@ -48,6 +52,7 @@ import org.elasticsearch.index.IndexNotFoundException;
 import org.elasticsearch.index.IndexService;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.engine.Engine;
+import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.MatchAllQueryBuilder;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
@@ -77,6 +82,7 @@ import org.elasticsearch.search.aggregations.support.AggregationContext;
 import org.elasticsearch.search.aggregations.support.ValueType;
 import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.dfs.AggregatedDfs;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchRequest;
 import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
@@ -85,6 +91,7 @@ import org.elasticsearch.search.internal.ReaderContext;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.ShardSearchContextId;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.suggest.SuggestBuilder;
 import org.elasticsearch.tasks.TaskCancelHelper;
@@ -112,6 +119,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Supplier;
 
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
@@ -1861,6 +1869,47 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         }
     }
 
+    public void testDfsQueryPhaseRewrite() {
+        createIndex("index");
+        client().prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get();
+        final SearchService service = getInstanceFromNode(SearchService.class);
+        final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
+        final IndexService indexService = indicesService.indexServiceSafe(resolveIndex("index"));
+        final IndexShard indexShard = indexService.getShard(0);
+        SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
+        searchRequest.source(SearchSourceBuilder.searchSource().query(new TestRewriteCounterQueryBuilder()));
+        ShardSearchRequest request = new ShardSearchRequest(
+            OriginalIndices.NONE,
+            searchRequest,
+            indexShard.shardId(),
+            0,
+            1,
+            AliasFilter.EMPTY,
+            1.0f,
+            -1,
+            null
+        );
+        PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
+        final Engine.SearcherSupplier reader = indexShard.acquireSearcherSupplier();
+        ReaderContext context = service.createAndPutReaderContext(
+            request,
+            indexService,
+            indexShard,
+            reader,
+            SearchService.KEEPALIVE_INTERVAL_SETTING.get(Settings.EMPTY).millis()
+        );
+        service.executeQueryPhase(
+            new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)),
+            new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()),
+            plainActionFuture
+        );
+
+        plainActionFuture.actionGet();
+        assertThat(((TestRewriteCounterQueryBuilder) request.source().query()).asyncRewriteCount, equalTo(1));
+        final ShardSearchContextId contextId = context.id();
+        assertTrue(service.freeReaderContext(contextId));
+    }
+
     private ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) {
         return new ReaderContext(
             new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()),
@@ -1871,4 +1920,74 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
             false
         );
     }
+
+    private static class TestRewriteCounterQueryBuilder extends AbstractQueryBuilder<TestRewriteCounterQueryBuilder> {
+
+        final int asyncRewriteCount;
+        final Supplier<Boolean> fetched;
+
+        TestRewriteCounterQueryBuilder() {
+            asyncRewriteCount = 0;
+            fetched = null;
+        }
+
+        private TestRewriteCounterQueryBuilder(int asyncRewriteCount, Supplier<Boolean> fetched) {
+            this.asyncRewriteCount = asyncRewriteCount;
+            this.fetched = fetched;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return "test_query";
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersion.ZERO;
+        }
+
+        @Override
+        protected void doWriteTo(StreamOutput out) throws IOException {}
+
+        @Override
+        protected void doXContent(XContentBuilder builder, Params params) throws IOException {}
+
+        @Override
+        protected Query doToQuery(SearchExecutionContext context) throws IOException {
+            return new MatchAllDocsQuery();
+        }
+
+        @Override
+        protected boolean doEquals(TestRewriteCounterQueryBuilder other) {
+            return true;
+        }
+
+        @Override
+        protected int doHashCode() {
+            return 42;
+        }
+
+        @Override
+        protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
+            if (asyncRewriteCount > 0) {
+                return this;
+            }
+            if (fetched != null) {
+                if (fetched.get() == null) {
+                    return this;
+                }
+                assert fetched.get();
+                return new TestRewriteCounterQueryBuilder(1, null);
+            }
+            if (queryRewriteContext.convertToDataRewriteContext() != null) {
+                SetOnce<Boolean> awaitingFetch = new SetOnce<>();
+                queryRewriteContext.registerAsyncAction((c, l) -> {
+                    awaitingFetch.set(true);
+                    l.onResponse(null);
+                });
+                return new TestRewriteCounterQueryBuilder(0, awaitingFetch::get);
+            }
+            return this;
+        }
+    }
 }

+ 1 - 4
x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlRescorerIT.java

@@ -146,7 +146,6 @@ public class MlRescorerIT extends ESRestTestCase {
     }
 
     @SuppressWarnings("unchecked")
-    @AwaitsFix(bugUrl = "Fix DFS rewrite for rescorers")
     public void testLtrSimpleDFS() throws Exception {
         Response searchResponse = searchDfs("""
             {
@@ -222,8 +221,7 @@ public class MlRescorerIT extends ESRestTestCase {
         Map<String, Object> response = responseAsMap(searchResponse);
         assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
 
-        // TODO add DFS support for rescorer rewrites
-        /* searchResponse = searchCanMatch("""
+        searchResponse = searchCanMatch("""
             { "query": {
               "match": { "product": { "query": "TV"}}
             },
@@ -238,7 +236,6 @@ public class MlRescorerIT extends ESRestTestCase {
 
         response = responseAsMap(searchResponse);
         assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(17.0, 17.0));
-        */
     }
 
     private void indexData(String data) throws IOException {