Ver código fonte

[8.x] Backporting propagating nested inner_hits to the parent compound retriever (#116707)

Panagiotis Bailis 11 meses atrás
pai
commit
7d33c5c597
19 arquivos alterados com 248 adições e 42 exclusões
  1. 6 0
      docs/changelog/116408.yaml
  2. 60 0
      server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java
  3. 1 1
      server/src/main/java/org/elasticsearch/TransportVersions.java
  4. 13 6
      server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java
  5. 3 0
      server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
  6. 1 1
      server/src/main/java/org/elasticsearch/search/SearchModule.java
  7. 6 2
      server/src/main/java/org/elasticsearch/search/SearchService.java
  8. 27 2
      server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java
  9. 6 1
      server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java
  10. 1 1
      server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java
  11. 1 1
      server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java
  12. 1 1
      server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java
  13. 2 3
      server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java
  14. 1 1
      server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java
  15. 1 1
      server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java
  16. 1 1
      server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java
  17. 5 7
      x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java
  18. 1 13
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java
  19. 111 0
      x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml

+ 6 - 0
docs/changelog/116408.yaml

@@ -0,0 +1,6 @@
+pr: 116408
+summary: Propagating nested `inner_hits` to the parent compound retriever
+area: Ranking
+type: bug
+issues:
+ - 116397

+ 60 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java

@@ -21,7 +21,9 @@ import org.elasticsearch.action.search.SearchType;
 import org.elasticsearch.cluster.health.ClusterHealthStatus;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.query.InnerHitBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.sort.NestedSortBuilder;
 import org.elasticsearch.search.sort.SortBuilders;
 import org.elasticsearch.search.sort.SortMode;
@@ -1581,6 +1583,64 @@ public class SimpleNestedIT extends ESIntegTestCase {
         assertThat(clusterStatsResponse.getIndicesStats().getSegments().getBitsetMemoryInBytes(), equalTo(0L));
     }
 
+    public void testSkipNestedInnerHits() throws Exception {
+        assertAcked(prepareCreate("test").setMapping("nested1", "type=nested"));
+        ensureGreen();
+
+        prepareIndex("test").setId("1")
+            .setSource(
+                jsonBuilder().startObject()
+                    .field("field1", "value1")
+                    .startArray("nested1")
+                    .startObject()
+                    .field("n_field1", "foo")
+                    .field("n_field2", "bar")
+                    .endObject()
+                    .endArray()
+                    .endObject()
+            )
+            .get();
+
+        waitForRelocation(ClusterHealthStatus.GREEN);
+        GetResponse getResponse = client().prepareGet("test", "1").get();
+        assertThat(getResponse.isExists(), equalTo(true));
+        assertThat(getResponse.getSourceAsBytesRef(), notNullValue());
+        refresh();
+
+        assertNoFailuresAndResponse(
+            prepareSearch("test").setSource(
+                new SearchSourceBuilder().query(
+                    QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
+                        .innerHit(new InnerHitBuilder())
+                )
+            ),
+            res -> {
+                assertNotNull(res.getHits());
+                assertHitCount(res, 1);
+                assertThat(res.getHits().getHits().length, equalTo(1));
+                // by default we should get inner hits
+                assertNotNull(res.getHits().getHits()[0].getInnerHits());
+                assertNotNull(res.getHits().getHits()[0].getInnerHits().get("nested1"));
+            }
+        );
+
+        assertNoFailuresAndResponse(
+            prepareSearch("test").setSource(
+                new SearchSourceBuilder().query(
+                    QueryBuilders.nestedQuery("nested1", QueryBuilders.termQuery("nested1.n_field1", "foo"), ScoreMode.Avg)
+                        .innerHit(new InnerHitBuilder())
+                ).skipInnerHits(true)
+            ),
+            res -> {
+                assertNotNull(res.getHits());
+                assertHitCount(res, 1);
+                assertThat(res.getHits().getHits().length, equalTo(1));
+                // if we explicitly say to ignore inner hits, then this should now be null
+                assertNull(res.getHits().getHits()[0].getInnerHits());
+            }
+        );
+    }
+
     private void assertDocumentCount(String index, long numdocs) {
         IndicesStatsResponse stats = indicesAdmin().prepareStats(index).clear().setDocs(true).get();
         assertNoFailures(stats);

+ 1 - 1
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -194,7 +194,7 @@ public class TransportVersions {
     public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
     public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
     public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);
-
+    public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE = def(8_791_00_0);
     /*
      * STOP! READ THIS FIRST! No, really,
      *        ____ _____ ___  ____  _        ____  _____    _    ____    _____ _   _ ___ ____    _____ ___ ____  ____ _____ _

+ 13 - 6
server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilder.java → server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java

@@ -7,7 +7,7 @@
  * License v3.0 only", or the "Server Side Public License, v 1".
  */
 
-package org.elasticsearch.search.retriever.rankdoc;
+package org.elasticsearch.index.query;
 
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.search.Query;
@@ -16,15 +16,13 @@ import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.index.query.AbstractQueryBuilder;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.QueryRewriteContext;
-import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.rank.RankDoc;
+import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
 import java.util.Arrays;
+import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE;
@@ -55,6 +53,15 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
         }
     }
 
+    @Override
+    protected void extractInnerHitBuilders(Map<String, InnerHitContextBuilder> innerHits) {
+        if (queryBuilders != null) {
+            for (QueryBuilder query : queryBuilders) {
+                InnerHitContextBuilder.extractInnerHits(query, innerHits);
+            }
+        }
+    }
+
     @Override
     protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
         if (queryBuilders != null) {
@@ -71,7 +78,7 @@ public class RankDocsQueryBuilder extends AbstractQueryBuilder<RankDocsQueryBuil
         return super.doRewrite(queryRewriteContext);
     }
 
-    RankDoc[] rankDocs() {
+    public RankDoc[] rankDocs() {
         return rankDocs;
     }
 

+ 3 - 0
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -34,6 +34,8 @@ public final class SearchCapabilities {
     private static final String KQL_QUERY_SUPPORTED = "kql_query";
     /** Support multi-dense-vector field mapper. */
     private static final String MULTI_DENSE_VECTOR_FIELD_MAPPER = "multi_dense_vector_field_mapper";
+    /** Support propagating nested retrievers' inner_hits to top-level compound retrievers . */
+    private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
 
     public static final Set<String> CAPABILITIES;
     static {
@@ -42,6 +44,7 @@ public final class SearchCapabilities {
         capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
         capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
         capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
+        capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
         if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
             capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
         }

+ 1 - 1
server/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -52,6 +52,7 @@ import org.elasticsearch.index.query.PrefixQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryStringQueryBuilder;
 import org.elasticsearch.index.query.RangeQueryBuilder;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.index.query.RegexpQueryBuilder;
 import org.elasticsearch.index.query.ScriptQueryBuilder;
 import org.elasticsearch.index.query.SimpleQueryStringBuilder;
@@ -238,7 +239,6 @@ import org.elasticsearch.search.retriever.KnnRetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
 import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
 import org.elasticsearch.search.sort.ScoreSortBuilder;

+ 6 - 2
server/src/main/java/org/elasticsearch/search/SearchService.java

@@ -1300,13 +1300,17 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
         );
         if (query != null) {
             QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(query, innerHitsRewriteContext, true);
-            InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
+            if (false == source.skipInnerHits()) {
+                InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
+            }
             searchExecutionContext.setAliasFilter(context.request().getAliasFilter().getQueryBuilder());
             context.parsedQuery(searchExecutionContext.toQuery(query));
         }
         if (source.postFilter() != null) {
             QueryBuilder rewrittenForInnerHits = Rewriteable.rewrite(source.postFilter(), innerHitsRewriteContext, true);
-            InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
+            if (false == source.skipInnerHits()) {
+                InnerHitContextBuilder.extractInnerHits(rewrittenForInnerHits, innerHitBuilders);
+            }
             context.parsedPostFilter(searchExecutionContext.toQuery(source.postFilter()));
         }
         if (innerHitBuilders.size() > 0) {

+ 27 - 2
server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

@@ -213,6 +213,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
 
     private Map<String, Object> runtimeMappings = emptyMap();
 
+    private boolean skipInnerHits = false;
+
     /**
      * Constructs a new search source builder.
      */
@@ -289,6 +291,11 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
             rankBuilder = in.readOptionalNamedWriteable(RankBuilder.class);
         }
+        if (in.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
+            skipInnerHits = in.readBoolean();
+        } else {
+            skipInnerHits = false;
+        }
     }
 
     @Override
@@ -378,6 +385,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         } else if (rankBuilder != null) {
             throw new IllegalArgumentException("cannot serialize [rank] to version [" + out.getTransportVersion().toReleaseVersion() + "]");
         }
+        if (out.getTransportVersion().onOrAfter(TransportVersions.SKIP_INNER_HITS_SEARCH_SOURCE)) {
+            out.writeBoolean(skipInnerHits);
+        }
     }
 
     /**
@@ -1279,6 +1289,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         rewrittenBuilder.collapse = collapse;
         rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder;
         rewrittenBuilder.runtimeMappings = runtimeMappings;
+        rewrittenBuilder.skipInnerHits = skipInnerHits;
         return rewrittenBuilder;
     }
 
@@ -1855,6 +1866,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         if (false == runtimeMappings.isEmpty()) {
             builder.field(RUNTIME_MAPPINGS_FIELD.getPreferredName(), runtimeMappings);
         }
+        if (skipInnerHits) {
+            builder.field("skipInnerHits", true);
+        }
 
         return builder;
     }
@@ -1867,6 +1881,15 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         return builder;
     }
 
+    public SearchSourceBuilder skipInnerHits(boolean skipInnerHits) {
+        this.skipInnerHits = skipInnerHits;
+        return this;
+    }
+
+    public boolean skipInnerHits() {
+        return this.skipInnerHits;
+    }
+
     public static class IndexBoost implements Writeable, ToXContentObject {
         private final String index;
         private final float boost;
@@ -2121,7 +2144,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             collapse,
             trackTotalHitsUpTo,
             pointInTimeBuilder,
-            runtimeMappings
+            runtimeMappings,
+            skipInnerHits
         );
     }
 
@@ -2166,7 +2190,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             && Objects.equals(collapse, other.collapse)
             && Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo)
             && Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder)
-            && Objects.equals(runtimeMappings, other.runtimeMappings);
+            && Objects.equals(runtimeMappings, other.runtimeMappings)
+            && Objects.equals(skipInnerHits, other.skipInnerHits);
     }
 
     @Override

+ 6 - 1
server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

@@ -236,7 +236,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
         return Objects.hash(innerRetrievers);
     }
 
-    protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
+    protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
         var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
             .trackTotalHits(false)
             .storedFields(new StoredFieldsContext(false))
@@ -254,6 +254,11 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
         }
         sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME));
         sourceBuilder.sort(sortBuilders);
+        sourceBuilder.skipInnerHits(true);
+        return finalizeSourceBuilder(sourceBuilder);
+    }
+
+    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
         return sourceBuilder;
     }
 

+ 1 - 1
server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

@@ -15,8 +15,8 @@ import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
 import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.search.vectors.QueryVectorBuilder;

+ 1 - 1
server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

@@ -12,9 +12,9 @@ package org.elasticsearch.search.retriever;
 import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.RankDoc;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;

+ 1 - 1
server/src/main/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQuery.java

@@ -283,7 +283,7 @@ public class RankDocsQuery extends Query {
         return starts;
     }
 
-    RankDoc[] rankDocs() {
+    public RankDoc[] rankDocs() {
         return docs;
     }
 

+ 2 - 3
server/src/test/java/org/elasticsearch/search/retriever/rankdoc/RankDocsQueryBuilderTests.java → server/src/test/java/org/elasticsearch/index/query/RankDocsQueryBuilderTests.java

@@ -7,7 +7,7 @@
  * License v3.0 only", or the "Server Side Public License, v 1".
  */
 
-package org.elasticsearch.search.retriever.rankdoc;
+package org.elasticsearch.index.query;
 
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.NumericDocValuesField;
@@ -22,9 +22,8 @@ import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopScoreDocCollectorManager;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
-import org.elasticsearch.index.query.QueryBuilder;
-import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.search.rank.RankDoc;
+import org.elasticsearch.search.retriever.rankdoc.RankDocsQuery;
 import org.elasticsearch.test.AbstractQueryTestCase;
 
 import java.io.IOException;

+ 1 - 1
server/src/test/java/org/elasticsearch/search/rank/AbstractRankDocWireSerializingTestCase.java

@@ -12,8 +12,8 @@ package org.elasticsearch.search.rank;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.search.SearchModule;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 
 import java.io.IOException;

+ 1 - 1
server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

@@ -17,11 +17,11 @@ import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.RandomQueryBuilder;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.RankDoc;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.usage.SearchUsage;
 import org.elasticsearch.xcontent.NamedXContentRegistry;

+ 1 - 1
server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

@@ -13,11 +13,11 @@ import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.RandomQueryBuilder;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.RankDoc;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.test.ESTestCase;
 
 import java.io.IOException;

+ 5 - 7
x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java

@@ -11,15 +11,14 @@ import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.RankDocsQueryBuilder;
 import org.elasticsearch.license.LicenseUtils;
-import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
-import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
 import org.elasticsearch.search.sort.ScoreSortBuilder;
 import org.elasticsearch.search.sort.SortBuilder;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
@@ -129,11 +128,10 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
     }
 
     @Override
-    protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
-        var ret = super.createSearchSourceBuilder(pit, retrieverBuilder);
-        checkValidSort(ret.sorts());
-        ret.query(new RuleQueryBuilder(ret.query(), matchCriteria, rulesetIds));
-        return ret;
+    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
+        checkValidSort(source.sorts());
+        source.query(new RuleQueryBuilder(source.query(), matchCriteria, rulesetIds));
+        return source;
     }
 
     private static void checkValidSort(List<SortBuilder<?>> sortBuilders) {

+ 1 - 13
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -12,9 +12,7 @@ import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.license.LicenseUtils;
-import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
-import org.elasticsearch.search.fetch.StoredFieldsContext;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverBuilder;
@@ -157,17 +155,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     }
 
     @Override
-    protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
-        var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
-            .trackTotalHits(false)
-            .storedFields(new StoredFieldsContext(false))
-            .size(rankWindowSize);
-        // apply the pre-filters downstream once
-        if (preFilterQueryBuilders.isEmpty() == false) {
-            retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
-        }
-        retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
-
+    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
         sourceBuilder.rankBuilder(
             new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
         );

+ 111 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml

@@ -35,6 +35,16 @@ setup:
                 properties:
                   views:
                     type: long
+              nested_inner_hits:
+                type: nested
+                properties:
+                  data:
+                    type: keyword
+                  paragraph_id:
+                    type: dense_vector
+                    dims: 1
+                    index: true
+                    similarity: l2_norm
 
   - do:
       index:
@@ -125,6 +135,16 @@ setup:
           integer: 2
           keyword: "technology"
           nested: { views: 10}
+          nested_inner_hits: [{"data": "foo"}, {"data": "bar"}, {"data": "baz"}]
+
+  - do:
+      index:
+        index: test
+        id: "10"
+        body:
+          id: 10
+          integer: 3
+          nested_inner_hits: [ {"data": "foo", "paragraph_id": [1]}]
   - do:
       indices.refresh: {}
 
@@ -960,3 +980,94 @@ setup:
   - length: { hits.hits : 1 }
 
   - match: { hits.hits.0._id: "1" }
+
+---
+"rrf retriever with inner_hits for sub-retriever":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ nested_retriever_inner_hits_support ]
+      test_runner_features: capabilities
+      reason: "Support for propagating nested retrievers' inner hits to the top-level compound retriever is required"
+
+  - do:
+      search:
+        _source: false
+        index: test
+        body:
+          retriever:
+            rrf:
+              retrievers: [
+                {
+                  # this will return doc 9 and doc 10
+                  standard: {
+                    query: {
+                      nested: {
+                        path: nested_inner_hits,
+                        inner_hits: {
+                          name: nested_data_field,
+                          _source: false,
+                          "sort": [ {
+                              "nested_inner_hits.data": "asc"
+                            }
+                          ],
+                          fields: [ nested_inner_hits.data ]
+                        },
+                        query: {
+                          match_all: { }
+                        }
+                      }
+                    }
+                  }
+                },
+                {
+                  # this will return doc 10
+                  standard: {
+                    query: {
+                      nested: {
+                        path: nested_inner_hits,
+                        inner_hits: {
+                          name: nested_vector_field,
+                          _source: false,
+                          size: 1,
+                          "fields": [ "nested_inner_hits.paragraph_id" ]
+                        },
+                        query: {
+                          knn: {
+                            field: nested_inner_hits.paragraph_id,
+                            query_vector: [ 1 ],
+                            num_candidates: 10
+                          }
+                        }
+                      }
+                    }
+                  }
+                },
+                {
+                  standard: {
+                    query: {
+                      match_all: { }
+                    }
+                  }
+                }
+              ]
+              rank_window_size: 10
+              rank_constant: 10
+          size: 3
+
+  - match: { hits.total.value: 10 }
+
+  - match: { hits.hits.0.inner_hits.nested_data_field.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: foo }
+  - match: { hits.hits.0.inner_hits.nested_vector_field.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested_vector_field.hits.hits.0.fields.nested_inner_hits.0.paragraph_id: [ 1 ] }
+
+  - match: { hits.hits.1.inner_hits.nested_data_field.hits.total.value: 3 }
+  - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.0.fields.nested_inner_hits.0.data.0: bar }
+  - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.1.fields.nested_inner_hits.0.data.0: baz }
+  - match: { hits.hits.1.inner_hits.nested_data_field.hits.hits.2.fields.nested_inner_hits.0.data.0: foo }
+  - match: { hits.hits.1.inner_hits.nested_vector_field.hits.total.value: 0 }
+
+  - match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
+  - match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }