Browse Source

Support kNN filter on nested metadata (#113949)

Current knn search over nested vectors only supports filtering on
parent's metadata.  This adds support for filtering over nested
metadata.

Closes #106994 Closes #128803
Mayya Sharipova 2 months ago
parent
commit
0e63c901be

+ 7 - 0
docs/changelog/113949.yaml

@@ -0,0 +1,7 @@
+pr: 113949
+summary: Support kNN filter on nested metadata
+area: Vector Search
+type: enhancement
+issues:
+ - 128803
+ - 106994

+ 51 - 8
docs/reference/query-languages/query-dsl/query-dsl-knn-query.md

@@ -203,10 +203,19 @@ POST my-image-index/_search
 `knn` query can be used inside a nested query. The behaviour here is similar to [top level nested kNN search](docs-content://solutions/search/vector/knn.md#nested-knn-search):
 
 * kNN search over nested dense_vectors diversifies the top results over the top-level document
-* `filter`  over the top-level document metadata is supported and acts as a pre-filter
-* `filter` over `nested` field metadata is not supported
+* `filter` both over the top-level document metadata and `nested` is supported and acts as a pre-filter
+
+::::{note}
+To ensure correct results: each individual filter must be either over
+the top-level metadata or `nested` metadata. However, a single knn query
+supports multiple filters, where some filters can be over the top-level
+metadata and some over nested.
+::::
 
-A sample query can look like below:
+
+Below is a sample query with filter over nested metadata.
+For scoring parents' documents,  this query only considers vectors that
+have "paragraph.language" set to "EN".
 
 ```json
 {
@@ -215,12 +224,46 @@ A sample query can look like below:
       "path" : "paragraph",
         "query" : {
           "knn": {
-            "query_vector": [
-                0.45,
-                45
-            ],
+            "query_vector": [0.45, 0.50],
             "field": "paragraph.vector",
-            "num_candidates": 2
+            "filter": {
+              "match": {
+                "paragraph.language": "EN"
+              }
+            }
+        }
+      }
+    }
+  }
+}
+```
+
+Below is a sample query with two filters: one over nested metadata
+and another over the top level metadata. For scoring parents' documents,
+this query only considers vectors whose parent's title contain "essay"
+word and have "paragraph.language" set to "EN".
+
+```json
+{
+  "query" : {
+    "nested" : {
+      "path" : "paragraph",
+      "query" : {
+        "knn": {
+          "query_vector": [0.45, 0.50],
+          "field": "paragraph.vector",
+          "filter": [
+            {
+              "match": {
+                "paragraph.language": "EN"
+              }
+            },
+            {
+              "match": {
+                "title": "essay"
+              }
+            }
+          ]
         }
       }
     }

+ 153 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/100_knn_nested_search.yml

@@ -16,6 +16,8 @@ setup:
               nested:
                 type: nested
                 properties:
+                  language:
+                    type: keyword
                   paragraph_id:
                     type: keyword
                   vector:
@@ -27,6 +29,13 @@ setup:
                       type: hnsw
                       m: 16
                       ef_construction: 200
+              nested2:
+                type: nested
+                properties:
+                  key:
+                    type: keyword
+                  value:
+                    type: keyword
 
   - do:
       index:
@@ -37,8 +46,16 @@ setup:
           nested:
           - paragraph_id: 0
             vector: [230.0, 300.33, -34.8988, 15.555, -200.0]
+            language: EN
           - paragraph_id: 1
             vector: [240.0, 300, -3, 1, -20]
+            language: FR
+          nested2:
+            - key: "category"
+              value: "domestic"
+            - key: "level"
+              value: "beginner"
+
 
   - do:
       index:
@@ -49,10 +66,18 @@ setup:
           nested:
           - paragraph_id: 0
             vector: [-0.5, 100.0, -13, 14.8, -156.0]
+            language: EN
           - paragraph_id: 2
             vector: [0, 100.0, 0, 14.8, -156.0]
+            language: EN
           - paragraph_id: 3
             vector: [0, 1.0, 0, 1.8, -15.0]
+            language: FR
+          nested2:
+            - key: "category"
+              value: "wild"
+            - key: "level"
+              value: "beginner"
 
   - do:
       index:
@@ -63,6 +88,12 @@ setup:
           nested:
             - paragraph_id: 0
               vector: [0.5, 111.3, -13.0, 14.8, -156.0]
+              language: FR
+          nested2:
+            - key: "category"
+              value: "domestic"
+            - key: "level"
+              value: "advanced"
 
   - do:
       indices.refresh: {}
@@ -461,3 +492,125 @@ setup:
   - match: {hits.hits.0._id: "2"}
   - length: {hits.hits.0.inner_hits.nested.hits.hits: 1}
   - match: {hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0"}
+
+
+---
+"Filter on nested fields":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ knn_filter_on_nested_fields ]
+      test_runner_features: ["capabilities", "close_to"]
+      reason: "Capability for filtering on nested fields required"
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          knn:
+            boost: 2
+            field: nested.vector
+            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+            k: 3
+            filter: { match: { nested.language: "EN" } }
+            inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }
+
+  - match: { hits.total.value: 2 }
+  - match: { hits.hits.0._id: "2" }
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" }
+  - close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value:  0.0182, error: 0.0001 } }
+  - match: { hits.hits.1._id: "1" }
+  - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
+
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          knn:
+            boost: 2
+            field: nested.vector
+            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+            k: 3
+            filter: { match: { nested.language: "FR" } }
+            inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" }
+  - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+
+  # filter on both nested and parent metadata with 2 different filters
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          knn:
+            boost: 2
+            field: nested.vector
+            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+            k: 3
+            num_candidates: 10
+            filter: [{ match: { nested.language: "FR" }}, {term: {name: "rabbit.jpg"}} ]
+            inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+
+
+---
+"Test filter on sibling nested fields works":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ knn_filter_on_nested_fields ]
+      test_runner_features: ["capabilities", "close_to"]
+      reason: "Capability for filtering on nested fields required"
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          knn:
+            field: nested.vector
+            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+            filter:
+              nested:
+                path: nested2
+                query:
+                  bool:
+                    filter:
+                      - match:
+                          nested2.key: "category"
+                      - match:
+                          nested2.value: "domestic"
+  - match: { hits.total.value: 2}

+ 179 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/130_knn_query_nested_search.yml

@@ -16,6 +16,8 @@ setup:
               nested:
                 type: nested
                 properties:
+                  language:
+                    type: keyword
                   paragraph_id:
                     type: keyword
                   vector:
@@ -23,6 +25,17 @@ setup:
                     dims: 5
                     index: true
                     similarity: l2_norm
+                    index_options:
+                      type: hnsw
+                      m: 16
+                      ef_construction: 200
+              nested2:
+                type: nested
+                properties:
+                  key:
+                    type: keyword
+                  value:
+                    type: keyword
           aliases:
             my_alias:
               filter:
@@ -38,8 +51,15 @@ setup:
           nested:
           - paragraph_id: 0
             vector: [230.0, 300.33, -34.8988, 15.555, -200.0]
+            language: EN
           - paragraph_id: 1
             vector: [240.0, 300, -3, 1, -20]
+            language: FR
+          nested2:
+            - key: "category"
+              value: "domestic"
+            - key: "level"
+              value: "beginner"
 
   - do:
       index:
@@ -50,10 +70,19 @@ setup:
           nested:
           - paragraph_id: 0
             vector: [-0.5, 100.0, -13, 14.8, -156.0]
+            language: EN
           - paragraph_id: 2
             vector: [0, 100.0, 0, 14.8, -156.0]
+            language: EN
           - paragraph_id: 3
             vector: [0, 1.0, 0, 1.8, -15.0]
+            language: FR
+          nested2:
+            - key: "category"
+              value: "wild"
+            - key: "level"
+              value: "beginner"
+
 
   - do:
       index:
@@ -64,6 +93,12 @@ setup:
           nested:
             - paragraph_id: 0
               vector: [0.5, 111.3, -13.0, 14.8, -156.0]
+              language: FR
+          nested2:
+            - key: "category"
+              value: "domestic"
+            - key: "level"
+              value: "advanced"
 
   - do:
       indices.refresh: {}
@@ -408,3 +443,147 @@ setup:
 
   - match: {hits.total.value: 1}
   - match: {hits.hits.0._id: "2"}
+
+
+---
+"Filter on nested fields":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ knn_filter_on_nested_fields ]
+      test_runner_features: ["capabilities", "close_to"]
+      reason: "Capability for filtering on nested fields required"
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          query:
+            nested:
+              path: nested
+              query:
+                knn:
+                  boost: 2
+                  field: nested.vector
+                  query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
+                  k: 3
+                  filter:
+                    match:
+                      nested.language: "EN"
+              inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }
+
+  - match: {hits.total.value: 2}
+  - match: {hits.hits.0._id: "2"}
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 2 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.paragraph_id.0: "2" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.1.fields.nested.0.language.0: "EN" }
+  - close_to: { hits.hits.0._score: { value: 0.0182, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0182, error: 0.0001 } }
+  - match: {hits.hits.1._id: "1"}
+  - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "EN" }
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          query:
+            nested:
+              path: nested
+              query:
+                knn:
+                  boost: 2
+                  field: nested.vector
+                  query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+                  k: 3
+                  filter:
+                    match:
+                      nested.language: "FR"
+              inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false }
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "3" }
+  - match: { hits.hits.1.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "1" }
+  - match: { hits.hits.2.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+
+  # filter on both nested and parent metadata
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          query:
+            nested:
+              path: nested
+              query:
+                knn:
+                  boost: 2
+                  field: nested.vector
+                  query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
+                  k: 10
+                  filter: [{ match: { nested.language: "FR" }}, {term: {name: "rabbit.jpg"}} ]
+              inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language" ], _source: false }
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0.inner_hits.nested.hits.total.value: 1 }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.paragraph_id.0: "0" }
+  - match: { hits.hits.0.inner_hits.nested.hits.hits.0.fields.nested.0.language.0: "FR" }
+  - close_to: { hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+  - close_to: { hits.hits.0.inner_hits.nested.hits.hits.0._score: { value: 0.0043, error: 0.0001 } }
+
+
+---
+"Test filter on sibling nested fields doesn't work":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ knn_filter_on_nested_fields ]
+      test_runner_features: ["capabilities", "close_to"]
+      reason: "Capability for filtering on nested fields required"
+
+  - do:
+      search:
+        index: test
+        body:
+          _source: false
+          query:
+            nested:
+              path: nested
+              query:
+                knn:
+                  field: nested.vector
+                  query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
+                  k: 10
+                  filter:
+                    nested:
+                      path: nested2
+                      query:
+                        bool:
+                          filter:
+                            - match:
+                                nested2.key: "category"
+                            - match:
+                                nested2.value: "domestic"
+              inner_hits: { size: 3, "fields": [ "nested.paragraph_id", "nested.language"], _source: false }
+
+  - match: { hits.total.value: 0 }
+

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -354,6 +354,7 @@ public class TransportVersions {
     public static final TransportVersion RERANK_SNIPPETS = def(9_130_0_00);
     public static final TransportVersion PIPELINE_TRACKING_INFO = def(9_131_0_00);
     public static final TransportVersion COMPONENT_TEMPLATE_TRACKING_INFO = def(9_132_0_00);
+    public static final TransportVersion TO_CHILD_BLOCK_JOIN_QUERY = def(9_133_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 2 - 1
server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

@@ -165,7 +165,8 @@ class DfsQueryPhase extends SearchPhase {
                 scoreDocs.toArray(Lucene.EMPTY_SCORE_DOCS),
                 source.knnSearch().get(i).getField(),
                 source.knnSearch().get(i).getQueryVector(),
-                source.knnSearch().get(i).getSimilarity()
+                source.knnSearch().get(i).getSimilarity(),
+                source.knnSearch().get(i).getFilterQueries()
             ).boost(source.knnSearch().get(i).boost()).queryName(source.knnSearch().get(i).queryName());
             if (nestedPath != null) {
                 query = new NestedQueryBuilder(nestedPath, query, ScoreMode.Max).innerHit(source.knnSearch().get(i).innerHit());

+ 113 - 0
server/src/main/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilder.java

@@ -0,0 +1,113 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.query;
+
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.join.BitSetProducer;
+import org.apache.lucene.search.join.ToChildBlockJoinQuery;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.lucene.search.Queries;
+import org.elasticsearch.index.mapper.NestedObjectMapper;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+/**
+ * A query returns child documents whose parent matches the provided query.
+ * This query is used only for internal purposes and is not exposed to a user.
+ */
+public class ToChildBlockJoinQueryBuilder extends AbstractQueryBuilder<ToChildBlockJoinQueryBuilder> {
+    public static final String NAME = "to_child_block_join";
+    private final QueryBuilder parentQueryBuilder;
+
+    public ToChildBlockJoinQueryBuilder(QueryBuilder parentQueryBuilder) {
+        this.parentQueryBuilder = parentQueryBuilder;
+    }
+
+    public ToChildBlockJoinQueryBuilder(StreamInput in) throws IOException {
+        super(in);
+        parentQueryBuilder = in.readNamedWriteable(QueryBuilder.class);
+    }
+
+    @Override
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeNamedWriteable(parentQueryBuilder);
+    }
+
+    @Override
+    protected void doXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject(NAME);
+        builder.field("query");
+        parentQueryBuilder.toXContent(builder, params);
+        boostAndQueryNameToXContent(builder);
+        builder.endObject();
+    }
+
+    @Override
+    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
+        QueryBuilder rewritten = parentQueryBuilder.rewrite(queryRewriteContext);
+        if (rewritten instanceof MatchNoneQueryBuilder) {
+            return rewritten;
+        }
+        if (rewritten != parentQueryBuilder) {
+            return new ToChildBlockJoinQueryBuilder(rewritten);
+        }
+        return this;
+    }
+
+    @Override
+    protected Query doToQuery(SearchExecutionContext context) throws IOException {
+        final Query parentFilter;
+        NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
+        if (originalObjectMapper != null) {
+            try {
+                // we are in a nested context, to get the parent filter we need to go up one level
+                context.nestedScope().previousLevel();
+                NestedObjectMapper objectMapper = context.nestedScope().getObjectMapper();
+                parentFilter = objectMapper == null
+                    ? Queries.newNonNestedFilter(context.indexVersionCreated())
+                    : objectMapper.nestedTypeFilter();
+            } finally {
+                context.nestedScope().nextLevel(originalObjectMapper);
+            }
+        } else {
+            // we are NOT in a nested context, coming from the top level knn search
+            parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated());
+        }
+        final BitSetProducer parentBitSet = context.bitsetFilter(parentFilter);
+        Query parentQuery = parentQueryBuilder.toQuery(context);
+        // ensure that parentQuery only applies to parent docs by adding parentFilter
+        return new ToChildBlockJoinQuery(Queries.filtered(parentQuery, parentFilter), parentBitSet);
+    }
+
+    @Override
+    protected boolean doEquals(ToChildBlockJoinQueryBuilder other) {
+        return Objects.equals(parentQueryBuilder, other.parentQueryBuilder);
+    }
+
+    @Override
+    protected int doHashCode() {
+        return Objects.hash(parentQueryBuilder);
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY;
+    }
+}

+ 2 - 0
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -57,6 +57,7 @@ public final class SearchCapabilities {
     private static final String FIELD_EXISTS_QUERY_FOR_TEXT_FIELDS_NO_INDEX_OR_DV = "field_exists_query_for_text_fields_no_index_or_dv";
     private static final String SYNTHETIC_VECTORS_SETTING = "synthetic_vectors_setting";
     private static final String UPDATE_FIELD_TO_BBQ_DISK = "update_field_to_bbq_disk";
+    private static final String KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY = "knn_filter_on_nested_fields";
 
     public static final Set<String> CAPABILITIES;
     static {
@@ -82,6 +83,7 @@ public final class SearchCapabilities {
         capabilities.add(DENSE_VECTOR_UPDATABLE_BBQ);
         capabilities.add(FIELD_EXISTS_QUERY_FOR_TEXT_FIELDS_NO_INDEX_OR_DV);
         capabilities.add(UPDATE_FIELD_TO_BBQ_DISK);
+        capabilities.add(KNN_FILTER_ON_NESTED_FIELDS_CAPABILITY);
         if (SYNTHETIC_VECTORS) {
             capabilities.add(SYNTHETIC_VECTORS_SETTING);
         }

+ 4 - 0
server/src/main/java/org/elasticsearch/search/SearchModule.java

@@ -66,6 +66,7 @@ import org.elasticsearch.index.query.SpanWithinQueryBuilder;
 import org.elasticsearch.index.query.TermQueryBuilder;
 import org.elasticsearch.index.query.TermsQueryBuilder;
 import org.elasticsearch.index.query.TermsSetQueryBuilder;
+import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder;
 import org.elasticsearch.index.query.WildcardQueryBuilder;
 import org.elasticsearch.index.query.WrapperQueryBuilder;
 import org.elasticsearch.index.query.functionscore.ExponentialDecayFunctionBuilder;
@@ -1187,6 +1188,9 @@ public class SearchModule {
         registerQuery(new QuerySpec<>(ExactKnnQueryBuilder.NAME, ExactKnnQueryBuilder::new, parser -> {
             throw new IllegalArgumentException("[exact_knn] queries cannot be provided directly");
         }));
+        registerQuery(new QuerySpec<>(ToChildBlockJoinQueryBuilder.NAME, ToChildBlockJoinQueryBuilder::new, parser -> {
+            throw new IllegalArgumentException("[to_child_block_join] queries cannot be provided directly");
+        }));
         registerQuery(
             new QuerySpec<>(RandomSamplingQueryBuilder.NAME, RandomSamplingQueryBuilder::new, RandomSamplingQueryBuilder::fromXContent)
         );

+ 44 - 4
server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

@@ -17,13 +17,16 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 
 /**
@@ -37,6 +40,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
     private final String fieldName;
     private final VectorData queryVector;
     private final Float vectorSimilarity;
+    private final List<QueryBuilder> filterQueries;
 
     /**
      * Creates a query builder.
@@ -44,11 +48,18 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
      * @param scoreDocs the docs and scores this query should match. The array must be
      *                  sorted in order of ascending doc IDs.
      */
-    public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector, Float vectorSimilarity) {
+    public KnnScoreDocQueryBuilder(
+        ScoreDoc[] scoreDocs,
+        String fieldName,
+        VectorData queryVector,
+        Float vectorSimilarity,
+        List<QueryBuilder> filterQueries
+    ) {
         this.scoreDocs = scoreDocs;
         this.fieldName = fieldName;
         this.queryVector = queryVector;
         this.vectorSimilarity = vectorSimilarity;
+        this.filterQueries = filterQueries;
     }
 
     public KnnScoreDocQueryBuilder(StreamInput in) throws IOException {
@@ -74,6 +85,11 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         } else {
             this.vectorSimilarity = null;
         }
+        if (in.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)) {
+            this.filterQueries = readQueries(in);
+        } else {
+            this.filterQueries = List.of();
+        }
     }
 
     @Override
@@ -116,6 +132,9 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
             out.writeOptionalFloat(vectorSimilarity);
         }
+        if (out.getTransportVersion().onOrAfter(TransportVersions.TO_CHILD_BLOCK_JOIN_QUERY)) {
+            writeQueries(out, filterQueries);
+        }
     }
 
     @Override
@@ -135,6 +154,13 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         if (vectorSimilarity != null) {
             builder.field("similarity", vectorSimilarity);
         }
+        if (filterQueries.isEmpty() == false) {
+            builder.startArray("filter");
+            for (QueryBuilder filterQuery : filterQueries) {
+                filterQuery.toXContent(builder, params);
+            }
+            builder.endArray();
+        }
         boostAndQueryNameToXContent(builder);
         builder.endObject();
     }
@@ -150,7 +176,20 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
             return new MatchNoneQueryBuilder("The \"" + getName() + "\" query was rewritten to a \"match_none\" query.");
         }
         if (queryRewriteContext.convertToInnerHitsRewriteContext() != null && queryVector != null && fieldName != null) {
-            return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
+            QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
+            if (filterQueries.isEmpty()) {
+                return exactKnnQuery;
+            } else {
+                BoolQueryBuilder boolQuery = new BoolQueryBuilder();
+                boolQuery.must(exactKnnQuery);
+                for (QueryBuilder filter : this.filterQueries) {
+                    // filter can be both over parents or nested docs, so add them as should clauses to a filter
+                    BoolQueryBuilder adjustedFilter = new BoolQueryBuilder().should(filter)
+                        .should(new ToChildBlockJoinQueryBuilder(filter));
+                    boolQuery.filter(adjustedFilter);
+                }
+                return boolQuery;
+            }
         }
         return super.doRewrite(queryRewriteContext);
     }
@@ -173,7 +212,8 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         }
         return Objects.equals(fieldName, other.fieldName)
             && Objects.equals(queryVector, other.queryVector)
-            && Objects.equals(vectorSimilarity, other.vectorSimilarity);
+            && Objects.equals(vectorSimilarity, other.vectorSimilarity)
+            && Objects.equals(filterQueries, other.filterQueries);
     }
 
     @Override
@@ -183,7 +223,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
             int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
             result = 31 * result + hashCode;
         }
-        return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector));
+        return Objects.hash(result, fieldName, vectorSimilarity, Objects.hashCode(queryVector), filterQueries);
     }
 
     @Override

+ 48 - 21
server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

@@ -27,10 +27,12 @@ import org.elasticsearch.index.mapper.NestedObjectMapper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
 import org.elasticsearch.index.query.AbstractQueryBuilder;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
+import org.elasticsearch.index.query.ToChildBlockJoinQueryBuilder;
 import org.elasticsearch.index.search.NestedHelper;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ObjectParser;
@@ -454,9 +456,6 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 vectorSimilarity
             ).boost(boost).queryName(queryName).addFilterQueries(filterQueries);
         }
-        if (ctx.convertToInnerHitsRewriteContext() != null) {
-            return new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity).boost(boost).queryName(queryName);
-        }
         boolean changed = false;
         List<QueryBuilder> rewrittenQueries = new ArrayList<>(filterQueries.size());
         for (QueryBuilder query : filterQueries) {
@@ -481,6 +480,22 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 vectorSimilarity
             ).boost(boost).queryName(queryName).addFilterQueries(rewrittenQueries);
         }
+        if (ctx.convertToInnerHitsRewriteContext() != null) {
+            QueryBuilder exactKnnQuery = new ExactKnnQueryBuilder(queryVector, fieldName, vectorSimilarity);
+            if (filterQueries.isEmpty()) {
+                return exactKnnQuery;
+            } else {
+                BoolQueryBuilder boolQuery = new BoolQueryBuilder();
+                boolQuery.must(exactKnnQuery);
+                for (QueryBuilder filter : this.filterQueries) {
+                    // filter can be both over parents or nested docs, so add them as should clauses to a filter
+                    BoolQueryBuilder adjustedFilter = new BoolQueryBuilder().should(filter)
+                        .should(new ToChildBlockJoinQueryBuilder(filter));
+                    boolQuery.filter(adjustedFilter);
+                }
+                return boolQuery;
+            }
+        }
         return this;
     }
 
@@ -500,29 +515,27 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         if (fieldType == null) {
             return new MatchNoDocsQuery();
         }
-
         if (fieldType instanceof DenseVectorFieldType == false) {
             throw new IllegalArgumentException(
                 "[" + NAME + "] queries are only supported on [" + DenseVectorFieldMapper.CONTENT_TYPE + "] fields"
             );
         }
+        DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
 
-        BooleanQuery.Builder builder = new BooleanQuery.Builder();
+        List<Query> filtersInitial = new ArrayList<>(filterQueries.size());
         for (QueryBuilder query : this.filterQueries) {
-            builder.add(query.toQuery(context), BooleanClause.Occur.FILTER);
+            filtersInitial.add(query.toQuery(context));
         }
         if (context.getAliasFilter() != null) {
-            builder.add(context.getAliasFilter().toQuery(context), BooleanClause.Occur.FILTER);
+            filtersInitial.add(context.getAliasFilter().toQuery(context));
         }
-        BooleanQuery booleanQuery = builder.build();
-        Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
 
-        DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) fieldType;
         String parentPath = context.nestedLookup().getNestedParent(fieldName);
-        Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
-
         BitSetProducer parentBitSet = null;
-        if (parentPath != null) {
+        Query filterQuery;
+        if (parentPath == null) {
+            filterQuery = buildFilterQuery(filtersInitial);
+        } else {
             final Query parentFilter;
             NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
             if (originalObjectMapper != null) {
@@ -541,19 +554,23 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 parentFilter = Queries.newNonNestedFilter(context.indexVersionCreated());
             }
             parentBitSet = context.bitsetFilter(parentFilter);
-            if (filterQuery != null) {
-                // We treat the provided filter as a filter over PARENT documents, so if it might match nested documents
-                // we need to adjust it.
-                if (NestedHelper.mightMatchNestedDocs(filterQuery, context)) {
-                    // Ensure that the query only returns parent documents matching `filterQuery`
-                    filterQuery = Queries.filtered(filterQuery, parentFilter);
+            List<Query> filterAdjusted = new ArrayList<>(filtersInitial.size());
+            for (Query f : filtersInitial) {
+                // If filter matches non-nested docs, we assume this is a filter over parents docs,
+                // so we will modify it accordingly: matching parents docs with join to its child docs
+                if (NestedHelper.mightMatchNonNestedDocs(f, parentPath, context)) {
+                    // Ensure that the query only returns parent documents matching filter
+                    f = Queries.filtered(f, parentFilter);
+                    f = new ToChildBlockJoinQuery(f, parentBitSet);
                 }
-                // Now join the filterQuery & parentFilter to provide the matching blocks of children
-                filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
+                filterAdjusted.add(f);
             }
+            filterQuery = buildFilterQuery(filterAdjusted);
         }
+
         DenseVectorFieldMapper.FilterHeuristic heuristic = context.getIndexSettings().getHnswFilterHeuristic();
         boolean hnswEarlyTermination = context.getIndexSettings().getHnswEarlyTermination();
+        Float oversample = rescoreVectorBuilder() == null ? null : rescoreVectorBuilder.oversample();
         return vectorFieldType.createKnnQuery(
             queryVector,
             k,
@@ -567,6 +584,16 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         );
     }
 
+    private static Query buildFilterQuery(List<Query> filters) {
+        BooleanQuery.Builder builder = new BooleanQuery.Builder();
+        for (Query f : filters) {
+            builder.add(f, BooleanClause.Occur.FILTER);
+        }
+        BooleanQuery booleanQuery = builder.build();
+        Query filterQuery = booleanQuery.clauses().isEmpty() ? null : booleanQuery;
+        return filterQuery;
+    }
+
     @Override
     protected int doHashCode() {
         return Objects.hash(

+ 4 - 2
server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java

@@ -367,13 +367,15 @@ public class DfsQueryPhaseTests extends ESTestCase {
             new ScoreDoc[] { new ScoreDoc(1, 3.0f, 1), new ScoreDoc(4, 1.5f, 1) },
             "vector",
             VectorData.fromFloats(new float[] { 0.0f }),
-            null
+            null,
+            List.of()
         );
         KnnScoreDocQueryBuilder ksdqb1 = new KnnScoreDocQueryBuilder(
             new ScoreDoc[] { new ScoreDoc(1, 2.0f, 1) },
             "vector2",
             VectorData.fromFloats(new float[] { 0.0f }),
-            null
+            null,
+            List.of()
         );
         assertEquals(
             List.of(bm25, ksdqb0, ksdqb1),

+ 53 - 0
server/src/test/java/org/elasticsearch/index/query/ToChildBlockJoinQueryBuilderTests.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.query;
+
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.join.ToChildBlockJoinQuery;
+import org.elasticsearch.test.AbstractQueryTestCase;
+
+import java.io.IOException;
+
+import static org.hamcrest.CoreMatchers.instanceOf;
+
+public class ToChildBlockJoinQueryBuilderTests extends AbstractQueryTestCase<ToChildBlockJoinQueryBuilder> {
+    @Override
+    protected ToChildBlockJoinQueryBuilder doCreateTestQueryBuilder() {
+        String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
+        return new ToChildBlockJoinQueryBuilder(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
+    }
+
+    @Override
+    protected void doAssertLuceneQuery(ToChildBlockJoinQueryBuilder queryBuilder, Query query, SearchExecutionContext context)
+        throws IOException {
+        assertThat(query, instanceOf(ToChildBlockJoinQuery.class));
+    }
+
+    @Override
+    public void testUnknownField() throws IOException {
+        // Test isn't relevant, since query is never parsed from xContent
+    }
+
+    @Override
+    public void testUnknownObjectException() {
+        // Test isn't relevant, since query is never parsed from xContent
+    }
+
+    @Override
+    public void testFromXContent() throws IOException {
+        // Test isn't relevant, since query is never parsed from xContent
+    }
+
+    @Override
+    public void testValidOutput() {
+        // Test isn't relevant, since query is never parsed from xContent
+    }
+
+}

+ 2 - 1
server/src/test/java/org/elasticsearch/search/SearchModuleTests.java

@@ -463,7 +463,8 @@ public class SearchModuleTests extends ESTestCase {
         "terms_set",
         "wildcard",
         "wrapper",
-        "distance_feature" };
+        "distance_feature",
+        "to_child_block_join" };
 
     // add here deprecated queries to make sure we log a deprecation warnings when they are used
     private static final String[] DEPRECATED_QUERIES = new String[] { "field_masking_span", "geo_polygon" };

+ 10 - 2
server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

@@ -25,6 +25,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.index.IndexVersions;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.InnerHitsRewriteContext;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
@@ -482,12 +483,19 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         queryBuilder.boost(randomFloat());
         queryBuilder.queryName(randomAlphaOfLength(10));
         QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext);
+        float queryBoost = rewritten.boost();
+        String queryName = rewritten.queryName();
+        if (queryBuilder.filterQueries().isEmpty() == false) {
+            assertTrue(rewritten instanceof BoolQueryBuilder);
+            BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten;
+            rewritten = boolQueryBuilder.must().get(0);
+        }
         assertTrue(rewritten instanceof ExactKnnQueryBuilder);
         ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten;
         assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery());
         assertEquals(queryBuilder.getFieldName(), exactKnnQueryBuilder.getField());
-        assertEquals(queryBuilder.boost(), exactKnnQueryBuilder.boost(), 0.0001f);
-        assertEquals(queryBuilder.queryName(), exactKnnQueryBuilder.queryName());
+        assertEquals(queryBuilder.boost(), queryBoost, 0.0001f);
+        assertEquals(queryBuilder.queryName(), queryName);
         assertEquals(queryBuilder.getVectorSimilarity(), exactKnnQueryBuilder.vectorSimilarity());
     }
 

+ 43 - 8
server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java

@@ -24,9 +24,11 @@ import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.search.Weight;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.InnerHitsRewriteContext;
 import org.elasticsearch.index.query.MatchNoneQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.test.AbstractQueryTestCase;
@@ -54,11 +56,20 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
         for (int doc = 0; doc < numDocs; doc++) {
             scoreDocs.add(new ScoreDoc(doc, randomFloat()));
         }
+        List<QueryBuilder> filters = new ArrayList<>();
+        if (randomBoolean()) {
+            int numFilters = randomIntBetween(1, 5);
+            for (int i = 0; i < numFilters; i++) {
+                String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
+                filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
+            }
+        }
         return new KnnScoreDocQueryBuilder(
             scoreDocs.toArray(new ScoreDoc[0]),
             randomBoolean() ? "field" : null,
             randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null,
-            randomBoolean() ? randomFloat() : null
+            randomBoolean() ? randomFloat() : null,
+            filters
         );
     }
 
@@ -68,7 +79,8 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
             new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) },
             "field",
             VectorData.fromFloats(new float[] { 1.0f, 2.0f }),
-            null
+            null,
+            List.of()
         );
         String expected = """
             {
@@ -159,7 +171,8 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
             new ScoreDoc[0],
             randomBoolean() ? "field" : null,
             randomBoolean() ? VectorData.fromFloats(randomVector(10)) : null,
-            randomBoolean() ? randomFloat() : null
+            randomBoolean() ? randomFloat() : null,
+            List.of()
         );
         QueryRewriteContext context = randomBoolean()
             ? new InnerHitsRewriteContext(createSearchExecutionContext().getParserConfig(), System::currentTimeMillis)
@@ -170,21 +183,41 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
     public void testRewriteForInnerHits() throws IOException {
         SearchExecutionContext context = createSearchExecutionContext();
         InnerHitsRewriteContext innerHitsRewriteContext = new InnerHitsRewriteContext(context.getParserConfig(), System::currentTimeMillis);
+        List<QueryBuilder> filters = new ArrayList<>();
+        boolean hasFilters = randomBoolean();
+        if (hasFilters) {
+            int numFilters = randomIntBetween(1, 5);
+            for (int i = 0; i < numFilters; i++) {
+                String filterFieldName = randomBoolean() ? KEYWORD_FIELD_NAME : TEXT_FIELD_NAME;
+                filters.add(QueryBuilders.termQuery(filterFieldName, randomAlphaOfLength(10)));
+            }
+        }
+
         KnnScoreDocQueryBuilder queryBuilder = new KnnScoreDocQueryBuilder(
             new ScoreDoc[] { new ScoreDoc(0, 4.25f), new ScoreDoc(5, 1.6f) },
             randomAlphaOfLength(10),
             VectorData.fromFloats(randomVector(10)),
-            randomBoolean() ? randomFloat() : null
+            randomBoolean() ? randomFloat() : null,
+            filters
         );
         queryBuilder.boost(randomFloat());
         queryBuilder.queryName(randomAlphaOfLength(10));
         QueryBuilder rewritten = queryBuilder.rewrite(innerHitsRewriteContext);
+        float queryBoost = rewritten.boost();
+        String queryName = rewritten.queryName();
+
+        if (hasFilters) {
+            assertTrue(rewritten instanceof BoolQueryBuilder);
+            BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) rewritten;
+            rewritten = boolQueryBuilder.must().get(0);
+        }
+
         assertTrue(rewritten instanceof ExactKnnQueryBuilder);
         ExactKnnQueryBuilder exactKnnQueryBuilder = (ExactKnnQueryBuilder) rewritten;
         assertEquals(queryBuilder.queryVector(), exactKnnQueryBuilder.getQuery());
         assertEquals(queryBuilder.fieldName(), exactKnnQueryBuilder.getField());
-        assertEquals(queryBuilder.boost(), exactKnnQueryBuilder.boost(), 0.0001f);
-        assertEquals(queryBuilder.queryName(), exactKnnQueryBuilder.queryName());
+        assertEquals(queryBuilder.boost(), queryBoost, 0.0001f);
+        assertEquals(queryBuilder.queryName(), queryName);
         assertEquals(queryBuilder.vectorSimilarity(), exactKnnQueryBuilder.vectorSimilarity());
     }
 
@@ -228,7 +261,8 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
                     scoreDocs,
                     "field",
                     VectorData.fromFloats(randomVector(10)),
-                    null
+                    null,
+                    List.of()
                 );
                 Query query = queryBuilder.doToQuery(context);
                 final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);
@@ -276,7 +310,8 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase<KnnScore
                     scoreDocs,
                     "field",
                     VectorData.fromFloats(randomVector(10)),
-                    null
+                    null,
+                    List.of()
                 );
                 final Query query = queryBuilder.doToQuery(context);
                 final Weight w = query.createWeight(searcher, ScoreMode.TOP_SCORES, 1.0f);