Browse Source

Support semantic reranking using contextual snippets instead of entire field text (#129369)

Kathleen DeRusso 2 months ago
parent
commit
37c27e0368
34 changed files with 893 additions and 139 deletions
  1. 6 0
      docs/changelog/129369.yaml
  2. 2 1
      qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java
  3. 1 0
      qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java
  4. 1 0
      qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java
  5. 1 0
      rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java
  6. 2 2
      server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java
  7. 1 1
      server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java
  8. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  9. 16 5
      server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java
  10. 1 1
      server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java
  11. 1 1
      server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java
  12. 5 0
      server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java
  13. 10 0
      server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java
  14. 17 6
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java
  15. 4 4
      server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java
  16. 14 9
      server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java
  17. 20 19
      server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java
  18. 12 5
      server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java
  19. 14 2
      server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java
  20. 2 1
      test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java
  21. 38 29
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
  22. 88 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java
  23. 88 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java
  24. 67 18
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java
  25. 54 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java
  26. 97 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java
  27. 73 4
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java
  28. 3 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java
  29. 2 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java
  30. 2 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java
  31. 10 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java
  32. 5 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java
  33. 234 4
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml
  34. 1 0
      x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java

+ 6 - 0
docs/changelog/129369.yaml

@@ -0,0 +1,6 @@
+pr: 129369
+summary: Support semantic reranking using contextual snippets instead of entire field
+  text
+area: Relevance
+type: enhancement
+issues: []

+ 2 - 1
qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/CcsCommonYamlTestSuiteIT.java

@@ -92,7 +92,8 @@ public class CcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
         .feature(FeatureFlag.TIME_SERIES_MODE)
         .feature(FeatureFlag.SUB_OBJECTS_AUTO_ENABLED)
         .feature(FeatureFlag.IVF_FORMAT)
-        .feature(FeatureFlag.SYNTHETIC_VECTORS);
+        .feature(FeatureFlag.SYNTHETIC_VECTORS)
+        .feature(FeatureFlag.RERANK_SNIPPETS);
 
     private static ElasticsearchCluster remoteCluster = ElasticsearchCluster.local()
         .name(REMOTE_CLUSTER_NAME)

+ 1 - 0
qa/ccs-common-rest/src/yamlRestTest/java/org/elasticsearch/test/rest/yaml/RcsCcsCommonYamlTestSuiteIT.java

@@ -94,6 +94,7 @@ public class RcsCcsCommonYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
         .feature(FeatureFlag.SUB_OBJECTS_AUTO_ENABLED)
         .feature(FeatureFlag.IVF_FORMAT)
         .feature(FeatureFlag.SYNTHETIC_VECTORS)
+        .feature(FeatureFlag.RERANK_SNIPPETS)
         .user("test_admin", "x-pack-test-password");
 
     private static ElasticsearchCluster fulfillingCluster = ElasticsearchCluster.local()

+ 1 - 0
qa/smoke-test-multinode/src/yamlRestTest/java/org/elasticsearch/smoketest/SmokeTestMultiNodeClientYamlTestSuiteIT.java

@@ -40,6 +40,7 @@ public class SmokeTestMultiNodeClientYamlTestSuiteIT extends ESClientYamlSuiteTe
         .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT)
         .feature(FeatureFlag.IVF_FORMAT)
         .feature(FeatureFlag.SYNTHETIC_VECTORS)
+        .feature(FeatureFlag.RERANK_SNIPPETS)
         .build();
 
     public SmokeTestMultiNodeClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {

+ 1 - 0
rest-api-spec/src/yamlRestTest/java/org/elasticsearch/test/rest/ClientYamlTestSuiteIT.java

@@ -40,6 +40,7 @@ public class ClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
         .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT)
         .feature(FeatureFlag.IVF_FORMAT)
         .feature(FeatureFlag.SYNTHETIC_VECTORS)
+        .feature(FeatureFlag.RERANK_SNIPPETS)
         .build();
 
     public ClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {

+ 2 - 2
server/src/internalClusterTest/java/org/elasticsearch/search/rank/FieldBasedRerankerIT.java

@@ -193,7 +193,7 @@ public class FieldBasedRerankerIT extends AbstractRerankerIT {
                         RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
                         for (int i = 0; i < hits.getHits().length; i++) {
                             rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
-                            rankFeatureDocs[i].featureData(hits.getHits()[i].field(field).getValue().toString());
+                            rankFeatureDocs[i].featureData(List.of(hits.getHits()[i].field(field).getValue().toString()));
                         }
                         return new RankFeatureShardResult(rankFeatureDocs);
                     } catch (Exception ex) {
@@ -210,7 +210,7 @@ public class FieldBasedRerankerIT extends AbstractRerankerIT {
                 protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
                     float[] scores = new float[featureDocs.length];
                     for (int i = 0; i < featureDocs.length; i++) {
-                        scores[i] = Float.parseFloat(featureDocs[i].featureData);
+                        scores[i] = Float.parseFloat(featureDocs[i].featureData.get(0));
                     }
                     scoreListener.onResponse(scores);
                 }

+ 1 - 1
server/src/internalClusterTest/java/org/elasticsearch/search/rank/MockedRequestActionBasedRerankerIT.java

@@ -275,7 +275,7 @@ public class MockedRequestActionBasedRerankerIT extends AbstractRerankerIT {
                 l.onResponse(scores);
             });
 
-            List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
+            List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).flatMap(List::stream).toList();
             TestRerankingActionRequest request = generateRequest(featureData);
             try {
                 ActionType<TestRerankingActionResponse> action = actionType();

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -352,6 +352,7 @@ public class TransportVersions {
     public static final TransportVersion ESQL_SAMPLE_OPERATOR_STATUS = def(9_127_0_00);
     public static final TransportVersion ESQL_TOPN_TIMINGS = def(9_128_0_00);
     public static final TransportVersion NODE_WEIGHTS_ADDED_TO_NODE_BALANCE_STATS = def(9_129_0_00);
+    public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 16 - 5
server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

@@ -1187,9 +1187,11 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
                 sliceBuilder,
                 sorts,
                 rescoreBuilders,
-                highlightBuilder
+                highlightBuilder,
+                rankBuilder
             )
         ));
+
         if (retrieverBuilder != null) {
             var newRetriever = retrieverBuilder.rewrite(context);
             if (newRetriever != retrieverBuilder) {
@@ -1205,6 +1207,11 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             }
         }
 
+        RankBuilder rankBuilder = null;
+        if (this.rankBuilder != null) {
+            rankBuilder = this.rankBuilder.rewrite(context);
+        }
+
         List<SubSearchSourceBuilder> subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context);
         QueryBuilder postQueryBuilder = null;
         if (this.postQueryBuilder != null) {
@@ -1229,7 +1236,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             || aggregations != this.aggregations
             || rescoreBuilders != this.rescoreBuilders
             || sorts != this.sorts
-            || this.highlightBuilder != highlightBuilder;
+            || this.highlightBuilder != highlightBuilder
+            || this.rankBuilder != rankBuilder;
         if (rewritten) {
             return shallowCopy(
                 subSearchSourceBuilders,
@@ -1239,7 +1247,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
                 this.sliceBuilder,
                 sorts,
                 rescoreBuilders,
-                highlightBuilder
+                highlightBuilder,
+                rankBuilder
             );
         }
         return this;
@@ -1257,7 +1266,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             sliceBuilder,
             sorts,
             rescoreBuilders,
-            highlightBuilder
+            highlightBuilder,
+            rankBuilder
         );
     }
 
@@ -1274,7 +1284,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         SliceBuilder slice,
         List<SortBuilder<?>> sorts,
         List<RescorerBuilder> rescoreBuilders,
-        HighlightBuilder highlightBuilder
+        HighlightBuilder highlightBuilder,
+        RankBuilder rankBuilder
     ) {
         SearchSourceBuilder rewrittenBuilder = new SearchSourceBuilder();
         rewrittenBuilder.aggregations = aggregations;

+ 1 - 1
server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/HighlightBuilder.java

@@ -399,7 +399,7 @@ public final class HighlightBuilder extends AbstractHighlighterBuilder<Highlight
             this.name = name;
         }
 
-        private Field(Field template, QueryBuilder builder) {
+        Field(Field template, QueryBuilder builder) {
             super(template, builder);
             name = template.name;
             fragmentOffset = template.fragmentOffset;

+ 1 - 1
server/src/main/java/org/elasticsearch/search/fetch/subphase/highlight/SearchHighlightContext.java

@@ -40,7 +40,7 @@ public class SearchHighlightContext {
         private final String field;
         private final FieldOptions fieldOptions;
 
-        Field(String field, FieldOptions fieldOptions) {
+        public Field(String field, FieldOptions fieldOptions) {
             assert field != null;
             assert fieldOptions != null;
             this.field = field;

+ 5 - 0
server/src/main/java/org/elasticsearch/search/rank/RankBuilder.java

@@ -19,6 +19,7 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.UpdateForV10;
 import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
@@ -60,6 +61,10 @@ public abstract class RankBuilder implements VersionedNamedWriteable, ToXContent
         doWriteTo(out);
     }
 
+    public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
+        return this;
+    }
+
     protected abstract void doWriteTo(StreamOutput out) throws IOException;
 
     @Override

+ 10 - 0
server/src/main/java/org/elasticsearch/search/rank/context/RankFeaturePhaseRankShardContext.java

@@ -11,6 +11,7 @@ package org.elasticsearch.search.rank.context;
 
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.search.rank.RankShardResult;
 
 /**
@@ -37,4 +38,13 @@ public abstract class RankFeaturePhaseRankShardContext {
      */
     @Nullable
     public abstract RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId);
+
+    /**
+     * Prepares a SearchContext with any additional information needed before executing
+     * commands on shards.
+     * @param context SearchContext
+     */
+    public void prepareForFetch(SearchContext context) {
+        // Default no-op
+    }
 }

+ 17 - 6
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureDoc.java

@@ -10,12 +10,14 @@
 package org.elasticsearch.search.rank.feature;
 
 import org.apache.lucene.search.Explanation;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 
 /**
@@ -26,7 +28,7 @@ public class RankFeatureDoc extends RankDoc {
     public static final String NAME = "rank_feature_doc";
 
     // TODO: update to support more than 1 fields; and not restrict to string data
-    public String featureData;
+    public List<String> featureData;
 
     public RankFeatureDoc(int doc, float score, int shardIndex) {
         super(doc, score, shardIndex);
@@ -34,7 +36,12 @@ public class RankFeatureDoc extends RankDoc {
 
     public RankFeatureDoc(StreamInput in) throws IOException {
         super(in);
-        featureData = in.readOptionalString();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
+            featureData = in.readOptionalStringCollectionAsList();
+        } else {
+            String featureDataString = in.readOptionalString();
+            featureData = featureDataString == null ? null : List.of(featureDataString);
+        }
     }
 
     @Override
@@ -42,13 +49,17 @@ public class RankFeatureDoc extends RankDoc {
         throw new UnsupportedOperationException("explain is not supported for {" + getClass() + "}");
     }
 
-    public void featureData(String featureData) {
+    public void featureData(List<String> featureData) {
         this.featureData = featureData;
     }
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
-        out.writeOptionalString(featureData);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
+            out.writeOptionalStringCollection(featureData);
+        } else {
+            out.writeOptionalString(featureData.get(0));
+        }
     }
 
     @Override
@@ -59,7 +70,7 @@ public class RankFeatureDoc extends RankDoc {
 
     @Override
     protected int doHashCode() {
-        return Objects.hashCode(featureData);
+        return Objects.hash(featureData);
     }
 
     @Override
@@ -69,6 +80,6 @@ public class RankFeatureDoc extends RankDoc {
 
     @Override
     protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.field("featureData", featureData);
+        builder.array("featureData", featureData);
     }
 }

+ 4 - 4
server/src/main/java/org/elasticsearch/search/rank/feature/RankFeatureShardPhase.java

@@ -48,10 +48,10 @@ public final class RankFeatureShardPhase {
 
         RankFeaturePhaseRankShardContext rankFeaturePhaseRankShardContext = shardContext(searchContext);
         if (rankFeaturePhaseRankShardContext != null) {
-            assert rankFeaturePhaseRankShardContext.getField() != null : "field must not be null";
-            searchContext.fetchFieldsContext(
-                new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(rankFeaturePhaseRankShardContext.getField(), null)))
-            );
+            String field = rankFeaturePhaseRankShardContext.getField();
+            assert field != null : "field must not be null";
+            searchContext.fetchFieldsContext(new FetchFieldsContext(Collections.singletonList(new FieldAndFormat(field, null))));
+            rankFeaturePhaseRankShardContext.prepareForFetch(searchContext);
             searchContext.storedFieldsContext(StoredFieldsContext.fromList(Collections.singletonList(StoredFieldsContext._NONE_)));
             searchContext.addFetchResult();
             Arrays.sort(request.getDocIds());

+ 14 - 9
server/src/main/java/org/elasticsearch/search/rank/rerank/RerankingRankFeaturePhaseRankShardContext.java

@@ -20,6 +20,7 @@ import org.elasticsearch.search.rank.feature.RankFeatureDoc;
 import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
 
 import java.util.Arrays;
+import java.util.List;
 
 /**
  * The {@code ReRankingRankFeaturePhaseRankShardContext} is handles the {@code SearchHits} generated from the {@code RankFeatureShardPhase}
@@ -37,15 +38,7 @@ public class RerankingRankFeaturePhaseRankShardContext extends RankFeaturePhaseR
     @Override
     public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId) {
         try {
-            RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
-            for (int i = 0; i < hits.getHits().length; i++) {
-                rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
-                DocumentField docField = hits.getHits()[i].field(field);
-                if (docField != null) {
-                    rankFeatureDocs[i].featureData(docField.getValue().toString());
-                }
-            }
-            return new RankFeatureShardResult(rankFeatureDocs);
+            return doBuildRankFeatureShardResult(hits, shardId);
         } catch (Exception ex) {
             logger.warn(
                 "Error while fetching feature data for {field: ["
@@ -58,4 +51,16 @@ public class RerankingRankFeaturePhaseRankShardContext extends RankFeaturePhaseR
             return null;
         }
     }
+
+    protected RankShardResult doBuildRankFeatureShardResult(SearchHits hits, int shardId) {
+        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+        for (int i = 0; i < hits.getHits().length; i++) {
+            rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
+            DocumentField docField = hits.getHits()[i].field(field);
+            if (docField != null) {
+                rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
+            }
+        }
+        return new RankFeatureShardResult(rankFeatureDocs);
+    }
 }

+ 20 - 19
server/src/test/java/org/elasticsearch/action/search/RankFeaturePhaseTests.java

@@ -73,7 +73,7 @@ public class RankFeaturePhaseTests extends ESTestCase {
         defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE)
     );
 
-    private record ExpectedRankFeatureDoc(int doc, int rank, float score, String featureData) {}
+    private record ExpectedRankFeatureDoc(int doc, int rank, float score, List<String> featureData) {}
 
     public void testRankFeaturePhaseWith1Shard() {
         // request params used within SearchSourceBuilder and *RankContext classes
@@ -145,8 +145,8 @@ public class RankFeaturePhaseTests extends ESTestCase {
 
                 SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
                 List<ExpectedRankFeatureDoc> expectedShardResults = List.of(
-                    new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"),
-                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2")
+                    new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2"))
                 );
                 List<ExpectedRankFeatureDoc> expectedFinalResults = new ArrayList<>(expectedShardResults);
                 assertShardResults(shard1Result, expectedShardResults);
@@ -263,19 +263,19 @@ public class RankFeaturePhaseTests extends ESTestCase {
                 assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
 
                 SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
-                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")));
                 assertShardResults(shard1Result, expectedShard1Results);
 
                 SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
-                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2"));
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, List.of("ranked_2")));
                 assertShardResults(shard2Result, expectedShard2Results);
 
                 SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
                 assertNull(shard3Result);
 
                 List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
-                    new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"),
-                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2")
+                    new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2"))
                 );
                 assertFinalResults(finalResults[0], expectedFinalResults);
             } finally {
@@ -379,7 +379,7 @@ public class RankFeaturePhaseTests extends ESTestCase {
                 assertNull(shard1Result);
 
                 SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
-                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, "ranked_2"));
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(2, 1, 109.0F, List.of("ranked_2")));
                 List<ExpectedRankFeatureDoc> expectedFinalResults = new ArrayList<>(expectedShard2Results);
                 assertShardResults(shard2Result, expectedShard2Results);
                 assertFinalResults(finalResults[0], expectedFinalResults);
@@ -609,22 +609,21 @@ public class RankFeaturePhaseTests extends ESTestCase {
                 assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
 
                 SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
-                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                List<ExpectedRankFeatureDoc> expectedShard1Results = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")));
                 assertShardResults(shard1Result, expectedShard1Results);
 
                 SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
                 List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(
-                    new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"),
-                    new ExpectedRankFeatureDoc(2, 2, 109.0F, "ranked_2"),
-                    new ExpectedRankFeatureDoc(200, 3, 101.0F, "ranked_200")
-
+                    new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11")),
+                    new ExpectedRankFeatureDoc(2, 2, 109.0F, List.of("ranked_2")),
+                    new ExpectedRankFeatureDoc(200, 3, 101.0F, List.of("ranked_200"))
                 );
                 assertShardResults(shard2Result, expectedShard2Results);
 
                 SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
                 assertNull(shard3Result);
 
-                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1"));
+                List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(new ExpectedRankFeatureDoc(1, 2, 110.0F, List.of("ranked_1")));
                 assertFinalResults(finalResults[0], expectedFinalResults);
             } finally {
                 rankFeaturePhase.rankPhaseResults.close();
@@ -748,19 +747,21 @@ public class RankFeaturePhaseTests extends ESTestCase {
                 assertEquals(2, rankPhaseResults.getSuccessfulResults().count());
 
                 SearchPhaseResult shard1Result = rankPhaseResults.getAtomicArray().get(0);
-                List<ExpectedRankFeatureDoc> expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, "ranked_1"));
+                List<ExpectedRankFeatureDoc> expectedShardResults = List.of(new ExpectedRankFeatureDoc(1, 1, 110.0F, List.of("ranked_1")));
                 assertShardResults(shard1Result, expectedShardResults);
 
                 SearchPhaseResult shard2Result = rankPhaseResults.getAtomicArray().get(1);
-                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"));
+                List<ExpectedRankFeatureDoc> expectedShard2Results = List.of(
+                    new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11"))
+                );
                 assertShardResults(shard2Result, expectedShard2Results);
 
                 SearchPhaseResult shard3Result = rankPhaseResults.getAtomicArray().get(2);
                 assertNull(shard3Result);
 
                 List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
-                    new ExpectedRankFeatureDoc(11, 1, 200.0F, "ranked_11"),
-                    new ExpectedRankFeatureDoc(1, 2, 110.0F, "ranked_1")
+                    new ExpectedRankFeatureDoc(11, 1, 200.0F, List.of("ranked_11")),
+                    new ExpectedRankFeatureDoc(1, 2, 110.0F, List.of("ranked_1"))
                 );
                 assertFinalResults(finalResults[0], expectedFinalResults);
             } finally {
@@ -813,7 +814,7 @@ public class RankFeaturePhaseTests extends ESTestCase {
                     SearchHit hit = hits.getHits()[i];
                     rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
                     rankFeatureDocs[i].score += 100f;
-                    rankFeatureDocs[i].featureData("ranked_" + hit.docId());
+                    rankFeatureDocs[i].featureData(List.of("ranked_" + hit.docId()));
                     rankFeatureDocs[i].rank = i + 1;
                 }
                 return new RankFeatureShardResult(rankFeatureDocs);

+ 12 - 5
server/src/test/java/org/elasticsearch/search/SearchServiceSingleNodeTests.java

@@ -523,7 +523,7 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
                                         for (int i = 0; i < hits.getHits().length; i++) {
                                             SearchHit hit = hits.getHits()[i];
                                             rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
-                                            rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                            rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName));
                                             rankFeatureDocs[i].score = (numDocs - i) + randomFloat();
                                             rankFeatureDocs[i].rank = i + 1;
                                         }
@@ -580,7 +580,7 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
             assertEquals(sortedRankWindowDocs.size(), rankFeatureShardResult.rankFeatureDocs.length);
             for (int i = 0; i < sortedRankWindowDocs.size(); i++) {
                 assertEquals((long) sortedRankWindowDocs.get(i), rankFeatureShardResult.rankFeatureDocs[i].doc);
-                assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, "aardvark_" + sortedRankWindowDocs.get(i));
+                assertEquals(rankFeatureShardResult.rankFeatureDocs[i].featureData, List.of("aardvark_" + sortedRankWindowDocs.get(i)));
             }
 
             List<Integer> globalTopKResults = randomNonEmptySubsetOf(
@@ -760,7 +760,7 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
                                             for (int i = 0; i < hits.getHits().length; i++) {
                                                 SearchHit hit = hits.getHits()[i];
                                                 rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
-                                                rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                                rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName));
                                                 rankFeatureDocs[i].score = randomFloat();
                                                 rankFeatureDocs[i].rank = i + 1;
                                             }
@@ -887,7 +887,7 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
                                         for (int i = 0; i < hits.getHits().length; i++) {
                                             SearchHit hit = hits.getHits()[i];
                                             rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
-                                            rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                            rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName));
                                             rankFeatureDocs[i].score = randomFloat();
                                             rankFeatureDocs[i].rank = i + 1;
                                         }
@@ -1151,7 +1151,7 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
                                                 for (int i = 0; i < hits.getHits().length; i++) {
                                                     SearchHit hit = hits.getHits()[i];
                                                     rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
-                                                    rankFeatureDocs[i].featureData(hit.getFields().get(rankFeatureFieldName).getValue());
+                                                    rankFeatureDocs[i].featureData(parseFeatureData(hit, rankFeatureFieldName));
                                                     rankFeatureDocs[i].score = randomFloat();
                                                     rankFeatureDocs[i].rank = i + 1;
                                                 }
@@ -2904,6 +2904,13 @@ public class SearchServiceSingleNodeTests extends ESSingleNodeTestCase {
         );
     }
 
+    private List<String> parseFeatureData(SearchHit hit, String fieldName) {
+        Object fieldValue = hit.getFields().get(fieldName).getValue();
+        @SuppressWarnings("unchecked")
+        List<String> fieldValues = fieldValue instanceof List ? (List<String>) fieldValue : List.of(String.valueOf(fieldValue));
+        return fieldValues;
+    }
+
     private static class TestRewriteCounterQueryBuilder extends AbstractQueryBuilder<TestRewriteCounterQueryBuilder> {
 
         final int asyncRewriteCount;

+ 14 - 2
server/src/test/java/org/elasticsearch/search/rank/RankFeatureShardPhaseTests.java

@@ -160,7 +160,12 @@ public class RankFeatureShardPhaseTests extends ESTestCase {
                         for (int i = 0; i < hits.getHits().length; i++) {
                             SearchHit hit = hits.getHits()[i];
                             rankFeatureDocs[i] = new RankFeatureDoc(hit.docId(), hit.getScore(), shardId);
-                            rankFeatureDocs[i].featureData(hit.getFields().get(field).getValue());
+                            Object fieldValue = hit.getFields().get(field).getValue();
+                            @SuppressWarnings("unchecked")
+                            List<String> featureData = fieldValue instanceof List
+                                ? (List<String>) fieldValue
+                                : List.of(String.valueOf(fieldValue));
+                            rankFeatureDocs[i].featureData(featureData);
                             rankFeatureDocs[i].rank = i + 1;
                         }
                         return new RankFeatureShardResult(rankFeatureDocs);
@@ -279,7 +284,14 @@ public class RankFeatureShardPhaseTests extends ESTestCase {
     public void testProcessFetch() {
         final String fieldName = "some_field";
         int numDocs = randomIntBetween(15, 30);
-        Map<Integer, String> expectedFieldData = Map.of(4, "doc_4_aardvark", 9, "doc_9_aardvark", numDocs - 1, "last_doc_aardvark");
+        Map<Integer, List<String>> expectedFieldData = Map.of(
+            4,
+            List.of("doc_4_aardvark"),
+            9,
+            List.of("doc_9_aardvark"),
+            numDocs - 1,
+            List.of("last_doc_aardvark")
+        );
 
         SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
         searchSourceBuilder.rankBuilder(getRankBuilder(fieldName));

+ 2 - 1
test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

@@ -23,7 +23,8 @@ public enum FeatureFlag {
     IVF_FORMAT("es.ivf_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null),
     LOGS_STREAM("es.logs_stream_feature_flag_enabled=true", Version.fromString("9.1.0"), null),
     PATTERNED_TEXT("es.patterned_text_feature_flag_enabled=true", Version.fromString("9.1.0"), null),
-    SYNTHETIC_VECTORS("es.mapping_synthetic_vectors=true", Version.fromString("9.2.0"), null);
+    SYNTHETIC_VECTORS("es.mapping_synthetic_vectors=true", Version.fromString("9.2.0"), null),
+    RERANK_SNIPPETS("es.text_similarity_reranker_snippets=true", Version.fromString("9.2.0"), null);
 
     public final String systemProperty;
     public final Version from;

+ 38 - 29
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

@@ -13,6 +13,7 @@ import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsM
 import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
 
+import java.util.HashSet;
 import java.util.Set;
 
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS;
@@ -23,6 +24,8 @@ import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRe
 import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
 import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
 import static org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor.SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED;
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.RERANK_SNIPPETS;
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_SNIPPETS;
 
 /**
  * Provides inference features.
@@ -46,35 +49,41 @@ public class InferenceFeatures implements FeatureSpecification {
 
     @Override
     public Set<NodeFeature> getTestFeatures() {
-        return Set.of(
-            SemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS,
-            SEMANTIC_TEXT_HIGHLIGHTER,
-            SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
-            SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
-            SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES,
-            SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
-            TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_ALIAS_HANDLING_FIX,
-            TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_MINSCORE_FIX,
-            SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT,
-            SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
-            SEMANTIC_KNN_FILTER_FIX,
-            TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT,
-            TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS,
-            SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
-            SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
-            SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
-            SEMANTIC_TEXT_INDEX_OPTIONS,
-            COHERE_V2_API,
-            SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS,
-            SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX,
-            SEMANTIC_TEXT_HIGHLIGHTING_FLAT
+        var testFeatures = new HashSet<>(
+            Set.of(
+                SemanticTextFieldMapper.SEMANTIC_TEXT_IN_OBJECT_FIELD_FIX,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS,
+                SEMANTIC_TEXT_HIGHLIGHTER,
+                SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
+                SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
+                SemanticInferenceMetadataFieldsMapper.EXPLICIT_NULL_FIXES,
+                SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED,
+                TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_ALIAS_HANDLING_FIX,
+                TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_MINSCORE_FIX,
+                SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT,
+                SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
+                SEMANTIC_KNN_FILTER_FIX,
+                TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT,
+                SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT,
+                TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS,
+                SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
+                SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
+                SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
+                SEMANTIC_TEXT_INDEX_OPTIONS,
+                COHERE_V2_API,
+                SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS,
+                SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX,
+                SEMANTIC_TEXT_HIGHLIGHTING_FLAT
+            )
         );
+        if (RERANK_SNIPPETS.isEnabled()) {
+            testFeatures.add(TEXT_SIMILARITY_RERANKER_SNIPPETS);
+        }
+        return testFeatures;
     }
 }

+ 88 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/SnippetConfig.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rank.textsimilarity;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.index.query.QueryBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class SnippetConfig implements Writeable {
+
+    public final Integer numSnippets;
+    private final String inferenceText;
+    private final Integer tokenSizeLimit;
+    public final QueryBuilder snippetQueryBuilder;
+
+    public static final int DEFAULT_NUM_SNIPPETS = 1;
+
+    public SnippetConfig(StreamInput in) throws IOException {
+        this.numSnippets = in.readOptionalVInt();
+        this.inferenceText = in.readString();
+        this.tokenSizeLimit = in.readOptionalVInt();
+        this.snippetQueryBuilder = in.readOptionalNamedWriteable(QueryBuilder.class);
+    }
+
+    public SnippetConfig(Integer numSnippets) {
+        this(numSnippets, null, null);
+    }
+
+    public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit) {
+        this(numSnippets, inferenceText, tokenSizeLimit, null);
+    }
+
+    public SnippetConfig(Integer numSnippets, String inferenceText, Integer tokenSizeLimit, QueryBuilder snippetQueryBuilder) {
+        this.numSnippets = numSnippets;
+        this.inferenceText = inferenceText;
+        this.tokenSizeLimit = tokenSizeLimit;
+        this.snippetQueryBuilder = snippetQueryBuilder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalVInt(numSnippets);
+        out.writeString(inferenceText);
+        out.writeOptionalVInt(tokenSizeLimit);
+        out.writeOptionalNamedWriteable(snippetQueryBuilder);
+    }
+
+    public Integer numSnippets() {
+        return numSnippets;
+    }
+
+    public String inferenceText() {
+        return inferenceText;
+    }
+
+    public Integer tokenSizeLimit() {
+        return tokenSizeLimit;
+    }
+
+    public QueryBuilder snippetQueryBuilder() {
+        return snippetQueryBuilder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        SnippetConfig that = (SnippetConfig) o;
+        return Objects.equals(numSnippets, that.numSnippets)
+            && Objects.equals(inferenceText, that.inferenceText)
+            && Objects.equals(tokenSizeLimit, that.tokenSizeLimit)
+            && Objects.equals(snippetQueryBuilder, that.snippetQueryBuilder);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(numSnippets, inferenceText, tokenSizeLimit, snippetQueryBuilder);
+    }
+}

+ 88 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankBuilder.java

@@ -12,8 +12,12 @@ import org.apache.lucene.search.Query;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.index.query.MatchQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.license.License;
 import org.elasticsearch.license.LicensedFeature;
 import org.elasticsearch.search.rank.RankBuilder;
@@ -23,7 +27,6 @@ import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
 import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
 import org.elasticsearch.search.rank.context.RankFeaturePhaseRankShardContext;
 import org.elasticsearch.search.rank.feature.RankFeatureDoc;
-import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
@@ -35,6 +38,7 @@ import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilari
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_ID_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.INFERENCE_TEXT_FIELD;
 import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.MIN_SCORE_FIELD;
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder.SNIPPETS_FIELD;
 
 /**
  * A {@code RankBuilder} that enables ranking with text similarity model inference. Supports parameters for configuring the inference call.
@@ -43,6 +47,11 @@ public class TextSimilarityRankBuilder extends RankBuilder {
 
     public static final String NAME = "text_similarity_reranker";
 
+    /**
+     * The default token size limit of the Elastic reranker is 512.
+     */
+    private static final int DEFAULT_TOKEN_SIZE_LIMIT = 512;
+
     public static final LicensedFeature.Momentary TEXT_SIMILARITY_RERANKER_FEATURE = LicensedFeature.momentary(
         null,
         "text-similarity-reranker",
@@ -54,6 +63,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
     private final String field;
     private final Float minScore;
     private final boolean failuresAllowed;
+    private final SnippetConfig snippetConfig;
 
     public TextSimilarityRankBuilder(
         String field,
@@ -61,7 +71,8 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         String inferenceText,
         int rankWindowSize,
         Float minScore,
-        boolean failuresAllowed
+        boolean failuresAllowed,
+        SnippetConfig snippetConfig
     ) {
         super(rankWindowSize);
         this.inferenceId = inferenceId;
@@ -69,6 +80,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         this.field = field;
         this.minScore = minScore;
         this.failuresAllowed = failuresAllowed;
+        this.snippetConfig = snippetConfig;
     }
 
     public TextSimilarityRankBuilder(StreamInput in) throws IOException {
@@ -84,6 +96,11 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         } else {
             this.failuresAllowed = false;
         }
+        if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
+            this.snippetConfig = in.readOptionalWriteable(SnippetConfig::new);
+        } else {
+            this.snippetConfig = null;
+        }
     }
 
     @Override
@@ -107,6 +124,9 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             || out.getTransportVersion().onOrAfter(TransportVersions.RERANKER_FAILURES_ALLOWED)) {
             out.writeBoolean(failuresAllowed);
         }
+        if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_SNIPPETS)) {
+            out.writeOptionalWriteable(snippetConfig);
+        }
     }
 
     @Override
@@ -122,6 +142,53 @@ public class TextSimilarityRankBuilder extends RankBuilder {
         if (failuresAllowed) {
             builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), true);
         }
+        if (snippetConfig != null) {
+            builder.field(SNIPPETS_FIELD.getPreferredName(), snippetConfig);
+        }
+    }
+
+    @Override
+    public RankBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
+        TextSimilarityRankBuilder rewritten = this;
+        if (snippetConfig != null) {
+            QueryBuilder snippetQueryBuilder = snippetConfig.snippetQueryBuilder();
+            if (snippetQueryBuilder == null) {
+                rewritten = new TextSimilarityRankBuilder(
+                    field,
+                    inferenceId,
+                    inferenceText,
+                    rankWindowSize(),
+                    minScore,
+                    failuresAllowed,
+                    new SnippetConfig(
+                        snippetConfig.numSnippets(),
+                        snippetConfig.inferenceText(),
+                        snippetConfig.tokenSizeLimit(),
+                        new MatchQueryBuilder(field, inferenceText)
+                    )
+                );
+            } else {
+                QueryBuilder rewrittenSnippetQueryBuilder = snippetQueryBuilder.rewrite(queryRewriteContext);
+                if (snippetQueryBuilder != rewrittenSnippetQueryBuilder) {
+                    rewritten = new TextSimilarityRankBuilder(
+                        field,
+                        inferenceId,
+                        inferenceText,
+                        rankWindowSize(),
+                        minScore,
+                        failuresAllowed,
+                        new SnippetConfig(
+                            snippetConfig.numSnippets(),
+                            snippetConfig.inferenceText(),
+                            snippetConfig.tokenSizeLimit(),
+                            rewrittenSnippetQueryBuilder
+                        )
+                    );
+                }
+            }
+        }
+
+        return rewritten;
     }
 
     @Override
@@ -168,7 +235,7 @@ public class TextSimilarityRankBuilder extends RankBuilder {
 
     @Override
     public RankFeaturePhaseRankShardContext buildRankFeaturePhaseShardContext() {
-        return new RerankingRankFeaturePhaseRankShardContext(field);
+        return new TextSimilarityRerankingRankFeaturePhaseRankShardContext(field, snippetConfig);
     }
 
     @Override
@@ -181,10 +248,19 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             inferenceId,
             inferenceText,
             minScore,
-            failuresAllowed
+            failuresAllowed,
+            snippetConfig != null ? new SnippetConfig(snippetConfig.numSnippets, inferenceText, tokenSizeLimit(inferenceId)) : null
         );
     }
 
+    /**
+     * @return The token size limit to apply to this rerank context.
+     * TODO: This should be pulled from the inference endpoint when available, not hardcoded.
+     */
+    public static Integer tokenSizeLimit(String inferenceId) {
+        return DEFAULT_TOKEN_SIZE_LIMIT;
+    }
+
     public String field() {
         return field;
     }
@@ -212,11 +288,17 @@ public class TextSimilarityRankBuilder extends RankBuilder {
             && Objects.equals(inferenceText, that.inferenceText)
             && Objects.equals(field, that.field)
             && Objects.equals(minScore, that.minScore)
-            && failuresAllowed == that.failuresAllowed;
+            && failuresAllowed == that.failuresAllowed
+            && Objects.equals(snippetConfig, that.snippetConfig);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed);
+        return Objects.hash(inferenceId, inferenceText, field, minScore, failuresAllowed, snippetConfig);
+    }
+
+    @Override
+    public String toString() {
+        return Strings.toString(this);
     }
 }

+ 67 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContext.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.TaskType;
@@ -39,6 +40,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
     protected final String inferenceId;
     protected final String inferenceText;
     protected final Float minScore;
+    protected final SnippetConfig snippetConfig;
 
     public TextSimilarityRankFeaturePhaseRankCoordinatorContext(
         int size,
@@ -48,39 +50,56 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         String inferenceId,
         String inferenceText,
         Float minScore,
-        boolean failuresAllowed
+        boolean failuresAllowed,
+        @Nullable SnippetConfig snippetConfig
     ) {
         super(size, from, rankWindowSize, failuresAllowed);
         this.client = client;
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.minScore = minScore;
+        this.snippetConfig = snippetConfig;
     }
 
     @Override
     protected void computeScores(RankFeatureDoc[] featureDocs, ActionListener<float[]> scoreListener) {
+
         // Wrap the provided rankListener to an ActionListener that would handle the response from the inference service
         // and then pass the results
         final ActionListener<InferenceAction.Response> inferenceListener = scoreListener.delegateFailureAndWrap((l, r) -> {
             InferenceServiceResults results = r.getResults();
             assert results instanceof RankedDocsResults;
 
-            // Ensure we get exactly as many scores as the number of docs we passed, otherwise we may return incorrect results
+            // If we have an empty list of ranked docs, simply return the original scores
             List<RankedDocsResults.RankedDoc> rankedDocs = ((RankedDocsResults) results).getRankedDocs();
-
-            if (rankedDocs.size() != featureDocs.length) {
-                l.onFailure(
-                    new IllegalStateException(
-                        "Reranker input document count and returned score count mismatch: ["
-                            + featureDocs.length
-                            + "] vs ["
-                            + rankedDocs.size()
-                            + "]"
-                    )
-                );
+            if (rankedDocs.isEmpty()) {
+                float[] originalScores = new float[featureDocs.length];
+                for (int i = 0; i < featureDocs.length; i++) {
+                    originalScores[i] = featureDocs[i].score;
+                }
+                l.onResponse(originalScores);
             } else {
-                float[] scores = extractScoresFromRankedDocs(rankedDocs);
-                l.onResponse(scores);
+                final float[] scores;
+                if (this.snippetConfig != null) {
+                    scores = extractScoresFromRankedSnippets(rankedDocs, featureDocs);
+                } else {
+                    scores = extractScoresFromRankedDocs(rankedDocs);
+                }
+
+                // Ensure we get exactly as many final scores as the number of docs we passed, otherwise we may return incorrect results
+                if (scores.length != featureDocs.length) {
+                    l.onFailure(
+                        new IllegalStateException(
+                            "Reranker input document count and returned score count mismatch: ["
+                                + featureDocs.length
+                                + "] vs ["
+                                + scores.length
+                                + "]"
+                        )
+                    );
+                } else {
+                    l.onResponse(scores);
+                }
             }
         });
 
@@ -118,8 +137,11 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
             if (featureDocs.length == 0) {
                 inferenceListener.onResponse(new InferenceAction.Response(new RankedDocsResults(List.of())));
             } else {
-                List<String> featureData = Arrays.stream(featureDocs).map(x -> x.featureData).toList();
-                InferenceAction.Request inferenceRequest = generateRequest(featureData);
+                List<String> inferenceInputs = Arrays.stream(featureDocs)
+                    .filter(featureDoc -> featureDoc.featureData != null)
+                    .flatMap(featureDoc -> featureDoc.featureData.stream())
+                    .toList();
+                InferenceAction.Request inferenceRequest = generateRequest(inferenceInputs);
                 try {
                     executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceRequest, inferenceListener);
                 } finally {
@@ -170,7 +192,7 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         );
     }
 
-    private float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
+    float[] extractScoresFromRankedDocs(List<RankedDocsResults.RankedDoc> rankedDocs) {
         float[] scores = new float[rankedDocs.size()];
         for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
             scores[rankedDoc.index()] = rankedDoc.relevanceScore();
@@ -178,6 +200,33 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContext extends RankFe
         return scores;
     }
 
+    float[] extractScoresFromRankedSnippets(List<RankedDocsResults.RankedDoc> rankedDocs, RankFeatureDoc[] featureDocs) {
+        float[] scores = new float[featureDocs.length];
+        boolean[] hasScore = new boolean[featureDocs.length];
+
+        // We need to correlate the index/doc values of each RankedDoc in correlation with its associated RankFeatureDoc.
+        int[] rankedDocToFeatureDoc = Arrays.stream(featureDocs)
+            .flatMapToInt(
+                doc -> java.util.stream.IntStream.generate(() -> Arrays.asList(featureDocs).indexOf(doc)).limit(doc.featureData.size())
+            )
+            .limit(rankedDocs.size())
+            .toArray();
+
+        for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
+            int docId = rankedDocToFeatureDoc[rankedDoc.index()];
+            float score = rankedDoc.relevanceScore();
+            scores[docId] = hasScore[docId] == false ? score : Math.max(scores[docId], score);
+            hasScore[docId] = true;
+        }
+
+        float[] result = new float[featureDocs.length];
+        for (int i = 0; i < featureDocs.length; i++) {
+            result[i] = hasScore[i] ? scores[i] : 0f;
+        }
+
+        return result;
+    }
+
     private static float normalizeScore(float score) {
         // As some models might produce negative scores, we want to ensure that all scores will be positive
         // so we will make use of the following normalization formula:

+ 54 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
 import org.apache.lucene.search.ScoreDoc;
+import org.elasticsearch.common.util.FeatureFlag;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.license.LicenseUtils;
@@ -41,12 +42,16 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         "text_similarity_reranker_alias_handling_fix"
     );
     public static final NodeFeature TEXT_SIMILARITY_RERANKER_MINSCORE_FIX = new NodeFeature("text_similarity_reranker_minscore_fix");
+    public static final NodeFeature TEXT_SIMILARITY_RERANKER_SNIPPETS = new NodeFeature("text_similarity_reranker_snippets");
+    public static final FeatureFlag RERANK_SNIPPETS = new FeatureFlag("text_similarity_reranker_snippets");
 
     public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
     public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
     public static final ParseField INFERENCE_TEXT_FIELD = new ParseField("inference_text");
     public static final ParseField FIELD_FIELD = new ParseField("field");
     public static final ParseField FAILURES_ALLOWED_FIELD = new ParseField("allow_rerank_failures");
+    public static final ParseField SNIPPETS_FIELD = new ParseField("snippets");
+    public static final ParseField NUM_SNIPPETS_FIELD = new ParseField("num_snippets");
 
     public static final ConstructingObjectParser<TextSimilarityRankRetrieverBuilder, RetrieverParserContext> PARSER =
         new ConstructingObjectParser<>(TextSimilarityRankBuilder.NAME, args -> {
@@ -56,6 +61,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             String field = (String) args[3];
             int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
             boolean failuresAllowed = args[5] != null && (Boolean) args[5];
+            SnippetConfig snippets = (SnippetConfig) args[6];
 
             return new TextSimilarityRankRetrieverBuilder(
                 retrieverBuilder,
@@ -63,10 +69,20 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
                 inferenceText,
                 field,
                 rankWindowSize,
-                failuresAllowed
+                failuresAllowed,
+                snippets
             );
         });
 
+    private static final ConstructingObjectParser<SnippetConfig, RetrieverParserContext> SNIPPETS_PARSER = new ConstructingObjectParser<>(
+        SNIPPETS_FIELD.getPreferredName(),
+        true,
+        args -> {
+            Integer numSnippets = (Integer) args[0];
+            return new SnippetConfig(numSnippets);
+        }
+    );
+
     static {
         PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
             RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
@@ -78,6 +94,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         PARSER.declareString(constructorArg(), FIELD_FIELD);
         PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
         PARSER.declareBoolean(optionalConstructorArg(), FAILURES_ALLOWED_FIELD);
+        PARSER.declareObject(optionalConstructorArg(), SNIPPETS_PARSER, SNIPPETS_FIELD);
+        if (RERANK_SNIPPETS.isEnabled()) {
+            SNIPPETS_PARSER.declareInt(optionalConstructorArg(), NUM_SNIPPETS_FIELD);
+        }
 
         RetrieverBuilder.declareBaseParserFields(PARSER);
     }
@@ -97,6 +117,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     private final String inferenceText;
     private final String field;
     private final boolean failuresAllowed;
+    private final SnippetConfig snippets;
 
     public TextSimilarityRankRetrieverBuilder(
         RetrieverBuilder retrieverBuilder,
@@ -104,13 +125,15 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         String inferenceText,
         String field,
         int rankWindowSize,
-        boolean failuresAllowed
+        boolean failuresAllowed,
+        SnippetConfig snippets
     ) {
         super(List.of(RetrieverSource.from(retrieverBuilder)), rankWindowSize);
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.field = field;
         this.failuresAllowed = failuresAllowed;
+        this.snippets = snippets;
     }
 
     public TextSimilarityRankRetrieverBuilder(
@@ -122,12 +145,16 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         Float minScore,
         boolean failuresAllowed,
         String retrieverName,
-        List<QueryBuilder> preFilterQueryBuilders
+        List<QueryBuilder> preFilterQueryBuilders,
+        SnippetConfig snippets
     ) {
         super(retrieverSource, rankWindowSize);
         if (retrieverSource.size() != 1) {
             throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever");
         }
+        if (snippets != null && snippets.numSnippets() != null && snippets.numSnippets() < 1) {
+            throw new IllegalArgumentException("num_snippets must be greater than 0, was: " + snippets.numSnippets());
+        }
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.field = field;
@@ -135,6 +162,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         this.failuresAllowed = failuresAllowed;
         this.retrieverName = retrieverName;
         this.preFilterQueryBuilders = preFilterQueryBuilders;
+        this.snippets = snippets;
     }
 
     @Override
@@ -151,7 +179,8 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             minScore,
             failuresAllowed,
             retrieverName,
-            newPreFilterQueryBuilders
+            newPreFilterQueryBuilders,
+            snippets
         );
     }
 
@@ -179,7 +208,17 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     @Override
     protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
         sourceBuilder.rankBuilder(
-            new TextSimilarityRankBuilder(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed)
+            new TextSimilarityRankBuilder(
+                field,
+                inferenceId,
+                inferenceText,
+                rankWindowSize,
+                minScore,
+                failuresAllowed,
+                snippets != null
+                    ? new SnippetConfig(snippets.numSnippets, inferenceText, TextSimilarityRankBuilder.tokenSizeLimit(inferenceId))
+                    : null
+            )
         );
         return sourceBuilder;
     }
@@ -207,6 +246,13 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
         if (failuresAllowed) {
             builder.field(FAILURES_ALLOWED_FIELD.getPreferredName(), failuresAllowed);
         }
+        if (snippets != null) {
+            builder.startObject(SNIPPETS_FIELD.getPreferredName());
+            if (snippets.numSnippets() != null) {
+                builder.field(NUM_SNIPPETS_FIELD.getPreferredName(), snippets.numSnippets());
+            }
+            builder.endObject();
+        }
     }
 
     @Override
@@ -218,11 +264,12 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             && Objects.equals(field, that.field)
             && rankWindowSize == that.rankWindowSize
             && Objects.equals(minScore, that.minScore)
-            && failuresAllowed == that.failuresAllowed;
+            && failuresAllowed == that.failuresAllowed
+            && Objects.equals(snippets, that.snippets);
     }
 
     @Override
     public int doHashCode() {
-        return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed);
+        return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore, failuresAllowed, snippets);
     }
 }

+ 97 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRerankingRankFeaturePhaseRankShardContext.java

@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rank.textsimilarity;
+
+import org.elasticsearch.common.document.DocumentField;
+import org.elasticsearch.common.logging.HeaderWarning;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.SearchHits;
+import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
+import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
+import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
+import org.elasticsearch.search.internal.SearchContext;
+import org.elasticsearch.search.rank.RankShardResult;
+import org.elasticsearch.search.rank.feature.RankFeatureDoc;
+import org.elasticsearch.search.rank.feature.RankFeatureShardResult;
+import org.elasticsearch.search.rank.rerank.RerankingRankFeaturePhaseRankShardContext;
+import org.elasticsearch.xcontent.Text;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.rank.textsimilarity.SnippetConfig.DEFAULT_NUM_SNIPPETS;
+
+public class TextSimilarityRerankingRankFeaturePhaseRankShardContext extends RerankingRankFeaturePhaseRankShardContext {
+
+    private final SnippetConfig snippetRankInput;
+
+    // Rough approximation of token size vs. characters in highlight fragments.
+    // TODO: highlighter should be able to set fragment size by token not length
+    private static final int TOKEN_SIZE_LIMIT_MULTIPLIER = 5;
+
+    public TextSimilarityRerankingRankFeaturePhaseRankShardContext(String field, @Nullable SnippetConfig snippetRankInput) {
+        super(field);
+        this.snippetRankInput = snippetRankInput;
+    }
+
+    @Override
+    public RankShardResult doBuildRankFeatureShardResult(SearchHits hits, int shardId) {
+        RankFeatureDoc[] rankFeatureDocs = new RankFeatureDoc[hits.getHits().length];
+        for (int i = 0; i < hits.getHits().length; i++) {
+            rankFeatureDocs[i] = new RankFeatureDoc(hits.getHits()[i].docId(), hits.getHits()[i].getScore(), shardId);
+            SearchHit hit = hits.getHits()[i];
+            DocumentField docField = hit.field(field);
+            if (snippetRankInput == null && docField != null) {
+                rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
+            } else {
+                Map<String, HighlightField> highlightFields = hit.getHighlightFields();
+                if (highlightFields != null && highlightFields.containsKey(field) && highlightFields.get(field).fragments().length > 0) {
+                    List<String> snippets = Arrays.stream(highlightFields.get(field).fragments()).map(Text::string).toList();
+                    rankFeatureDocs[i].featureData(snippets);
+                } else if (docField != null) {
+                    // If we did not get highlighting results, backfill with the doc field value
+                    // but pass in a warning because we are not reranking on snippets only
+                    rankFeatureDocs[i].featureData(List.of(docField.getValue().toString()));
+                    HeaderWarning.addWarning(
+                        "Reranking on snippets requested, but no snippets were found for field [" + field + "]. Using field value instead."
+                    );
+                }
+            }
+        }
+        return new RankFeatureShardResult(rankFeatureDocs);
+    }
+
+    @Override
+    public void prepareForFetch(SearchContext context) {
+        if (snippetRankInput != null) {
+            try {
+                HighlightBuilder highlightBuilder = new HighlightBuilder();
+                highlightBuilder.highlightQuery(snippetRankInput.snippetQueryBuilder());
+                // Stripping pre/post tags as they're not useful for snippet creation
+                highlightBuilder.field(field).preTags("").postTags("");
+                // Return highest scoring fragments
+                highlightBuilder.order(HighlightBuilder.Order.SCORE);
+                int numSnippets = snippetRankInput.numSnippets() != null ? snippetRankInput.numSnippets() : DEFAULT_NUM_SNIPPETS;
+                highlightBuilder.numOfFragments(numSnippets);
+                // Rely on the model to determine the fragment size
+                int tokenSizeLimit = snippetRankInput.tokenSizeLimit();
+                int fragmentSize = tokenSizeLimit * TOKEN_SIZE_LIMIT_MULTIPLIER;
+                highlightBuilder.fragmentSize(fragmentSize);
+                highlightBuilder.noMatchSize(fragmentSize);
+                SearchHighlightContext searchHighlightContext = highlightBuilder.build(context.getSearchExecutionContext());
+                context.highlight(searchHighlightContext);
+            } catch (IOException e) {
+                throw new RuntimeException("Failed to generate snippet request", e);
+            }
+        }
+    }
+
+}

+ 73 - 4
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankFeaturePhaseRankCoordinatorContextTests.java

@@ -12,6 +12,9 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.search.rank.feature.RankFeatureDoc;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
+
+import java.util.List;
 
 import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener;
 import static org.mockito.ArgumentMatchers.any;
@@ -32,16 +35,29 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
         "my-inference-id",
         "some query",
         0.0f,
-        false
+        false,
+        null
+    );
+
+    TextSimilarityRankFeaturePhaseRankCoordinatorContext withSnippets = new TextSimilarityRankFeaturePhaseRankCoordinatorContext(
+        10,
+        0,
+        100,
+        mockClient,
+        "my-inference-id",
+        "some query",
+        0.0f,
+        false,
+        new SnippetConfig(2, "some query", 10)
     );
 
     public void testComputeScores() {
         RankFeatureDoc featureDoc1 = new RankFeatureDoc(0, 1.0f, 0);
-        featureDoc1.featureData("text 1");
+        featureDoc1.featureData(List.of("text 1"));
         RankFeatureDoc featureDoc2 = new RankFeatureDoc(1, 3.0f, 1);
-        featureDoc2.featureData("text 2");
+        featureDoc2.featureData(List.of("text 2"));
         RankFeatureDoc featureDoc3 = new RankFeatureDoc(2, 2.0f, 0);
-        featureDoc3.featureData("text 3");
+        featureDoc3.featureData(List.of("text 3"));
         RankFeatureDoc[] featureDocs = new RankFeatureDoc[] { featureDoc1, featureDoc2, featureDoc3 };
 
         subject.computeScores(featureDocs, assertNoFailureListener(f -> assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, f, 0.0f)));
@@ -61,4 +77,57 @@ public class TextSimilarityRankFeaturePhaseRankCoordinatorContextTests extends E
         );
     }
 
+    public void testExtractScoresFromRankedDocs() {
+        List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
+            new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"),
+            new RankedDocsResults.RankedDoc(1, 3.0f, "text 2"),
+            new RankedDocsResults.RankedDoc(2, 2.0f, "text 3")
+        );
+        float[] scores = subject.extractScoresFromRankedDocs(rankedDocs);
+        assertArrayEquals(new float[] { 1.0f, 3.0f, 2.0f }, scores, 0.0f);
+    }
+
+    public void testExtractScoresFromSingleSnippets() {
+
+        List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
+            new RankedDocsResults.RankedDoc(0, 1.0f, "text 1"),
+            new RankedDocsResults.RankedDoc(1, 2.5f, "text 2"),
+            new RankedDocsResults.RankedDoc(2, 1.5f, "text 3")
+        );
+        RankFeatureDoc[] featureDocs = new RankFeatureDoc[] {
+            createRankFeatureDoc(0, 1.0f, 0, List.of("text 1")),
+            createRankFeatureDoc(1, 3.0f, 1, List.of("text 2")),
+            createRankFeatureDoc(2, 2.0f, 0, List.of("text 3")) };
+
+        float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
+        // Returned cores are from the snippet, not the whole text
+        assertArrayEquals(new float[] { 1.0f, 2.5f, 1.5f }, scores, 0.0f);
+    }
+
+    public void testExtractScoresFromMultipleSnippets() {
+
+        List<RankedDocsResults.RankedDoc> rankedDocs = List.of(
+            new RankedDocsResults.RankedDoc(0, 1.0f, "this is text 1"),
+            new RankedDocsResults.RankedDoc(1, 2.5f, "some more text"),
+            new RankedDocsResults.RankedDoc(2, 1.5f, "yet more text"),
+            new RankedDocsResults.RankedDoc(3, 3.0f, "this is text 2"),
+            new RankedDocsResults.RankedDoc(4, 2.0f, "this is text 3"),
+            new RankedDocsResults.RankedDoc(5, 1.5f, "oh look, more text")
+        );
+        RankFeatureDoc[] featureDocs = new RankFeatureDoc[] {
+            createRankFeatureDoc(0, 1.0f, 0, List.of("this is text 1", "some more text")),
+            createRankFeatureDoc(1, 3.0f, 1, List.of("yet more text", "this is text 2")),
+            createRankFeatureDoc(2, 2.0f, 0, List.of("this is text 3", "oh look, more text")) };
+
+        float[] scores = withSnippets.extractScoresFromRankedSnippets(rankedDocs, featureDocs);
+        // Returned scores are from the best-ranking snippet, not the whole text
+        assertArrayEquals(new float[] { 2.5f, 3.0f, 2.0f }, scores, 0.0f);
+    }
+
+    private RankFeatureDoc createRankFeatureDoc(int doc, float score, int shardIndex, List<String> featureData) {
+        RankFeatureDoc featureDoc = new RankFeatureDoc(doc, score, shardIndex);
+        featureDoc.featureData(featureData);
+        return featureDoc;
+    }
+
 }

+ 3 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankMultiNodeTests.java

@@ -32,7 +32,7 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT {
 
     @Override
     protected RankBuilder getRankBuilder(int rankWindowSize, String rankFeatureField) {
-        return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore, false);
+        return new TextSimilarityRankBuilder(rankFeatureField, inferenceId, inferenceText, rankWindowSize, minScore, false, null);
     }
 
     @Override
@@ -53,7 +53,8 @@ public class TextSimilarityRankMultiNodeTests extends AbstractRerankerIT {
             inferenceText,
             minScore,
             failuresAllowed,
-            type.name()
+            type.name(),
+            null
         );
     }
 

+ 2 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

@@ -58,7 +58,8 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
             randomAlphaOfLength(20),
             randomAlphaOfLength(50),
             randomIntBetween(100, 10000),
-            randomBoolean()
+            randomBoolean(),
+            null
         );
     }
 

+ 2 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverTelemetryTests.java

@@ -139,7 +139,8 @@ public class TextSimilarityRankRetrieverTelemetryTests extends ESIntegTestCase {
                         "some_inference_text",
                         "some_field",
                         10,
-                        false
+                        false,
+                        null
                     )
                 )
             );

+ 10 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java

@@ -57,7 +57,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
             Float minScore,
             int topN
         ) {
-            super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false);
+            super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false, null);
         }
     }
 
@@ -76,7 +76,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
             Float minScore,
             int inferenceResultCount
         ) {
-            super(field, inferenceId, inferenceText, rankWindowSize, minScore, false);
+            super(field, inferenceId, inferenceText, rankWindowSize, minScore, false, null);
             this.inferenceResultCount = inferenceResultCount;
         }
 
@@ -90,7 +90,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                 inferenceId,
                 inferenceText,
                 minScore,
-                failuresAllowed()
+                failuresAllowed(),
+                null
             ) {
                 @Override
                 protected InferenceAction.Request generateRequest(List<String> docFeatures) {
@@ -136,7 +137,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
         ElasticsearchAssertions.assertNoFailuresAndResponse(
             // Execute search with text similarity reranking
             client.prepareSearch()
-                .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false))
+                .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false, null))
                 .setQuery(QueryBuilders.matchAllQuery()),
             response -> {
                 // Verify order, rank and score of results
@@ -159,7 +160,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
         ElasticsearchAssertions.assertNoFailuresAndResponse(
             // Execute search with text similarity reranking
             client.prepareSearch()
-                .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false))
+                .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false, null))
                 .setQuery(QueryBuilders.matchAllQuery()),
             response -> {
                 // Verify order, rank and score of results
@@ -183,7 +184,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                         "my query",
                         0.7f,
                         false,
-                        AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
+                        AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name(),
+                        null
                     )
                 )
                 .setQuery(QueryBuilders.matchAllQuery()),
@@ -204,7 +206,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
                         "my query",
                         null,
                         true,
-                        AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
+                        AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name(),
+                        null
                     )
                 )
                 .setQuery(

+ 5 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java

@@ -176,9 +176,10 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
             String inferenceText,
             Float minScore,
             boolean failuresAllowed,
-            String throwingType
+            String throwingType,
+            SnippetConfig snippetConfig
         ) {
-            super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed);
+            super(field, inferenceId, inferenceText, rankWindowSize, minScore, failuresAllowed, snippetConfig);
             this.throwingRankBuilderType = AbstractRerankerIT.ThrowingRankBuilderType.valueOf(throwingType);
         }
 
@@ -218,7 +219,8 @@ public class TextSimilarityTestPlugin extends Plugin implements ActionPlugin {
                     inferenceId,
                     inferenceText,
                     minScore,
-                    failuresAllowed()
+                    failuresAllowed(),
+                    null
                 ) {
                     @Override
                     protected InferenceAction.Request generateRequest(List<String> docFeatures) {

+ 234 - 4
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

@@ -20,6 +20,21 @@ setup:
             }
           }
 
+  - do:
+      inference.put:
+        task_type: sparse_embedding
+        inference_id: sparse-inference-id
+        body: >
+          {
+            "service": "test_service",
+            "service_settings": {
+              "model": "my_model",
+              "api_key": "abc64"
+            },
+            "task_settings": {
+            }
+          }
+
   - do:
       indices.create:
         index: test-index
@@ -28,12 +43,20 @@ setup:
             properties:
               text:
                 type: text
+                copy_to: semantic_text_field
               topic:
                 type: keyword
               subtopic:
                 type: keyword
               inference_text_field:
                 type: text
+              semantic_text_field:
+                type: semantic_text
+                inference_id: sparse-inference-id
+                chunking_settings:
+                  strategy: word
+                  max_chunk_size: 10
+                  overlap: 1
 
   - do:
       index:
@@ -298,8 +321,10 @@ setup:
   - match: { hits.hits.0._id: "doc_2" }
   - match: { hits.hits.1._id: "doc_1" }
 
-  - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" }
-  - match: {hits.hits.0._explanation.details.0.details.0.description: "/subtopic.*astronomy.*/" }
+  - match: { hits.hits.0._explanation.description: "sum of:" }
+  - match: { hits.hits.0._explanation.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[inference_text_field\\].*/" }
+  - match: { hits.hits.0._explanation.details.0.details.0.details.0.description: "/subtopic.*astronomy.*/" }
+  - match: { hits.hits.0._explanation.details.1.description: "/match.on.required.clause,.product.of:*/" }
 
 ---
 "text similarity reranker properly handles aliases":
@@ -448,7 +473,7 @@ setup:
               retriever:
                 standard:
                   query:
-                    match_all: {}
+                    match_all: { }
               rank_window_size: 10
               inference_id: my-rerank-model
               inference_text: "How often does the moon hide the sun?"
@@ -477,7 +502,7 @@ setup:
               retriever:
                 standard:
                   query:
-                    match_all: {}
+                    match_all: { }
               rank_window_size: 10
               inference_id: my-rerank-model
               inference_text: "How often does the moon hide the sun?"
@@ -487,3 +512,208 @@ setup:
 
   - match: { hits.total.value: 0 }
   - length: { hits.hits: 0 }
+
+
+---
+"Text similarity reranker specifying number of snippets must be > 0":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      catch: /num_snippets must be greater than 0/
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match_all: { }
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "How often does the moon hide the sun?"
+              field: inference_text_field
+              snippets:
+                num_snippets: 0
+          size: 10
+
+  - match: { status: 400 }
+
+---
+"Reranking based on snippets":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "How often does the moon hide the sun?"
+              field: text
+              snippets:
+                num_snippets: 2
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on snippets using defaults":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    term:
+                      topic: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "How often does the moon hide the sun?"
+              field: text
+              snippets: { }
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on snippets on a semantic_text field":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              snippets:
+                num_snippets: 2
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on snippets on a semantic_text field using defaults":
+
+  - requires:
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "semantic_text_field", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    match:
+                      topic:
+                        query: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "how often does the moon hide the sun?"
+              field: semantic_text_field
+              snippets: { }
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_2" }
+
+---
+"Reranking based on snippets when highlighter doesn't return results":
+
+  - requires:
+      test_runner_features: allowed_warnings
+      cluster_features: "text_similarity_reranker_snippets"
+      reason: snippets introduced in 9.2.0
+
+  - do:
+      allowed_warnings:
+        - "Reranking on snippets requested, but no snippets were found for field [inference_text_field]. Using field value instead."
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            text_similarity_reranker:
+              retriever:
+                standard:
+                  query:
+                    term:
+                      topic: "science"
+              rank_window_size: 10
+              inference_id: my-rerank-model
+              inference_text: "How often does the moon hide the sun?"
+              field: inference_text_field
+              snippets:
+                num_snippets: 2
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_2" }
+  - match: { hits.hits.1._id: "doc_1" }

+ 1 - 0
x-pack/qa/core-rest-tests-with-security/src/yamlRestTest/java/org/elasticsearch/xpack/security/CoreWithSecurityClientYamlTestSuiteIT.java

@@ -54,6 +54,7 @@ public class CoreWithSecurityClientYamlTestSuiteIT extends ESClientYamlSuiteTest
         .feature(FeatureFlag.USE_LUCENE101_POSTINGS_FORMAT)
         .feature(FeatureFlag.IVF_FORMAT)
         .feature(FeatureFlag.SYNTHETIC_VECTORS)
+        .feature(FeatureFlag.RERANK_SNIPPETS)
         .build();
 
     public CoreWithSecurityClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {