Browse Source

Move Retriever Handling to Rewrite Phase (#110641)

This change moves the handling of the retriever to the rewrite phase. It also adds validation of the search source builder after extracting the retriever into the source builder.

Relates #110482
Jim Ferenczi 1 year ago
parent
commit
20071da493

+ 4 - 127
server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

@@ -17,20 +17,16 @@ import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.mapper.SourceLoader;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.Scroll;
-import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.internal.SearchContext;
-import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.search.sort.FieldSortBuilder;
-import org.elasticsearch.search.sort.ShardDocSortField;
 import org.elasticsearch.search.sort.SortBuilder;
 import org.elasticsearch.search.sort.SortBuilders;
 import org.elasticsearch.tasks.TaskId;
@@ -324,124 +320,15 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
     public ActionRequestValidationException validate() {
         ActionRequestValidationException validationException = null;
         boolean scroll = scroll() != null;
+
+        if (source != null) {
+            validationException = source.validate(validationException, scroll);
+        }
         if (scroll) {
-            if (source != null) {
-                if (source.trackTotalHitsUpTo() != null && source.trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
-                    validationException = addValidationError(
-                        "disabling [track_total_hits] is not allowed in a scroll context",
-                        validationException
-                    );
-                }
-                if (source.from() > 0) {
-                    validationException = addValidationError("using [from] is not allowed in a scroll context", validationException);
-                }
-                if (source.size() == 0) {
-                    validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException);
-                }
-                if (source.rescores() != null && source.rescores().isEmpty() == false) {
-                    validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException);
-                }
-                if (CollectionUtils.isEmpty(source.searchAfter()) == false) {
-                    validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException);
-                }
-                if (source.collapse() != null) {
-                    validationException = addValidationError("cannot use `collapse` in a scroll context", validationException);
-                }
-            }
             if (requestCache != null && requestCache) {
                 validationException = addValidationError("[request_cache] cannot be used in a scroll context", validationException);
             }
         }
-        if (source != null) {
-            if (source.slice() != null) {
-                if (source.pointInTimeBuilder() == null && (scroll == false)) {
-                    validationException = addValidationError(
-                        "[slice] can only be used with [scroll] or [point-in-time] requests",
-                        validationException
-                    );
-                }
-            }
-            if (source.from() > 0 && CollectionUtils.isEmpty(source.searchAfter()) == false) {
-                validationException = addValidationError(
-                    "[from] parameter must be set to 0 when [search_after] is used",
-                    validationException
-                );
-            }
-            if (source.storedFields() != null) {
-                if (source.storedFields().fetchFields() == false) {
-                    if (source.fetchSource() != null && source.fetchSource().fetchSource()) {
-                        validationException = addValidationError(
-                            "[stored_fields] cannot be disabled if [_source] is requested",
-                            validationException
-                        );
-                    }
-                    if (source.fetchFields() != null) {
-                        validationException = addValidationError(
-                            "[stored_fields] cannot be disabled when using the [fields] option",
-                            validationException
-                        );
-                    }
-
-                }
-            }
-            if (source.subSearches().size() >= 2 && source.rankBuilder() == null) {
-                validationException = addValidationError("[sub_searches] requires [rank]", validationException);
-            }
-            if (source.aggregations() != null) {
-                validationException = source.aggregations().validate(validationException);
-            }
-            if (source.rankBuilder() != null) {
-                int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size();
-                if (size == 0) {
-                    validationException = addValidationError("[rank] requires [size] greater than [0]", validationException);
-                }
-                if (size > source.rankBuilder().rankWindowSize()) {
-                    validationException = addValidationError(
-                        "[rank] requires [rank_window_size: "
-                            + source.rankBuilder().rankWindowSize()
-                            + "]"
-                            + " be greater than or equal to [size: "
-                            + size
-                            + "]",
-                        validationException
-                    );
-                }
-                int queryCount = source.subSearches().size() + source.knnSearch().size();
-                if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) {
-                    validationException = addValidationError(
-                        "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches",
-                        validationException
-                    );
-                }
-                if (scroll) {
-                    validationException = addValidationError("[rank] cannot be used in a scroll context", validationException);
-                }
-                if (source.rescores() != null && source.rescores().isEmpty() == false) {
-                    validationException = addValidationError("[rank] cannot be used with [rescore]", validationException);
-                }
-                if (source.sorts() != null && source.sorts().isEmpty() == false) {
-                    validationException = addValidationError("[rank] cannot be used with [sort]", validationException);
-                }
-                if (source.collapse() != null) {
-                    validationException = addValidationError("[rank] cannot be used with [collapse]", validationException);
-                }
-                if (source.suggest() != null && source.suggest().getSuggestions().isEmpty() == false) {
-                    validationException = addValidationError("[rank] cannot be used with [suggest]", validationException);
-                }
-                if (source.highlighter() != null) {
-                    validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException);
-                }
-                if (source.pointInTimeBuilder() != null) {
-                    validationException = addValidationError("[rank] cannot be used with [point in time]", validationException);
-                }
-            }
-            if (source.rescores() != null) {
-                for (@SuppressWarnings("rawtypes")
-                RescorerBuilder rescoreBuilder : source.rescores()) {
-                    validationException = rescoreBuilder.validate(this, validationException);
-                }
-            }
-        }
         if (pointInTimeBuilder() != null) {
             if (scroll) {
                 validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException);
@@ -461,16 +348,6 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
             if (preference() != null) {
                 validationException = addValidationError("[preference] cannot be used with point in time", validationException);
             }
-        } else if (source != null && source.sorts() != null) {
-            for (SortBuilder<?> sortBuilder : source.sorts()) {
-                if (sortBuilder instanceof FieldSortBuilder
-                    && ShardDocSortField.NAME.equals(((FieldSortBuilder) sortBuilder).getFieldName())) {
-                    validationException = addValidationError(
-                        "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]",
-                        validationException
-                    );
-                }
-            }
         }
         if (minCompatibleShardNode() != null) {
             if (isCcsMinimizeRoundtrips()) {

+ 211 - 37
server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

@@ -10,12 +10,16 @@ package org.elasticsearch.search.builder;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionRequestValidationException;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.logging.DeprecationLogger;
+import org.elasticsearch.common.util.CollectionUtils;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.core.Booleans;
 import org.elasticsearch.core.Nullable;
@@ -28,6 +32,7 @@ import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.search.SearchExtBuilder;
+import org.elasticsearch.search.SearchService;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.AggregatorFactories;
 import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
@@ -43,7 +48,9 @@ import org.elasticsearch.search.retriever.RetrieverBuilder;
 import org.elasticsearch.search.retriever.RetrieverParserContext;
 import org.elasticsearch.search.searchafter.SearchAfterBuilder;
 import org.elasticsearch.search.slice.SliceBuilder;
+import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.search.sort.ScoreSortBuilder;
+import org.elasticsearch.search.sort.ShardDocSortField;
 import org.elasticsearch.search.sort.SortBuilder;
 import org.elasticsearch.search.sort.SortBuilders;
 import org.elasticsearch.search.sort.SortOrder;
@@ -71,6 +78,7 @@ import java.util.function.ToLongFunction;
 import java.util.stream.Collectors;
 
 import static java.util.Collections.emptyMap;
+import static org.elasticsearch.action.ValidateActions.addValidationError;
 import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery;
 import static org.elasticsearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;
 import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_ACCURATE;
@@ -78,10 +86,9 @@ import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_D
 
 /**
  * A search source builder allowing to easily build search source. Simple
- * construction using
- * {@link org.elasticsearch.search.builder.SearchSourceBuilder#searchSource()}.
+ * construction using {@link SearchSourceBuilder#searchSource()}.
  *
- * @see org.elasticsearch.action.search.SearchRequest#source(SearchSourceBuilder)
+ * @see SearchRequest#source(SearchSourceBuilder)
  */
 public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable<SearchSourceBuilder> {
     private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class);
@@ -141,6 +148,8 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         return new HighlightBuilder();
     }
 
+    private transient RetrieverBuilder retrieverBuilder;
+
     private List<SubSearchSourceBuilder> subSearchSourceBuilders = new ArrayList<>();
 
     private QueryBuilder postQueryBuilder;
@@ -283,6 +292,9 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        if (retrieverBuilder != null) {
+            throw new IllegalStateException("SearchSourceBuilder should be rewritten first");
+        }
         out.writeOptionalWriteable(aggregations);
         out.writeOptionalBoolean(explain);
         out.writeOptionalWriteable(fetchSourceContext);
@@ -367,6 +379,18 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         }
     }
 
+    /**
+     * Sets the retriever for this request.
+     */
+    public SearchSourceBuilder retriever(RetrieverBuilder retrieverBuilder) {
+        this.retrieverBuilder = retrieverBuilder;
+        return this;
+    }
+
+    public RetrieverBuilder retriever() {
+        return retrieverBuilder;
+    }
+
     /**
      * Sets the query for this request.
      */
@@ -1134,6 +1158,21 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
                 highlightBuilder
             )
         ));
+        if (retrieverBuilder != null) {
+            var newRetriever = retrieverBuilder.rewrite(context);
+            if (newRetriever != retrieverBuilder) {
+                var rewritten = shallowCopy();
+                rewritten.retrieverBuilder = newRetriever;
+                return rewritten;
+            } else {
+                // retriever is transient, the rewritten version is extracted in this source.
+                var retriever = retrieverBuilder;
+                retrieverBuilder = null;
+                retriever.extractToSearchSourceBuilder(this, false);
+                validate();
+            }
+        }
+
         List<SubSearchSourceBuilder> subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context);
         QueryBuilder postQueryBuilder = null;
         if (this.postQueryBuilder != null) {
@@ -1293,7 +1332,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         }
         List<KnnSearchBuilder.Builder> knnBuilders = new ArrayList<>();
 
-        RetrieverBuilder retrieverBuilder = null;
         SearchUsage searchUsage = new SearchUsage();
         while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
             if (token == XContentParser.Token.FIELD_NAME) {
@@ -1627,39 +1665,6 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         }
 
         knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList());
-
-        if (retrieverBuilder != null) {
-            List<String> specified = new ArrayList<>();
-            if (subSearchSourceBuilders.isEmpty() == false) {
-                specified.add(QUERY_FIELD.getPreferredName());
-            }
-            if (knnSearch.isEmpty() == false) {
-                specified.add(KNN_FIELD.getPreferredName());
-            }
-            if (searchAfterBuilder != null) {
-                specified.add(SEARCH_AFTER.getPreferredName());
-            }
-            if (terminateAfter != DEFAULT_TERMINATE_AFTER) {
-                specified.add(TERMINATE_AFTER_FIELD.getPreferredName());
-            }
-            if (sorts != null) {
-                specified.add(SORT_FIELD.getPreferredName());
-            }
-            if (rescoreBuilders != null) {
-                specified.add(RESCORE_FIELD.getPreferredName());
-            }
-            if (minScore != null) {
-                specified.add(MIN_SCORE_FIELD.getPreferredName());
-            }
-            if (rankBuilder != null) {
-                specified.add(RANK_FIELD.getPreferredName());
-            }
-            if (specified.isEmpty() == false) {
-                throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified);
-            }
-            retrieverBuilder.extractToSearchSourceBuilder(this, false);
-        }
-
         searchUsageConsumer.accept(searchUsage);
         return this;
     }
@@ -1689,6 +1694,10 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             builder.field(TERMINATE_AFTER_FIELD.getPreferredName(), terminateAfter);
         }
 
+        if (retrieverBuilder != null) {
+            builder.field(RETRIEVER.getPreferredName(), retrieverBuilder);
+        }
+
         if (subSearchSourceBuilders.isEmpty() == false) {
             if (subSearchSourceBuilders.size() == 1) {
                 builder.field(QUERY_FIELD.getPreferredName(), subSearchSourceBuilders.get(0).getQueryBuilder());
@@ -2183,4 +2192,169 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
 
         return collapse == null && (aggregations == null || aggregations.supportsParallelCollection(fieldCardinality));
     }
+
+    private void validate() throws ValidationException {
+        var exceptions = validate(null, false);
+        if (exceptions != null) {
+            throw exceptions;
+        }
+    }
+
+    public ActionRequestValidationException validate(ActionRequestValidationException validationException, boolean isScroll) {
+        if (retriever() != null) {
+            List<String> specified = new ArrayList<>();
+            if (subSearches().isEmpty() == false) {
+                specified.add(QUERY_FIELD.getPreferredName());
+            }
+            if (knnSearch().isEmpty() == false) {
+                specified.add(KNN_FIELD.getPreferredName());
+            }
+            if (searchAfter() != null) {
+                specified.add(SEARCH_AFTER.getPreferredName());
+            }
+            if (terminateAfter() != DEFAULT_TERMINATE_AFTER) {
+                specified.add(TERMINATE_AFTER_FIELD.getPreferredName());
+            }
+            if (sorts() != null) {
+                specified.add(SORT_FIELD.getPreferredName());
+            }
+            if (rescores() != null) {
+                specified.add(RESCORE_FIELD.getPreferredName());
+            }
+            if (minScore() != null) {
+                specified.add(MIN_SCORE_FIELD.getPreferredName());
+            }
+            if (rankBuilder() != null) {
+                specified.add(RANK_FIELD.getPreferredName());
+            }
+            if (specified.isEmpty() == false) {
+                validationException = addValidationError(
+                    "cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified,
+                    validationException
+                );
+            }
+        }
+        if (isScroll) {
+            if (trackTotalHitsUpTo() != null && trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
+                validationException = addValidationError(
+                    "disabling [track_total_hits] is not allowed in a scroll context",
+                    validationException
+                );
+            }
+            if (from() > 0) {
+                validationException = addValidationError("using [from] is not allowed in a scroll context", validationException);
+            }
+            if (size() == 0) {
+                validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException);
+            }
+            if (rescores() != null && rescores().isEmpty() == false) {
+                validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException);
+            }
+            if (CollectionUtils.isEmpty(searchAfter()) == false) {
+                validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException);
+            }
+            if (collapse() != null) {
+                validationException = addValidationError("cannot use `collapse` in a scroll context", validationException);
+            }
+        }
+        if (slice() != null) {
+            if (pointInTimeBuilder() == null && (isScroll == false)) {
+                validationException = addValidationError(
+                    "[slice] can only be used with [scroll] or [point-in-time] requests",
+                    validationException
+                );
+            }
+        }
+        if (from() > 0 && CollectionUtils.isEmpty(searchAfter()) == false) {
+            validationException = addValidationError("[from] parameter must be set to 0 when [search_after] is used", validationException);
+        }
+        if (storedFields() != null) {
+            if (storedFields().fetchFields() == false) {
+                if (fetchSource() != null && fetchSource().fetchSource()) {
+                    validationException = addValidationError(
+                        "[stored_fields] cannot be disabled if [_source] is requested",
+                        validationException
+                    );
+                }
+                if (fetchFields() != null) {
+                    validationException = addValidationError(
+                        "[stored_fields] cannot be disabled when using the [fields] option",
+                        validationException
+                    );
+                }
+            }
+        }
+        if (subSearches().size() >= 2 && rankBuilder() == null) {
+            validationException = addValidationError("[sub_searches] requires [rank]", validationException);
+        }
+        if (aggregations() != null) {
+            validationException = aggregations().validate(validationException);
+        }
+
+        if (rankBuilder() != null) {
+            int s = size() == -1 ? SearchService.DEFAULT_SIZE : size();
+            if (s == 0) {
+                validationException = addValidationError("[rank] requires [size] greater than [0]", validationException);
+            }
+            if (s > rankBuilder().rankWindowSize()) {
+                validationException = addValidationError(
+                    "[rank] requires [rank_window_size: "
+                        + rankBuilder().rankWindowSize()
+                        + "]"
+                        + " be greater than or equal to [size: "
+                        + s
+                        + "]",
+                    validationException
+                );
+            }
+            int queryCount = subSearches().size() + knnSearch().size();
+            if (rankBuilder().isCompoundBuilder() && queryCount < 2) {
+                validationException = addValidationError(
+                    "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches",
+                    validationException
+                );
+            }
+            if (isScroll) {
+                validationException = addValidationError("[rank] cannot be used in a scroll context", validationException);
+            }
+            if (rescores() != null && rescores().isEmpty() == false) {
+                validationException = addValidationError("[rank] cannot be used with [rescore]", validationException);
+            }
+            if (sorts() != null && sorts().isEmpty() == false) {
+                validationException = addValidationError("[rank] cannot be used with [sort]", validationException);
+            }
+            if (collapse() != null) {
+                validationException = addValidationError("[rank] cannot be used with [collapse]", validationException);
+            }
+            if (suggest() != null && suggest().getSuggestions().isEmpty() == false) {
+                validationException = addValidationError("[rank] cannot be used with [suggest]", validationException);
+            }
+            if (highlighter() != null) {
+                validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException);
+            }
+            if (pointInTimeBuilder() != null) {
+                validationException = addValidationError("[rank] cannot be used with [point in time]", validationException);
+            }
+        }
+
+        if (rescores() != null) {
+            for (@SuppressWarnings("rawtypes")
+            var rescorer : rescores()) {
+                validationException = rescorer.validate(this, validationException);
+            }
+        }
+
+        if (pointInTimeBuilder() == null && sorts() != null) {
+            for (var sortBuilder : sorts()) {
+                if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder
+                    && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) {
+                    validationException = addValidationError(
+                        "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]",
+                        validationException
+                    );
+                }
+            }
+        }
+        return validationException;
+    }
 }

+ 2 - 2
server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java

@@ -9,7 +9,6 @@
 package org.elasticsearch.search.rescore;
 
 import org.elasticsearch.action.ActionRequestValidationException;
-import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -17,6 +16,7 @@ import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -120,7 +120,7 @@ public abstract class RescorerBuilder<RB extends RescorerBuilder<RB>>
         return builder;
     }
 
-    public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) {
+    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) {
         return validationException;
     }
 

+ 20 - 5
server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java

@@ -14,6 +14,8 @@ import org.elasticsearch.common.xcontent.SuggestingErrorOnUnknown;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryRewriteContext;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.xcontent.AbstractObjectParser;
 import org.elasticsearch.xcontent.FilterXContentParserWrapper;
@@ -33,16 +35,17 @@ import java.util.Objects;
 /**
  * A retriever represents an API element that returns an ordered list of top
  * documents. These can be obtained from a query, from another retriever, etc.
- * Internally, a {@link RetrieverBuilder} is just a wrapper for other search
- * elements that are extracted into a {@link SearchSourceBuilder}. The advantage
- * retrievers have is in the API they appear as a tree-like structure enabling
+ * Internally, a {@link RetrieverBuilder} is first rewritten into its simplest
+ * form and then its elements are extracted into a {@link SearchSourceBuilder}.
+ *
+ * The advantage retrievers have is in the API they appear as a tree-like structure enabling
  * easier reasoning about what a search does.
  *
  * This is the base class for all other retrievers. This class does not support
  * serialization and is expected to be fully extracted to a {@link SearchSourceBuilder}
  * prior to any transport calls.
  */
-public abstract class RetrieverBuilder implements ToXContent {
+public abstract class RetrieverBuilder implements Rewriteable<RetrieverBuilder>, ToXContent {
 
     public static final NodeFeature RETRIEVERS_SUPPORTED = new NodeFeature("retrievers_supported");
 
@@ -181,6 +184,13 @@ public abstract class RetrieverBuilder implements ToXContent {
 
     protected String retrieverName;
 
+    /**
+     * Determines if this retriever contains sub-retrievers that need to be executed prior to search.
+     */
+    public boolean isCompound() {
+        return false;
+    }
+
     /**
      * Gets the filters for this retriever.
      */
@@ -188,8 +198,13 @@ public abstract class RetrieverBuilder implements ToXContent {
         return preFilterQueryBuilders;
     }
 
+    @Override
+    public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
+        return this;
+    }
+
     /**
-     * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}.
+     * This method is called at the end of rewriting on behalf of a {@link SearchSourceBuilder}.
      * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}.
      */
     public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed);

+ 4 - 0
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

@@ -298,6 +298,10 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         return field;
     }
 
+    public List<QueryBuilder> getFilterQueries() {
+        return filterQueries;
+    }
+
     public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) {
         Objects.requireNonNull(filterQuery);
         this.filterQueries.add(filterQuery);

+ 33 - 1
server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java

@@ -10,9 +10,14 @@ package org.elasticsearch.search.retriever;
 
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.RandomQueryBuilder;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.AbstractXContentTestCase;
 import org.elasticsearch.usage.SearchUsage;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -23,6 +28,10 @@ import java.util.ArrayList;
 import java.util.List;
 
 import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
 
 public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase<KnnRetrieverBuilder> {
 
@@ -34,7 +43,7 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase<Kn
     public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() {
         String field = randomAlphaOfLength(6);
         int dim = randomIntBetween(2, 30);
-        float[] vector = randomBoolean() ? null : randomVector(dim);
+        float[] vector = randomVector(dim);
         int k = randomIntBetween(1, 100);
         int numCands = randomIntBetween(k + 20, 1000);
         Float similarity = randomBoolean() ? null : randomFloat();
@@ -70,6 +79,29 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase<Kn
         );
     }
 
+    public void testRewrite() throws IOException {
+        for (int i = 0; i < 10; i++) {
+            KnnRetrieverBuilder knnRetriever = createRandomKnnRetrieverBuilder();
+            SearchSourceBuilder source = new SearchSourceBuilder().retriever(knnRetriever);
+            QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
+            source = Rewriteable.rewrite(source, queryRewriteContext);
+            assertNull(source.retriever());
+            assertNull(source.query());
+            assertThat(source.knnSearch().size(), equalTo(1));
+            assertThat(source.knnSearch().get(0).getFilterQueries().size(), equalTo(knnRetriever.preFilterQueryBuilders.size()));
+            for (int j = 0; j < knnRetriever.preFilterQueryBuilders.size(); j++) {
+                assertThat(
+                    source.knnSearch().get(0).getFilterQueries().get(j),
+                    anyOf(
+                        instanceOf(MatchAllQueryBuilder.class),
+                        instanceOf(MatchNoneQueryBuilder.class),
+                        equalTo(knnRetriever.preFilterQueryBuilders.get(j))
+                    )
+                );
+            }
+        }
+    }
+
     @Override
     protected boolean supportsUnknownFields() {
         return false;

+ 36 - 16
server/src/test/java/org/elasticsearch/search/retriever/RetrieverBuilderErrorTests.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.search.retriever;
 
+import org.elasticsearch.action.ActionRequestValidationException;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.search.SearchModule;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -19,6 +20,8 @@ import org.elasticsearch.xcontent.json.JsonXContent;
 import java.io.IOException;
 import java.util.List;
 
+import static org.hamcrest.Matchers.containsString;
+
 /**
  * Tests exceptions related to usage of restricted global values with a retriever.
  */
@@ -32,8 +35,10 @@ public class RetrieverBuilderErrorTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [query]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query]"));
         }
 
         try (
@@ -44,26 +49,35 @@ public class RetrieverBuilderErrorTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [knn]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [knn]"));
         }
 
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"search_after\": [1], \"retriever\":{\"standard\":{}}}")) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [search_after]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [search_after]"));
+
         }
 
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"terminate_after\": 1, \"retriever\":{\"standard\":{}}}")) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [terminate_after]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [terminate_after]"));
         }
 
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"sort\": [\"field\"], \"retriever\":{\"standard\":{}}}")) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [sort]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [sort]"));
         }
 
         try (
@@ -73,14 +87,18 @@ public class RetrieverBuilderErrorTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [rescore]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [rescore]"));
         }
 
         try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"min_score\": 2, \"retriever\":{\"standard\":{}}}")) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [min_score]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [min_score]"));
         }
 
         try (
@@ -90,8 +108,10 @@ public class RetrieverBuilderErrorTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
-            assertEquals("cannot specify [retriever] and [query, terminate_after, min_score]", iae.getMessage());
+            ssb.parseXContent(parser, true, nf -> true);
+            ActionRequestValidationException iae = ssb.validate(null, false);
+            assertNotNull(iae);
+            assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query, terminate_after, min_score]"));
         }
     }
 

+ 61 - 1
server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java

@@ -11,8 +11,15 @@ package org.elasticsearch.search.retriever;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.query.BoolQueryBuilder;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.index.query.MatchNoneQueryBuilder;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.RandomQueryBuilder;
+import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.search.SearchModule;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.collapse.CollapseBuilderTests;
 import org.elasticsearch.search.searchafter.SearchAfterBuilderTests;
 import org.elasticsearch.search.sort.SortBuilderTests;
@@ -27,6 +34,11 @@ import java.io.UncheckedIOException;
 import java.util.List;
 import java.util.function.BiFunction;
 
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.Mockito.mock;
+
 public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCase<StandardRetrieverBuilder> {
 
     /**
@@ -59,7 +71,7 @@ public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCa
             }
 
             if (randomBoolean()) {
-                standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList();
+                standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(false);
             }
 
             if (randomBoolean()) {
@@ -109,4 +121,52 @@ public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCa
     protected NamedXContentRegistry xContentRegistry() {
         return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents());
     }
+
+    public void testRewrite() throws IOException {
+        for (int i = 0; i < 10; i++) {
+            StandardRetrieverBuilder standardRetriever = createTestInstance();
+            SearchSourceBuilder source = new SearchSourceBuilder().retriever(standardRetriever);
+            QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
+            source = Rewriteable.rewrite(source, queryRewriteContext);
+            assertNull(source.retriever());
+            assertTrue(source.knnSearch().isEmpty());
+            if (standardRetriever.queryBuilder != null) {
+                assertNotNull(source.query());
+                if (standardRetriever.preFilterQueryBuilders.size() > 0) {
+                    if (source.query() instanceof MatchAllQueryBuilder == false
+                        && source.query() instanceof MatchNoneQueryBuilder == false) {
+                        assertThat(source.query(), instanceOf(BoolQueryBuilder.class));
+                        BoolQueryBuilder bq = (BoolQueryBuilder) source.query();
+                        assertFalse(bq.must().isEmpty());
+                        assertThat(bq.must().size(), equalTo(1));
+                        assertThat(bq.must().get(0), equalTo(standardRetriever.queryBuilder));
+                        for (int j = 0; j < bq.filter().size(); j++) {
+                            assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j));
+                        }
+                    }
+                } else {
+                    assertEqualQueryOrMatchAllNone(source.query(), standardRetriever.queryBuilder);
+                }
+            } else if (standardRetriever.preFilterQueryBuilders.size() > 0) {
+                if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) {
+                    assertNotNull(source.query());
+                    assertThat(source.query(), instanceOf(BoolQueryBuilder.class));
+                    BoolQueryBuilder bq = (BoolQueryBuilder) source.query();
+                    assertTrue(bq.must().isEmpty());
+                    for (int j = 0; j < bq.filter().size(); j++) {
+                        assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j));
+                    }
+                }
+            } else {
+                assertNull(source.query());
+            }
+            if (standardRetriever.sortBuilders != null) {
+                assertThat(source.sorts().size(), equalTo(standardRetriever.sortBuilders.size()));
+            }
+        }
+    }
+
+    private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) {
+        assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected)));
+    }
 }

+ 3 - 3
server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java

@@ -119,7 +119,7 @@ public class SortBuilderTests extends ESTestCase {
     public void testRandomSortBuilders() throws IOException {
         for (int runs = 0; runs < NUMBER_OF_RUNS; runs++) {
             Set<String> expectedWarningHeaders = new HashSet<>();
-            List<SortBuilder<?>> testBuilders = randomSortBuilderList();
+            List<SortBuilder<?>> testBuilders = randomSortBuilderList(randomBoolean());
             XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
             xContentBuilder.startObject();
             if (testBuilders.size() > 1) {
@@ -171,7 +171,7 @@ public class SortBuilderTests extends ESTestCase {
         }
     }
 
-    public static List<SortBuilder<?>> randomSortBuilderList() {
+    public static List<SortBuilder<?>> randomSortBuilderList(boolean hasPIT) {
         int size = randomIntBetween(1, 5);
         List<SortBuilder<?>> list = new ArrayList<>(size);
         for (int i = 0; i < size; i++) {
@@ -181,7 +181,7 @@ public class SortBuilderTests extends ESTestCase {
                 case 2 -> SortBuilders.fieldSort(FieldSortBuilder.DOC_FIELD_NAME);
                 case 3 -> GeoDistanceSortBuilderTests.randomGeoDistanceSortBuilder();
                 case 4 -> ScriptSortBuilderTests.randomScriptSortBuilder();
-                case 5 -> SortBuilders.pitTiebreaker();
+                case 5 -> hasPIT ? SortBuilders.pitTiebreaker() : ScriptSortBuilderTests.randomScriptSortBuilder();
                 default -> throw new IllegalStateException("unexpected randomization in tests");
             });
         }

+ 5 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java

@@ -11,12 +11,12 @@ import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionRequestValidationException;
-import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.Rewriteable;
 import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rescore.RescorerBuilder;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
@@ -134,10 +134,10 @@ public class LearningToRankRescorerBuilder extends RescorerBuilder<LearningToRan
     }
 
     @Override
-    public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) {
-        validationException = super.validate(searchRequest, validationException);
+    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) {
+        validationException = super.validate(source, validationException);
 
-        int searchRequestPaginationSize = searchRequest.source().from() + searchRequest.source().size();
+        int searchRequestPaginationSize = source.from() + source.size();
 
         if (windowSize() < searchRequestPaginationSize) {
             return addValidationError(
@@ -151,7 +151,7 @@ public class LearningToRankRescorerBuilder extends RescorerBuilder<LearningToRan
         }
 
         @SuppressWarnings("rawtypes")
-        List<RescorerBuilder> rescorers = searchRequest.source().rescores();
+        List<RescorerBuilder> rescorers = source.rescores();
         assert rescorers != null && rescorers.contains(this);
 
         for (int i = rescorers.indexOf(this) + 1; i < rescorers.size(); i++) {

+ 28 - 7
x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

@@ -48,7 +48,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[search_after] cannot be used in children of compound retrievers", iae.getMessage());
         }
 
@@ -60,7 +63,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[terminate_after] cannot be used in children of compound retrievers", iae.getMessage());
         }
 
@@ -71,7 +77,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[sort] cannot be used in children of compound retrievers", iae.getMessage());
         }
 
@@ -82,7 +91,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[min_score] cannot be used in children of compound retrievers", iae.getMessage());
         }
 
@@ -94,7 +106,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[collapse] cannot be used in children of compound retrievers", iae.getMessage());
         }
 
@@ -105,7 +120,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[rank] cannot be used in children of compound retrievers", iae.getMessage());
         }
     }
@@ -119,7 +137,10 @@ public class RRFRetrieverBuilderTests extends ESTestCase {
             )
         ) {
             SearchSourceBuilder ssb = new SearchSourceBuilder();
-            IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true));
+            IllegalArgumentException iae = expectThrows(
+                IllegalArgumentException.class,
+                () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null)
+            );
             assertEquals("[1:65] [rrf] failed to parse field [retrievers]", iae.getMessage());
             assertEquals(
                 "the nested depth of the [standard] retriever exceeds the maximum nested depth [2] for retrievers",