Browse Source

Adding RankFeature implementation (#108538)

Panagiotis Bailis 1 year ago
parent
commit
4a1d7426d7
46 changed files with 4207 additions and 109 deletions
  1. 5 0
      docs/changelog/108538.yaml
  2. 811 0
      server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java
  3. 1 0
      server/src/main/java/module-info.java
  4. 1 1
      server/src/main/java/org/elasticsearch/TransportVersions.java
  5. 18 0
      server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java
  6. 23 45
      server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java
  7. 166 9
      server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java
  8. 34 0
      server/src/main/java/org/elasticsearch/action/search/SearchPhase.java
  9. 2 2
      server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java
  10. 32 0
      server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java
  11. 1 1
      server/src/main/java/org/elasticsearch/action/search/SearchRequest.java
  12. 1 0
      server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java
  13. 30 0
      server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java
  14. 1 0
      server/src/main/java/org/elasticsearch/node/NodeConstruction.java
  15. 3 0
      server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java
  16. 13 0
      server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java
  17. 5 0
      server/src/main/java/org/elasticsearch/search/SearchModule.java
  18. 16 0
      server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java
  19. 39 0
      server/src/main/java/org/elasticsearch/search/SearchService.java
  20. 0 1
      server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java
  21. 11 0
      server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java
  22. 5 0
      server/src/main/java/org/elasticsearch/search/internal/SearchContext.java
  23. 29 26
      server/src/main/java/org/elasticsearch/search/query/QueryPhase.java
  24. 22 1
      server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java
  25. 15 4
      server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java
  26. 96 0
      server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java
  27. 39 0
      server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java
  28. 54 0
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java
  29. 70 0
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java
  30. 99 0
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java
  31. 101 0
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java
  32. 68 0
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java
  33. 1170 0
      server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java
  34. 2 2
      server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java
  35. 739 14
      server/src/test/java/org/elasticsearch/search/SearchServiceTests.java
  36. 409 0
      server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java
  37. 2 0
      server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java
  38. 4 0
      test/framework/src/main/java/org/elasticsearch/node/MockNode.java
  39. 3 0
      test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java
  40. 18 1
      test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java
  41. 11 0
      test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java
  42. 4 0
      test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java
  43. 14 0
      x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java
  44. 18 1
      x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java
  45. 1 1
      x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java
  46. 1 0
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java

+ 5 - 0
docs/changelog/108538.yaml

@@ -0,0 +1,5 @@
+pr: 108538
+summary: Adding RankFeature search phase implementation
+area: Search
+type: feature
+issues: []

+ 811 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java

@@ -0,0 +1,811 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank;
+
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.search.SearchPhaseController;
+import org.elasticsearch.action.search.SearchPhaseExecutionException;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.plugins.SearchPlugin;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.query.QuerySearchResult;
+import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
+import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery;
+import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailuresAndResponse;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasId;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
+import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
+import static org.hamcrest.Matchers.equalTo;
+
+@ESIntegTestCase.ClusterScope(minNumDataNodes = 3)
+public class FieldBasedRerankerIT extends ESIntegTestCase {
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return List.of(FieldBasedRerankerPlugin.class);
+    }
+
+    public void testFieldBasedReranker() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        assertNoFailuresAndResponse(
+            prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField))
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10),
+            response -> {
+                assertHitCount(response, 5L);
+                int rank = 1;
+                for (SearchHit searchHit : response.getHits().getHits()) {
+                    assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
+                    assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f);
+                    assertThat(searchHit, hasRank(rank));
+                    assertNotNull(searchHit.getFields().get(searchField));
+                    rank++;
+                }
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testFieldBasedRerankerPagination() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        assertResponse(
+            prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField))
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(2)
+                .setFrom(2),
+            response -> {
+                assertHitCount(response, 5L);
+                int rank = 3;
+                for (SearchHit searchHit : response.getHits().getHits()) {
+                    assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
+                    assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f);
+                    assertThat(searchHit, hasRank(rank));
+                    assertNotNull(searchHit.getFields().get(searchField));
+                    rank++;
+                }
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testFieldBasedRerankerPaginationOutsideOfBounds() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        assertNoFailuresAndResponse(
+            prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField))
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(2)
+                .setFrom(10),
+            response -> {
+                assertHitCount(response, 5L);
+                assertEquals(0, response.getHits().getHits().length);
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testNotAllShardsArePresentInFetchPhase() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build());
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A").setRouting("A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B").setRouting("B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C").setRouting("C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D").setRouting("C"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E").setRouting("C")
+        );
+
+        assertNoFailuresAndResponse(
+            prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(0.1f))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(0.3f))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(0.3f))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(0.3f))
+            )
+                .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField))
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(2),
+            response -> {
+                assertHitCount(response, 4L);
+                assertEquals(2, response.getHits().getHits().length);
+                int rank = 1;
+                for (SearchHit searchHit : response.getHits().getHits()) {
+                    assertThat(searchHit, hasId(String.valueOf(5 - (rank - 1))));
+                    assertEquals(searchHit.getScore(), (0.5f - ((rank - 1) * 0.1f)), 1e-5f);
+                    assertThat(searchHit, hasRank(rank));
+                    assertNotNull(searchHit.getFields().get(searchField));
+                    rank++;
+                }
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testFieldBasedRerankerNoMatchingDocs() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        assertNoFailuresAndResponse(
+            prepareSearch().setQuery(boolQuery().should(constantScoreQuery(matchQuery(searchField, "F")).boost(randomFloat())))
+                .setRankBuilder(new FieldBasedRankBuilder(rankWindowSize, rankFeatureField))
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10),
+            response -> {
+                assertHitCount(response, 0L);
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testQueryPhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        // this test is irrespective of the number of shards, as we will always reach QueryPhaseRankShardContext#combineQueryPhaseResults
+        // even with no results. So, when we get back to the coordinator, all shards will have failed, and the whole response
+        // will be marked as a failure
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(
+                    new ThrowingRankBuilder(
+                        rankWindowSize,
+                        rankFeatureField,
+                        ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT.name()
+                    )
+                )
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10)
+                .get()
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testQueryPhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        // when we throw on the coordinator, the onPhaseFailure handler will be invoked, which in turn will mark the whole
+        // search request as a failure (i.e. no partial results)
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(
+                    new ThrowingRankBuilder(
+                        rankWindowSize,
+                        rankFeatureField,
+                        ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT.name()
+                    )
+                )
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10)
+                .get()
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedPartialFailures() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 10).build());
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        // we have 10 shards and 5 documents, so when the exception is thrown we know that not all shards will report failures
+        assertResponse(
+            prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(
+                    new ThrowingRankBuilder(
+                        rankWindowSize,
+                        rankFeatureField,
+                        ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name()
+                    )
+                )
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10),
+            response -> {
+                assertTrue(response.getFailedShards() > 0);
+                assertTrue(
+                    Arrays.stream(response.getShardFailures())
+                        .allMatch(failure -> failure.getCause().getMessage().contains("rfs - simulated failure"))
+                );
+                assertHitCount(response, 5);
+                assertTrue(response.getHits().getHits().length == 0);
+            }
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testRankFeaturePhaseShardThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        // we have 1 shard and 5 documents, so when the exception is thrown we know that all shards will have failed
+        createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1).build());
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(
+                    new ThrowingRankBuilder(
+                        rankWindowSize,
+                        rankFeatureField,
+                        ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT.name()
+                    )
+                )
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10)
+                .get()
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    public void testRankFeaturePhaseCoordinatorThrowingRankBuilderAllContextsAreClosedAllShardsFail() throws Exception {
+        final String indexName = "test_index";
+        final String rankFeatureField = "rankFeatureField";
+        final String searchField = "searchField";
+        final int rankWindowSize = 10;
+
+        createIndex(indexName);
+        indexRandom(
+            true,
+            prepareIndex(indexName).setId("1").setSource(rankFeatureField, 0.1, searchField, "A"),
+            prepareIndex(indexName).setId("2").setSource(rankFeatureField, 0.2, searchField, "B"),
+            prepareIndex(indexName).setId("3").setSource(rankFeatureField, 0.3, searchField, "C"),
+            prepareIndex(indexName).setId("4").setSource(rankFeatureField, 0.4, searchField, "D"),
+            prepareIndex(indexName).setId("5").setSource(rankFeatureField, 0.5, searchField, "E")
+        );
+
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> prepareSearch().setQuery(
+                boolQuery().should(constantScoreQuery(matchQuery(searchField, "A")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "B")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "C")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "D")).boost(randomFloat()))
+                    .should(constantScoreQuery(matchQuery(searchField, "E")).boost(randomFloat()))
+            )
+                .setRankBuilder(
+                    new ThrowingRankBuilder(
+                        rankWindowSize,
+                        rankFeatureField,
+                        ThrowingRankBuilder.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
+                    )
+                )
+                .addFetchField(searchField)
+                .setTrackTotalHits(true)
+                .setAllowPartialSearchResults(true)
+                .setSize(10)
+                .get()
+        );
+        assertNoOpenContext(indexName);
+    }
+
+    private void assertNoOpenContext(final String indexName) throws Exception {
+        assertBusy(
+            () -> assertThat(indicesAdmin().prepareStats(indexName).get().getTotal().getSearch().getOpenContexts(), equalTo(0L)),
+            1,
+            TimeUnit.SECONDS
+        );
+    }
+
+    public static class FieldBasedRankBuilder extends RankBuilder {
+
+        public static final ParseField FIELD_FIELD = new ParseField("field");
+        static final ConstructingObjectParser<FieldBasedRankBuilder, Void> PARSER = new ConstructingObjectParser<>(
+            "field-based-rank",
+            args -> {
+                int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0];
+                String field = (String) args[1];
+                if (field == null || field.isEmpty()) {
+                    throw new IllegalArgumentException("Field cannot be null or empty");
+                }
+                return new FieldBasedRankBuilder(rankWindowSize, field);
+            }
+        );
+
+        static {
+            PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
+            PARSER.declareString(constructorArg(), FIELD_FIELD);
+        }
+
+        protected final String field;
+
+        public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException {
+            return PARSER.parse(parser, null);
+        }
+
+        public FieldBasedRankBuilder(final int rankWindowSize, final String field) {
+            super(rankWindowSize);
+            this.field = field;
+        }
+
+        public FieldBasedRankBuilder(StreamInput in) throws IOException {
+            super(in);
+            this.field = in.readString();
+        }
+
+        @Override
+        protected void doWriteTo(StreamOutput out) throws IOException {
+            out.writeString(field);
+        }
+
+        @Override
+        protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.field(FIELD_FIELD.getPreferredName(), field);
+        }
+
+        @Override
+        public boolean isCompoundBuilder() {
+            return false;
+        }
+
+        @Override
+        public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+            return new QueryPhaseRankShardContext(queries, rankWindowSize()) {
+                @Override
+                public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                    Map<Integer, RankFeatureDoc> rankDocs = new HashMap<>();
+                    rankResults.forEach(topDocs -> {
+                        for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+                            rankDocs.compute(scoreDoc.doc, (key, value) -> {
+                                if (value == null) {
+                                    return new RankFeatureDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
+                                } else {
+                                    value.score = Math.max(scoreDoc.score, rankDocs.get(scoreDoc.doc).score);
+                                    return value;
+                                }
+                            });
+                        }
+                    });
+                    RankFeatureDoc[] sortedResults = rankDocs.values().toArray(RankFeatureDoc[]::new);
+                    Arrays.sort(sortedResults, (o1, o2) -> Float.compare(o2.score, o1.score));
+                    return new RankFeatureShardResult(sortedResults);
+                }
+            };
+        }
+
+        @Override
+        public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+            return new QueryPhaseRankCoordinatorContext(rankWindowSize()) {
+                @Override
+                public ScoreDoc[] rankQueryPhaseResults(
+                    List<QuerySearchResult> querySearchResults,
+                    SearchPhaseController.TopDocsStats topDocStats
+                ) {
+                    List<RankFeatureDoc> rankDocs = new ArrayList<>();
+                    for (int i = 0; i < querySearchResults.size(); i++) {
+                        QuerySearchResult querySearchResult = querySearchResults.get(i);
+                        RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult();
+                        for (RankFeatureDoc frd : shardResult.rankFeatureDocs) {
+                            frd.shardIndex = i;
+                            rankDocs.add(frd);
+                        }
+                    }
+                    // no support for sort field atm
+                    // should pass needed info to make use of org.elasticsearch.action.search.SearchPhaseController.sortDocs?
+                    rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
+                    RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new);
+
+                    assert topDocStats.fetchHits == 0;
+                    topDocStats.fetchHits = topResults.length;
+
+                    return topResults;
+                }
+            };
+        }
+
+        @Override
+        public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+            return new RankFeaturePhaseRankShardContext(field) {
+                @Override
+                public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                    try {
+                        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                        for (int i = 0; i < hits.getHits().length; i++) {
+                            rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
+                            rankFeatureDocs[i].featureData(hits.getHits()[i].field(field).getValue().toString());
+                        }
+                        return new RankFeatureShardResult(rankFeatureDocs);
+                    } catch (Exception ex) {
+                        throw ex;
+                    }
+                }
+            };
+        }
+
+        @Override
+        public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+            return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
+                @Override
+                protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                    float[] scores = new float[featureDocs.length];
+                    for (int i = 0; i < featureDocs.length; i++) {
+                        scores[i] = Float.parseFloat(featureDocs[i].featureData);
+                    }
+                    scoreListener.onResponse(scores);
+                }
+            };
+        }
+
+        @Override
+        protected boolean doEquals(RankBuilder other) {
+            return other instanceof FieldBasedRankBuilder && Objects.equals(field, ((FieldBasedRankBuilder) other).field);
+        }
+
+        @Override
+        protected int doHashCode() {
+            return Objects.hash(field);
+        }
+
+        @Override
+        public String getWriteableName() {
+            return "field-based-rank";
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersions.RANK_FEATURE_PHASE_ADDED;
+        }
+    }
+
+    public static class ThrowingRankBuilder extends FieldBasedRankBuilder {
+
+        public enum ThrowingRankBuilderType {
+            THROWING_QUERY_PHASE_SHARD_CONTEXT,
+            THROWING_QUERY_PHASE_COORDINATOR_CONTEXT,
+            THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT,
+            THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT;
+        }
+
+        protected final ThrowingRankBuilderType throwingRankBuilderType;
+
+        public static final ParseField FIELD_FIELD = new ParseField("field");
+        public static final ParseField THROWING_TYPE_FIELD = new ParseField("throwing-type");
+        static final ConstructingObjectParser<ThrowingRankBuilder, Void> PARSER = new ConstructingObjectParser<>("throwing-rank", args -> {
+            int rankWindowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0];
+            String field = (String) args[1];
+            if (field == null || field.isEmpty()) {
+                throw new IllegalArgumentException("Field cannot be null or empty");
+            }
+            String throwingType = (String) args[2];
+            return new ThrowingRankBuilder(rankWindowSize, field, throwingType);
+        });
+
+        static {
+            PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
+            PARSER.declareString(constructorArg(), FIELD_FIELD);
+            PARSER.declareString(constructorArg(), THROWING_TYPE_FIELD);
+        }
+
+        public static FieldBasedRankBuilder fromXContent(XContentParser parser) throws IOException {
+            return PARSER.parse(parser, null);
+        }
+
+        public ThrowingRankBuilder(final int rankWindowSize, final String field, final String throwingType) {
+            super(rankWindowSize, field);
+            this.throwingRankBuilderType = ThrowingRankBuilderType.valueOf(throwingType);
+        }
+
+        public ThrowingRankBuilder(StreamInput in) throws IOException {
+            super(in);
+            this.throwingRankBuilderType = in.readEnum(ThrowingRankBuilderType.class);
+        }
+
+        @Override
+        protected void doWriteTo(StreamOutput out) throws IOException {
+            super.doWriteTo(out);
+            out.writeEnum(throwingRankBuilderType);
+        }
+
+        @Override
+        protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+            super.doXContent(builder, params);
+            builder.field(THROWING_TYPE_FIELD.getPreferredName(), throwingRankBuilderType);
+        }
+
+        @Override
+        public String getWriteableName() {
+            return "throwing-rank";
+        }
+
+        @Override
+        public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+            if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_SHARD_CONTEXT)
+                return new QueryPhaseRankShardContext(queries, rankWindowSize()) {
+                    @Override
+                    public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                        throw new UnsupportedOperationException("qps - simulated failure");
+                    }
+                };
+            else {
+                return super.buildQueryPhaseShardContext(queries, from);
+            }
+        }
+
+        @Override
+        public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+            if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_QUERY_PHASE_COORDINATOR_CONTEXT)
+                return new QueryPhaseRankCoordinatorContext(rankWindowSize()) {
+                    @Override
+                    public ScoreDoc[] rankQueryPhaseResults(
+                        List<QuerySearchResult> querySearchResults,
+                        SearchPhaseController.TopDocsStats topDocStats
+                    ) {
+                        throw new UnsupportedOperationException("qpc - simulated failure");
+                    }
+                };
+            else {
+                return super.buildQueryPhaseCoordinatorContext(size, from);
+            }
+        }
+
+        @Override
+        public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+            if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_SHARD_CONTEXT)
+                return new RankFeaturePhaseRankShardContext(field) {
+                    @Override
+                    public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                        throw new UnsupportedOperationException("rfs - simulated failure");
+                    }
+                };
+            else {
+                return super.buildRankFeaturePhaseShardContext();
+            }
+        }
+
+        @Override
+        public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+            if (this.throwingRankBuilderType == ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT)
+                return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize()) {
+                    @Override
+                    protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                        throw new UnsupportedOperationException("rfc - simulated failure");
+                    }
+                };
+            else {
+                return super.buildRankFeaturePhaseCoordinatorContext(size, from);
+            }
+        }
+    }
+
+    public static class FieldBasedRerankerPlugin extends Plugin implements SearchPlugin {
+
+        private static final String FIELD_BASED_RANK_BUILDER_NAME = "field-based-rank";
+        private static final String THROWING_RANK_BUILDER_NAME = "throwing-rank";
+
+        @Override
+        public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
+            return List.of(
+                new NamedWriteableRegistry.Entry(RankBuilder.class, FIELD_BASED_RANK_BUILDER_NAME, FieldBasedRankBuilder::new),
+                new NamedWriteableRegistry.Entry(RankBuilder.class, THROWING_RANK_BUILDER_NAME, ThrowingRankBuilder::new),
+                new NamedWriteableRegistry.Entry(RankShardResult.class, "rank_feature_shard", RankFeatureShardResult::new)
+            );
+        }
+
+        @Override
+        public List<NamedXContentRegistry.Entry> getNamedXContent() {
+            return List.of(
+                new NamedXContentRegistry.Entry(
+                    RankBuilder.class,
+                    new ParseField(FIELD_BASED_RANK_BUILDER_NAME),
+                    FieldBasedRankBuilder::fromXContent
+                ),
+                new NamedXContentRegistry.Entry(
+                    RankBuilder.class,
+                    new ParseField(THROWING_RANK_BUILDER_NAME),
+                    ThrowingRankBuilder::fromXContent
+                )
+            );
+        }
+    }
+}

+ 1 - 0
server/src/main/java/module-info.java

@@ -362,6 +362,7 @@ module org.elasticsearch.server {
     exports org.elasticsearch.search.query;
     exports org.elasticsearch.search.rank;
     exports org.elasticsearch.search.rank.context;
+    exports org.elasticsearch.search.rank.feature;
     exports org.elasticsearch.search.rescore;
     exports org.elasticsearch.search.retriever;
     exports org.elasticsearch.search.runtime;

+ 1 - 1
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -184,7 +184,7 @@ public class TransportVersions {
     public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0);
     public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0);
     public static final TransportVersion ML_CHUNK_INFERENCE_OPTION = def(8_677_00_0);
-
+    public static final TransportVersion RANK_FEATURE_PHASE_ADDED = def(8_678_00_0);
     /*
      * STOP! READ THIS FIRST! No, really,
      *        ____ _____ ___  ____  _        ____  _____    _    ____    _____ _   _ ___ ____    _____ ___ ____  ____ _____ _

+ 18 - 0
server/src/main/java/org/elasticsearch/action/search/CCSSingleCoordinatorSearchProgressListener.java

@@ -260,6 +260,24 @@ public class CCSSingleCoordinatorSearchProgressListener extends SearchProgressLi
         }
     }
 
+    /**
+     * Executed when a shard returns a rank feature result.
+     *
+     * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
+     */
+    @Override
+    public void onRankFeatureResult(int shardIndex) {}
+
+    /**
+     * Executed when a shard reports a rank feature failure.
+     *
+     * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
+     * @param shardTarget The last shard target that thrown an exception.
+     * @param exc The cause of the failure.
+     */
+    @Override
+    public void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
+
     /**
      * Executed when a shard returns a fetch result.
      *

+ 23 - 45
server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

@@ -17,8 +17,6 @@ import org.elasticsearch.search.dfs.AggregatedDfs;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
 import org.elasticsearch.search.internal.ShardSearchContextId;
-import org.elasticsearch.search.query.QuerySearchResult;
-import org.elasticsearch.transport.Transport;
 
 import java.util.List;
 import java.util.function.BiFunction;
@@ -29,7 +27,7 @@ import java.util.function.BiFunction;
  */
 final class FetchSearchPhase extends SearchPhase {
     private final ArraySearchPhaseResults<FetchSearchResult> fetchResults;
-    private final AtomicArray<SearchPhaseResult> queryResults;
+    private final AtomicArray<SearchPhaseResult> searchPhaseShardResults;
     private final BiFunction<SearchResponseSections, AtomicArray<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
     private final SearchPhaseContext context;
     private final Logger logger;
@@ -74,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase {
         }
         this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
         context.addReleasable(fetchResults);
-        this.queryResults = resultConsumer.getAtomicArray();
+        this.searchPhaseShardResults = resultConsumer.getAtomicArray();
         this.aggregatedDfs = aggregatedDfs;
         this.nextPhaseFactory = nextPhaseFactory;
         this.context = context;
@@ -103,19 +101,20 @@ final class FetchSearchPhase extends SearchPhase {
         final int numShards = context.getNumShards();
         // Usually when there is a single shard, we force the search type QUERY_THEN_FETCH. But when there's kNN, we might
         // still use DFS_QUERY_THEN_FETCH, which does not perform the "query and fetch" optimization during the query phase.
-        final boolean queryAndFetchOptimization = queryResults.length() == 1
+        final boolean queryAndFetchOptimization = searchPhaseShardResults.length() == 1
             && context.getRequest().hasKnnSearch() == false
-            && reducedQueryPhase.rankCoordinatorContext() == null;
+            && reducedQueryPhase.queryPhaseRankCoordinatorContext() == null;
         if (queryAndFetchOptimization) {
             assert assertConsistentWithQueryAndFetchOptimization();
             // query AND fetch optimization
-            moveToNextPhase(queryResults);
+            moveToNextPhase(searchPhaseShardResults);
         } else {
             ScoreDoc[] scoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
             // no docs to fetch -- sidestep everything and return
             if (scoreDocs.length == 0) {
                 // we have to release contexts here to free up resources
-                queryResults.asList().stream().map(SearchPhaseResult::queryResult).forEach(this::releaseIrrelevantSearchContext);
+                searchPhaseShardResults.asList()
+                    .forEach(searchPhaseShardResult -> releaseIrrelevantSearchContext(searchPhaseShardResult, context));
                 moveToNextPhase(fetchResults.getAtomicArray());
             } else {
                 final ScoreDoc[] lastEmittedDocPerShard = context.getRequest().scroll() != null
@@ -130,19 +129,19 @@ final class FetchSearchPhase extends SearchPhase {
                 );
                 for (int i = 0; i < docIdsToLoad.length; i++) {
                     List<Integer> entry = docIdsToLoad[i];
-                    SearchPhaseResult queryResult = queryResults.get(i);
+                    SearchPhaseResult shardPhaseResult = searchPhaseShardResults.get(i);
                     if (entry == null) { // no results for this shard ID
-                        if (queryResult != null) {
+                        if (shardPhaseResult != null) {
                             // if we got some hits from this shard we have to release the context there
                             // we do this as we go since it will free up resources and passing on the request on the
                             // transport layer is cheap.
-                            releaseIrrelevantSearchContext(queryResult.queryResult());
+                            releaseIrrelevantSearchContext(shardPhaseResult, context);
                             progressListener.notifyFetchResult(i);
                         }
                         // in any case we count down this result since we don't talk to this shard anymore
                         counter.countDown();
                     } else {
-                        executeFetch(queryResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null);
+                        executeFetch(shardPhaseResult, counter, entry, (lastEmittedDocPerShard != null) ? lastEmittedDocPerShard[i] : null);
                     }
                 }
             }
@@ -150,31 +149,33 @@ final class FetchSearchPhase extends SearchPhase {
     }
 
     private boolean assertConsistentWithQueryAndFetchOptimization() {
-        var phaseResults = queryResults.asList();
+        var phaseResults = searchPhaseShardResults.asList();
         assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null
             : "phaseResults empty [" + phaseResults.isEmpty() + "], single result: " + phaseResults.get(0).fetchResult();
         return true;
     }
 
     private void executeFetch(
-        SearchPhaseResult queryResult,
+        SearchPhaseResult shardPhaseResult,
         final CountedCollector<FetchSearchResult> counter,
         final List<Integer> entry,
         ScoreDoc lastEmittedDocForShard
     ) {
-        final SearchShardTarget shardTarget = queryResult.getSearchShardTarget();
-        final int shardIndex = queryResult.getShardIndex();
-        final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
+        final SearchShardTarget shardTarget = shardPhaseResult.getSearchShardTarget();
+        final int shardIndex = shardPhaseResult.getShardIndex();
+        final ShardSearchContextId contextId = shardPhaseResult.queryResult() != null
+            ? shardPhaseResult.queryResult().getContextId()
+            : shardPhaseResult.rankFeatureResult().getContextId();
         context.getSearchTransport()
             .sendExecuteFetch(
                 context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
                 new ShardFetchSearchRequest(
-                    context.getOriginalIndices(queryResult.getShardIndex()),
+                    context.getOriginalIndices(shardPhaseResult.getShardIndex()),
                     contextId,
-                    queryResult.getShardSearchRequest(),
+                    shardPhaseResult.getShardSearchRequest(),
                     entry,
                     lastEmittedDocForShard,
-                    queryResult.getRescoreDocIds(),
+                    shardPhaseResult.getRescoreDocIds(),
                     aggregatedDfs
                 ),
                 context.getTask(),
@@ -199,40 +200,17 @@ final class FetchSearchPhase extends SearchPhase {
                             // the search context might not be cleared on the node where the fetch was executed for example
                             // because the action was rejected by the thread pool. in this case we need to send a dedicated
                             // request to clear the search context.
-                            releaseIrrelevantSearchContext(queryResult.queryResult());
+                            releaseIrrelevantSearchContext(shardPhaseResult, context);
                         }
                     }
                 }
             );
     }
 
-    /**
-     * Releases shard targets that are not used in the docsIdsToLoad.
-     */
-    private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
-        // we only release search context that we did not fetch from, if we are not scrolling
-        // or using a PIT and if it has at least one hit that didn't make it to the global topDocs
-        if (queryResult.hasSearchContext()
-            && context.getRequest().scroll() == null
-            && (context.isPartOfPointInTime(queryResult.getContextId()) == false)) {
-            try {
-                SearchShardTarget shardTarget = queryResult.getSearchShardTarget();
-                Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
-                context.sendReleaseSearchContext(
-                    queryResult.getContextId(),
-                    connection,
-                    context.getOriginalIndices(queryResult.getShardIndex())
-                );
-            } catch (Exception e) {
-                logger.trace("failed to release context", e);
-            }
-        }
-    }
-
     private void moveToNextPhase(AtomicArray<? extends SearchPhaseResult> fetchResultsArr) {
         var resp = SearchPhaseController.merge(context.getRequest().scroll() != null, reducedQueryPhase, fetchResultsArr);
         context.addReleasable(resp::decRef);
         fetchResults.close();
-        context.executeNextPhase(this, nextPhaseFactory.apply(resp, queryResults));
+        context.executeNextPhase(this, nextPhaseFactory.apply(resp, searchPhaseShardResults));
     }
 }

+ 166 - 9
server/src/main/java/org/elasticsearch/action/search/RankFeaturePhase.java

@@ -7,23 +7,39 @@
  */
 package org.elasticsearch.action.search;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.lucene.search.ScoreDoc;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ThreadedActionListener;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.dfs.AggregatedDfs;
+import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
+
+import java.util.List;
 
 /**
  * This search phase is responsible for executing any re-ranking needed for the given search request, iff that is applicable.
- * It starts by retrieving {code num_shards * window_size} results from the query phase and reduces them to a global list of
+ * It starts by retrieving {@code num_shards * window_size} results from the query phase and reduces them to a global list of
  * the top {@code window_size} results. It then reaches out to the shards to extract the needed feature data,
  * and finally passes all this information to the appropriate {@code RankFeatureRankCoordinatorContext} which is responsible for reranking
  * the results. If no rank query is specified, it proceeds directly to the next phase (FetchSearchPhase) by first reducing the results.
  */
-public final class RankFeaturePhase extends SearchPhase {
+public class RankFeaturePhase extends SearchPhase {
 
+    private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class);
     private final SearchPhaseContext context;
-    private final SearchPhaseResults<SearchPhaseResult> queryPhaseResults;
-
+    final SearchPhaseResults<SearchPhaseResult> queryPhaseResults;
+    final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
     private final AggregatedDfs aggregatedDfs;
+    private final SearchProgressListener progressListener;
 
     RankFeaturePhase(SearchPhaseResults<SearchPhaseResult> queryPhaseResults, AggregatedDfs aggregatedDfs, SearchPhaseContext context) {
         super("rank-feature");
@@ -38,6 +54,9 @@ public final class RankFeaturePhase extends SearchPhase {
         this.context = context;
         this.queryPhaseResults = queryPhaseResults;
         this.aggregatedDfs = aggregatedDfs;
+        this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards());
+        context.addReleasable(rankPhaseResults);
+        this.progressListener = context.getTask().getProgressListener();
     }
 
     @Override
@@ -59,16 +78,154 @@ public final class RankFeaturePhase extends SearchPhase {
         });
     }
 
-    private void innerRun() throws Exception {
-        // other than running reduce, this is currently close to a no-op
+    void innerRun() throws Exception {
+        // if the RankBuilder specifies a QueryPhaseCoordinatorContext, it will be called as part of the reduce call
+        // to operate on the first `window_size * num_shards` results and merge them appropriately.
         SearchPhaseController.ReducedQueryPhase reducedQueryPhase = queryPhaseResults.reduce();
-        moveToNextPhase(queryPhaseResults, reducedQueryPhase);
+        RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
+        if (rankFeaturePhaseRankCoordinatorContext != null) {
+            ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs(); // rank_window_size
+            final List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(context.getNumShards(), queryScoreDocs);
+            final CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<>(
+                rankPhaseResults,
+                context.getNumShards(),
+                () -> onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase),
+                context
+            );
+
+            // we send out a request to each shard in order to fetch the needed feature info
+            for (int i = 0; i < docIdsToLoad.length; i++) {
+                List<Integer> entry = docIdsToLoad[i];
+                SearchPhaseResult queryResult = queryPhaseResults.getAtomicArray().get(i);
+                if (entry == null || entry.isEmpty()) {
+                    if (queryResult != null) {
+                        releaseIrrelevantSearchContext(queryResult, context);
+                        progressListener.notifyRankFeatureResult(i);
+                    }
+                    rankRequestCounter.countDown();
+                } else {
+                    executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
+                }
+            }
+        } else {
+            moveToNextPhase(queryPhaseResults, reducedQueryPhase);
+        }
+    }
+
+    private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
+        return source == null || source.rankBuilder() == null
+            ? null
+            : context.getRequest()
+                .source()
+                .rankBuilder()
+                .buildRankFeaturePhaseCoordinatorContext(context.getRequest().source().size(), context.getRequest().source().from());
     }
 
-    private void moveToNextPhase(
-        SearchPhaseResults<SearchPhaseResult> phaseResults,
+    private void executeRankFeatureShardPhase(
+        SearchPhaseResult queryResult,
+        final CountedCollector<SearchPhaseResult> rankRequestCounter,
+        final List<Integer> entry
+    ) {
+        final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget();
+        final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
+        final int shardIndex = queryResult.getShardIndex();
+        context.getSearchTransport()
+            .sendExecuteRankFeature(
+                context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId()),
+                new RankFeatureShardRequest(
+                    context.getOriginalIndices(queryResult.getShardIndex()),
+                    queryResult.getContextId(),
+                    queryResult.getShardSearchRequest(),
+                    entry
+                ),
+                context.getTask(),
+                new SearchActionListener<>(shardTarget, shardIndex) {
+                    @Override
+                    protected void innerOnResponse(RankFeatureResult response) {
+                        try {
+                            progressListener.notifyRankFeatureResult(shardIndex);
+                            rankRequestCounter.onResult(response);
+                        } catch (Exception e) {
+                            context.onPhaseFailure(RankFeaturePhase.this, "", e);
+                        }
+                    }
+
+                    @Override
+                    public void onFailure(Exception e) {
+                        try {
+                            logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", e);
+                            progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
+                            rankRequestCounter.onFailure(shardIndex, shardTarget, e);
+                        } finally {
+                            releaseIrrelevantSearchContext(queryResult, context);
+                        }
+                    }
+                }
+            );
+    }
+
+    private void onPhaseDone(
+        RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
         SearchPhaseController.ReducedQueryPhase reducedQueryPhase
     ) {
+        assert rankFeaturePhaseRankCoordinatorContext != null;
+        ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() {
+            @Override
+            public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
+                RankFeatureDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores);
+                SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = newReducedQueryPhaseResults(
+                    reducedQueryPhase,
+                    topResults
+                );
+                moveToNextPhase(rankPhaseResults, reducedRankFeaturePhase);
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                context.onPhaseFailure(RankFeaturePhase.this, "Computing updated ranks for results failed", e);
+            }
+        });
+        rankFeaturePhaseRankCoordinatorContext.rankGlobalResults(
+            rankPhaseResults.getAtomicArray().asList().stream().map(SearchPhaseResult::rankFeatureResult).toList(),
+            rankResultListener
+        );
+    }
+
+    private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(
+        SearchPhaseController.ReducedQueryPhase reducedQueryPhase,
+        ScoreDoc[] scoreDocs
+    ) {
+
+        return new SearchPhaseController.ReducedQueryPhase(
+            reducedQueryPhase.totalHits(),
+            reducedQueryPhase.fetchHits(),
+            maxScore(scoreDocs),
+            reducedQueryPhase.timedOut(),
+            reducedQueryPhase.terminatedEarly(),
+            reducedQueryPhase.suggest(),
+            reducedQueryPhase.aggregations(),
+            reducedQueryPhase.profileBuilder(),
+            new SearchPhaseController.SortedTopDocs(scoreDocs, false, null, null, null, 0),
+            reducedQueryPhase.sortValueFormats(),
+            reducedQueryPhase.queryPhaseRankCoordinatorContext(),
+            reducedQueryPhase.numReducePhases(),
+            reducedQueryPhase.size(),
+            reducedQueryPhase.from(),
+            reducedQueryPhase.isEmptyResult()
+        );
+    }
+
+    private float maxScore(ScoreDoc[] scoreDocs) {
+        float maxScore = Float.NaN;
+        for (ScoreDoc scoreDoc : scoreDocs) {
+            if (Float.isNaN(maxScore) || scoreDoc.score > maxScore) {
+                maxScore = scoreDoc.score;
+            }
+        }
+        return maxScore;
+    }
+
+    void moveToNextPhase(SearchPhaseResults<SearchPhaseResult> phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) {
         context.executeNextPhase(this, new FetchSearchPhase(phaseResults, aggregatedDfs, context, reducedQueryPhase));
     }
 }

+ 34 - 0
server/src/main/java/org/elasticsearch/action/search/SearchPhase.java

@@ -9,6 +9,9 @@ package org.elasticsearch.action.search;
 
 import org.elasticsearch.cluster.routing.GroupShardsIterator;
 import org.elasticsearch.core.CheckedRunnable;
+import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.transport.Transport;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
@@ -62,4 +65,35 @@ abstract class SearchPhase implements CheckedRunnable<IOException> {
             }
         }
     }
+
+    /**
+     * Releases shard targets that are not used in the docsIdsToLoad.
+     */
+    protected void releaseIrrelevantSearchContext(SearchPhaseResult searchPhaseResult, SearchPhaseContext context) {
+        // we only release search context that we did not fetch from, if we are not scrolling
+        // or using a PIT and if it has at least one hit that didn't make it to the global topDocs
+        if (searchPhaseResult == null) {
+            return;
+        }
+        // phaseResult.getContextId() is the same for query & rank feature results
+        SearchPhaseResult phaseResult = searchPhaseResult.queryResult() != null
+            ? searchPhaseResult.queryResult()
+            : searchPhaseResult.rankFeatureResult();
+        if (phaseResult != null
+            && phaseResult.hasSearchContext()
+            && context.getRequest().scroll() == null
+            && (context.isPartOfPointInTime(phaseResult.getContextId()) == false)) {
+            try {
+                SearchShardTarget shardTarget = phaseResult.getSearchShardTarget();
+                Transport.Connection connection = context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
+                context.sendReleaseSearchContext(
+                    phaseResult.getContextId(),
+                    connection,
+                    context.getOriginalIndices(phaseResult.getShardIndex())
+                );
+            } catch (Exception e) {
+                context.getLogger().trace("failed to release context", e);
+            }
+        }
+    }
 }

+ 2 - 2
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -456,7 +456,7 @@ public final class SearchPhaseController {
                     : "not enough hits fetched. index [" + index + "] length: " + fetchResult.hits().getHits().length;
                 SearchHit searchHit = fetchResult.hits().getHits()[index];
                 searchHit.shard(fetchResult.getSearchShardTarget());
-                if (reducedQueryPhase.rankCoordinatorContext != null) {
+                if (reducedQueryPhase.queryPhaseRankCoordinatorContext != null) {
                     assert shardDoc instanceof RankDoc;
                     searchHit.setRank(((RankDoc) shardDoc).rank);
                     searchHit.score(shardDoc.score);
@@ -747,7 +747,7 @@ public final class SearchPhaseController {
         // sort value formats used to sort / format the result
         DocValueFormat[] sortValueFormats,
         // the rank context if ranking is used
-        QueryPhaseRankCoordinatorContext rankCoordinatorContext,
+        QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext,
         // the number of reduces phases
         int numReducePhases,
         // the size of the top hits to return

+ 32 - 0
server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java

@@ -88,6 +88,22 @@ public abstract class SearchProgressListener {
      */
     protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {}
 
+    /**
+     * Executed when a shard returns a rank feature result.
+     *
+     * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
+     */
+    protected void onRankFeatureResult(int shardIndex) {}
+
+    /**
+     * Executed when a shard reports a rank feature failure.
+     *
+     * @param shardIndex The index of the shard in the list provided by {@link SearchProgressListener#onListShards})}.
+     * @param shardTarget The last shard target that thrown an exception.
+     * @param exc The cause of the failure.
+     */
+    protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {}
+
     /**
      * Executed when a shard returns a fetch result.
      *
@@ -160,6 +176,22 @@ public abstract class SearchProgressListener {
         }
     }
 
+    final void notifyRankFeatureResult(int shardIndex) {
+        try {
+            onRankFeatureResult(shardIndex);
+        } catch (Exception e) {
+            logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature result", e);
+        }
+    }
+
+    final void notifyRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
+        try {
+            onRankFeatureFailure(shardIndex, shardTarget, exc);
+        } catch (Exception e) {
+            logger.warn(() -> "[" + shards.get(shardIndex) + "] Failed to execute progress listener on rank-feature failure", e);
+        }
+    }
+
     final void notifyFetchResult(int shardIndex) {
         try {
             onFetchResult(shardIndex);

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

@@ -407,7 +407,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
                     );
                 }
                 int queryCount = source.subSearches().size() + source.knnSearch().size();
-                if (queryCount < 2) {
+                if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) {
                     validationException = addValidationError(
                         "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches",
                         validationException

+ 1 - 0
server/src/main/java/org/elasticsearch/action/search/SearchTransportAPMMetrics.java

@@ -19,6 +19,7 @@ public class SearchTransportAPMMetrics {
     public static final String DFS_ACTION_METRIC = "dfs_query_then_fetch/shard_dfs_phase";
     public static final String QUERY_ID_ACTION_METRIC = "dfs_query_then_fetch/shard_query_phase";
     public static final String QUERY_ACTION_METRIC = "query_then_fetch/shard_query_phase";
+    public static final String RANK_SHARD_FEATURE_ACTION_METRIC = "rank/shard_feature_phase";
     public static final String FREE_CONTEXT_ACTION_METRIC = "shard_release_context";
     public static final String FETCH_ID_ACTION_METRIC = "shard_fetch_phase";
     public static final String QUERY_SCROLL_ACTION_METRIC = "scroll/shard_query_phase";

+ 30 - 0
server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

@@ -39,6 +39,8 @@ import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.query.ScrollQuerySearchResult;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.RemoteClusterService;
@@ -70,6 +72,7 @@ import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_CA
 import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_FETCH_SCROLL_ACTION_METRIC;
 import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_ID_ACTION_METRIC;
 import static org.elasticsearch.action.search.SearchTransportAPMMetrics.QUERY_SCROLL_ACTION_METRIC;
+import static org.elasticsearch.action.search.SearchTransportAPMMetrics.RANK_SHARD_FEATURE_ACTION_METRIC;
 
 /**
  * An encapsulation of {@link org.elasticsearch.search.SearchService} operations exposed through
@@ -96,6 +99,8 @@ public class SearchTransportService {
     public static final String FETCH_ID_SCROLL_ACTION_NAME = "indices:data/read/search[phase/fetch/id/scroll]";
     public static final String FETCH_ID_ACTION_NAME = "indices:data/read/search[phase/fetch/id]";
 
+    public static final String RANK_FEATURE_SHARD_ACTION_NAME = "indices:data/read/search[phase/rank/feature]";
+
     /**
      * The Can-Match phase. It is executed to pre-filter shards that a search request hits. It rewrites the query on
      * the shard and checks whether the result of the rewrite matches no documents, in which case the shard can be
@@ -250,6 +255,21 @@ public class SearchTransportService {
         );
     }
 
+    public void sendExecuteRankFeature(
+        Transport.Connection connection,
+        final RankFeatureShardRequest request,
+        SearchTask task,
+        final SearchActionListener<RankFeatureResult> listener
+    ) {
+        transportService.sendChildRequest(
+            connection,
+            RANK_FEATURE_SHARD_ACTION_NAME,
+            request,
+            task,
+            new ConnectionCountingHandler<>(listener, RankFeatureResult::new, connection)
+        );
+    }
+
     public void sendExecuteScrollFetch(
         Transport.Connection connection,
         final InternalScrollSearchRequest request,
@@ -539,6 +559,16 @@ public class SearchTransportService {
         );
         TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new);
 
+        final TransportRequestHandler<RankFeatureShardRequest> rankShardFeatureRequest = (request, channel, task) -> searchService
+            .executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
+        transportService.registerRequestHandler(
+            RANK_FEATURE_SHARD_ACTION_NAME,
+            EsExecutors.DIRECT_EXECUTOR_SERVICE,
+            RankFeatureShardRequest::new,
+            instrumentedHandler(RANK_SHARD_FEATURE_ACTION_METRIC, transportService, searchTransportMetrics, rankShardFeatureRequest)
+        );
+        TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new);
+
         final TransportRequestHandler<ShardFetchRequest> shardFetchRequestHandler = (request, channel, task) -> searchService
             .executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
         transportService.registerRequestHandler(

+ 1 - 0
server/src/main/java/org/elasticsearch/node/NodeConstruction.java

@@ -1044,6 +1044,7 @@ class NodeConstruction {
             threadPool,
             scriptService,
             bigArrays,
+            searchModule.getRankFeatureShardPhase(),
             searchModule.getFetchPhase(),
             responseCollectorService,
             circuitBreakerService,

+ 3 - 0
server/src/main/java/org/elasticsearch/node/NodeServiceProvider.java

@@ -33,6 +33,7 @@ import org.elasticsearch.script.ScriptEngine;
 import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.fetch.FetchPhase;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.telemetry.tracing.Tracer;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -116,6 +117,7 @@ class NodeServiceProvider {
         ThreadPool threadPool,
         ScriptService scriptService,
         BigArrays bigArrays,
+        RankFeatureShardPhase rankFeatureShardPhase,
         FetchPhase fetchPhase,
         ResponseCollectorService responseCollectorService,
         CircuitBreakerService circuitBreakerService,
@@ -128,6 +130,7 @@ class NodeServiceProvider {
             threadPool,
             scriptService,
             bigArrays,
+            rankFeatureShardPhase,
             fetchPhase,
             responseCollectorService,
             circuitBreakerService,

+ 13 - 0
server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java

@@ -70,6 +70,7 @@ import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.slice.SliceBuilder;
 import org.elasticsearch.search.sort.SortAndFormats;
@@ -102,6 +103,7 @@ final class DefaultSearchContext extends SearchContext {
     private final ContextIndexSearcher searcher;
     private DfsSearchResult dfsResult;
     private QuerySearchResult queryResult;
+    private RankFeatureResult rankFeatureResult;
     private FetchSearchResult fetchResult;
     private final float queryBoost;
     private final boolean lowLevelCancellation;
@@ -308,6 +310,17 @@ final class DefaultSearchContext extends SearchContext {
         return false;
     }
 
+    @Override
+    public void addRankFeatureResult() {
+        this.rankFeatureResult = new RankFeatureResult(this.readerContext.id(), this.shardTarget, this.request);
+        addReleasable(rankFeatureResult::decRef);
+    }
+
+    @Override
+    public RankFeatureResult rankFeatureResult() {
+        return rankFeatureResult;
+    }
+
     @Override
     public void addFetchResult() {
         this.fetchResult = new FetchSearchResult(this.readerContext.id(), this.shardTarget);

+ 5 - 0
server/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -226,6 +226,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.HighlightPhase;
 import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
 import org.elasticsearch.search.fetch.subphase.highlight.PlainHighlighter;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
 import org.elasticsearch.search.rescore.QueryRescorerBuilder;
 import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
@@ -1252,6 +1253,10 @@ public class SearchModule {
         );
     }
 
+    public RankFeatureShardPhase getRankFeatureShardPhase() {
+        return new RankFeatureShardPhase();
+    }
+
     public FetchPhase getFetchPhase() {
         return new FetchPhase(fetchSubPhases);
     }

+ 16 - 0
server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java

@@ -15,6 +15,7 @@ import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.internal.ShardSearchContextId;
 import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.transport.TransportResponse;
 
 import java.io.IOException;
@@ -43,6 +44,14 @@ public abstract class SearchPhaseResult extends TransportResponse {
         super(in);
     }
 
+    /**
+     * Specifies whether the specific search phase results are associated with an opened SearchContext on the shards that
+     * executed the request.
+     */
+    public boolean hasSearchContext() {
+        return false;
+    }
+
     /**
      * Returns the search context ID that is used to reference the search context on the executing node
      * or <code>null</code> if no context was created.
@@ -81,6 +90,13 @@ public abstract class SearchPhaseResult extends TransportResponse {
         return null;
     }
 
+    /**
+     * Returns the rank feature result iff it's included in this response otherwise <code>null</code>
+     */
+    public RankFeatureResult rankFeatureResult() {
+        return null;
+    }
+
     /**
      * Returns the fetch result iff it's included in this response otherwise <code>null</code>
      */

+ 39 - 0
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -112,6 +112,9 @@ import org.elasticsearch.search.query.QueryPhase;
 import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.query.ScrollQuerySearchResult;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
 import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.search.searchafter.SearchAfterBuilder;
 import org.elasticsearch.search.sort.FieldSortBuilder;
@@ -151,6 +154,7 @@ import static org.elasticsearch.core.TimeValue.timeValueHours;
 import static org.elasticsearch.core.TimeValue.timeValueMillis;
 import static org.elasticsearch.core.TimeValue.timeValueMinutes;
 import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
+import static org.elasticsearch.search.rank.feature.RankFeatureShardPhase.EMPTY_RESULT;
 
 public class SearchService extends AbstractLifecycleComponent implements IndexEventListener {
     private static final Logger logger = LogManager.getLogger(SearchService.class);
@@ -276,6 +280,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
     private final DfsPhase dfsPhase = new DfsPhase();
 
     private final FetchPhase fetchPhase;
+    private final RankFeatureShardPhase rankFeatureShardPhase;
     private volatile boolean enableSearchWorkerThreads;
     private volatile boolean enableQueryPhaseParallelCollection;
 
@@ -314,6 +319,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         ThreadPool threadPool,
         ScriptService scriptService,
         BigArrays bigArrays,
+        RankFeatureShardPhase rankFeatureShardPhase,
         FetchPhase fetchPhase,
         ResponseCollectorService responseCollectorService,
         CircuitBreakerService circuitBreakerService,
@@ -327,6 +333,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         this.scriptService = scriptService;
         this.responseCollectorService = responseCollectorService;
         this.bigArrays = bigArrays;
+        this.rankFeatureShardPhase = rankFeatureShardPhase;
         this.fetchPhase = fetchPhase;
         this.multiBucketConsumerService = new MultiBucketConsumerService(
             clusterService,
@@ -713,6 +720,32 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         }
     }
 
+    public void executeRankFeaturePhase(RankFeatureShardRequest request, SearchShardTask task, ActionListener<RankFeatureResult> listener) {
+        final ReaderContext readerContext = findReaderContext(request.contextId(), request);
+        final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
+        final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
+        runAsync(getExecutor(readerContext.indexShard()), () -> {
+            try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.RANK_FEATURE, false)) {
+                int[] docIds = request.getDocIds();
+                if (docIds == null || docIds.length == 0) {
+                    searchContext.rankFeatureResult().shardResult(EMPTY_RESULT);
+                    searchContext.rankFeatureResult().incRef();
+                    return searchContext.rankFeatureResult();
+                }
+                rankFeatureShardPhase.prepareForFetch(searchContext, request);
+                fetchPhase.execute(searchContext, docIds);
+                rankFeatureShardPhase.processFetch(searchContext);
+                var rankFeatureResult = searchContext.rankFeatureResult();
+                rankFeatureResult.incRef();
+                return rankFeatureResult;
+            } catch (Exception e) {
+                assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
+                // we handle the failure in the failure listener below
+                throw e;
+            }
+        }, wrapFailureListener(listener, readerContext, markAsUsed));
+    }
+
     private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchContext context, long afterQueryTime) {
         try (
             Releasable scope = tracer.withScope(context.getTask());
@@ -1559,6 +1592,12 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
                 context.addQueryResult();
             }
         },
+        RANK_FEATURE {
+            @Override
+            void addResultsObject(SearchContext context) {
+                context.addRankFeatureResult();
+            }
+        },
         FETCH {
             @Override
             void addResultsObject(SearchContext context) {

+ 0 - 1
server/src/main/java/org/elasticsearch/search/fetch/FetchPhase.java

@@ -98,7 +98,6 @@ public final class FetchPhase {
     }
 
     private SearchHits buildSearchHits(SearchContext context, int[] docIdsToLoad, Profiler profiler) {
-
         FetchContext fetchContext = new FetchContext(context);
         SourceLoader sourceLoader = context.newSourceLoader();
 

+ 11 - 0
server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java

@@ -35,6 +35,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
@@ -374,6 +375,16 @@ public abstract class FilteredSearchContext extends SearchContext {
         return in.getMaxScore();
     }
 
+    @Override
+    public void addRankFeatureResult() {
+        in.addRankFeatureResult();
+    }
+
+    @Override
+    public RankFeatureResult rankFeatureResult() {
+        return in.rankFeatureResult();
+    }
+
     @Override
     public FetchSearchResult fetchResult() {
         return in.fetchResult();

+ 5 - 0
server/src/main/java/org/elasticsearch/search/internal/SearchContext.java

@@ -42,6 +42,7 @@ import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
@@ -332,6 +333,10 @@ public abstract class SearchContext implements Releasable {
 
     public abstract float getMaxScore();
 
+    public abstract void addRankFeatureResult();
+
+    public abstract RankFeatureResult rankFeatureResult();
+
     public abstract FetchPhase fetchPhase();
 
     public abstract FetchSearchResult fetchResult();

+ 29 - 26
server/src/main/java/org/elasticsearch/search/query/QueryPhase.java

@@ -87,35 +87,38 @@ public class QueryPhase {
         boolean searchTimedOut = querySearchResult.searchTimedOut();
         long serviceTimeEWMA = querySearchResult.serviceTimeEWMA();
         int nodeQueueSize = querySearchResult.nodeQueueSize();
-
-        // run each of the rank queries
-        for (Query rankQuery : queryPhaseRankShardContext.queries()) {
-            // if a search timeout occurs, exit with partial results
-            if (searchTimedOut) {
-                break;
-            }
-            try (
-                RankSearchContext rankSearchContext = new RankSearchContext(
-                    searchContext,
-                    rankQuery,
-                    queryPhaseRankShardContext.rankWindowSize()
-                )
-            ) {
-                QueryPhase.addCollectorsAndSearch(rankSearchContext);
-                QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult();
-                rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs);
-                serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA();
-                nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize());
-                searchTimedOut = rrfQuerySearchResult.searchTimedOut();
+        try {
+            // run each of the rank queries
+            for (Query rankQuery : queryPhaseRankShardContext.queries()) {
+                // if a search timeout occurs, exit with partial results
+                if (searchTimedOut) {
+                    break;
+                }
+                try (
+                    RankSearchContext rankSearchContext = new RankSearchContext(
+                        searchContext,
+                        rankQuery,
+                        queryPhaseRankShardContext.rankWindowSize()
+                    )
+                ) {
+                    QueryPhase.addCollectorsAndSearch(rankSearchContext);
+                    QuerySearchResult rrfQuerySearchResult = rankSearchContext.queryResult();
+                    rrfRankResults.add(rrfQuerySearchResult.topDocs().topDocs);
+                    serviceTimeEWMA += rrfQuerySearchResult.serviceTimeEWMA();
+                    nodeQueueSize = Math.max(nodeQueueSize, rrfQuerySearchResult.nodeQueueSize());
+                    searchTimedOut = rrfQuerySearchResult.searchTimedOut();
+                }
             }
-        }
 
-        querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults));
+            querySearchResult.setRankShardResult(queryPhaseRankShardContext.combineQueryPhaseResults(rrfRankResults));
 
-        // record values relevant to all queries
-        querySearchResult.searchTimedOut(searchTimedOut);
-        querySearchResult.serviceTimeEWMA(serviceTimeEWMA);
-        querySearchResult.nodeQueueSize(nodeQueueSize);
+            // record values relevant to all queries
+            querySearchResult.searchTimedOut(searchTimedOut);
+            querySearchResult.serviceTimeEWMA(serviceTimeEWMA);
+            querySearchResult.nodeQueueSize(nodeQueueSize);
+        } catch (Exception e) {
+            throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Failed to execute rank query", e);
+        }
     }
 
     static void executeQuery(SearchContext searchContext) throws QueryPhaseExecutionException {

+ 22 - 1
server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

@@ -16,6 +16,8 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -32,7 +34,7 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent
 
     public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
 
-    public static final int DEFAULT_WINDOW_SIZE = SearchService.DEFAULT_SIZE;
+    public static final int DEFAULT_RANK_WINDOW_SIZE = SearchService.DEFAULT_SIZE;
 
     private final int rankWindowSize;
 
@@ -68,6 +70,12 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent
         return rankWindowSize;
     }
 
+    /**
+     * Specify whether this rank builder is a compound builder or not. A compound builder is a rank builder that requires
+     * two or more queries to be executed in order to generate the final result.
+     */
+    public abstract boolean isCompoundBuilder();
+
     /**
      * Generates a context used to execute required searches during the query phase on the shard.
      */
@@ -78,6 +86,19 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent
      */
     public abstract QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from);
 
+    /**
+     * Generates a context used to execute the rank feature phase on the shard. This is responsible for retrieving any needed
+     * feature data, and passing them back to the coordinator through the appropriate {@link  RankShardResult}.
+     */
+    public abstract RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext();
+
+    /**
+     * Generates a context used to perform global ranking during the RankFeature phase,
+     * on the coordinator based on all the individual shard results. The output of this will be a `size` ranked list of ordered results,
+     * which will then be passed to fetch phase.
+     */
+    public abstract RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from);
+
     @Override
     public final boolean equals(Object obj) {
         if (this == obj) {

+ 15 - 4
server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java

@@ -43,6 +43,7 @@ import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
@@ -57,14 +58,14 @@ public class RankSearchContext extends SearchContext {
 
     private final SearchContext parent;
     private final Query rankQuery;
-    private final int windowSize;
+    private final int rankWindowSize;
     private final QuerySearchResult querySearchResult;
 
     @SuppressWarnings("this-escape")
-    public RankSearchContext(SearchContext parent, Query rankQuery, int windowSize) {
+    public RankSearchContext(SearchContext parent, Query rankQuery, int rankWindowSize) {
         this.parent = parent;
         this.rankQuery = parent.buildFilteredQuery(rankQuery);
-        this.windowSize = windowSize;
+        this.rankWindowSize = rankWindowSize;
         this.querySearchResult = new QuerySearchResult(parent.readerContext().id(), parent.shardTarget(), parent.request());
         this.addReleasable(querySearchResult::decRef);
     }
@@ -182,7 +183,7 @@ public class RankSearchContext extends SearchContext {
 
     @Override
     public int size() {
-        return windowSize;
+        return rankWindowSize;
     }
 
     /**
@@ -492,6 +493,16 @@ public class RankSearchContext extends SearchContext {
         throw new UnsupportedOperationException();
     }
 
+    @Override
+    public void addRankFeatureResult() {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public RankFeatureResult rankFeatureResult() {
+        throw new UnsupportedOperationException();
+    }
+
     @Override
     public FetchSearchResult fetchResult() {
         throw new UnsupportedOperationException();

+ 96 - 0
server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankCoordinatorContext.java

@@ -0,0 +1,96 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.context;
+
+import org.apache.lucene.search.ScoreDoc;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.elasticsearch.search.SearchService.DEFAULT_FROM;
+import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
+
+/**
+ * {@code RankFeaturePhaseRankCoordinatorContext} is a base class that runs on the coordinating node and is responsible for retrieving
+ * {@code window_size} total results from all shards, rank them, and then produce a final paginated response of [from, from+size] results.
+ */
+public abstract class RankFeaturePhaseRankCoordinatorContext {
+
+    protected final int size;
+    protected final int from;
+    protected final int rankWindowSize;
+
+    public RankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
+        this.size = size < 0 ? DEFAULT_SIZE : size;
+        this.from = from < 0 ? DEFAULT_FROM : from;
+        this.rankWindowSize = rankWindowSize;
+    }
+
+    /**
+     * Computes the updated scores for a list of features (i.e. document-based data). We also pass along an ActionListener
+     * that should be called with the new scores, and will continue execution to the next phase
+     */
+    protected abstract void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener);
+
+    /**
+     * This method is responsible for ranking the global results based on the provided rank feature results from each shard.
+     * <p>
+     * We first start by extracting ordered feature data through a {@code List<RankFeatureDoc>}
+     * from the provided rankSearchResults, and then compute the updated score for each of the documents.
+     * Once all the scores have been computed, we sort the results, perform any pagination needed, and then call the `onFinish` consumer
+     * with the final array of {@link ScoreDoc} results.
+     *
+     * @param rankSearchResults a list of rank feature results from each shard
+     * @param rankListener      a rankListener to handle the global ranking result
+     */
+    public void rankGlobalResults(List<RankFeatureResult> rankSearchResults, ActionListener<RankFeatureDoc[]> rankListener) {
+        // extract feature data from each shard rank-feature phase result
+        RankFeatureDoc[] featureDocs = extractFeatureDocs(rankSearchResults);
+
+        // generate the final `topResults` paginated results, and pass them to fetch phase through the `rankListener`
+        computeScores(featureDocs, rankListener.delegateFailureAndWrap((listener, scores) -> {
+            for (int i = 0; i < featureDocs.length; i++) {
+                featureDocs[i].score = scores[i];
+            }
+            listener.onResponse(featureDocs);
+        }));
+    }
+
+    /**
+     * Ranks the provided {@link RankFeatureDoc} array and paginates the results based on the `from` and `size` parameters.
+     */
+    public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
+        Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
+        RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))];
+        for (int rank = 0; rank < topResults.length; ++rank) {
+            topResults[rank] = rankFeatureDocs[from + rank];
+            topResults[rank].rank = from + rank + 1;
+        }
+        return topResults;
+    }
+
+    private RankFeatureDoc[] extractFeatureDocs(List<RankFeatureResult> rankSearchResults) {
+        List<RankFeatureDoc> docFeatures = new ArrayList<>();
+        for (RankFeatureResult rankFeatureResult : rankSearchResults) {
+            RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
+            for (RankFeatureDoc rankFeatureDoc : shardResult.rankFeatureDocs) {
+                if (rankFeatureDoc.featureData != null) {
+                    docFeatures.add(rankFeatureDoc);
+                }
+            }
+        }
+        return docFeatures.toArray(new RankFeatureDoc[0]);
+    }
+}

+ 39 - 0
server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.context;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.rank.RankShardResult;
+
+/**
+ * {@link RankFeaturePhaseRankShardContext} is a base class used to execute the RankFeature phase on each shard.
+ * In this class, we can fetch the feature data for a given set of documents and pass them back to the coordinator
+ * through the {@link RankShardResult}.
+ */
+public abstract class RankFeaturePhaseRankShardContext {
+
+    protected final String field;
+
+    public RankFeaturePhaseRankShardContext(final String field) {
+        this.field = field;
+    }
+
+    public String getField() {
+        return field;
+    }
+
+    /**
+     * This is used to fetch the feature data for a given set of documents, using the {@link  org.elasticsearch.search.fetch.FetchPhase}
+     * and the {@link org.elasticsearch.search.fetch.subphase.FetchFieldsPhase} subphase.
+     * The feature data is then stored in a {@link org.elasticsearch.search.rank.feature.RankFeatureDoc} and passed back to the coordinator.
+     */
+    @Nullable
+    public abstract RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId);
+}

+ 54 - 0
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.feature;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.rank.RankDoc;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A {@link RankDoc} that contains field data to be used later by the reranker on the coordinator node.
+ */
+public class RankFeatureDoc extends RankDoc {
+
+    // todo: update to support more than 1 fields; and not restrict to string data
+    public String featureData;
+
+    public RankFeatureDoc(int doc, float score, int shardIndex) {
+        super(doc, score, shardIndex);
+    }
+
+    public RankFeatureDoc(StreamInput in) throws IOException {
+        super(in);
+        featureData = in.readOptionalString();
+    }
+
+    public void featureData(String featureData) {
+        this.featureData = featureData;
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeOptionalString(featureData);
+    }
+
+    @Override
+    protected boolean doEquals(RankDoc rd) {
+        RankFeatureDoc other = (RankFeatureDoc) rd;
+        return Objects.equals(this.featureData, other.featureData);
+    }
+
+    @Override
+    protected int doHashCode() {
+        return Objects.hashCode(featureData);
+    }
+}

+ 70 - 0
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureResult.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.feature;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.internal.ShardSearchRequest;
+
+import java.io.IOException;
+
+/**
+ * The result of a rank feature search phase.
+ * Each instance holds a {@code RankFeatureShardResult} along with the references associated with it.
+ */
+public class RankFeatureResult extends SearchPhaseResult {
+
+    private RankFeatureShardResult rankShardResult;
+
+    public RankFeatureResult() {}
+
+    public RankFeatureResult(ShardSearchContextId id, SearchShardTarget shardTarget, ShardSearchRequest request) {
+        this.contextId = id;
+        setSearchShardTarget(shardTarget);
+        setShardSearchRequest(request);
+    }
+
+    public RankFeatureResult(StreamInput in) throws IOException {
+        super(in);
+        contextId = new ShardSearchContextId(in);
+        rankShardResult = in.readOptionalWriteable(RankFeatureShardResult::new);
+        setShardSearchRequest(in.readOptionalWriteable(ShardSearchRequest::new));
+        setSearchShardTarget(in.readOptionalWriteable(SearchShardTarget::new));
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        assert hasReferences();
+        contextId.writeTo(out);
+        out.writeOptionalWriteable(rankShardResult);
+        out.writeOptionalWriteable(getShardSearchRequest());
+        out.writeOptionalWriteable(getSearchShardTarget());
+    }
+
+    @Override
+    public RankFeatureResult rankFeatureResult() {
+        return this;
+    }
+
+    public void shardResult(RankFeatureShardResult shardResult) {
+        this.rankShardResult = shardResult;
+    }
+
+    public RankFeatureShardResult shardResult() {
+        return rankShardResult;
+    }
+
+    @Override
+    public boolean hasSearchContext() {
+        return rankShardResult != null;
+    }
+}

+ 99 - 0
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java

@@ -0,0 +1,99 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.feature;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.search.SearchContextSourcePrinter;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.fetch.FetchSearchResult;
+import org.elasticsearch.search.fetch.StoredFieldsContext;
+import org.elasticsearch.search.fetch.subphase.FetchFieldsContext;
+import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
+import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
+import org.elasticsearch.tasks.TaskCancelledException;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+/**
+ * The {@code RankFeatureShardPhase} executes the rank feature phase on the shard, iff there is a {@code RankBuilder} that requires it.
+ * This phase is responsible for reading field data for a set of docids. To do this, it reuses the {@code FetchPhase} to read the required
+ * fields for all requested documents using the `FetchFieldPhase` sub-phase.
+ */
+public final class RankFeatureShardPhase {
+
+    private static final Logger logger = LogManager.getLogger(RankFeatureShardPhase.class);
+
+    public static final RankFeatureShardResult EMPTY_RESULT = new RankFeatureShardResult(new RankFeatureDoc[0]);
+
+    public RankFeatureShardPhase() {}
+
+    public void prepareForFetch(SearchContext searchContext, RankFeatureShardRequest request) {
+        if (logger.isTraceEnabled()) {
+            logger.trace("{}", new SearchContextSourcePrinter(searchContext));
+        }
+
+        if (searchContext.isCancelled()) {
+            throw new TaskCancelledException("cancelled");
+        }
+
+        RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext);
+        if (rankFeaturePhaseRankShardContext != null) {
+            assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null";
+            searchContext.fetchFieldsContext(
+                new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null)))
+            );
+            searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_)));
+            searchContext.addFetchResult();
+            Arrays.sort(request.getDocIds());
+        }
+    }
+
+    public void processFetch(SearchContext searchContext) {
+        if (logger.isTraceEnabled()) {
+            logger.trace("{}", new SearchContextSourcePrinter(searchContext));
+        }
+
+        if (searchContext.isCancelled()) {
+            throw new TaskCancelledException("cancelled");
+        }
+
+        RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = searchContext.request().source().rankBuilder() != null
+            ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext()
+            : null;
+        if (rankFeaturePhaseRankShardContext != null) {
+            // TODO: here we populate the profile part of the fetchResult as well
+            // we need to see what info we want to include on the overall profiling section. This is something that is per-shard
+            // so most likely we will still care about the `FetchFieldPhase` profiling info as we could potentially
+            // operate on `rank_window_size` instead of just `size` results, so this could be much more expensive.
+            FetchSearchResult fetchSearchResult = searchContext.fetchResult();
+            if (fetchSearchResult == null || fetchSearchResult.hits() == null) {
+                return;
+            }
+            // this cannot be null; as we have either already checked for it, or we would have thrown in
+            // FetchSearchResult#shardResult()
+            SearchHits hits = fetchSearchResult.hits();
+            RankFeatureShardResult featureRankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext
+                .buildRankFeatureShardResult(hits, searchContext.shardTarget().getShardId().id());
+            // save the result in the search context
+            // need to add profiling info as well available from fetch
+            if (featureRankShardResult != null) {
+                searchContext.rankFeatureResult().shardResult(featureRankShardResult);
+            }
+        }
+    }
+
+    private RankFeaturePhaseRankShardContext shardContext(SearchContext searchContext) {
+        return searchContext.request().source() != null && searchContext.request().source().rankBuilder() != null
+            ? searchContext.request().source().rankBuilder().buildRankFeaturePhaseShardContext()
+            : null;
+    }
+}

+ 101 - 0
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardRequest.java

@@ -0,0 +1,101 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.feature;
+
+import org.elasticsearch.action.IndicesRequest;
+import org.elasticsearch.action.OriginalIndices;
+import org.elasticsearch.action.search.SearchShardTask;
+import org.elasticsearch.action.support.IndicesOptions;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.tasks.TaskId;
+import org.elasticsearch.transport.TransportRequest;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * Shard level request for extracting all needed feature for a global reranker
+ */
+
+public class RankFeatureShardRequest extends TransportRequest implements IndicesRequest {
+
+    private final OriginalIndices originalIndices;
+    private final ShardSearchRequest shardSearchRequest;
+
+    private final ShardSearchContextId contextId;
+
+    private final int[] docIds;
+
+    public RankFeatureShardRequest(
+        OriginalIndices originalIndices,
+        ShardSearchContextId contextId,
+        ShardSearchRequest shardSearchRequest,
+        List<Integer> docIds
+    ) {
+        this.originalIndices = originalIndices;
+        this.shardSearchRequest = shardSearchRequest;
+        this.docIds = docIds.stream().flatMapToInt(IntStream::of).toArray();
+        this.contextId = contextId;
+    }
+
+    public RankFeatureShardRequest(StreamInput in) throws IOException {
+        super(in);
+        originalIndices = OriginalIndices.readOriginalIndices(in);
+        shardSearchRequest = in.readOptionalWriteable(ShardSearchRequest::new);
+        docIds = in.readIntArray();
+        contextId = in.readOptionalWriteable(ShardSearchContextId::new);
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        OriginalIndices.writeOriginalIndices(originalIndices, out);
+        out.writeOptionalWriteable(shardSearchRequest);
+        out.writeIntArray(docIds);
+        out.writeOptionalWriteable(contextId);
+    }
+
+    @Override
+    public String[] indices() {
+        if (originalIndices == null) {
+            return null;
+        }
+        return originalIndices.indices();
+    }
+
+    @Override
+    public IndicesOptions indicesOptions() {
+        if (originalIndices == null) {
+            return null;
+        }
+        return originalIndices.indicesOptions();
+    }
+
+    public ShardSearchRequest getShardSearchRequest() {
+        return shardSearchRequest;
+    }
+
+    public int[] getDocIds() {
+        return docIds;
+    }
+
+    public ShardSearchContextId contextId() {
+        return contextId;
+    }
+
+    @Override
+    public SearchShardTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
+        return new SearchShardTask(id, type, action, getDescription(), parentTaskId, headers);
+    }
+}

+ 68 - 0
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardResult.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank.feature;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.rank.RankShardResult;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * The result set of {@link RankFeatureDoc} docs for the shard.
+ */
+public class RankFeatureShardResult implements RankShardResult {
+
+    public final RankFeatureDoc[] rankFeatureDocs;
+
+    public RankFeatureShardResult(RankFeatureDoc[] rankFeatureDocs) {
+        this.rankFeatureDocs = Objects.requireNonNull(rankFeatureDocs);
+    }
+
+    public RankFeatureShardResult(StreamInput in) throws IOException {
+        rankFeatureDocs = in.readArray(RankFeatureDoc::new, RankFeatureDoc[]::new);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return "rank_feature_shard";
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.RANK_FEATURE_PHASE_ADDED;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeArray(rankFeatureDocs);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        RankFeatureShardResult that = (RankFeatureShardResult) o;
+        return Arrays.equals(rankFeatureDocs, that.rankFeatureDocs);
+    }
+
+    @Override
+    public int hashCode() {
+        return 31 * Arrays.hashCode(rankFeatureDocs);
+    }
+
+    @Override
+    public String toString() {
+        return this.getClass().getSimpleName() + "{rankFeatureDocs=" + Arrays.toString(rankFeatureDocs) + '}';
+    }
+}

+ 1170 - 0
server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

@@ -0,0 +1,1170 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+package org.elasticsearch.action.search;
+
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.tests.store.MockDirectoryWrapper;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.UUIDs;
+import org.elasticsearch.common.breaker.CircuitBreaker;
+import org.elasticsearch.common.breaker.NoopCircuitBreaker;
+import org.elasticsearch.common.document.DocumentField;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.DocValueFormat;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.SearchPhaseResult;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.internal.ShardSearchContextId;
+import org.elasticsearch.search.query.QuerySearchResult;
+import org.elasticsearch.search.rank.RankBuilder;
+import org.elasticsearch.search.rank.RankShardResult;
+import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.InternalAggregationTestCase;
+import org.elasticsearch.transport.Transport;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class RankFeaturePhaseTests extends ESTestCase {
+
+    private static final int DEFAULT_RANK_WINDOW_SIZE = 10;
+    private static final int DEFAULT_FROM = 0;
+    private static final int DEFAULT_SIZE = 10;
+    private static final String DEFAULT_FIELD = "some_field";
+
+    private final RankBuilder DEFAULT_RANK_BUILDER = rankBuilder(
+        DEFAULT_RANK_WINDOW_SIZE,
+        defaultQueryPhaseRankShardContext(new ArrayList<>(), DEFAULT_RANK_WINDOW_SIZE),
+        defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE),
+        defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD),
+        defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE)
+    );
+
+    private record ExpectedRankFeatureDoc(int doc, int rank, float score, String featureData) {}
+
+    public void testRankFeaturePhaseWith1Shard() {
+        // request params used within SearchSourceBuilder and *RankContext classes
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null);
+            try {
+                queryResult.setShardIndex(shard1Target.getShardId().getId());
+                // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+                // here we have 2 results, with doc ids 1 and 2
+                int totalHits = randomIntBetween(2, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResult, totalHits, shard1Docs);
+                results.consumeResult(queryResult, () -> {});
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        // make sure to match the context id generated above, otherwise we throw
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) {
+                            RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard1Target,
+                                totalHits,
+                                shard1Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResult.decRef();
+            }
+
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+
+                mockSearchPhaseContext.assertNoFailure();
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertTrue(phaseDone.get());
+                assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty());
+
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.rankPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(1, rankPhaseResults.getAtomicArray().length());
+                assertEquals(1, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
+                List<ExpectedRankFeatureDoc> expectedShardResults = List.of(
+                    new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2")
+                );
+                List<ExpectedRankFeatureDoc> expectedFinalResults = new ArrayList<>(expectedShardResults);
+                assertShardResults(shard1Result, expectedShardResults);
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeaturePhaseWithMultipleShardsOneEmpty() {
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+        SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null);
+        SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively
+            final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456);
+            final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789);
+
+            QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null);
+            QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null);
+            QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null);
+            try {
+                queryResultShard1.setShardIndex(shard1Target.getShardId().getId());
+                queryResultShard2.setShardIndex(shard2Target.getShardId().getId());
+                queryResultShard3.setShardIndex(shard3Target.getShardId().getId());
+
+                final int shard1Results = randomIntBetween(1, 100);
+                final int shard2Results = randomIntBetween(1, 100);
+                final int shard3Results = 0;
+
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) };
+                populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs);
+                final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs);
+                final ScoreDoc[] shard3Docs = new ScoreDoc[0];
+                populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs);
+
+                results.consumeResult(queryResultShard2, () -> {});
+                results.consumeResult(queryResultShard3, () -> {});
+                results.consumeResult(queryResultShard1, () -> {});
+
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        // make sure to match the context id generated above, otherwise we throw
+                        // first shard
+                        RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) {
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard1Target,
+                                shard1Results,
+                                shard1Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) {
+                            // second shard
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard2Target,
+                                shard2Results,
+                                shard2Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else if (request.contextId().getId() == 789) {
+                            listener.onResponse(rankFeatureResult);
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResultShard1.decRef();
+                queryResultShard2.decRef();
+                queryResultShard3.decRef();
+            }
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+                mockSearchPhaseContext.assertNoFailure();
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertTrue(phaseDone.get());
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.rankPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(3, rankPhaseResults.getAtomicArray().length());
+                // one result is null
+                assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
+                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                assertShardResults(shard1Result, expectedShard1Results);
+
+                SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2"));
+                assertShardResults(shard2Result, expectedShard2Results);
+
+                SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
+                assertNull(shard3Result);
+
+                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
+                    new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2")
+                );
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeaturePhaseNoNeedForFetchingFieldData() {
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // build the appropriate RankBuilder; using a null rankFeaturePhaseRankShardContext
+        // and non-field based rankFeaturePhaseRankCoordinatorContext
+        RankBuilder rankBuilder = rankBuilder(
+            DEFAULT_RANK_WINDOW_SIZE,
+            defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE),
+            negatingScoresQueryFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE),
+            null,
+            null
+        );
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 2 results, with doc ids 1 and 2
+            final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null);
+
+            try {
+                queryResult.setShardIndex(shard1Target.getShardId().getId());
+                int totalHits = randomIntBetween(2, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResult, totalHits, shard1Docs);
+                results.consumeResult(queryResult, () -> {});
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        // make sure to match the context id generated above, otherwise we throw
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) {
+                            listener.onFailure(new UnsupportedOperationException("should not have reached here"));
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResult.decRef();
+            }
+            // override the RankFeaturePhase to skip moving to next phase
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+                mockSearchPhaseContext.assertNoFailure();
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertTrue(phaseDone.get());
+
+                // in this case there was no additional "RankFeature" results on shards, so we shortcut directly to queryPhaseResults
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.queryPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(1, rankPhaseResults.getAtomicArray().length());
+                assertEquals(1, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shardResult = rankPhaseResults.getAtomicArray().get(0);
+                assertTrue(shardResult instanceof QuerySearchResult);
+                QuerySearchResult rankResult = (QuerySearchResult) shardResult;
+                assertNull(rankResult.rankFeatureResult());
+                assertNotNull(rankResult.queryResult());
+
+                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
+                    new ExpectedRankFeatureDoc(2, 1, -9.0F, null),
+                    new ExpectedRankFeatureDoc(1, 2, -10.0F, null)
+                );
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeaturePhaseOneShardFails() {
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+        SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 2 results, with doc ids 1 and 2 found on shards 0 and 1 respectively
+            final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456);
+
+            QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null);
+            QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null);
+            try {
+                queryResultShard1.setShardIndex(shard1Target.getShardId().getId());
+                queryResultShard2.setShardIndex(shard2Target.getShardId().getId());
+
+                final int shard1Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) };
+                populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs);
+
+                final int shard2Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs);
+
+                results.consumeResult(queryResultShard2, () -> {});
+                results.consumeResult(queryResultShard1, () -> {});
+
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        // make sure to match the context id generated above, otherwise we throw
+                        // first shard
+                        if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 2 })) {
+                            RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard2Target,
+                                shard2Results,
+                                shard2Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+
+                        } else if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) {
+                            // other shard; this one throws an exception
+                            listener.onFailure(new IllegalArgumentException("simulated failure"));
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResultShard1.decRef();
+                queryResultShard2.decRef();
+            }
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+
+                mockSearchPhaseContext.assertNoFailure();
+                assertEquals(1, mockSearchPhaseContext.failures.size());
+                assertTrue(mockSearchPhaseContext.failures.get(0).getCause().getMessage().contains("simulated failure"));
+                assertTrue(phaseDone.get());
+
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.rankPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(2, rankPhaseResults.getAtomicArray().length());
+                // one shard failed
+                assertEquals(1, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
+                assertNull(shard1Result);
+
+                SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2"));
+                List<ExpectedRankFeatureDoc> expectedFinalResults = new ArrayList<>(expectedShard2Results);
+                assertShardResults(shard2Result, expectedShard2Results);
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeaturePhaseExceptionThrownOnPhase() {
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(DEFAULT_RANK_BUILDER);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 2 results, with doc ids 1 and 2
+            final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null);
+            try {
+                queryResult.setShardIndex(shard1Target.getShardId().getId());
+                int totalHits = randomIntBetween(2, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResult, totalHits, shard1Docs);
+                results.consumeResult(queryResult, () -> {});
+
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        // make sure to match the context id generated above, otherwise we throw
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) {
+                            RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard1Target,
+                                totalHits,
+                                shard1Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResult.decRef();
+            }
+            // override the RankFeaturePhase to raise an exception
+            RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext) {
+                @Override
+                void innerRun() {
+                    throw new IllegalArgumentException("simulated failure");
+                }
+
+                @Override
+                public void moveToNextPhase(
+                    SearchPhaseResults<SearchPhaseResult> phaseResults,
+                    SearchPhaseController.ReducedQueryPhase reducedQueryPhase
+                ) {
+                    // this is called after the RankFeaturePhaseCoordinatorContext has been executed
+                    phaseDone.set(true);
+                    finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs();
+                    logger.debug("Skipping moving to next phase");
+                }
+            };
+            assertEquals("rank-feature", rankFeaturePhase.getName());
+            try {
+                rankFeaturePhase.run();
+                assertNotNull(mockSearchPhaseContext.phaseFailure.get());
+                assertTrue(mockSearchPhaseContext.phaseFailure.get().getMessage().contains("simulated failure"));
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertFalse(phaseDone.get());
+                assertTrue(rankFeaturePhase.rankPhaseResults.getAtomicArray().asList().isEmpty());
+                assertNull(finalResults[0][0]);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeatureWithPagination() {
+        // request params used within SearchSourceBuilder and *RankContext classes
+        final int from = 1;
+        final int size = 1;
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // build the appropriate RankBuilder
+        RankBuilder rankBuilder = rankBuilder(
+            DEFAULT_RANK_WINDOW_SIZE,
+            defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE),
+            defaultQueryPhaseRankCoordinatorContext(DEFAULT_RANK_WINDOW_SIZE),
+            defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD),
+            defaultRankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE)
+        );
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+        SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null);
+        SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 4 results, with doc ids 1 and (11, 2, 200) found on shards 0 and 1 respectively
+            final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456);
+            final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789);
+
+            QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null);
+            QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null);
+            QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null);
+
+            try {
+                queryResultShard1.setShardIndex(shard1Target.getShardId().getId());
+                queryResultShard2.setShardIndex(shard2Target.getShardId().getId());
+                queryResultShard3.setShardIndex(shard3Target.getShardId().getId());
+
+                final int shard1Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) };
+                populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs);
+
+                final int shard2Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard2Docs = new ScoreDoc[] {
+                    new ScoreDoc(11, 100.0F, -1),
+                    new ScoreDoc(2, 9.0F),
+                    new ScoreDoc(200, 1F, -1) };
+                populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs);
+
+                final int shard3Results = 0;
+                final ScoreDoc[] shard3Docs = new ScoreDoc[0];
+                populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs);
+
+                results.consumeResult(queryResultShard2, () -> {});
+                results.consumeResult(queryResultShard3, () -> {});
+                results.consumeResult(queryResultShard1, () -> {});
+
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+
+                        RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                        // make sure to match the context id generated above, otherwise we throw
+                        // first shard
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) {
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard1Target,
+                                shard1Results,
+                                shard1Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11, 2, 200 })) {
+                            // second shard
+
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard2Target,
+                                shard2Results,
+                                shard2Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+
+                    }
+                };
+            } finally {
+                queryResultShard1.decRef();
+                queryResultShard2.decRef();
+                queryResultShard3.decRef();
+            }
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+
+                mockSearchPhaseContext.assertNoFailure();
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertTrue(phaseDone.get());
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.rankPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(3, rankPhaseResults.getAtomicArray().length());
+                // one result is null
+                assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
+                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                assertShardResults(shard1Result, expectedShard1Results);
+
+                SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(
+                    new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2"),
+                    new ExpectedRankFeatureDoc(200, 3, 101.0F, "ranked_200")
+
+                );
+                assertShardResults(shard2Result, expectedShard2Results);
+
+                SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
+                assertNull(shard3Result);
+
+                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1"));
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    public void testRankFeatureCollectOnlyRankWindowSizeFeatures() {
+        // request params used within SearchSourceBuilder and *RankContext classes
+        final int rankWindowSize = 2;
+        AtomicBoolean phaseDone = new AtomicBoolean(false);
+        final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
+
+        // build the appropriate RankBuilder
+        RankBuilder rankBuilder = rankBuilder(
+            rankWindowSize,
+            defaultQueryPhaseRankShardContext(Collections.emptyList(), rankWindowSize),
+            defaultQueryPhaseRankCoordinatorContext(rankWindowSize),
+            defaultRankFeaturePhaseRankShardContext(DEFAULT_FIELD),
+            defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, rankWindowSize)
+        );
+        // create a SearchSource to attach to the request
+        SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder);
+
+        SearchPhaseController controller = searchPhaseController();
+        SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);
+        SearchShardTarget shard2Target = new SearchShardTarget("node1", new ShardId("test", "na", 1), null);
+        SearchShardTarget shard3Target = new SearchShardTarget("node2", new ShardId("test", "na", 2), null);
+
+        MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(3);
+        mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
+        try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
+            // generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
+            // here we have 3 results, with doc ids 1, and (11, 2) found on shards 0 and 1 respectively
+            final ShardSearchContextId ctxShard1 = new ShardSearchContextId(UUIDs.base64UUID(), 123);
+            final ShardSearchContextId ctxShard2 = new ShardSearchContextId(UUIDs.base64UUID(), 456);
+            final ShardSearchContextId ctxShard3 = new ShardSearchContextId(UUIDs.base64UUID(), 789);
+
+            QuerySearchResult queryResultShard1 = new QuerySearchResult(ctxShard1, shard1Target, null);
+            QuerySearchResult queryResultShard2 = new QuerySearchResult(ctxShard2, shard2Target, null);
+            QuerySearchResult queryResultShard3 = new QuerySearchResult(ctxShard3, shard2Target, null);
+
+            try {
+                queryResultShard1.setShardIndex(shard1Target.getShardId().getId());
+                queryResultShard2.setShardIndex(shard2Target.getShardId().getId());
+                queryResultShard3.setShardIndex(shard3Target.getShardId().getId());
+
+                final int shard1Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F) };
+                populateQuerySearchResult(queryResultShard1, shard1Results, shard1Docs);
+
+                final int shard2Results = randomIntBetween(1, 100);
+                final ScoreDoc[] shard2Docs = new ScoreDoc[] { new ScoreDoc(11, 100.0F), new ScoreDoc(2, 9.0F) };
+                populateQuerySearchResult(queryResultShard2, shard2Results, shard2Docs);
+
+                final int shard3Results = 0;
+                final ScoreDoc[] shard3Docs = new ScoreDoc[0];
+                populateQuerySearchResult(queryResultShard3, shard3Results, shard3Docs);
+
+                results.consumeResult(queryResultShard2, () -> {});
+                results.consumeResult(queryResultShard3, () -> {});
+                results.consumeResult(queryResultShard1, () -> {});
+
+                // do not make an actual http request, but rather generate the response
+                // as if we would have read it from the RankFeatureShardPhase
+                mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
+                    @Override
+                    public void sendExecuteRankFeature(
+                        Transport.Connection connection,
+                        final RankFeatureShardRequest request,
+                        SearchTask task,
+                        final SearchActionListener<RankFeatureResult> listener
+                    ) {
+                        RankFeatureResult rankFeatureResult = new RankFeatureResult();
+                        // make sure to match the context id generated above, otherwise we throw
+                        // first shard
+                        if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1 })) {
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard1Target,
+                                shard1Results,
+                                shard1Docs
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else if (request.contextId().getId() == 456 && Arrays.equals(request.getDocIds(), new int[] { 11 })) {
+                            // second shard
+                            buildRankFeatureResult(
+                                mockSearchPhaseContext.getRequest().source().rankBuilder(),
+                                rankFeatureResult,
+                                shard2Target,
+                                shard2Results,
+                                new ScoreDoc[] { shard2Docs[0] }
+                            );
+                            listener.onResponse(rankFeatureResult);
+                        } else {
+                            listener.onFailure(new MockDirectoryWrapper.FakeIOException());
+                        }
+                    }
+                };
+            } finally {
+                queryResultShard1.decRef();
+                queryResultShard2.decRef();
+                queryResultShard3.decRef();
+            }
+            RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
+            try {
+                rankFeaturePhase.run();
+                mockSearchPhaseContext.assertNoFailure();
+                assertTrue(mockSearchPhaseContext.failures.isEmpty());
+                assertTrue(phaseDone.get());
+                SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.rankPhaseResults;
+                assertNotNull(rankPhaseResults.getAtomicArray());
+                assertEquals(3, rankPhaseResults.getAtomicArray().length());
+                // one result is null
+                assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
+
+                SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
+                List<ExpectedRankFeatureDoc> expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                assertShardResults(shard1Result, expectedShardResults);
+
+                SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"));
+                assertShardResults(shard2Result, expectedShard2Results);
+
+                SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
+                assertNull(shard3Result);
+
+                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
+                    new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"),
+                    new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1")
+                );
+                assertFinalResults(finalResults[0], expectedFinalResults);
+            } finally {
+                rankFeaturePhase.rankPhaseResults.close();
+            }
+        } finally {
+            if (mockSearchPhaseContext.searchResponse.get() != null) {
+                mockSearchPhaseContext.searchResponse.get().decRef();
+            }
+        }
+    }
+
+    private RankFeaturePhaseRankCoordinatorContext defaultRankFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
+        return new RankFeaturePhaseRankCoordinatorContext(size, from, rankWindowSize) {
+
+            @Override
+            protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                // no-op
+                // this one is handled directly in rankGlobalResults to create a RankFeatureDoc
+                // and avoid modifying in-place the ScoreDoc's rank
+            }
+
+            @Override
+            public void rankGlobalResults(List<RankFeatureResult> rankSearchResults, ActionListener<RankFeatureDoc[]> rankListener) {
+                List<RankFeatureDoc> features = new ArrayList<>();
+                for (RankFeatureResult rankFeatureResult : rankSearchResults) {
+                    RankFeatureShardResult shardResult = rankFeatureResult.shardResult();
+                    features.addAll(Arrays.stream(shardResult.rankFeatureDocs).toList());
+                }
+                rankListener.onResponse(features.toArray(new RankFeatureDoc[0]));
+            }
+
+            @Override
+            public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
+                Arrays.sort(rankFeatureDocs, Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
+                RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, rankFeatureDocs.length - from))];
+                // perform pagination
+                for (int rank = 0; rank < topResults.length; ++rank) {
+                    RankFeatureDoc rfd = rankFeatureDocs[from + rank];
+                    topResults[rank] = new RankFeatureDoc(rfd.doc, rfd.score, rfd.shardIndex);
+                    topResults[rank].rank = from + rank + 1;
+                }
+                return topResults;
+            }
+        };
+    }
+
+    private QueryPhaseRankCoordinatorContext negatingScoresQueryFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
+        return new QueryPhaseRankCoordinatorContext(rankWindowSize) {
+            @Override
+            public ScoreDoc[] rankQueryPhaseResults(
+                List<QuerySearchResult> rankSearchResults,
+                SearchPhaseController.TopDocsStats topDocsStats
+            ) {
+                List<ScoreDoc> docScores = new ArrayList<>();
+                for (QuerySearchResult phaseResults : rankSearchResults) {
+                    docScores.addAll(Arrays.asList(phaseResults.topDocs().topDocs.scoreDocs));
+                }
+                ScoreDoc[] sortedDocs = docScores.toArray(new ScoreDoc[0]);
+                // negating scores
+                Arrays.stream(sortedDocs).forEach(doc -> doc.score *= -1);
+
+                Arrays.sort(sortedDocs, Comparator.comparing((ScoreDoc doc) -> doc.score).reversed());
+                sortedDocs = Arrays.stream(sortedDocs).limit(rankWindowSize).toArray(ScoreDoc[]::new);
+                RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))];
+                // perform pagination
+                for (int rank = 0; rank < topResults.length; ++rank) {
+                    ScoreDoc base = sortedDocs[from + rank];
+                    topResults[rank] = new RankFeatureDoc(base.doc, base.score, base.shardIndex);
+                    topResults[rank].rank = from + rank + 1;
+                }
+                topDocsStats.fetchHits = topResults.length;
+                return topResults;
+            }
+        };
+    }
+
+    private RankFeaturePhaseRankShardContext defaultRankFeaturePhaseRankShardContext(String field) {
+        return new RankFeaturePhaseRankShardContext(field) {
+            @Override
+            public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                for (int i = 0; i < hits.getHits().length; i++) {
+                    SearchHit hit = hits.getHits()[i];
+                    rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                    rankFeatureDocs[i].score += 100f;
+                    rankFeatureDocs[i].featureData("ranked_" + hit.docId());
+                    rankFeatureDocs[i].rank = i + 1;
+                }
+                return new RankFeatureShardResult(rankFeatureDocs);
+            }
+        };
+    }
+
+    private QueryPhaseRankCoordinatorContext defaultQueryPhaseRankCoordinatorContext(int rankWindowSize) {
+        return new QueryPhaseRankCoordinatorContext(rankWindowSize) {
+            @Override
+            public ScoreDoc[] rankQueryPhaseResults(
+                List<QuerySearchResult> querySearchResults,
+                SearchPhaseController.TopDocsStats topDocStats
+            ) {
+                List<RankFeatureDoc> rankDocs = new ArrayList<>();
+                for (int i = 0; i < querySearchResults.size(); i++) {
+                    QuerySearchResult querySearchResult = querySearchResults.get(i);
+                    RankFeatureShardResult shardResult = (RankFeatureShardResult) querySearchResult.getRankShardResult();
+                    for (RankFeatureDoc frd : shardResult.rankFeatureDocs) {
+                        frd.shardIndex = i;
+                        rankDocs.add(frd);
+                    }
+                }
+                rankDocs.sort(Comparator.comparing((RankFeatureDoc doc) -> doc.score).reversed());
+                RankFeatureDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(RankFeatureDoc[]::new);
+                topDocStats.fetchHits = topResults.length;
+                return topResults;
+            }
+        };
+    }
+
+    private QueryPhaseRankShardContext defaultQueryPhaseRankShardContext(List<Query> queries, int rankWindowSize) {
+        return new QueryPhaseRankShardContext(queries, rankWindowSize) {
+            @Override
+            public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                throw new UnsupportedOperationException(
+                    "shard-level QueryPhase context should not be accessed as part of the RankFeature phase"
+                );
+            }
+        };
+    }
+
+    private SearchPhaseController searchPhaseController() {
+        return new SearchPhaseController((task, request) -> InternalAggregationTestCase.emptyReduceContextBuilder());
+    }
+
+    private RankBuilder rankBuilder(
+        int rankWindowSize,
+        QueryPhaseRankShardContext queryPhaseRankShardContext,
+        QueryPhaseRankCoordinatorContext queryPhaseRankCoordinatorContext,
+        RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext,
+        RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
+    ) {
+        return new RankBuilder(rankWindowSize) {
+            @Override
+            protected void doWriteTo(StreamOutput out) throws IOException {
+                // no-op
+            }
+
+            @Override
+            protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+                // no-op
+            }
+
+            @Override
+            public boolean isCompoundBuilder() {
+                return true;
+            }
+
+            @Override
+            public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                return queryPhaseRankShardContext;
+            }
+
+            @Override
+            public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                return queryPhaseRankCoordinatorContext;
+            }
+
+            @Override
+            public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                return rankFeaturePhaseRankShardContext;
+            }
+
+            @Override
+            public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                return rankFeaturePhaseRankCoordinatorContext;
+            }
+
+            @Override
+            protected boolean doEquals(RankBuilder other) {
+                return other != null && other.rankWindowSize() == rankWindowSize;
+            }
+
+            @Override
+            protected int doHashCode() {
+                return 0;
+            }
+
+            @Override
+            public String getWriteableName() {
+                return "test-rank-builder";
+            }
+
+            @Override
+            public TransportVersion getMinimalSupportedVersion() {
+                return TransportVersions.V_8_12_0;
+            }
+        };
+    }
+
+    private SearchSourceBuilder searchSourceWithRankBuilder(RankBuilder rankBuilder) {
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(rankBuilder);
+        return searchSourceBuilder;
+    }
+
+    private SearchPhaseResults<SearchPhaseResult> searchPhaseResults(
+        SearchPhaseController controller,
+        MockSearchPhaseContext mockSearchPhaseContext
+    ) {
+        return controller.newSearchPhaseResults(
+            EsExecutors.DIRECT_EXECUTOR_SERVICE,
+            new NoopCircuitBreaker(CircuitBreaker.REQUEST),
+            () -> false,
+            SearchProgressListener.NOOP,
+            mockSearchPhaseContext.getRequest(),
+            mockSearchPhaseContext.numShards,
+            exc -> {}
+        );
+    }
+
+    private void buildRankFeatureResult(
+        RankBuilder shardRankBuilder,
+        RankFeatureResult rankFeatureResult,
+        SearchShardTarget shardTarget,
+        int totalHits,
+        ScoreDoc[] scoreDocs
+    ) {
+        rankFeatureResult.setSearchShardTarget(shardTarget);
+        // these are the SearchHits generated by the FetchFieldPhase processor
+        SearchHit[] searchHits = new SearchHit[scoreDocs.length];
+        float maxScore = Float.MIN_VALUE;
+        for (int i = 0; i < searchHits.length; i++) {
+            searchHits[i] = SearchHit.unpooled(scoreDocs[i].doc);
+            searchHits[i].shard(shardTarget);
+            searchHits[i].score(scoreDocs[i].score);
+            searchHits[i].setDocumentField(DEFAULT_FIELD, new DocumentField(DEFAULT_FIELD, Collections.singletonList(scoreDocs[i].doc)));
+            if (scoreDocs[i].score > maxScore) {
+                maxScore = scoreDocs[i].score;
+            }
+        }
+        SearchHits hits = null;
+        try {
+            hits = SearchHits.unpooled(searchHits, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), maxScore);
+            // construct the appropriate RankFeatureDoc objects based on the rank builder
+            RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardRankBuilder.buildRankFeaturePhaseShardContext();
+            RankFeatureShardResult rankShardResult = (RankFeatureShardResult) rankFeaturePhaseRankShardContext.buildRankFeatureShardResult(
+                hits,
+                shardTarget.getShardId().id()
+            );
+            rankFeatureResult.shardResult(rankShardResult);
+        } finally {
+            if (hits != null) {
+                hits.decRef();
+            }
+        }
+    }
+
+    private void populateQuerySearchResult(QuerySearchResult queryResult, int totalHits, ScoreDoc[] scoreDocs) {
+        // this would have been populated during the QueryPhase by the appropriate QueryPhaseShardContext
+        float maxScore = Float.MIN_VALUE;
+        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[scoreDocs.length];
+        for (int i = 0; i < scoreDocs.length; i++) {
+            if (scoreDocs[i].score > maxScore) {
+                maxScore = scoreDocs[i].score;
+            }
+            rankFeatureDocs[i] = new RankFeatureDoc(scoreDocs[i].doc, scoreDocs[i].score, scoreDocs[i].shardIndex);
+        }
+        queryResult.setRankShardResult(new RankFeatureShardResult(rankFeatureDocs));
+        queryResult.topDocs(
+            new TopDocsAndMaxScore(
+                new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs),
+                maxScore
+
+            ),
+            new DocValueFormat[0]
+        );
+        queryResult.size(totalHits);
+    }
+
+    private RankFeaturePhase rankFeaturePhase(
+        SearchPhaseResults<SearchPhaseResult> results,
+        MockSearchPhaseContext mockSearchPhaseContext,
+        ScoreDoc[][] finalResults,
+        AtomicBoolean phaseDone
+    ) {
+        // override the RankFeaturePhase to skip moving to next phase
+        return new RankFeaturePhase(results, null, mockSearchPhaseContext) {
+            @Override
+            public void moveToNextPhase(
+                SearchPhaseResults<SearchPhaseResult> phaseResults,
+                SearchPhaseController.ReducedQueryPhase reducedQueryPhase
+            ) {
+                // this is called after the RankFeaturePhaseCoordinatorContext has been executed
+                phaseDone.set(true);
+                finalResults[0] = reducedQueryPhase.sortedTopDocs().scoreDocs();
+                logger.debug("Skipping moving to next phase");
+            }
+        };
+    }
+
+    private void assertRankFeatureResults(RankFeatureShardResult rankFeatureShardResult, List<ExpectedRankFeatureDoc> expectedResults) {
+        assertEquals(expectedResults.size(), rankFeatureShardResult.rankFeatureDocs.length);
+        for (int i = 0; i < expectedResults.size(); i++) {
+            ExpectedRankFeatureDoc expected = expectedResults.get(i);
+            RankFeatureDoc actual = rankFeatureShardResult.rankFeatureDocs[i];
+            assertEquals(expected.doc, actual.doc);
+            assertEquals(expected.rank, actual.rank);
+            assertEquals(expected.score, actual.score, 10E-5);
+            assertEquals(expected.featureData, actual.featureData);
+        }
+    }
+
+    private void assertFinalResults(ScoreDoc[] finalResults, List<ExpectedRankFeatureDoc> expectedResults) {
+        assertEquals(expectedResults.size(), finalResults.length);
+        for (int i = 0; i < expectedResults.size(); i++) {
+            ExpectedRankFeatureDoc expected = expectedResults.get(i);
+            RankFeatureDoc actual = (RankFeatureDoc) finalResults[i];
+            assertEquals(expected.doc, actual.doc);
+            assertEquals(expected.rank, actual.rank);
+            assertEquals(expected.score, actual.score, 10E-5);
+        }
+    }
+
+    private void assertShardResults(SearchPhaseResult shardResult, List<ExpectedRankFeatureDoc> expectedShardResults) {
+        assertTrue(shardResult instanceof RankFeatureResult);
+        RankFeatureResult rankResult = (RankFeatureResult) shardResult;
+        assertNotNull(rankResult.rankFeatureResult());
+        assertNull(rankResult.queryResult());
+        assertNotNull(rankResult.rankFeatureResult().shardResult());
+        RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult();
+        assertRankFeatureResults(rankFeatureShardResult, expectedShardResults);
+    }
+}

+ 2 - 2
server/src/test/java/org/elasticsearch/search/DefaultSearchContextTests.java

@@ -644,8 +644,8 @@ public class DefaultSearchContextTests extends MapperServiceTestCase {
         ToLongFunction<String> fieldCardinality = name -> -1;
         for (var resultsType : SearchService.ResultsType.values()) {
             switch (resultsType) {
-                case NONE, FETCH -> assertFalse(
-                    "NONE and FETCH phases do not support parallel collection.",
+                case NONE, RANK_FEATURE, FETCH -> assertFalse(
+                    "NONE, RANK_FEATURE, and FETCH phases do not support parallel collection.",
                     DefaultSearchContext.isParallelCollectionSupportedForResults(
                         resultsType,
                         searchSourceBuilderOrNull,

+ 739 - 14
server/src/test/java/org/elasticsearch/search/SearchServiceTests.java

@@ -13,6 +13,8 @@ import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
 import org.apache.lucene.search.TotalHitCountCollectorManager;
 import org.apache.lucene.store.AlreadyClosedException;
 import org.apache.lucene.util.SetOnce;
@@ -27,6 +29,7 @@ import org.elasticsearch.action.admin.cluster.settings.ClusterUpdateSettingsResp
 import org.elasticsearch.action.search.ClearScrollRequest;
 import org.elasticsearch.action.search.ClosePointInTimeRequest;
 import org.elasticsearch.action.search.OpenPointInTimeRequest;
+import org.elasticsearch.action.search.SearchPhaseController;
 import org.elasticsearch.action.search.SearchPhaseExecutionException;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
@@ -92,6 +95,7 @@ import org.elasticsearch.search.collapse.CollapseBuilder;
 import org.elasticsearch.search.dfs.AggregatedDfs;
 import org.elasticsearch.search.fetch.FetchSearchResult;
 import org.elasticsearch.search.fetch.ShardFetchRequest;
+import org.elasticsearch.search.fetch.ShardFetchSearchRequest;
 import org.elasticsearch.search.fetch.subphase.FieldAndFormat;
 import org.elasticsearch.search.internal.AliasFilter;
 import org.elasticsearch.search.internal.ContextIndexSearcher;
@@ -102,12 +106,26 @@ import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.query.NonCountingTermQuery;
 import org.elasticsearch.search.query.QuerySearchRequest;
 import org.elasticsearch.search.query.QuerySearchResult;
+import org.elasticsearch.search.rank.RankBuilder;
+import org.elasticsearch.search.rank.RankShardResult;
+import org.elasticsearch.search.rank.TestRankBuilder;
+import org.elasticsearch.search.rank.TestRankDoc;
+import org.elasticsearch.search.rank.TestRankShardResult;
+import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
 import org.elasticsearch.search.slice.SliceBuilder;
 import org.elasticsearch.search.suggest.SuggestBuilder;
 import org.elasticsearch.tasks.TaskCancelHelper;
 import org.elasticsearch.tasks.TaskCancelledException;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESSingleNodeTestCase;
+import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.json.JsonXContent;
@@ -115,8 +133,10 @@ import org.junit.Before;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Comparator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Locale;
@@ -136,8 +156,8 @@ import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonList;
 import static org.elasticsearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
-import static org.elasticsearch.indices.cluster.AbstractIndicesClusterStateServiceTestCase.awaitIndexShardCloseAsyncTasks;
 import static org.elasticsearch.indices.cluster.IndicesClusterStateService.AllocatedIndices.IndexRemovalReason.DELETED;
+import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
 import static org.elasticsearch.search.SearchService.QUERY_PHASE_PARALLEL_COLLECTION_ENABLED;
 import static org.elasticsearch.search.SearchService.SEARCH_WORKER_THREADS_ENABLED;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
@@ -371,7 +391,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
                                 -1,
                                 null
                             ),
-                            new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()),
+                            new SearchShardTask(123L, "", "", "", null, emptyMap()),
                             result.delegateFailure((l, r) -> {
                                 r.incRef();
                                 l.onResponse(r);
@@ -387,7 +407,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
                                 null/* not a scroll */
                             );
                             PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
-                            service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()), listener);
+                            service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener);
                             listener.get();
                             if (useScroll) {
                                 // have to free context since this test does not remove the index from IndicesService.
@@ -422,6 +442,711 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         assertEquals(0, totalStats.getFetchCurrent());
     }
 
+    public void testRankFeaturePhaseSearchPhases() throws InterruptedException, ExecutionException {
+        final String indexName = "index";
+        final String rankFeatureFieldName = "field";
+        final String searchFieldName = "search_field";
+        final String searchFieldValue = "some_value";
+        final String fetchFieldName = "fetch_field";
+        final String fetchFieldValue = "fetch_value";
+
+        final int minDocs = 3;
+        final int maxDocs = 10;
+        int numDocs = between(minDocs, maxDocs);
+        createIndex(indexName);
+        // index some documents
+        for (int i = 0; i < numDocs; i++) {
+            prepareIndex(indexName).setId(String.valueOf(i))
+                .setSource(
+                    rankFeatureFieldName,
+                    "aardvark_" + i,
+                    searchFieldName,
+                    searchFieldValue,
+                    fetchFieldName,
+                    fetchFieldValue + "_" + i
+                )
+                .get();
+        }
+        indicesAdmin().prepareRefresh(indexName).get();
+
+        final SearchService service = getInstanceFromNode(SearchService.class);
+
+        final IndicesService indicesService = getInstanceFromNode(IndicesService.class);
+        final IndexService indexService = indicesService.indexServiceSafe(resolveIndex(indexName));
+        final IndexShard indexShard = indexService.getShard(0);
+        SearchShardTask searchTask = new SearchShardTask(123L, "", "", "", null, emptyMap());
+
+        // create a SearchRequest that will return all documents and defines a TestRankBuilder with shard-level only operations
+        SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true)
+            .source(
+                new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue))
+                    .size(DEFAULT_SIZE)
+                    .fetchField(fetchFieldName)
+                    .rankBuilder(
+                        // here we override only the shard-level contexts
+                        new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+                            @Override
+                            public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                                return new QueryPhaseRankShardContext(queries, from) {
+
+                                    @Override
+                                    public int rankWindowSize() {
+                                        return DEFAULT_RANK_WINDOW_SIZE;
+                                    }
+
+                                    @Override
+                                    public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                                        // we know we have just 1 query, so return all the docs from it
+                                        return new TestRankShardResult(
+                                            Arrays.stream(rankResults.get(0).scoreDocs)
+                                                .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex))
+                                                .limit(rankWindowSize())
+                                                .toArray(TestRankDoc[]::new)
+                                        );
+                                    }
+                                };
+                            }
+
+                            @Override
+                            public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                                return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
+                                    @Override
+                                    public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                                        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                                        for (int i = 0; i < hits.getHits().length; i++) {
+                                            SearchHit hit = hits.getHits()[i];
+                                            rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                                            rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                            rankFeatureDocs[i].score = (numDocs - i) + randomFloat();
+                                            rankFeatureDocs[i].rank = i + 1;
+                                        }
+                                        return new RankFeatureShardResult(rankFeatureDocs);
+                                    }
+                                };
+                            }
+                        }
+                    )
+            );
+
+        ShardSearchRequest request = new ShardSearchRequest(
+            OriginalIndices.NONE,
+            searchRequest,
+            indexShard.shardId(),
+            0,
+            1,
+            AliasFilter.EMPTY,
+            1.0f,
+            -1,
+            null
+        );
+        QuerySearchResult queryResult = null;
+        RankFeatureResult rankResult = null;
+        try {
+            // Execute the query phase and store the result in a SearchPhaseResult container using a PlainActionFuture
+            PlainActionFuture<SearchPhaseResult> queryPhaseResults = new PlainActionFuture<>();
+            service.executeQueryPhase(request, searchTask, queryPhaseResults);
+            queryResult = (QuerySearchResult) queryPhaseResults.get();
+
+            // these are the matched docs from the query phase
+            final TestRankDoc[] queryRankDocs = ((TestRankShardResult) queryResult.getRankShardResult()).testRankDocs;
+
+            // assume that we have cut down to these from the coordinator node as the top-docs to run the rank feature phase upon
+            List<Integer> topRankWindowSizeDocs = randomNonEmptySubsetOf(Arrays.stream(queryRankDocs).map(x -> x.doc).toList());
+
+            // now we create a RankFeatureShardRequest to extract feature info for the top-docs above
+            RankFeatureShardRequest rankFeatureShardRequest = new RankFeatureShardRequest(
+                OriginalIndices.NONE,
+                queryResult.getContextId(), // use the context from the query phase
+                request,
+                topRankWindowSizeDocs
+            );
+            PlainActionFuture<RankFeatureResult> rankPhaseResults = new PlainActionFuture<>();
+            service.executeRankFeaturePhase(rankFeatureShardRequest, searchTask, rankPhaseResults);
+            rankResult = rankPhaseResults.get();
+
+            assertNotNull(rankResult);
+            assertNotNull(rankResult.rankFeatureResult());
+            RankFeatureShardResult rankFeatureShardResult = rankResult.rankFeatureResult().shardResult();
+            assertNotNull(rankFeatureShardResult);
+
+            List<Integer> sortedRankWindowDocs = topRankWindowSizeDocs.stream().sorted().toList();
+            assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length);
+            for (int i = 0; i < sortedRankWindowDocs.size(); i++) {
+                assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc);
+                assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i));
+            }
+
+            List<Integer> globalTopKResults = randomNonEmptySubsetOf(
+                Arrays.stream(rankFeatureShardResult.rankFeatureDocs).map(x -> x.doc).toList()
+            );
+
+            // finally let's create a fetch request to bring back fetch info for the top results
+            ShardFetchSearchRequest fetchRequest = new ShardFetchSearchRequest(
+                OriginalIndices.NONE,
+                rankResult.getContextId(),
+                request,
+                globalTopKResults,
+                null,
+                rankResult.getRescoreDocIds(),
+                null
+            );
+
+            // execute fetch phase and perform any validations once we retrieve the response
+            // the difference in how we do assertions here is needed because once the transport service sends back the response
+            // it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null
+            service.executeFetchPhase(fetchRequest, searchTask, new ActionListener<>() {
+                @Override
+                public void onResponse(FetchSearchResult fetchSearchResult) {
+                    assertNotNull(fetchSearchResult);
+                    assertNotNull(fetchSearchResult.hits());
+
+                    int totalHits = fetchSearchResult.hits().getHits().length;
+                    assertEquals(globalTopKResults.size(), totalHits);
+                    for (int i = 0; i < totalHits; i++) {
+                        // rank and score are set by the SearchPhaseController#merge so no need to validate that here
+                        SearchHit hit = fetchSearchResult.hits().getAt(i);
+                        assertNotNull(hit.getFields().get(fetchFieldName));
+                        assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId());
+                    }
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    throw new AssertionError("No failure should have been raised", e);
+                }
+            });
+        } catch (Exception ex) {
+            if (queryResult != null) {
+                if (queryResult.hasReferences()) {
+                    queryResult.decRef();
+                }
+                service.freeReaderContext(queryResult.getContextId());
+            }
+            if (rankResult != null && rankResult.hasReferences()) {
+                rankResult.decRef();
+            }
+            throw ex;
+        }
+    }
+
+    public void testRankFeaturePhaseUsingClient() {
+        final String indexName = "index";
+        final String rankFeatureFieldName = "field";
+        final String searchFieldName = "search_field";
+        final String searchFieldValue = "some_value";
+        final String fetchFieldName = "fetch_field";
+        final String fetchFieldValue = "fetch_value";
+
+        final int minDocs = 4;
+        final int maxDocs = 10;
+        int numDocs = between(minDocs, maxDocs);
+        createIndex(indexName);
+        // index some documents
+        for (int i = 0; i < numDocs; i++) {
+            prepareIndex(indexName).setId(String.valueOf(i))
+                .setSource(
+                    rankFeatureFieldName,
+                    "aardvark_" + i,
+                    searchFieldName,
+                    searchFieldValue,
+                    fetchFieldName,
+                    fetchFieldValue + "_" + i
+                )
+                .get();
+        }
+        indicesAdmin().prepareRefresh(indexName).get();
+
+        ElasticsearchAssertions.assertResponse(
+            client().prepareSearch(indexName)
+                .setSource(
+                    new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue))
+                        .size(2)
+                        .from(2)
+                        .fetchField(fetchFieldName)
+                        .rankBuilder(
+                            // here we override only the shard-level contexts
+                            new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+
+                                // no need for more than one queries
+                                @Override
+                                public boolean isCompoundBuilder() {
+                                    return false;
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                                    return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                                            float[] scores = new float[featureDocs.length];
+                                            for (int i = 0; i < featureDocs.length; i++) {
+                                                scores[i] = featureDocs[i].score;
+                                            }
+                                            scoreListener.onResponse(scores);
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                                    return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        public ScoreDoc[] rankQueryPhaseResults(
+                                            List<QuerySearchResult> querySearchResults,
+                                            SearchPhaseController.TopDocsStats topDocStats
+                                        ) {
+                                            List<TestRankDoc> rankDocs = new ArrayList<>();
+                                            for (int i = 0; i < querySearchResults.size(); i++) {
+                                                QuerySearchResult querySearchResult = querySearchResults.get(i);
+                                                TestRankShardResult shardResult = (TestRankShardResult) querySearchResult
+                                                    .getRankShardResult();
+                                                for (TestRankDoc trd : shardResult.testRankDocs) {
+                                                    trd.shardIndex = i;
+                                                    rankDocs.add(trd);
+                                                }
+                                            }
+                                            rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed());
+                                            TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new);
+                                            topDocStats.fetchHits = topResults.length;
+                                            return topResults;
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                                    return new QueryPhaseRankShardContext(queries, from) {
+
+                                        @Override
+                                        public int rankWindowSize() {
+                                            return DEFAULT_RANK_WINDOW_SIZE;
+                                        }
+
+                                        @Override
+                                        public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                                            // we know we have just 1 query, so return all the docs from it
+                                            return new TestRankShardResult(
+                                                Arrays.stream(rankResults.get(0).scoreDocs)
+                                                    .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex))
+                                                    .limit(rankWindowSize())
+                                                    .toArray(TestRankDoc[]::new)
+                                            );
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                                    return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
+                                        @Override
+                                        public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                                            RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                                            for (int i = 0; i < hits.getHits().length; i++) {
+                                                SearchHit hit = hits.getHits()[i];
+                                                rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                                                rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                                rankFeatureDocs[i].score = randomFloat();
+                                                rankFeatureDocs[i].rank = i + 1;
+                                            }
+                                            return new RankFeatureShardResult(rankFeatureDocs);
+                                        }
+                                    };
+                                }
+                            }
+                        )
+                ),
+            (response) -> {
+                SearchHits hits = response.getHits();
+                assertEquals(hits.getTotalHits().value, numDocs);
+                assertEquals(hits.getHits().length, 2);
+                int index = 0;
+                for (SearchHit hit : hits.getHits()) {
+                    assertEquals(hit.getRank(), 3 + index);
+                    assertTrue(hit.getScore() >= 0);
+                    assertEquals(hit.getFields().get(fetchFieldName).getValue(), fetchFieldValue + "_" + hit.docId());
+                    index++;
+                }
+            }
+        );
+    }
+
+    public void testRankFeaturePhaseExceptionOnCoordinatingNode() {
+        final String indexName = "index";
+        final String rankFeatureFieldName = "field";
+        final String searchFieldName = "search_field";
+        final String searchFieldValue = "some_value";
+        final String fetchFieldName = "fetch_field";
+        final String fetchFieldValue = "fetch_value";
+
+        final int minDocs = 3;
+        final int maxDocs = 10;
+        int numDocs = between(minDocs, maxDocs);
+        createIndex(indexName);
+        // index some documents
+        for (int i = 0; i < numDocs; i++) {
+            prepareIndex(indexName).setId(String.valueOf(i))
+                .setSource(
+                    rankFeatureFieldName,
+                    "aardvark_" + i,
+                    searchFieldName,
+                    searchFieldValue,
+                    fetchFieldName,
+                    fetchFieldValue + "_" + i
+                )
+                .get();
+        }
+        indicesAdmin().prepareRefresh(indexName).get();
+
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> client().prepareSearch(indexName)
+                .setSource(
+                    new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue))
+                        .size(2)
+                        .from(2)
+                        .fetchField(fetchFieldName)
+                        .rankBuilder(new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+
+                            // no need for more than one queries
+                            @Override
+                            public boolean isCompoundBuilder() {
+                                return false;
+                            }
+
+                            @Override
+                            public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                                return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
+                                    @Override
+                                    protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                                        throw new IllegalStateException("should have failed earlier");
+                                    }
+                                };
+                            }
+
+                            @Override
+                            public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                                return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+                                    @Override
+                                    public ScoreDoc[] rankQueryPhaseResults(
+                                        List<QuerySearchResult> querySearchResults,
+                                        SearchPhaseController.TopDocsStats topDocStats
+                                    ) {
+                                        throw new UnsupportedOperationException("simulated failure");
+                                    }
+                                };
+                            }
+
+                            @Override
+                            public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                                return new QueryPhaseRankShardContext(queries, from) {
+
+                                    @Override
+                                    public int rankWindowSize() {
+                                        return DEFAULT_RANK_WINDOW_SIZE;
+                                    }
+
+                                    @Override
+                                    public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                                        // we know we have just 1 query, so return all the docs from it
+                                        return new TestRankShardResult(
+                                            Arrays.stream(rankResults.get(0).scoreDocs)
+                                                .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex))
+                                                .limit(rankWindowSize())
+                                                .toArray(TestRankDoc[]::new)
+                                        );
+                                    }
+                                };
+                            }
+
+                            @Override
+                            public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                                return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
+                                    @Override
+                                    public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                                        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                                        for (int i = 0; i < hits.getHits().length; i++) {
+                                            SearchHit hit = hits.getHits()[i];
+                                            rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                                            rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                            rankFeatureDocs[i].score = randomFloat();
+                                            rankFeatureDocs[i].rank = i + 1;
+                                        }
+                                        return new RankFeatureShardResult(rankFeatureDocs);
+                                    }
+                                };
+                            }
+                        })
+                )
+                .get()
+        );
+    }
+
+    public void testRankFeaturePhaseExceptionAllShardFail() {
+        final String indexName = "index";
+        final String rankFeatureFieldName = "field";
+        final String searchFieldName = "search_field";
+        final String searchFieldValue = "some_value";
+        final String fetchFieldName = "fetch_field";
+        final String fetchFieldValue = "fetch_value";
+
+        final int minDocs = 3;
+        final int maxDocs = 10;
+        int numDocs = between(minDocs, maxDocs);
+        createIndex(indexName);
+        // index some documents
+        for (int i = 0; i < numDocs; i++) {
+            prepareIndex(indexName).setId(String.valueOf(i))
+                .setSource(
+                    rankFeatureFieldName,
+                    "aardvark_" + i,
+                    searchFieldName,
+                    searchFieldValue,
+                    fetchFieldName,
+                    fetchFieldValue + "_" + i
+                )
+                .get();
+        }
+        indicesAdmin().prepareRefresh(indexName).get();
+
+        expectThrows(
+            SearchPhaseExecutionException.class,
+            () -> client().prepareSearch(indexName)
+                .setAllowPartialSearchResults(true)
+                .setSource(
+                    new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue))
+                        .fetchField(fetchFieldName)
+                        .rankBuilder(
+                            // here we override only the shard-level contexts
+                            new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+
+                                // no need for more than one queries
+                                @Override
+                                public boolean isCompoundBuilder() {
+                                    return false;
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                                    return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                                            float[] scores = new float[featureDocs.length];
+                                            for (int i = 0; i < featureDocs.length; i++) {
+                                                scores[i] = featureDocs[i].score;
+                                            }
+                                            scoreListener.onResponse(scores);
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                                    return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        public ScoreDoc[] rankQueryPhaseResults(
+                                            List<QuerySearchResult> querySearchResults,
+                                            SearchPhaseController.TopDocsStats topDocStats
+                                        ) {
+                                            List<TestRankDoc> rankDocs = new ArrayList<>();
+                                            for (int i = 0; i < querySearchResults.size(); i++) {
+                                                QuerySearchResult querySearchResult = querySearchResults.get(i);
+                                                TestRankShardResult shardResult = (TestRankShardResult) querySearchResult
+                                                    .getRankShardResult();
+                                                for (TestRankDoc trd : shardResult.testRankDocs) {
+                                                    trd.shardIndex = i;
+                                                    rankDocs.add(trd);
+                                                }
+                                            }
+                                            rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed());
+                                            TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new);
+                                            topDocStats.fetchHits = topResults.length;
+                                            return topResults;
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                                    return new QueryPhaseRankShardContext(queries, from) {
+
+                                        @Override
+                                        public int rankWindowSize() {
+                                            return DEFAULT_RANK_WINDOW_SIZE;
+                                        }
+
+                                        @Override
+                                        public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                                            // we know we have just 1 query, so return all the docs from it
+                                            return new TestRankShardResult(
+                                                Arrays.stream(rankResults.get(0).scoreDocs)
+                                                    .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex))
+                                                    .limit(rankWindowSize())
+                                                    .toArray(TestRankDoc[]::new)
+                                            );
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                                    return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
+                                        @Override
+                                        public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                                            throw new UnsupportedOperationException("simulated failure");
+                                        }
+                                    };
+                                }
+                            }
+                        )
+                )
+                .get()
+        );
+    }
+
+    public void testRankFeaturePhaseExceptionOneShardFails() {
+        // if we have only one shard and it fails, it will fallback to context.onPhaseFailure which will eventually clean up all contexts.
+        // in this test we want to make sure that even if one shard (of many) fails during the RankFeaturePhase, then the appropriate
+        // context will have been cleaned up.
+        final String indexName = "index";
+        final String rankFeatureFieldName = "field";
+        final String searchFieldName = "search_field";
+        final String searchFieldValue = "some_value";
+        final String fetchFieldName = "fetch_field";
+        final String fetchFieldValue = "fetch_value";
+
+        final int minDocs = 3;
+        final int maxDocs = 10;
+        int numDocs = between(minDocs, maxDocs);
+        createIndex(indexName, Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 2).build());
+        // index some documents
+        for (int i = 0; i < numDocs; i++) {
+            prepareIndex(indexName).setId(String.valueOf(i))
+                .setSource(
+                    rankFeatureFieldName,
+                    "aardvark_" + i,
+                    searchFieldName,
+                    searchFieldValue,
+                    fetchFieldName,
+                    fetchFieldValue + "_" + i
+                )
+                .get();
+        }
+        indicesAdmin().prepareRefresh(indexName).get();
+
+        assertResponse(
+            client().prepareSearch(indexName)
+                .setAllowPartialSearchResults(true)
+                .setSource(
+                    new SearchSourceBuilder().query(new TermQueryBuilder(searchFieldName, searchFieldValue))
+                        .fetchField(fetchFieldName)
+                        .rankBuilder(
+                            // here we override only the shard-level contexts
+                            new TestRankBuilder(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+
+                                // no need for more than one queries
+                                @Override
+                                public boolean isCompoundBuilder() {
+                                    return false;
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                                    return new RankFeaturePhaseRankCoordinatorContext(size, from, DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+                                            float[] scores = new float[featureDocs.length];
+                                            for (int i = 0; i < featureDocs.length; i++) {
+                                                scores[i] = featureDocs[i].score;
+                                            }
+                                            scoreListener.onResponse(scores);
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                                    return new QueryPhaseRankCoordinatorContext(RankBuilder.DEFAULT_RANK_WINDOW_SIZE) {
+                                        @Override
+                                        public ScoreDoc[] rankQueryPhaseResults(
+                                            List<QuerySearchResult> querySearchResults,
+                                            SearchPhaseController.TopDocsStats topDocStats
+                                        ) {
+                                            List<TestRankDoc> rankDocs = new ArrayList<>();
+                                            for (int i = 0; i < querySearchResults.size(); i++) {
+                                                QuerySearchResult querySearchResult = querySearchResults.get(i);
+                                                TestRankShardResult shardResult = (TestRankShardResult) querySearchResult
+                                                    .getRankShardResult();
+                                                for (TestRankDoc trd : shardResult.testRankDocs) {
+                                                    trd.shardIndex = i;
+                                                    rankDocs.add(trd);
+                                                }
+                                            }
+                                            rankDocs.sort(Comparator.comparing((TestRankDoc doc) -> doc.score).reversed());
+                                            TestRankDoc[] topResults = rankDocs.stream().limit(rankWindowSize).toArray(TestRankDoc[]::new);
+                                            topDocStats.fetchHits = topResults.length;
+                                            return topResults;
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                                    return new QueryPhaseRankShardContext(queries, from) {
+
+                                        @Override
+                                        public int rankWindowSize() {
+                                            return DEFAULT_RANK_WINDOW_SIZE;
+                                        }
+
+                                        @Override
+                                        public RankShardResult combineQueryPhaseResults(List<TopDocs> rankResults) {
+                                            // we know we have just 1 query, so return all the docs from it
+                                            return new TestRankShardResult(
+                                                Arrays.stream(rankResults.get(0).scoreDocs)
+                                                    .map(x -> new TestRankDoc(x.doc, x.score, x.shardIndex))
+                                                    .limit(rankWindowSize())
+                                                    .toArray(TestRankDoc[]::new)
+                                            );
+                                        }
+                                    };
+                                }
+
+                                @Override
+                                public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                                    return new RankFeaturePhaseRankShardContext(rankFeatureFieldName) {
+                                        @Override
+                                        public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                                            if (shardId == 0) {
+                                                throw new UnsupportedOperationException("simulated failure");
+                                            } else {
+                                                RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                                                for (int i = 0; i < hits.getHits().length; i++) {
+                                                    SearchHit hit = hits.getHits()[i];
+                                                    rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                                                    rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                                    rankFeatureDocs[i].score = randomFloat();
+                                                    rankFeatureDocs[i].rank = i + 1;
+                                                }
+                                                return new RankFeatureShardResult(rankFeatureDocs);
+                                            }
+                                        }
+                                    };
+                                }
+                            }
+                        )
+                ),
+            (searchResponse) -> {
+                assertEquals(1, searchResponse.getSuccessfulShards());
+                assertEquals("simulated failure", searchResponse.getShardFailures()[0].getCause().getMessage());
+                assertNotEquals(0, searchResponse.getHits().getHits().length);
+                for (SearchHit hit : searchResponse.getHits().getHits()) {
+                    assertEquals(fetchFieldValue + "_" + hit.getId(), hit.getFields().get(fetchFieldName).getValue());
+                    assertEquals(1, hit.getShard().getShardId().id());
+                }
+            }
+        );
+    }
+
     public void testSearchWhileIndexDeletedDoesNotLeakSearchContext() throws ExecutionException, InterruptedException {
         createIndex("index");
         prepareIndex("index").setId("1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get();
@@ -457,7 +1182,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
                 -1,
                 null
             ),
-            new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()),
+            new SearchShardTask(123L, "", "", "", null, emptyMap()),
             result
         );
 
@@ -694,7 +1419,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         for (int i = 0; i < maxScriptFields; i++) {
             searchSourceBuilder.scriptField(
                 "field" + i,
-                new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap())
+                new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap())
             );
         }
         final ShardSearchRequest request = new ShardSearchRequest(
@@ -723,7 +1448,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
             }
             searchSourceBuilder.scriptField(
                 "anotherScriptField",
-                new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap())
+                new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap())
             );
             IllegalArgumentException ex = expectThrows(
                 IllegalArgumentException.class,
@@ -752,7 +1477,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         searchRequest.source(searchSourceBuilder);
         searchSourceBuilder.scriptField(
             "field" + 0,
-            new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, Collections.emptyMap())
+            new Script(ScriptType.INLINE, MockScriptEngine.NAME, CustomScriptPlugin.DUMMY_SCRIPT, emptyMap())
         );
         searchSourceBuilder.size(0);
         final ShardSearchRequest request = new ShardSearchRequest(
@@ -1036,7 +1761,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         );
 
         CountDownLatch latch = new CountDownLatch(1);
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
         // Because the foo field used in alias filter is unmapped the term query builder rewrite can resolve to a match no docs query,
         // without acquiring a searcher and that means the wrapper is not called
         assertEquals(5, numWrapInvocations.get());
@@ -1330,7 +2055,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
             0,
             null
         );
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
 
         {
             CountDownLatch latch = new CountDownLatch(1);
@@ -1705,7 +2430,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
 
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
         ShardSearchRequest request = new ShardSearchRequest(
             OriginalIndices.NONE,
             searchRequest,
@@ -1740,7 +2465,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
 
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
         PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
         ShardSearchRequest request = new ShardSearchRequest(
             OriginalIndices.NONE,
@@ -1778,7 +2503,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
 
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
         PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
         ShardSearchRequest request = new ShardSearchRequest(
             OriginalIndices.NONE,
@@ -1815,7 +2540,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         final DocWriteResponse response = prepareIndex("index").setSource("id", "1").get();
         assertEquals(RestStatus.CREATED, response.status());
 
-        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap());
+        SearchShardTask task = new SearchShardTask(123L, "", "", "", null, emptyMap());
         PlainActionFuture<SearchPhaseResult> future = new PlainActionFuture<>();
         ShardSearchRequest request = new ShardSearchRequest(
             OriginalIndices.NONE,
@@ -1901,7 +2626,7 @@ public class SearchServiceTests extends ESSingleNodeTestCase {
         PlainActionFuture<QuerySearchResult> plainActionFuture = new PlainActionFuture<>();
         service.executeQueryPhase(
             new QuerySearchRequest(null, context.id(), request, new AggregatedDfs(Map.of(), Map.of(), 10)),
-            new SearchShardTask(42L, "", "", "", null, Collections.emptyMap()),
+            new SearchShardTask(42L, "", "", "", null, emptyMap()),
             plainActionFuture
         );
 

+ 409 - 0
server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java

@@ -0,0 +1,409 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.rank;
+
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TotalHits;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.document.DocumentField;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.Index;
+import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.index.shard.ShardId;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.SearchShardTarget;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.fetch.FetchSearchResult;
+import org.elasticsearch.search.fetch.StoredFieldsContext;
+import org.elasticsearch.search.fetch.subphase.FetchFieldsContext;
+import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
+import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
+import org.elasticsearch.tasks.TaskCancelledException;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.TestSearchContext;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+
+import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+public class RankFeatureShardPhaseTests extends ESTestCase {
+
+    private SearchContext getSearchContext() {
+        return new TestSearchContext((SearchExecutionContext) null) {
+
+            private FetchSearchResult fetchResult;
+            private RankFeatureResult rankFeatureResult;
+            private FetchFieldsContext fetchFieldsContext;
+            private StoredFieldsContext storedFieldsContext;
+
+            @Override
+            public FetchSearchResult fetchResult() {
+                return fetchResult;
+            }
+
+            @Override
+            public void addFetchResult() {
+                this.fetchResult = new FetchSearchResult();
+                this.addReleasable(fetchResult::decRef);
+            }
+
+            @Override
+            public RankFeatureResult rankFeatureResult() {
+                return rankFeatureResult;
+            }
+
+            @Override
+            public void addRankFeatureResult() {
+                this.rankFeatureResult = new RankFeatureResult();
+                this.addReleasable(rankFeatureResult::decRef);
+            }
+
+            @Override
+            public SearchContext fetchFieldsContext(FetchFieldsContext fetchFieldsContext) {
+                this.fetchFieldsContext = fetchFieldsContext;
+                return this;
+            }
+
+            @Override
+            public FetchFieldsContext fetchFieldsContext() {
+                return fetchFieldsContext;
+            }
+
+            @Override
+            public SearchContext storedFieldsContext(StoredFieldsContext storedFieldsContext) {
+                this.storedFieldsContext = storedFieldsContext;
+                return this;
+            }
+
+            @Override
+            public StoredFieldsContext storedFieldsContext() {
+                return storedFieldsContext;
+            }
+
+            @Override
+            public boolean isCancelled() {
+                return false;
+            }
+        };
+    }
+
+    private RankBuilder getRankBuilder(final String field) {
+        return new RankBuilder(DEFAULT_RANK_WINDOW_SIZE) {
+            @Override
+            protected void doWriteTo(StreamOutput out) throws IOException {
+                // no-op
+            }
+
+            @Override
+            protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+                // no-op
+            }
+
+            @Override
+            public boolean isCompoundBuilder() {
+                return false;
+            }
+
+            // no work to be done on the query phase
+            @Override
+            public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
+                return null;
+            }
+
+            // no work to be done on the query phase
+            @Override
+            public QueryPhaseRankCoordinatorContext buildQueryPhaseCoordinatorContext(int size, int from) {
+                return null;
+            }
+
+            @Override
+            public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+                return new RankFeaturePhaseRankShardContext(field) {
+                    @Override
+                    public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
+                        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+                        for (int i = 0; i < hits.getHits().length; i++) {
+                            SearchHit hit = hits.getHits()[i];
+                            rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
+                            rankFeatureDocs[i].featureData(hit.getFields().get(field).getValue());
+                            rankFeatureDocs[i].rank = i + 1;
+                        }
+                        return new RankFeatureShardResult(rankFeatureDocs);
+                    }
+                };
+            }
+
+            // no work to be done on the coordinator node for the rank feature phase
+            @Override
+            public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+                return null;
+            }
+
+            @Override
+            protected boolean doEquals(RankBuilder other) {
+                return false;
+            }
+
+            @Override
+            protected int doHashCode() {
+                return 0;
+            }
+
+            @Override
+            public String getWriteableName() {
+                return "rank_builder_rank_feature_shard_phase_enabled";
+            }
+
+            @Override
+            public TransportVersion getMinimalSupportedVersion() {
+                return TransportVersions.RANK_FEATURE_PHASE_ADDED;
+            }
+        };
+    }
+
+    public void testPrepareForFetch() {
+
+        final String fieldName = "some_field";
+        int numDocs = randomIntBetween(10, 30);
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            when(searchContext.isCancelled()).thenReturn(false);
+            when(searchContext.request()).thenReturn(searchRequest);
+
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            rankFeatureShardPhase.prepareForFetch(searchContext, request);
+
+            assertNotNull(searchContext.fetchFieldsContext());
+            assertEquals(searchContext.fetchFieldsContext().fields().size(), 1);
+            assertEquals(searchContext.fetchFieldsContext().fields().get(0).field, fieldName);
+            assertNotNull(searchContext.storedFieldsContext());
+            assertNull(searchContext.storedFieldsContext().fieldNames());
+            assertFalse(searchContext.storedFieldsContext().fetchFields());
+            assertNotNull(searchContext.fetchResult());
+        }
+    }
+
+    public void testPrepareForFetchNoRankFeatureContext() {
+        int numDocs = randomIntBetween(10, 30);
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(null);
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            when(searchContext.isCancelled()).thenReturn(false);
+            when(searchContext.request()).thenReturn(searchRequest);
+
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            rankFeatureShardPhase.prepareForFetch(searchContext, request);
+
+            assertNull(searchContext.fetchFieldsContext());
+            assertNull(searchContext.fetchResult());
+        }
+    }
+
+    public void testPrepareForFetchWhileTaskIsCancelled() {
+
+        final String fieldName = "some_field";
+        int numDocs = randomIntBetween(10, 30);
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            when(searchContext.isCancelled()).thenReturn(true);
+            when(searchContext.request()).thenReturn(searchRequest);
+
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.prepareForFetch(searchContext, request));
+        }
+    }
+
+    public void testProcessFetch() {
+        final String fieldName = "some_field";
+        int numDocs = randomIntBetween(10, 30);
+        Map<Integer, String> expectedFieldData = Map.of(4, "doc_4_aardvark", 9, "doc_9_aardvark", numDocs - 1, "last_doc_aardvark");
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        SearchShardTarget shardTarget = new SearchShardTarget(
+            "node_id",
+            new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0),
+            null
+        );
+        SearchHits searchHits = null;
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            searchContext.addFetchResult();
+            SearchHit[] hits = new SearchHit[3];
+            hits[0] = SearchHit.unpooled(4);
+            hits[0].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(4))));
+
+            hits[1] = SearchHit.unpooled(9);
+            hits[1].setDocumentField(fieldName, new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(9))));
+
+            hits[2] = SearchHit.unpooled(numDocs - 1);
+            hits[2].setDocumentField(
+                fieldName,
+                new DocumentField(fieldName, Collections.singletonList(expectedFieldData.get(numDocs - 1)))
+            );
+            searchHits = SearchHits.unpooled(hits, new TotalHits(3, TotalHits.Relation.EQUAL_TO), 1.0f);
+            searchContext.fetchResult().shardResult(searchHits, null);
+            when(searchContext.isCancelled()).thenReturn(false);
+            when(searchContext.request()).thenReturn(searchRequest);
+            when(searchContext.shardTarget()).thenReturn(shardTarget);
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            // this is called as part of the search context initialization
+            // with the ResultsType.RANK_FEATURE type
+            searchContext.addRankFeatureResult();
+            rankFeatureShardPhase.processFetch(searchContext);
+
+            assertNotNull(searchContext.rankFeatureResult());
+            assertNotNull(searchContext.rankFeatureResult().rankFeatureResult());
+            for (RankFeatureDoc rankFeatureDoc : searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs) {
+                assertTrue(expectedFieldData.containsKey(rankFeatureDoc.doc));
+                assertEquals(rankFeatureDoc.featureData, expectedFieldData.get(rankFeatureDoc.doc));
+            }
+        } finally {
+            if (searchHits != null) {
+                searchHits.decRef();
+            }
+        }
+    }
+
+    public void testProcessFetchEmptyHits() {
+        final String fieldName = "some_field";
+        int numDocs = randomIntBetween(10, 30);
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        SearchShardTarget shardTarget = new SearchShardTarget(
+            "node_id",
+            new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0),
+            null
+        );
+
+        SearchHits searchHits = null;
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            searchContext.addFetchResult();
+            SearchHit[] hits = new SearchHit[0];
+            searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f);
+            searchContext.fetchResult().shardResult(searchHits, null);
+            when(searchContext.isCancelled()).thenReturn(false);
+            when(searchContext.request()).thenReturn(searchRequest);
+            when(searchContext.shardTarget()).thenReturn(shardTarget);
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            // this is called as part of the search context initialization
+            // with the ResultsType.RANK_FEATURE type
+            searchContext.addRankFeatureResult();
+            rankFeatureShardPhase.processFetch(searchContext);
+
+            assertNotNull(searchContext.rankFeatureResult());
+            assertNotNull(searchContext.rankFeatureResult().rankFeatureResult());
+            assertEquals(searchContext.rankFeatureResult().rankFeatureResult().shardResult().rankFeatureDocs.length, 0);
+        } finally {
+            if (searchHits != null) {
+                searchHits.decRef();
+            }
+        }
+    }
+
+    public void testProcessFetchWhileTaskIsCancelled() {
+
+        final String fieldName = "some_field";
+        int numDocs = randomIntBetween(10, 30);
+
+        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
+        searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));
+
+        ShardSearchRequest searchRequest = mock(ShardSearchRequest.class);
+        when(searchRequest.source()).thenReturn(searchSourceBuilder);
+
+        SearchShardTarget shardTarget = new SearchShardTarget(
+            "node_id",
+            new ShardId(new Index("some_index", UUID.randomUUID().toString()), 0),
+            null
+        );
+
+        SearchHits searchHits = null;
+        try (SearchContext searchContext = spy(getSearchContext())) {
+            searchContext.addFetchResult();
+            SearchHit[] hits = new SearchHit[0];
+            searchHits = SearchHits.unpooled(hits, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 1.0f);
+            searchContext.fetchResult().shardResult(searchHits, null);
+            when(searchContext.isCancelled()).thenReturn(true);
+            when(searchContext.request()).thenReturn(searchRequest);
+            when(searchContext.shardTarget()).thenReturn(shardTarget);
+            RankFeatureShardRequest request = mock(RankFeatureShardRequest.class);
+            when(request.getDocIds()).thenReturn(new int[] { 4, 9, numDocs - 1 });
+
+            RankFeatureShardPhase rankFeatureShardPhase = new RankFeatureShardPhase();
+            // this is called as part of the search context initialization
+            // with the ResultsType.RANK_FEATURE type
+            searchContext.addRankFeatureResult();
+            expectThrows(TaskCancelledException.class, () -> rankFeatureShardPhase.processFetch(searchContext));
+        } finally {
+            if (searchHits != null) {
+                searchHits.decRef();
+            }
+        }
+    }
+}

+ 2 - 0
server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java

@@ -178,6 +178,7 @@ import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.fetch.FetchPhase;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
 import org.elasticsearch.telemetry.TelemetryProvider;
 import org.elasticsearch.telemetry.tracing.Tracer;
 import org.elasticsearch.test.ClusterServiceUtils;
@@ -2249,6 +2250,7 @@ public class SnapshotResiliencyTests extends ESTestCase {
                     threadPool,
                     scriptService,
                     bigArrays,
+                    new RankFeatureShardPhase(),
                     new FetchPhase(Collections.emptyList()),
                     responseCollectorService,
                     new NoneCircuitBreakerService(),

+ 4 - 0
test/framework/src/main/java/org/elasticsearch/node/MockNode.java

@@ -40,6 +40,7 @@ import org.elasticsearch.script.ScriptService;
 import org.elasticsearch.search.MockSearchService;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.fetch.FetchPhase;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
 import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.telemetry.tracing.Tracer;
 import org.elasticsearch.test.ESTestCase;
@@ -97,6 +98,7 @@ public class MockNode extends Node {
             ThreadPool threadPool,
             ScriptService scriptService,
             BigArrays bigArrays,
+            RankFeatureShardPhase rankFeatureShardPhase,
             FetchPhase fetchPhase,
             ResponseCollectorService responseCollectorService,
             CircuitBreakerService circuitBreakerService,
@@ -111,6 +113,7 @@ public class MockNode extends Node {
                     threadPool,
                     scriptService,
                     bigArrays,
+                    rankFeatureShardPhase,
                     fetchPhase,
                     responseCollectorService,
                     circuitBreakerService,
@@ -124,6 +127,7 @@ public class MockNode extends Node {
                 threadPool,
                 scriptService,
                 bigArrays,
+                rankFeatureShardPhase,
                 fetchPhase,
                 responseCollectorService,
                 circuitBreakerService,

+ 3 - 0
test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java

@@ -23,6 +23,7 @@ import org.elasticsearch.search.fetch.FetchPhase;
 import org.elasticsearch.search.internal.ReaderContext;
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.internal.ShardSearchRequest;
+import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
 import org.elasticsearch.telemetry.tracing.Tracer;
 import org.elasticsearch.threadpool.ThreadPool;
 
@@ -81,6 +82,7 @@ public class MockSearchService extends SearchService {
         ThreadPool threadPool,
         ScriptService scriptService,
         BigArrays bigArrays,
+        RankFeatureShardPhase rankFeatureShardPhase,
         FetchPhase fetchPhase,
         ResponseCollectorService responseCollectorService,
         CircuitBreakerService circuitBreakerService,
@@ -93,6 +95,7 @@ public class MockSearchService extends SearchService {
             threadPool,
             scriptService,
             bigArrays,
+            rankFeatureShardPhase,
             fetchPhase,
             responseCollectorService,
             circuitBreakerService,

+ 18 - 1
test/framework/src/main/java/org/elasticsearch/search/rank/TestRankBuilder.java

@@ -15,6 +15,8 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
@@ -31,7 +33,7 @@ public class TestRankBuilder extends RankBuilder {
 
     static final ConstructingObjectParser<TestRankBuilder, Void> PARSER = new ConstructingObjectParser<>(
         NAME,
-        args -> new TestRankBuilder(args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0])
+        args -> new TestRankBuilder(args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0])
     );
 
     static {
@@ -74,6 +76,11 @@ public class TestRankBuilder extends RankBuilder {
         // do nothing
     }
 
+    @Override
+    public boolean isCompoundBuilder() {
+        return true;
+    }
+
     @Override
     public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
         throw new UnsupportedOperationException();
@@ -84,6 +91,16 @@ public class TestRankBuilder extends RankBuilder {
         throw new UnsupportedOperationException();
     }
 
+    @Override
+    public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+        throw new UnsupportedOperationException();
+    }
+
     @Override
     protected boolean doEquals(RankBuilder other) {
         return true;

+ 11 - 0
test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java

@@ -44,6 +44,7 @@ import org.elasticsearch.search.internal.ShardSearchRequest;
 import org.elasticsearch.search.profile.Profilers;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.feature.RankFeatureResult;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.sort.SortAndFormats;
 import org.elasticsearch.search.suggest.SuggestionSearchContext;
@@ -463,6 +464,16 @@ public class TestSearchContext extends SearchContext {
         return queryResult.getMaxScore();
     }
 
+    @Override
+    public void addRankFeatureResult() {
+        // this space intentionally left blank
+    }
+
+    @Override
+    public RankFeatureResult rankFeatureResult() {
+        return null;
+    }
+
     @Override
     public FetchSearchResult fetchResult() {
         return null;

+ 4 - 0
test/framework/src/main/java/org/elasticsearch/test/hamcrest/ElasticsearchAssertions.java

@@ -687,6 +687,10 @@ public class ElasticsearchAssertions {
         return transformedMatch(SearchHit::getScore, equalTo(score));
     }
 
+    public static Matcher<SearchHit> hasRank(final int rank) {
+        return transformedMatch(SearchHit::getRank, equalTo(rank));
+    }
+
     public static <T extends Query> T assertBooleanSubQuery(Query query, Class<T> subqueryType, int i) {
         assertThat(query, instanceOf(BooleanQuery.class));
         BooleanQuery q = (BooleanQuery) query;

+ 14 - 0
x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

@@ -397,6 +397,14 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
             }
         }
 
+        @Override
+        protected void onRankFeatureResult(int shardIndex) {
+            checkCancellation();
+            if (delegate != null) {
+                delegate.onRankFeatureResult(shardIndex);
+            }
+        }
+
         @Override
         protected void onFetchResult(int shardIndex) {
             checkCancellation();
@@ -420,6 +428,12 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
                 );
         }
 
+        @Override
+        protected void onRankFeatureFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
+            // best effort to cancel expired tasks
+            checkCancellation();
+        }
+
         @Override
         protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
             // best effort to cancel expired tasks

+ 18 - 1
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankBuilder.java

@@ -16,6 +16,8 @@ import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.search.rank.RankBuilder;
 import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
 import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
+import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -38,7 +40,7 @@ public class RRFRankBuilder extends RankBuilder {
     public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant");
 
     static final ConstructingObjectParser<RRFRankBuilder, Void> PARSER = new ConstructingObjectParser<>(RRFRankPlugin.NAME, args -> {
-        int windowSize = args[0] == null ? DEFAULT_WINDOW_SIZE : (int) args[0];
+        int windowSize = args[0] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[0];
         int rankConstant = args[1] == null ? DEFAULT_RANK_CONSTANT : (int) args[1];
         if (rankConstant < 1) {
             throw new IllegalArgumentException("[rank_constant] must be greater than [0] for [rrf]");
@@ -94,6 +96,11 @@ public class RRFRankBuilder extends RankBuilder {
         return rankConstant;
     }
 
+    @Override
+    public boolean isCompoundBuilder() {
+        return true;
+    }
+
     public QueryPhaseRankShardContext buildQueryPhaseShardContext(List<Query> queries, int from) {
         return new RRFQueryPhaseRankShardContext(queries, rankWindowSize(), rankConstant);
     }
@@ -103,6 +110,16 @@ public class RRFRankBuilder extends RankBuilder {
         return new RRFQueryPhaseRankCoordinatorContext(size, from, rankWindowSize(), rankConstant);
     }
 
+    @Override
+    public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
+        return null;
+    }
+
+    @Override
+    public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorContext(int size, int from) {
+        return null;
+    }
+
     @Override
     protected boolean doEquals(RankBuilder other) {
         return Objects.equals(rankConstant, ((RRFRankBuilder) other).rankConstant);

+ 1 - 1
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

@@ -71,7 +71,7 @@ public final class RRFRetrieverBuilder extends RetrieverBuilder {
     }
 
     List<RetrieverBuilder> retrieverBuilders = Collections.emptyList();
-    int rankWindowSize = RRFRankBuilder.DEFAULT_WINDOW_SIZE;
+    int rankWindowSize = RRFRankBuilder.DEFAULT_RANK_WINDOW_SIZE;
     int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT;
 
     @Override

+ 1 - 0
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/PreAuthorizationUtils.java

@@ -45,6 +45,7 @@ public final class PreAuthorizationUtils {
             SearchTransportService.QUERY_ACTION_NAME,
             SearchTransportService.QUERY_ID_ACTION_NAME,
             SearchTransportService.FETCH_ID_ACTION_NAME,
+            SearchTransportService.RANK_FEATURE_SHARD_ACTION_NAME,
             SearchTransportService.QUERY_CAN_MATCH_NODE_NAME
         )
     );