Browse Source

Backporting text_similarity_reranker retriever rework to be evaluated during rewrite phase to 8.x (#114282)

* backporting text_similarity_reranker rework to 8.x

* fixing comp

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Panagiotis Bailis 1 year ago
parent
commit
e745c92ebc
19 changed files with 760 additions and 196 deletions
  1. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  2. 9 0
      server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java
  3. 9 2
      server/src/main/java/org/elasticsearch/search/rank/RankDoc.java
  4. 2 2
      server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java
  5. 5 1
      server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java
  6. 57 0
      server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java
  7. 9 7
      server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java
  8. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java
  9. 3 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  10. 103 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java
  11. 73 54
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java
  12. 88 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java
  13. 2 119
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java
  14. 37 3
      x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml
  15. 4 0
      x-pack/plugin/rank-rrf/build.gradle
  16. 6 0
      x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java
  17. 14 7
      x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java
  18. 2 0
      x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java
  19. 334 0
      x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -236,6 +236,7 @@ public class TransportVersions {
     public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0);
     public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0);
     public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0);
+    public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 9 - 0
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -36,6 +36,7 @@ import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchService;
+import org.elasticsearch.search.SearchSortValues;
 import org.elasticsearch.search.aggregations.AggregationReduceContext;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.InternalAggregations;
@@ -51,6 +52,7 @@ import org.elasticsearch.search.profile.SearchProfileResultsBuilder;
 import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.search.rank.context.QueryPhaseRankCoordinatorContext;
+import org.elasticsearch.search.sort.ShardDocSortField;
 import org.elasticsearch.search.suggest.Suggest;
 import org.elasticsearch.search.suggest.Suggest.Suggestion;
 import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
@@ -464,6 +466,13 @@ public final class SearchPhaseController {
                     assert shardDoc instanceof RankDoc;
                     searchHit.setRank(((RankDoc) shardDoc).rank);
                     searchHit.score(shardDoc.score);
+                    long shardAndDoc = ShardDocSortField.encodeShardAndDoc(shardDoc.shardIndex, shardDoc.doc);
+                    searchHit.sortValues(
+                        new SearchSortValues(
+                            new Object[] { shardDoc.score, shardAndDoc },
+                            new DocValueFormat[] { DocValueFormat.RAW, DocValueFormat.RAW }
+                        )
+                    );
                 } else if (sortedTopDocs.isSortedByField) {
                     FieldDoc fieldDoc = (FieldDoc) shardDoc;
                     searchHit.sortValues(fieldDoc.fields, reducedQueryPhase.sortValueFormats);

+ 9 - 2
server/src/main/java/org/elasticsearch/search/rank/RankDoc.java

@@ -11,9 +11,11 @@ package org.elasticsearch.search.rank;
 
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.ScoreDoc;
-import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
 import org.elasticsearch.xcontent.ToXContentFragment;
 import org.elasticsearch.xcontent.XContentBuilder;
 
@@ -24,7 +26,7 @@ import java.util.Objects;
  * {@code RankDoc} is the base class for all ranked results.
  * Subclasses should extend this with additional information required for their global ranking method.
  */
-public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragment, Comparable<RankDoc> {
+public class RankDoc extends ScoreDoc implements VersionedNamedWriteable, ToXContentFragment, Comparable<RankDoc> {
 
     public static final String NAME = "rank_doc";
 
@@ -40,6 +42,11 @@ public class RankDoc extends ScoreDoc implements NamedWriteable, ToXContentFragm
         return NAME;
     }
 
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.RANK_DOCS_RETRIEVER;
+    }
+
     @Override
     public final int compareTo(RankDoc other) {
         if (score != other.score) {

+ 2 - 2
server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

@@ -160,7 +160,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
 
     @Override
     public final QueryBuilder topDocsQuery() {
-        throw new IllegalStateException(getName() + " cannot be nested");
+        throw new IllegalStateException("Should not be called, missing a rewrite?");
     }
 
     @Override
@@ -208,7 +208,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
         return Objects.hash(innerRetrievers);
     }
 
-    private SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
+    protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
         var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
             .trackTotalHits(false)
             .storedFields(new StoredFieldsContext(false))

+ 5 - 1
server/src/main/java/org/elasticsearch/search/sort/ShardDocSortField.java

@@ -64,7 +64,7 @@ public class ShardDocSortField extends SortField {
 
             @Override
             public Long value(int slot) {
-                return (((long) shardRequestIndex) << 32) | (delegate.value(slot) & 0xFFFFFFFFL);
+                return encodeShardAndDoc(shardRequestIndex, delegate.value(slot));
             }
 
             @Override
@@ -87,4 +87,8 @@ public class ShardDocSortField extends SortField {
     public static int decodeShardRequestIndex(long value) {
         return (int) (value >> 32);
     }
+
+    public static long encodeShardAndDoc(int shardIndex, int doc) {
+        return (((long) shardIndex) << 32) | (doc & 0xFFFFFFFFL);
+    }
 }

+ 57 - 0
server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java

@@ -0,0 +1,57 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.search.rank;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.hamcrest.Matchers.equalTo;
+
+public abstract class AbstractRankDocWireSerializingTestCase<T extends RankDoc> extends AbstractWireSerializingTestCase<T> {
+
+    protected abstract T createTestRankDoc();
+
+    @Override
+    protected NamedWriteableRegistry writableRegistry() {
+        SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList());
+        List<NamedWriteableRegistry.Entry> entries = searchModule.getNamedWriteables();
+        entries.addAll(getAdditionalNamedWriteables());
+        return new NamedWriteableRegistry(entries);
+    }
+
+    protected abstract List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables();
+
+    @Override
+    protected T createTestInstance() {
+        return createTestRankDoc();
+    }
+
+    @SuppressWarnings({ "unchecked", "rawtypes" })
+    public void testRankDocSerialization() throws IOException {
+        int totalDocs = randomIntBetween(10, 100);
+        Set<T> docs = new HashSet<>();
+        for (int i = 0; i < totalDocs; i++) {
+            docs.add(createTestRankDoc());
+        }
+        RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(docs.toArray((T[]) new RankDoc[0]), null, randomBoolean());
+        RankDocsQueryBuilder copy = (RankDocsQueryBuilder) copyNamedWriteable(rankDocsQueryBuilder, writableRegistry(), QueryBuilder.class);
+        assertThat(rankDocsQueryBuilder, equalTo(copy));
+    }
+}

+ 9 - 7
server/src/test/java/org/elasticsearch/search/rank/RankDocTests.java

@@ -9,27 +9,29 @@
 
 package org.elasticsearch.search.rank;
 
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
 
-public class RankDocTests extends AbstractWireSerializingTestCase<RankDoc> {
+public class RankDocTests extends AbstractRankDocWireSerializingTestCase<RankDoc> {
 
-    static RankDoc createTestRankDoc() {
+    protected RankDoc createTestRankDoc() {
         RankDoc rankDoc = new RankDoc(randomNonNegativeInt(), randomFloat(), randomIntBetween(0, 1));
         rankDoc.rank = randomNonNegativeInt();
         return rankDoc;
     }
 
     @Override
-    protected Writeable.Reader<RankDoc> instanceReader() {
-        return RankDoc::new;
+    protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
+        return Collections.emptyList();
     }
 
     @Override
-    protected RankDoc createTestInstance() {
-        return createTestRankDoc();
+    protected Writeable.Reader<RankDoc> instanceReader() {
+        return RankDoc::new;
     }
 
     @Override

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

@@ -25,7 +25,8 @@ public class InferenceFeatures implements FeatureSpecification {
         return Set.of(
             TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
             RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
-            SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID
+            SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
+            TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
         );
     }
 

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -36,6 +36,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin;
 import org.elasticsearch.rest.RestController;
 import org.elasticsearch.rest.RestHandler;
 import org.elasticsearch.search.rank.RankBuilder;
+import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.threadpool.ExecutorBuilder;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
 import org.elasticsearch.xcontent.ParseField;
@@ -66,6 +67,7 @@ import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
 import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
+import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankDoc;
 import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceEndpointAction;
@@ -253,6 +255,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
         var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables());
         entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new));
         entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new));
+        entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
         return entries;
     }
 

+ 103 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java

@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rank.textsimilarity;
+
+import org.apache.lucene.search.Explanation;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.search.rank.RankDoc;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class TextSimilarityRankDoc extends RankDoc {
+
+    public static final String NAME = "text_similarity_rank_doc";
+
+    public final String inferenceId;
+    public final String field;
+
+    public TextSimilarityRankDoc(int doc, float score, int shardIndex, String inferenceId, String field) {
+        super(doc, score, shardIndex);
+        this.inferenceId = inferenceId;
+        this.field = field;
+    }
+
+    public TextSimilarityRankDoc(StreamInput in) throws IOException {
+        super(in);
+        inferenceId = in.readString();
+        field = in.readString();
+    }
+
+    @Override
+    public Explanation explain(Explanation[] sources, String[] queryNames) {
+        final String queryAlias = queryNames[0] == null ? "" : "[" + queryNames[0] + "]";
+        return Explanation.match(
+            score,
+            "text_similarity_reranker match using inference endpoint: ["
+                + inferenceId
+                + "] on document field: ["
+                + field
+                + "] matching on source query "
+                + queryAlias,
+            sources
+        );
+    }
+
+    @Override
+    public void doWriteTo(StreamOutput out) throws IOException {
+        out.writeString(inferenceId);
+        out.writeString(field);
+    }
+
+    @Override
+    public boolean doEquals(RankDoc rd) {
+        TextSimilarityRankDoc tsrd = (TextSimilarityRankDoc) rd;
+        return Objects.equals(inferenceId, tsrd.inferenceId) && Objects.equals(field, tsrd.field);
+    }
+
+    @Override
+    public int doHashCode() {
+        return Objects.hash(inferenceId, field);
+    }
+
+    @Override
+    public String toString() {
+        return "TextSimilarityRankDoc{"
+            + "doc="
+            + doc
+            + ", shardIndex="
+            + shardIndex
+            + ", score="
+            + score
+            + ", inferenceId="
+            + inferenceId
+            + ", field="
+            + field
+            + '}';
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field("inferenceId", inferenceId);
+        builder.field("field", field);
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE;
+    }
+}

+ 73 - 54
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -7,14 +7,20 @@
 
 package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
+import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.license.LicenseUtils;
+import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.fetch.StoredFieldsContext;
+import org.elasticsearch.search.rank.RankDoc;
+import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
+import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -32,11 +38,14 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
 /**
  * A {@code RetrieverBuilder} for parsing and constructing a text similarity reranker retriever.
  */
-public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
+public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder<TextSimilarityRankRetrieverBuilder> {
 
     public static final NodeFeature TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED = new NodeFeature(
         "text_similarity_reranker_retriever_supported"
     );
+    public static final NodeFeature TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED = new NodeFeature(
+        "text_similarity_reranker_retriever_composition_supported"
+    );
 
     public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
     public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
@@ -51,7 +60,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
             String inferenceText = (String) args[2];
             String field = (String) args[3];
             int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
-
             return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize);
         });
 
@@ -70,17 +78,20 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
         if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED) == false) {
             throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + TextSimilarityRankBuilder.NAME + "]");
         }
+        if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED) == false) {
+            throw new UnsupportedOperationException(
+                "[text_similarity_reranker] retriever composition feature is not supported by all nodes in the cluster"
+            );
+        }
         if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
             throw LicenseUtils.newComplianceException(TextSimilarityRankBuilder.NAME);
         }
         return PARSER.apply(parser, context);
     }
 
-    private final RetrieverBuilder retrieverBuilder;
     private final String inferenceId;
     private final String inferenceText;
     private final String field;
-    private final int rankWindowSize;
 
     public TextSimilarityRankRetrieverBuilder(
         RetrieverBuilder retrieverBuilder,
@@ -89,15 +100,14 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
         String field,
         int rankWindowSize
     ) {
-        this.retrieverBuilder = retrieverBuilder;
+        super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.field = field;
-        this.rankWindowSize = rankWindowSize;
     }
 
     public TextSimilarityRankRetrieverBuilder(
-        RetrieverBuilder retrieverBuilder,
+        List<RetrieverSource> retrieverSource,
         String inferenceId,
         String inferenceText,
         String field,
@@ -106,66 +116,75 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
         String retrieverName,
         List<QueryBuilder> preFilterQueryBuilders
     ) {
-        this.retrieverBuilder = retrieverBuilder;
+        super(retrieverSource, rankWindowSize);
+        if (retrieverSource.size() != 1) {
+            throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever");
+        }
         this.inferenceId = inferenceId;
         this.inferenceText = inferenceText;
         this.field = field;
-        this.rankWindowSize = rankWindowSize;
         this.minScore = minScore;
         this.retrieverName = retrieverName;
         this.preFilterQueryBuilders = preFilterQueryBuilders;
     }
 
     @Override
-    public QueryBuilder topDocsQuery() {
-        // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
-        return retrieverBuilder.topDocsQuery();
+    protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
+        return new TextSimilarityRankRetrieverBuilder(
+            newChildRetrievers,
+            inferenceId,
+            inferenceText,
+            field,
+            rankWindowSize,
+            minScore,
+            retrieverName,
+            preFilterQueryBuilders
+        );
     }
 
     @Override
-    public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
-        // rewrite prefilters
-        boolean hasChanged = false;
-        var newPreFilters = rewritePreFilters(ctx);
-        hasChanged |= newPreFilters != preFilterQueryBuilders;
-
-        // rewrite nested retriever
-        RetrieverBuilder newRetriever = retrieverBuilder.rewrite(ctx);
-        hasChanged |= newRetriever != retrieverBuilder;
-        if (hasChanged) {
-            return new TextSimilarityRankRetrieverBuilder(
-                newRetriever,
-                field,
-                inferenceText,
-                inferenceId,
-                rankWindowSize,
-                minScore,
-                this.retrieverName,
-                newPreFilters
-            );
+    protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
+        assert rankResults.size() == 1;
+        ScoreDoc[] scoreDocs = rankResults.get(0);
+        TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
+        for (int i = 0; i < scoreDocs.length; i++) {
+            ScoreDoc scoreDoc = scoreDocs[i];
+            textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field);
         }
-        return this;
+        return textSimilarityRankDocs;
     }
 
     @Override
-    public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
-        retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
-        retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed);
-        // Combining with other rank builder (such as RRF) is not supported yet
-        if (searchSourceBuilder.rankBuilder() != null) {
-            throw new IllegalArgumentException("text similarity rank builder cannot be combined with other rank builders");
-        }
+    public QueryBuilder explainQuery() {
+        // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
+        return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.get(0).retriever().explainQuery() }, true);
+    }
 
-        searchSourceBuilder.rankBuilder(
+    @Override
+    protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
+        var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
+            .trackTotalHits(false)
+            .storedFields(new StoredFieldsContext(false))
+            .size(rankWindowSize);
+        if (preFilterQueryBuilders.isEmpty() == false) {
+            retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
+        }
+        retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
+
+        // apply the pre-filters
+        if (preFilterQueryBuilders.size() > 0) {
+            QueryBuilder query = sourceBuilder.query();
+            BoolQueryBuilder newQuery = new BoolQueryBuilder();
+            if (query != null) {
+                newQuery.must(query);
+            }
+            preFilterQueryBuilders.forEach(newQuery::filter);
+            sourceBuilder.query(newQuery);
+        }
+        sourceBuilder.rankBuilder(
             new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
         );
-    }
-
-    /**
-     * Determines if this retriever contains sub-retrievers that need to be executed prior to search.
-     */
-    public boolean isCompound() {
-        return retrieverBuilder.isCompound();
+        return sourceBuilder;
     }
 
     @Override
@@ -179,7 +198,7 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
 
     @Override
     protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
+        builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.get(0).retriever());
         builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
         builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
         builder.field(FIELD_FIELD.getPreferredName(), field);
@@ -187,9 +206,9 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
     }
 
     @Override
-    protected boolean doEquals(Object other) {
+    public boolean doEquals(Object other) {
         TextSimilarityRankRetrieverBuilder that = (TextSimilarityRankRetrieverBuilder) other;
-        return Objects.equals(retrieverBuilder, that.retrieverBuilder)
+        return super.doEquals(other)
             && Objects.equals(inferenceId, that.inferenceId)
             && Objects.equals(inferenceText, that.inferenceText)
             && Objects.equals(field, that.field)
@@ -198,7 +217,7 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
     }
 
     @Override
-    protected int doHashCode() {
-        return Objects.hash(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, minScore);
+    public int doHashCode() {
+        return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore);
     }
 }

+ 88 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDocTests.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.rank.textsimilarity;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.InferencePlugin;
+
+import java.io.IOException;
+import java.util.List;
+
+import static org.elasticsearch.search.rank.RankDoc.NO_RANK;
+
+public class TextSimilarityRankDocTests extends AbstractRankDocWireSerializingTestCase<TextSimilarityRankDoc> {
+
+    static TextSimilarityRankDoc createTestTextSimilarityRankDoc() {
+        TextSimilarityRankDoc instance = new TextSimilarityRankDoc(
+            randomNonNegativeInt(),
+            randomFloat(),
+            randomBoolean() ? -1 : randomNonNegativeInt(),
+            randomAlphaOfLength(randomIntBetween(2, 5)),
+            randomAlphaOfLength(randomIntBetween(2, 5))
+        );
+        instance.rank = randomBoolean() ? NO_RANK : randomIntBetween(1, 10000);
+        return instance;
+    }
+
+    @Override
+    protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
+        try (InferencePlugin plugin = new InferencePlugin(Settings.EMPTY)) {
+            return plugin.getNamedWriteables();
+        }
+    }
+
+    @Override
+    protected Writeable.Reader<TextSimilarityRankDoc> instanceReader() {
+        return TextSimilarityRankDoc::new;
+    }
+
+    @Override
+    protected TextSimilarityRankDoc createTestRankDoc() {
+        return createTestTextSimilarityRankDoc();
+    }
+
+    @Override
+    protected TextSimilarityRankDoc mutateInstance(TextSimilarityRankDoc instance) throws IOException {
+        int doc = instance.doc;
+        int shardIndex = instance.shardIndex;
+        float score = instance.score;
+        int rank = instance.rank;
+        String inferenceId = instance.inferenceId;
+        String field = instance.field;
+
+        switch (randomInt(5)) {
+            case 0:
+                doc = randomValueOtherThan(doc, ESTestCase::randomNonNegativeInt);
+                break;
+            case 1:
+                shardIndex = shardIndex == -1 ? randomNonNegativeInt() : -1;
+                break;
+            case 2:
+                score = randomValueOtherThan(score, ESTestCase::randomFloat);
+                break;
+            case 3:
+                rank = rank == NO_RANK ? randomIntBetween(1, 10000) : NO_RANK;
+                break;
+            case 4:
+                inferenceId = randomValueOtherThan(inferenceId, () -> randomAlphaOfLength(randomIntBetween(2, 5)));
+                break;
+            case 5:
+                field = randomValueOtherThan(field, () -> randomAlphaOfLength(randomIntBetween(2, 5)));
+                break;
+            default:
+                throw new AssertionError();
+        }
+        TextSimilarityRankDoc mutated = new TextSimilarityRankDoc(doc, score, shardIndex, inferenceId, field);
+        mutated.rank = rank;
+        return mutated;
+    }
+}

+ 2 - 119
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilderTests.java

@@ -9,17 +9,9 @@ package org.elasticsearch.xpack.inference.rank.textsimilarity;
 
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.index.query.BoolQueryBuilder;
-import org.elasticsearch.index.query.MatchAllQueryBuilder;
-import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryRewriteContext;
-import org.elasticsearch.index.query.RandomQueryBuilder;
-import org.elasticsearch.index.query.RangeQueryBuilder;
-import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.search.builder.SubSearchSourceBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
 import org.elasticsearch.search.retriever.TestRetrieverBuilder;
@@ -38,10 +30,8 @@ import java.util.ArrayList;
 import java.util.List;
 
 import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
-import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
-import static org.mockito.Mockito.mock;
 
 public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTestCase<TextSimilarityRankRetrieverBuilder> {
 
@@ -82,6 +72,7 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
                 new SearchUsage(),
                 nf -> nf == RetrieverBuilder.RETRIEVERS_SUPPORTED
                     || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED
+                    || nf == TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
             )
         );
     }
@@ -131,86 +122,6 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
         }
     }
 
-    public void testRewriteInnerRetriever() throws IOException {
-        final boolean[] rewritten = { false };
-        List<QueryBuilder> preFilterQueryBuilders = new ArrayList<>();
-        if (randomBoolean()) {
-            for (int i = 0; i < randomIntBetween(1, 5); i++) {
-                preFilterQueryBuilders.add(RandomQueryBuilder.createQuery(random()));
-            }
-        }
-        RetrieverBuilder innerRetriever = new TestRetrieverBuilder("top-level-retriever") {
-            @Override
-            public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
-                if (randomBoolean()) {
-                    return this;
-                }
-                rewritten[0] = true;
-                return new TestRetrieverBuilder("nested-rewritten-retriever") {
-                    @Override
-                    public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
-                        if (preFilterQueryBuilders.isEmpty() == false) {
-                            BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-
-                            for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) {
-                                boolQueryBuilder.filter(preFilterQueryBuilder);
-                            }
-                            boolQueryBuilder.must(new RangeQueryBuilder("some_field"));
-                            searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder));
-                        } else {
-                            searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new RangeQueryBuilder("some_field")));
-                        }
-                    }
-                };
-            }
-
-            @Override
-            public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
-                if (preFilterQueryBuilders.isEmpty() == false) {
-                    BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
-
-                    for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) {
-                        boolQueryBuilder.filter(preFilterQueryBuilder);
-                    }
-                    boolQueryBuilder.must(new TermQueryBuilder("field", "value"));
-                    searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(boolQueryBuilder));
-                } else {
-                    searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(new TermQueryBuilder("field", "value")));
-                }
-            }
-        };
-        TextSimilarityRankRetrieverBuilder textSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder(
-            innerRetriever
-        );
-        textSimilarityRankRetrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
-        SearchSourceBuilder source = new SearchSourceBuilder().retriever(textSimilarityRankRetrieverBuilder);
-        QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
-        source = Rewriteable.rewrite(source, queryRewriteContext);
-        assertNull(source.retriever());
-        if (false == preFilterQueryBuilders.isEmpty()) {
-            if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) {
-                assertThat(source.query(), instanceOf(BoolQueryBuilder.class));
-                BoolQueryBuilder bq = (BoolQueryBuilder) source.query();
-                assertFalse(bq.must().isEmpty());
-                assertThat(bq.must().size(), equalTo(1));
-                if (rewritten[0]) {
-                    assertThat(bq.must().get(0), instanceOf(RangeQueryBuilder.class));
-                } else {
-                    assertThat(bq.must().get(0), instanceOf(TermQueryBuilder.class));
-                }
-                for (int j = 0; j < bq.filter().size(); j++) {
-                    assertEqualQueryOrMatchAllNone(bq.filter().get(j), preFilterQueryBuilders.get(j));
-                }
-            }
-        } else {
-            if (rewritten[0]) {
-                assertThat(source.query(), instanceOf(RangeQueryBuilder.class));
-            } else {
-                assertThat(source.query(), instanceOf(TermQueryBuilder.class));
-            }
-        }
-    }
-
     public void testTextSimilarityRetrieverParsing() throws IOException {
         String restContent = "{"
             + "  \"retriever\": {"
@@ -250,29 +161,6 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
         }
     }
 
-    public void testIsCompound() {
-        RetrieverBuilder compoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
-            @Override
-            public boolean isCompound() {
-                return true;
-            }
-        };
-        RetrieverBuilder nonCompoundInnerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
-            @Override
-            public boolean isCompound() {
-                return false;
-            }
-        };
-        TextSimilarityRankRetrieverBuilder compoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder(
-            compoundInnerRetriever
-        );
-        assertTrue(compoundTextSimilarityRankRetrieverBuilder.isCompound());
-        TextSimilarityRankRetrieverBuilder nonCompoundTextSimilarityRankRetrieverBuilder = createRandomTextSimilarityRankRetrieverBuilder(
-            nonCompoundInnerRetriever
-        );
-        assertFalse(nonCompoundTextSimilarityRankRetrieverBuilder.isCompound());
-    }
-
     public void testTopDocsQuery() {
         RetrieverBuilder innerRetriever = new TestRetrieverBuilder(ESTestCase.randomAlphaOfLengthBetween(5, 10)) {
             @Override
@@ -281,11 +169,6 @@ public class TextSimilarityRankRetrieverBuilderTests extends AbstractXContentTes
             }
         };
         TextSimilarityRankRetrieverBuilder retriever = createRandomTextSimilarityRankRetrieverBuilder(innerRetriever);
-        assertThat(retriever.topDocsQuery(), instanceOf(TermQueryBuilder.class));
-    }
-
-    private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) {
-        assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected)));
+        expectThrows(IllegalStateException.class, "Should not be called, missing a rewrite?", retriever::topDocsQuery);
     }
-
 }

+ 37 - 3
x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/70_text_similarity_rank_retriever.yml

@@ -87,11 +87,9 @@ setup:
   - length: { hits.hits: 2 }
 
   - match: { hits.hits.0._id: "doc_2" }
-  - match: { hits.hits.0._rank: 1 }
   - close_to: { hits.hits.0._score: { value: 0.4, error: 0.001 } }
 
   - match: { hits.hits.1._id: "doc_1" }
-  - match: { hits.hits.1._rank: 2 }
   - close_to: { hits.hits.1._score: { value: 0.2, error: 0.001 } }
 
 ---
@@ -123,7 +121,6 @@ setup:
   - length: { hits.hits: 1 }
 
   - match: { hits.hits.0._id: "doc_1" }
-  - match: { hits.hits.0._rank: 1 }
   - close_to: { hits.hits.0._score: { value: 0.2, error: 0.001 } }
 
 
@@ -178,3 +175,40 @@ setup:
               field: text
           size: 10
 
+
+---
+"text similarity reranking with explain":
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever: {
+            text_similarity_reranker: {
+              retriever:
+                {
+                  standard: {
+                    query: {
+                      term: {
+                        topic: "science"
+                      }
+                    }
+                  }
+                },
+              rank_window_size: 10,
+              inference_id: my-rerank-model,
+              inference_text: "How often does the moon hide the sun?",
+              field: text
+            }
+          }
+          size: 10
+          explain: true
+
+  - match: { hits.hits.0._id: "doc_2" }
+  - match: { hits.hits.1._id: "doc_1" }
+
+  - close_to: { hits.hits.0._explanation.value: { value: 0.4, error: 0.000001 } }
+  - match: {hits.hits.0._explanation.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
+  - match: {hits.hits.0._explanation.details.0.description: "/weight.*science.*/" }

+ 4 - 0
x-pack/plugin/rank-rrf/build.gradle

@@ -20,7 +20,11 @@ dependencies {
   compileOnly project(path: xpackModule('core'))
 
   testImplementation(testArtifact(project(xpackModule('core'))))
+  testImplementation(testArtifact(project(':server')))
 
   clusterModules project(xpackModule('rank-rrf'))
+  clusterModules project(xpackModule('inference'))
   clusterModules project(':modules:lang-painless')
+
+  clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
 }

+ 6 - 0
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.rank.rrf;
 
 import org.apache.lucene.search.Explanation;
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -169,4 +170,9 @@ public final class RRFRankDoc extends RankDoc {
         builder.field("scores", scores);
         builder.field("rankConstant", rankConstant);
     }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.RRF_QUERY_REWRITE;
+    }
 }

+ 14 - 7
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRankDocTests.java

@@ -7,15 +7,17 @@
 
 package org.elasticsearch.xpack.rank.rrf;
 
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable.Reader;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.search.rank.AbstractRankDocWireSerializingTestCase;
 import org.elasticsearch.test.ESTestCase;
 
 import java.io.IOException;
+import java.util.List;
 
 import static org.elasticsearch.xpack.rank.rrf.RRFRankDoc.NO_RANK;
 
-public class RRFRankDocTests extends AbstractWireSerializingTestCase<RRFRankDoc> {
+public class RRFRankDocTests extends AbstractRankDocWireSerializingTestCase<RRFRankDoc> {
 
     static RRFRankDoc createTestRRFRankDoc(int queryCount) {
         RRFRankDoc instance = new RRFRankDoc(
@@ -35,9 +37,13 @@ public class RRFRankDocTests extends AbstractWireSerializingTestCase<RRFRankDoc>
         return instance;
     }
 
-    static RRFRankDoc createTestRRFRankDoc() {
-        int queryCount = randomIntBetween(2, 20);
-        return createTestRRFRankDoc(queryCount);
+    @Override
+    protected List<NamedWriteableRegistry.Entry> getAdditionalNamedWriteables() {
+        try (RRFRankPlugin rrfRankPlugin = new RRFRankPlugin()) {
+            return rrfRankPlugin.getNamedWriteables();
+        } catch (IOException ex) {
+            throw new AssertionError("Failed to create RRFRankPlugin", ex);
+        }
     }
 
     @Override
@@ -46,8 +52,9 @@ public class RRFRankDocTests extends AbstractWireSerializingTestCase<RRFRankDoc>
     }
 
     @Override
-    protected RRFRankDoc createTestInstance() {
-        return createTestRRFRankDoc();
+    protected RRFRankDoc createTestRankDoc() {
+        int queryCount = randomIntBetween(2, 20);
+        return createTestRRFRankDoc(queryCount);
     }
 
     @Override

+ 2 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/java/org/elasticsearch/xpack/rank/rrf/RRFRankClientYamlTestSuiteIT.java

@@ -23,7 +23,9 @@ public class RRFRankClientYamlTestSuiteIT extends ESClientYamlSuiteTestCase {
         .nodes(2)
         .module("rank-rrf")
         .module("lang-painless")
+        .module("x-pack-inference")
         .setting("xpack.license.self_generated.type", "trial")
+        .plugin("inference-service-test")
         .build();
 
     public RRFRankClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) {

+ 334 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/800_rrf_with_text_similarity_reranker_retriever.yml

@@ -0,0 +1,334 @@
+setup:
+  - requires:
+      cluster_features: ['rrf_retriever_composition_supported', 'text_similarity_reranker_retriever_supported']
+      reason: need to have support for rrf and semantic reranking composition
+      test_runner_features: "close_to"
+
+  - do:
+      inference.put:
+        task_type: rerank
+        inference_id: my-rerank-model
+        body: >
+          {
+            "service": "test_reranking_service",
+            "service_settings": {
+              "model_id": "my_model",
+              "api_key": "abc64"
+            },
+            "task_settings": {
+            }
+          }
+
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          settings:
+            number_of_shards: 1
+          mappings:
+            properties:
+              text:
+                type: text
+              topic:
+                type: keyword
+              subtopic:
+                type: keyword
+              integer:
+                type: integer
+
+  - do:
+      index:
+        index: test-index
+        id: doc_1
+        body:
+          text: "Sun Moon Lake is a lake in Nantou County, Taiwan. It is the largest lake in Taiwan."
+          topic: [ "geography" ]
+          integer: 1
+
+  - do:
+      index:
+        index: test-index
+        id: doc_2
+        body:
+          text: "The phases of the Moon come from the position of the Moon relative to the Earth and Sun."
+          topic: [ "science" ]
+          subtopic: [ "astronomy" ]
+          integer: 2
+
+  - do:
+      index:
+        index: test-index
+        id: doc_3
+        body:
+          text: "As seen from Earth, a solar eclipse happens when the Moon is directly between the Earth and the Sun."
+          topic: [ "science" ]
+          subtopic: [ "technology" ]
+          integer: 3
+
+  - do:
+      indices.refresh: {}
+
+---
+"rrf retriever with a nested text similarity reranker":
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            rrf: {
+              retrievers:
+                [
+                  {
+                    standard: {
+                      query: {
+                        bool: {
+                          should:
+                            [
+                              {
+                                constant_score: {
+                                  filter: {
+                                    term: {
+                                      integer: 1
+                                    }
+                                  },
+                                  boost: 10
+                                }
+                              },
+                              {
+                                constant_score:
+                                  {
+                                    filter:
+                                      {
+                                        term:
+                                          {
+                                            integer: 2
+                                          }
+                                      },
+                                    boost: 1
+                                  }
+                              }
+                            ]
+                        }
+                      }
+                    }
+                  },
+                  {
+                    text_similarity_reranker: {
+                      retriever:
+                        {
+                          standard: {
+                            query: {
+                              term: {
+                                topic: "science"
+                              }
+                            }
+                          }
+                        },
+                      rank_window_size: 10,
+                      inference_id: my-rerank-model,
+                      inference_text: "How often does the moon hide the sun?",
+                      field: text
+                    }
+                  }
+                ],
+              rank_window_size: 10,
+              rank_constant: 1
+            }
+          size: 10
+          from: 1
+          aggs:
+            topics:
+              terms:
+                field: topic
+                size: 10
+
+  - match: { hits.total.value: 3 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_1" }
+  - match: { hits.hits.1._id: "doc_3" }
+
+  - match: { aggregations.topics.buckets.0.key: "science" }
+  - match: { aggregations.topics.buckets.0.doc_count: 2 }
+  - match: { aggregations.topics.buckets.1.key: "geography" }
+  - match: { aggregations.topics.buckets.1.doc_count: 1 }
+
+---
+"Text similarity reranker on top of an RRF retriever":
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            {
+              text_similarity_reranker: {
+                retriever:
+                  {
+                    rrf: {
+                      retrievers:
+                        [
+                          {
+                            standard: {
+                              query: {
+                                bool: {
+                                  should:
+                                    [
+                                      {
+                                        constant_score: {
+                                          filter: {
+                                            term: {
+                                              integer: 1
+                                            }
+                                          },
+                                          boost: 10
+                                        }
+                                      },
+                                      {
+                                        constant_score:
+                                          {
+                                            filter:
+                                              {
+                                                term:
+                                                  {
+                                                    integer: 3
+                                                  }
+                                              },
+                                            boost: 1
+                                          }
+                                      }
+                                    ]
+                                }
+                              }
+                            }
+                          },
+                          {
+                            standard: {
+                              query: {
+                                term: {
+                                  topic: "geography"
+                                }
+                              }
+                            }
+                          }
+                        ],
+                      rank_window_size: 10,
+                      rank_constant: 1
+                    }
+                  },
+                rank_window_size: 10,
+                inference_id: my-rerank-model,
+                inference_text: "How often does the moon hide the sun?",
+                field: text
+              }
+            }
+          size: 10
+          aggs:
+            topics:
+              terms:
+                field: topic
+                size: 10
+
+  - match: { hits.total.value: 2 }
+  - length: { hits.hits: 2 }
+
+  - match: { hits.hits.0._id: "doc_3" }
+  - match: { hits.hits.1._id: "doc_1" }
+
+  - match: { aggregations.topics.buckets.0.key: "geography" }
+  - match: { aggregations.topics.buckets.0.doc_count: 1 }
+  - match: { aggregations.topics.buckets.1.key: "science" }
+  - match: { aggregations.topics.buckets.1.doc_count: 1 }
+
+
+---
+"explain using rrf retriever and text-similarity":
+
+  - do:
+      search:
+        index: test-index
+        body:
+          track_total_hits: true
+          fields: [ "text", "topic" ]
+          retriever:
+            rrf: {
+              retrievers:
+                [
+                  {
+                    standard: {
+                      query: {
+                        bool: {
+                          should:
+                            [
+                              {
+                                constant_score: {
+                                  filter: {
+                                    term: {
+                                      integer: 1
+                                    }
+                                  },
+                                  boost: 10
+                                }
+                              },
+                              {
+                                constant_score:
+                                  {
+                                    filter:
+                                      {
+                                        term:
+                                          {
+                                            integer: 2
+                                          }
+                                      },
+                                    boost: 1
+                                  }
+                              }
+                            ]
+                        }
+                      }
+                    }
+                  },
+                  {
+                    text_similarity_reranker: {
+                      retriever:
+                        {
+                          standard: {
+                            query: {
+                              term: {
+                                topic: "science"
+                              }
+                            }
+                          }
+                        },
+                      rank_window_size: 10,
+                      inference_id: my-rerank-model,
+                      inference_text: "How often does the moon hide the sun?",
+                      field: text
+                    }
+                  }
+                ],
+              rank_window_size: 10,
+              rank_constant: 1
+            }
+          size: 10
+          explain: true
+
+  - match: { hits.hits.0._id: "doc_2" }
+  - match: { hits.hits.1._id: "doc_1" }
+  - match: { hits.hits.2._id: "doc_3" }
+
+  - close_to: { hits.hits.0._explanation.value: { value: 0.6666667, error: 0.000001 } }
+  - match: {hits.hits.0._explanation.description: "/rrf.score:.\\[0.6666667\\].*/" }
+  - match: {hits.hits.0._explanation.details.0.value: 2}
+  - match: {hits.hits.0._explanation.details.0.description: "/rrf.score:.\\[0.33333334\\].*/" }
+  - match: {hits.hits.0._explanation.details.0.details.0.details.0.description: "/ConstantScore.*/" }
+  - match: {hits.hits.0._explanation.details.1.value: 2}
+  - match: {hits.hits.0._explanation.details.1.description: "/rrf.score:.\\[0.33333334\\].*/" }
+  - match: {hits.hits.0._explanation.details.1.details.0.description: "/text_similarity_reranker.match.using.inference.endpoint:.\\[my-rerank-model\\].on.document.field:.\\[text\\].*/" }
+  - match: {hits.hits.0._explanation.details.1.details.0.details.0.description: "/weight.*science.*/" }