Browse Source

[8.x] Fix for propagating filters from compound to inner retrievers (#117914) (#118046)

* Fix for propagating filters from compound to inner retrievers (#117914)

* Update RRFRetrieverBuilderIT.java
Panagiotis Bailis 10 months ago
parent
commit
3e57a57b28

+ 5 - 0
docs/changelog/117914.yaml

@@ -0,0 +1,5 @@
+pr: 117914
+summary: Fix for propagating filters from compound to inner retrievers
+area: Ranking
+type: bug
+issues: []

+ 21 - 11
server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java

@@ -20,6 +20,7 @@ import org.elasticsearch.action.search.MultiSearchResponse;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.search.TransportMultiSearchAction;
+import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.rest.RestStatus;
@@ -46,6 +47,8 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
  */
 public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>> extends RetrieverBuilder {
 
+    public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
+
     public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {}
 
     protected final int rankWindowSize;
@@ -64,9 +67,9 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
 
     /**
      * Returns a clone of the original retriever, replacing the sub-retrievers with
-     * the provided {@code newChildRetrievers}.
+     * the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}.
      */
-    protected abstract T clone(List<RetrieverSource> newChildRetrievers);
+    protected abstract T clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders);
 
     /**
      * Combines the provided {@code rankResults} to return the final top documents.
@@ -85,13 +88,25 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
         }
 
         // Rewrite prefilters
-        boolean hasChanged = false;
+        // We eagerly rewrite prefilters, because some of the innerRetrievers
+        // could be compound too, so we want to propagate all the necessary filter information to them
+        // and have it available as part of their own rewrite step
         var newPreFilters = rewritePreFilters(ctx);
-        hasChanged |= newPreFilters != preFilterQueryBuilders;
+        if (newPreFilters != preFilterQueryBuilders) {
+            return clone(innerRetrievers, newPreFilters);
+        }
 
+        boolean hasChanged = false;
         // Rewrite retriever sources
         List<RetrieverSource> newRetrievers = new ArrayList<>();
         for (var entry : innerRetrievers) {
+            // we propagate the filters only for compound retrievers as they won't be attached through
+            // the createSearchSourceBuilder.
+            // We could remove this check, but we would end up adding the same filters
+            // multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite
+            if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) {
+                entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
+            }
             RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx);
             if (newRetriever != entry.retriever) {
                 newRetrievers.add(new RetrieverSource(newRetriever, null));
@@ -106,7 +121,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
             }
         }
         if (hasChanged) {
-            return clone(newRetrievers);
+            return clone(newRetrievers, newPreFilters);
         }
 
         // execute searches
@@ -166,12 +181,7 @@ public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilde
             });
         });
 
-        return new RankDocsRetrieverBuilder(
-            rankWindowSize,
-            newRetrievers.stream().map(s -> s.retriever).toList(),
-            results::get,
-            newPreFilters
-        );
+        return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
     }
 
     @Override

+ 1 - 2
server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

@@ -184,8 +184,7 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
                     ll.onResponse(null);
                 }));
             });
-            var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null);
-            return rewritten;
+            return new KnnRetrieverBuilder(this, () -> toSet.get(), null);
         }
         return super.rewrite(ctx);
     }

+ 3 - 13
server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java

@@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
     final List<RetrieverBuilder> sources;
     final Supplier<RankDoc[]> rankDocs;
 
-    public RankDocsRetrieverBuilder(
-        int rankWindowSize,
-        List<RetrieverBuilder> sources,
-        Supplier<RankDoc[]> rankDocs,
-        List<QueryBuilder> preFilterQueryBuilders
-    ) {
+    public RankDocsRetrieverBuilder(int rankWindowSize, List<RetrieverBuilder> sources, Supplier<RankDoc[]> rankDocs) {
         this.rankWindowSize = rankWindowSize;
         this.rankDocs = rankDocs;
         if (sources == null || sources.isEmpty()) {
             throw new IllegalArgumentException("sources must not be null or empty");
         }
         this.sources = sources;
-        this.preFilterQueryBuilders = preFilterQueryBuilders;
     }
 
     @Override
@@ -73,10 +67,6 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
     @Override
     public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
         assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first";
-        var rewrittenFilters = rewritePreFilters(ctx);
-        if (rewrittenFilters != preFilterQueryBuilders) {
-            return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters);
-        }
         return this;
     }
 
@@ -94,7 +84,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
                 boolQuery.should(query);
             }
         }
-        // ignore prefilters of this level, they are already propagated to children
+        // ignore prefilters of this level, they were already propagated to children
         return boolQuery;
     }
 
@@ -133,7 +123,7 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder {
         } else {
             rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false);
         }
-        // ignore prefilters of this level, they are already propagated to children
+        // ignore prefilters of this level, they were already propagated to children
         searchSourceBuilder.query(rankQuery);
         if (sourceHasMinScore()) {
             searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore());

+ 1 - 6
server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java

@@ -95,12 +95,7 @@ public class RankDocsRetrieverBuilderTests extends ESTestCase {
     }
 
     private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException {
-        return new RankDocsRetrieverBuilder(
-            randomIntBetween(1, 100),
-            innerRetrievers(queryRewriteContext),
-            rankDocsSupplier(),
-            preFilters(queryRewriteContext)
-        );
+        return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier());
     }
 
     public void testExtractToSearchSourceBuilder() throws IOException {

+ 4 - 4
server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

@@ -27,9 +27,9 @@ import java.util.Objects;
 /**
  * A SearchPlugin to exercise query vector builder
  */
-class TestQueryVectorBuilderPlugin implements SearchPlugin {
+public class TestQueryVectorBuilderPlugin implements SearchPlugin {
 
-    static class TestQueryVectorBuilder implements QueryVectorBuilder {
+    public static class TestQueryVectorBuilder implements QueryVectorBuilder {
         private static final String NAME = "test_query_vector_builder";
 
         private static final ParseField QUERY_VECTOR = new ParseField("query_vector");
@@ -47,11 +47,11 @@ class TestQueryVectorBuilderPlugin implements SearchPlugin {
 
         private List<Float> vectorToBuild;
 
-        TestQueryVectorBuilder(List<Float> vectorToBuild) {
+        public TestQueryVectorBuilder(List<Float> vectorToBuild) {
             this.vectorToBuild = vectorToBuild;
         }
 
-        TestQueryVectorBuilder(float[] expected) {
+        public TestQueryVectorBuilder(float[] expected) {
             this.vectorToBuild = new ArrayList<>(expected.length);
             for (float f : expected) {
                 vectorToBuild.add(f);

+ 6 - 4
test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java

@@ -10,6 +10,7 @@
 package org.elasticsearch.search.retriever;
 
 import org.apache.lucene.search.ScoreDoc;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.search.rank.RankDoc;
 import org.elasticsearch.xcontent.XContentBuilder;
 
@@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder<TestC
     public static final String NAME = "test_compound_retriever_builder";
 
     public TestCompoundRetrieverBuilder(int rankWindowSize) {
-        this(new ArrayList<>(), rankWindowSize);
+        this(new ArrayList<>(), rankWindowSize, new ArrayList<>());
     }
 
-    TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize) {
+    TestCompoundRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, List<QueryBuilder> preFilterQueryBuilders) {
         super(childRetrievers, rankWindowSize);
+        this.preFilterQueryBuilders = preFilterQueryBuilders;
     }
 
     @Override
-    protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
-        return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize);
+    protected TestCompoundRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
+        return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders);
     }
 
     @Override

+ 12 - 3
x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java

@@ -110,12 +110,14 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
         Map<String, Object> matchCriteria,
         List<RetrieverSource> retrieverSource,
         int rankWindowSize,
-        String retrieverName
+        String retrieverName,
+        List<QueryBuilder> preFilterQueryBuilders
     ) {
         super(retrieverSource, rankWindowSize);
         this.rulesetIds = rulesetIds;
         this.matchCriteria = matchCriteria;
         this.retrieverName = retrieverName;
+        this.preFilterQueryBuilders = preFilterQueryBuilders;
     }
 
     @Override
@@ -156,8 +158,15 @@ public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<Qu
     }
 
     @Override
-    protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
-        return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName);
+    protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
+        return new QueryRuleRetrieverBuilder(
+            rulesetIds,
+            matchCriteria,
+            newChildRetrievers,
+            rankWindowSize,
+            retrieverName,
+            newPreFilterQueryBuilders
+        );
     }
 
     @Override

+ 5 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java

@@ -129,7 +129,10 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
     }
 
     @Override
-    protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
+    protected TextSimilarityRankRetrieverBuilder clone(
+        List<RetrieverSource> newChildRetrievers,
+        List<QueryBuilder> newPreFilterQueryBuilders
+    ) {
         return new TextSimilarityRankRetrieverBuilder(
             newChildRetrievers,
             inferenceId,
@@ -138,7 +141,7 @@ public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder
             rankWindowSize,
             minScore,
             retrieverName,
-            preFilterQueryBuilders
+            newPreFilterQueryBuilders
         );
     }
 

+ 37 - 1
x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

@@ -33,6 +33,7 @@ import org.elasticsearch.search.sort.FieldSortBuilder;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
 import org.elasticsearch.search.vectors.QueryVectorBuilder;
+import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin;
 import org.elasticsearch.test.ESIntegTestCase;
 import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -57,7 +58,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo;
 public class RRFRetrieverBuilderIT extends ESIntegTestCase {
 
     protected static String INDEX = "test_index";
-    protected static final String ID_FIELD = "_id";
     protected static final String DOC_FIELD = "doc";
     protected static final String TEXT_FIELD = "text";
     protected static final String VECTOR_FIELD = "vector";
@@ -743,6 +743,42 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase {
         expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get());
     }
 
+    public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() {
+        final int rankWindowSize = 100;
+        final int rankConstant = 10;
+        SearchSourceBuilder source = new SearchSourceBuilder();
+        // this will retriever all but 7 only due to top-level filter
+        StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery());
+        // this will too retrieve just doc 7
+        KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder(
+            "vector",
+            null,
+            new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }),
+            10,
+            10,
+            null
+        );
+        source.retriever(
+            new RRFRetrieverBuilder(
+                Arrays.asList(
+                    new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null),
+                    new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null)
+                ),
+                rankWindowSize,
+                rankConstant
+            )
+        );
+        source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7")));
+        source.size(10);
+        SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source);
+        ElasticsearchAssertions.assertResponse(req, resp -> {
+            assertNull(resp.pointInTimeId());
+            assertNotNull(resp.getHits().getTotalHits());
+            assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
+            assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7"));
+        });
+    }
+
     public void testRewriteOnce() {
         final float[] vector = new float[] { 1 };
         AtomicInteger numAsyncCalls = new AtomicInteger();

+ 6 - 0
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java

@@ -12,6 +12,7 @@ import org.elasticsearch.features.NodeFeature;
 
 import java.util.Set;
 
+import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT;
 import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED;
 
 /**
@@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification {
     public Set<NodeFeature> getFeatures() {
         return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED);
     }
+
+    @Override
+    public Set<NodeFeature> getTestFeatures() {
+        return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT);
+    }
 }

+ 5 - 2
x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

@@ -11,6 +11,7 @@ import org.apache.lucene.search.ScoreDoc;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.search.rank.RankBuilder;
 import org.elasticsearch.search.rank.RankDoc;
@@ -108,8 +109,10 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
     }
 
     @Override
-    protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers) {
-        return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
+    protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
+        RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant);
+        clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
+        return clone;
     }
 
     @Override

+ 74 - 0
x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml

@@ -1071,3 +1071,77 @@ setup:
 
   - match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 }
   - match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 }
+
+
+---
+"rrf retriever with filters to be passed to nested rrf retrievers":
+  - requires:
+      cluster_features: 'inner_retrievers_filter_support'
+      reason: 'requires fix for properly propagating filters to nested sub-retrievers'
+
+  - do:
+      search:
+        _source: false
+        index: test
+        body:
+          retriever:
+            {
+              rrf:
+                {
+                  filter: {
+                    term: {
+                      keyword: "technology"
+                    }
+                  },
+                  retrievers: [
+                    {
+                      rrf: {
+                        retrievers: [
+                          {
+                            # this should only return docs 3 and 5 due to top level filter
+                            standard: {
+                              query: {
+                                knn: {
+                                  field: vector,
+                                  query_vector: [ 4.0 ],
+                                  k: 3
+                                }
+                              }
+                            } },
+                          {
+                            # this should return no docs as no docs match both biology and technology
+                            standard: {
+                              query: {
+                                term: {
+                                  keyword: "biology"
+                                }
+                              }
+                            }
+                          }
+                        ],
+                        rank_window_size: 10,
+                        rank_constant: 10
+                      }
+                    },
+                    # this should only return doc 5
+                    {
+                      standard: {
+                        query: {
+                          term: {
+                            text: "term5"
+                          }
+                        }
+                      }
+                    }
+                  ],
+                  rank_window_size: 10,
+                  rank_constant: 10
+                }
+            }
+          size: 10
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "5" }
+  - match: { hits.hits.1._id: "3" }
+
+