Pārlūkot izejas kodu

Allow more than one KNN search clause (#92118)

It makes sense to allow more than one KNN search clause per individual search request. It may be that different documents have separate vector spaces or that a single doc is index with more than one vector space. In both of these scenarios, users may want to retrieve a resulting set that takes into account all their indexed vector spaces. 

A prime example here would be searching a semantic text embedding along with searching an image embedding. 


closes https://github.com/elastic/elasticsearch/issues/91187
Benjamin Trent 2 gadi atpakaļ
vecāks
revīzija
a46e532cda
27 mainītis faili ar 631 papildinājumiem un 151 dzēšanām
  1. 14 0
      docs/changelog/92118.yaml
  2. 3 2
      docs/reference/search/profile.asciidoc
  3. 58 4
      docs/reference/search/search-your-data/knn-search.asciidoc
  4. 4 4
      docs/reference/search/search.asciidoc
  5. 50 1
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml
  6. 27 27
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml
  7. 21 9
      server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java
  8. 24 7
      server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java
  9. 1 1
      server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java
  10. 20 13
      server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java
  11. 1 1
      server/src/main/java/org/elasticsearch/action/search/SearchRequest.java
  12. 1 1
      server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java
  13. 49 18
      server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java
  14. 29 19
      server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
  15. 22 5
      server/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java
  16. 1 1
      server/src/main/java/org/elasticsearch/search/profile/Profilers.java
  17. 48 8
      server/src/main/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResult.java
  18. 18 7
      server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java
  19. 142 4
      server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java
  20. 38 1
      server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java
  21. 3 3
      server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java
  22. 1 1
      server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java
  23. 2 1
      server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java
  24. 10 3
      server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java
  25. 30 1
      server/src/test/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResultTests.java
  26. 13 8
      test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java
  27. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

+ 14 - 0
docs/changelog/92118.yaml

@@ -0,0 +1,14 @@
+pr: 92118
+summary: Allow more than one KNN search clause
+area: Vector Search
+type: enhancement
+issues:
+ - 91187
+highlight:
+  title: Allow more than one KNN search clause
+  body: "Some vector search scenarios require relevance ranking using a few kNN clauses,\n\
+  e.g. when ranking based on several fields, each with its own vector, or when a document \n\
+  includes a vector for the image and another vector for the text. The user may want to obtain\n\
+   relevance ranking based on a combination of all of these kNN clauses.\n\ncloses\
+    \ https://github.com/elastic/elasticsearch/issues/91187"
+  notable: true

+ 3 - 2
docs/reference/search/profile.asciidoc

@@ -1248,7 +1248,8 @@ One of the `dfs.knn` sections for a shard looks like the following:
 [source,js]
 --------------------------------------------------
 "dfs" : {
-    "knn" : {
+    "knn" : [
+        {
         "query" : [
             {
                 "type" : "DocAndScoreQuery",
@@ -1286,7 +1287,7 @@ One of the `dfs.knn` sections for a shard looks like the following:
                 "time_in_nanos" : 17163
             }
         ]
-    }
+    }   ]
 }
 --------------------------------------------------
 // TESTRESPONSE[s/^/{\n"took": $body.took,\n"timed_out": $body.timed_out,\n"_shards": $body._shards,\n"hits": $body.hits,\n"profile": {\n"shards": [ {\n"id": "$body.$_path",\n/]

+ 58 - 4
docs/reference/search/search-your-data/knn-search.asciidoc

@@ -151,7 +151,7 @@ page cache for it to be efficient. Please consult the
 configuration and sizing.
 
 To run an approximate kNN search, use the <<search-api-knn, `knn` option>>
-to search a `dense_vector` field with indexing enabled.
+to search one or more `dense_vector` fields with indexing enabled.
 
 . Explicitly map one or more `dense_vector` fields. Approximate kNN search
 requires the following mapping options:
@@ -176,6 +176,12 @@ PUT image-index
         "index": true,
         "similarity": "l2_norm"
       },
+      "title-vector": {
+        "type": "dense_vector",
+        "dims": 5,
+        "index": true,
+        "similarity": "l2_norm"
+      },
       "title": {
         "type": "text"
       },
@@ -194,11 +200,11 @@ PUT image-index
 ----
 POST image-index/_bulk?refresh=true
 { "index": { "_id": "1" } }
-{ "image-vector": [1, 5, -20], "title": "moose family", "file-type": "jpg" }
+{ "image-vector": [1, 5, -20], "title-vector": [12, 50, -10, 0, 1], "title": "moose family", "file-type": "jpg" }
 { "index": { "_id": "2" } }
-{ "image-vector": [42, 8, -15], "title": "alpine lake", "file-type": "png" }
+{ "image-vector": [42, 8, -15], "title-vector": [25, 1, 4, -12, 2], "title": "alpine lake", "file-type": "png" }
 { "index": { "_id": "3" } }
-{ "image-vector": [15, 11, 23], "title": "full moon", "file-type": "jpg" }
+{ "image-vector": [15, 11, 23], "title-vector": [1, 5, 25, 50, 20], "title": "full moon", "file-type": "jpg" }
 ...
 ----
 //TEST[continued]
@@ -406,6 +412,54 @@ over all documents that match the search. So for approximate kNN search, aggrega
 nearest documents. If the search also includes a `query`, then aggregations are calculated on the combined set of `knn`
 and `query` matches.
 
+[discrete]
+==== Search multiple kNN fields
+
+In addition to 'hybrid retrieval', you can search more than one kNN vector field at a time:
+
+[source,console]
+----
+POST image-index/_search
+{
+  "query": {
+    "match": {
+      "title": {
+        "query": "mountain lake",
+        "boost": 0.9
+      }
+    }
+  },
+  "knn": [ {
+    "field": "image-vector",
+    "query_vector": [54, 10, -2],
+    "k": 5,
+    "num_candidates": 50,
+    "boost": 0.1
+  },
+  {
+    "field": "title-vector",
+    "query_vector": [1, 20, -52, 23, 10],
+    "k": 10,
+    "num_candidates": 10,
+    "boost": 0.5
+  }],
+  "size": 10
+}
+----
+// TEST[continued]
+
+This search finds the global top `k = 5` vector matches for `image-vector` and the global `k = 10` for the `title-vector`.
+These top values are then combined with the matches from the `match` query and the top-10 documents are returned.
+The multiple `knn` entries and the `query` matches are combined through a disjunction,
+as if you took a boolean 'or' between them. The top `k` vector results represent the global nearest neighbors across
+all index shards.
+
+The scoring for a doc with the above configured boosts would be:
+
+```
+score = 0.9 * match_score + 0.1 * knn_score_image-vector + 0.5 * knn_score_title-vector
+```
+
 [discrete]
 [[knn-indexing-considerations]]
 ==== Indexing considerations

+ 4 - 4
docs/reference/search/search.asciidoc

@@ -483,14 +483,14 @@ A boost value greater than `1.0` increases the score. A boost value between
 experimental::[]
 [[search-api-knn]]
 `knn`::
-(Optional, object) 
+(Optional, object or array of objects)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn]
 +
 .Properties of `knn` object
 [%collapsible%open]
 ====
 `field`::
-(Required, string) 
+(Required, string)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-field]
 
 `filter`::
@@ -498,11 +498,11 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-field]
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-filter]
 
 `k`::
-(Required, integer) 
+(Required, integer)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k]
 
 `num_candidates`::
-(Required, integer) 
+(Required, integer)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
 
 `query_vector`::

+ 50 - 1
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml

@@ -17,6 +17,11 @@ setup:
                 dims: 5
                 index: true
                 similarity: l2_norm
+              another_vector:
+                type: dense_vector
+                dims: 5
+                index: true
+                similarity: l2_norm
 
   - do:
       index:
@@ -25,6 +30,7 @@ setup:
         body:
           name: cow.jpg
           vector: [230.0, 300.33, -34.8988, 15.555, -200.0]
+          another_vector: [130.0, 115.0, -1.02, 15.555, -100.0]
 
   - do:
       index:
@@ -33,6 +39,7 @@ setup:
         body:
           name: moose.jpg
           vector: [-0.5, 100.0, -13, 14.8, -156.0]
+          another_vector: [-0.5, 50.0, -1, 1, 120]
 
   - do:
       index:
@@ -41,6 +48,7 @@ setup:
         body:
           name: rabbit.jpg
           vector: [0.5, 111.3, -13.0, 14.8, -156.0]
+          another_vector: [-0.5, 11.0, 0, 12, 111.0]
 
   - do:
       indices.refresh: {}
@@ -66,7 +74,25 @@ setup:
 
   - match: {hits.hits.1._id: "3"}
   - match: {hits.hits.1.fields.name.0: "rabbit.jpg"}
+---
+"kNN multi-field search only":
+  - skip:
+      version: ' - 8.6.99'
+      reason: 'multi-field kNN search added to search endpoint in 8.7'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+           - {field: vector, query_vector: [-0.5, 90.0, -10, 14.8, -156.0], k: 2, num_candidates: 3}
+           - {field: another_vector, query_vector: [-0.5, 11.0, 0, 12, 111.0], k: 2, num_candidates: 3}
+
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
 
+  - match: {hits.hits.1._id: "2"}
+  - match: {hits.hits.1.fields.name.0: "moose.jpg"}
 ---
 "kNN search plus query":
   - skip:
@@ -94,7 +120,31 @@ setup:
 
   - match: {hits.hits.2._id: "3"}
   - match: {hits.hits.2.fields.name.0: "rabbit.jpg"}
+---
+"kNN multi-field search with query":
+  - skip:
+      version: ' - 8.6.99'
+      reason: 'multi-field kNN search added to search endpoint in 8.7'
+  - do:
+      search:
+        index: test
+        body:
+          fields: [ "name" ]
+          knn:
+            - {field: vector, query_vector: [-0.5, 90.0, -10, 14.8, -156.0], k: 2, num_candidates: 3}
+            - {field: another_vector, query_vector: [-0.5, 11.0, 0, 12, 111.0], k: 2, num_candidates: 3}
+          query:
+            term:
+              name: cow.jpg
+
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0.fields.name.0: "rabbit.jpg"}
 
+  - match: {hits.hits.1._id: "1"}
+  - match: {hits.hits.1.fields.name.0: "cow.jpg"}
+
+  - match: {hits.hits.2._id: "2"}
+  - match: {hits.hits.2.fields.name.0: "moose.jpg"}
 ---
 "kNN search with filter":
   - skip:
@@ -110,7 +160,6 @@ setup:
             query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
             k: 2
             num_candidates: 3
-
             filter:
               term:
                 name: "rabbit.jpg"

+ 27 - 27
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/370_profile.yml

@@ -168,8 +168,8 @@ disabling stored fields removes fetch sub phases:
 ---
 dfs knn vector profiling:
   - skip:
-      version: ' - 8.5.99'
-      reason: dfs profiling implemented in 8.6.0
+      version: ' - 8.6.99'
+      reason: multi-knn dfs profiling implemented in 8.7.0
 
   - do:
       indices.create:
@@ -205,31 +205,31 @@ dfs knn vector profiling:
             num_candidates: 100
 
   - match: { hits.total.value: 1 }
-  - match: { profile.shards.0.dfs.knn.query.0.type: "DocAndScoreQuery" }
-  - match: { profile.shards.0.dfs.knn.query.0.description: "DocAndScore[100]" }
-  - gt: { profile.shards.0.dfs.knn.query.0.time_in_nanos: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score_count: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.set_min_competitive_score: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.match_count: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.match: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance_count: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.shallow_advance: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc_count: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.next_doc: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.score_count: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.score: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score_count: 0 }
-  - match: { profile.shards.0.dfs.knn.query.0.breakdown.compute_max_score: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance_count: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.advance: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer_count: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.build_scorer: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight: 0 }
-  - gt: { profile.shards.0.dfs.knn.query.0.breakdown.create_weight_count: 0 }
-  - gt: { profile.shards.0.dfs.knn.rewrite_time: 0 }
-  - match: { profile.shards.0.dfs.knn.collector.0.name: "SimpleTopScoreDocCollector" }
-  - match: { profile.shards.0.dfs.knn.collector.0.reason: "search_top_hits" }
-  - gt: { profile.shards.0.dfs.knn.collector.0.time_in_nanos: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.type: "DocAndScoreQuery" }
+  - match: { profile.shards.0.dfs.knn.0.query.0.description: "DocAndScore[100]" }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.time_in_nanos: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.set_min_competitive_score_count: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.set_min_competitive_score: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.match_count: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.match: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.shallow_advance_count: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.shallow_advance: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.next_doc_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.next_doc: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.score_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.score: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.compute_max_score_count: 0 }
+  - match: { profile.shards.0.dfs.knn.0.query.0.breakdown.compute_max_score: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.advance_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.advance: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.build_scorer_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.build_scorer: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.create_weight: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.query.0.breakdown.create_weight_count: 0 }
+  - gt: { profile.shards.0.dfs.knn.0.rewrite_time: 0 }
+  - match: { profile.shards.0.dfs.knn.0.collector.0.name: "SimpleTopScoreDocCollector" }
+  - match: { profile.shards.0.dfs.knn.0.collector.0.reason: "search_top_hits" }
+  - gt: { profile.shards.0.dfs.knn.0.collector.0.time_in_nanos: 0 }
 
 ---
 dfs profile for search with dfs_query_then_fetch:

+ 21 - 9
server/src/internalClusterTest/java/org/elasticsearch/search/profile/dfs/DfsProfilerIT.java

@@ -67,13 +67,23 @@ public class DfsProfilerIT extends ESIntegTestCase {
         for (int i = 0; i < iters; i++) {
             QueryBuilder q = randomQueryBuilder(List.of(textField), List.of(numericField), numDocs, 3);
             logger.info("Query: {}", q);
-
             SearchResponse resp = client().prepareSearch()
                 .setQuery(q)
                 .setTrackTotalHits(true)
                 .setProfile(true)
                 .setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
-                .setKnnSearch(new KnnSearchBuilder(vectorField, new float[] { randomFloat(), randomFloat(), randomFloat() }, 5, 50))
+                .setKnnSearch(
+                    randomList(
+                        2,
+                        5,
+                        () -> new KnnSearchBuilder(
+                            vectorField,
+                            new float[] { randomFloat(), randomFloat(), randomFloat() },
+                            randomIntBetween(5, 10),
+                            50
+                        )
+                    )
+                )
                 .get();
 
             assertNotNull("Profile response element should not be null", resp.getProfileResults());
@@ -91,17 +101,19 @@ public class DfsProfilerIT extends ESIntegTestCase {
                 }
                 SearchProfileDfsPhaseResult searchProfileDfsPhaseResult = shard.getValue().getSearchProfileDfsPhaseResult();
                 assertThat(searchProfileDfsPhaseResult, is(notNullValue()));
-                for (ProfileResult result : searchProfileDfsPhaseResult.getQueryProfileShardResult().getQueryResults()) {
-                    assertNotNull(result.getQueryName());
-                    assertNotNull(result.getLuceneDescription());
+                for (QueryProfileShardResult queryProfileShardResult : searchProfileDfsPhaseResult.getQueryProfileShardResult()) {
+                    for (ProfileResult result : queryProfileShardResult.getQueryResults()) {
+                        assertNotNull(result.getQueryName());
+                        assertNotNull(result.getLuceneDescription());
+                        assertThat(result.getTime(), greaterThan(0L));
+                    }
+                    CollectorResult result = queryProfileShardResult.getCollectorResult();
+                    assertThat(result.getName(), is(not(emptyOrNullString())));
+                    assertThat(result.getTime(), greaterThan(0L));
                     assertThat(result.getTime(), greaterThan(0L));
                 }
-                CollectorResult result = searchProfileDfsPhaseResult.getQueryProfileShardResult().getCollectorResult();
-                assertThat(result.getName(), is(not(emptyOrNullString())));
-                assertThat(result.getTime(), greaterThan(0L));
                 ProfileResult statsResult = searchProfileDfsPhaseResult.getDfsShardResult();
                 assertThat(statsResult.getQueryName(), equalTo("statistics"));
-                assertThat(result.getTime(), greaterThan(0L));
             }
         }
     }

+ 24 - 7
server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

@@ -38,7 +38,7 @@ final class DfsQueryPhase extends SearchPhase {
     private final QueryPhaseResultConsumer queryResult;
     private final List<DfsSearchResult> searchResults;
     private final AggregatedDfs dfs;
-    private final DfsKnnResults knnResults;
+    private final List<DfsKnnResults> knnResults;
     private final Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
     private final SearchPhaseContext context;
     private final SearchTransportService searchTransportService;
@@ -47,7 +47,7 @@ final class DfsQueryPhase extends SearchPhase {
     DfsQueryPhase(
         List<DfsSearchResult> searchResults,
         AggregatedDfs dfs,
-        DfsKnnResults knnResults,
+        List<DfsKnnResults> knnResults,
         QueryPhaseResultConsumer queryResult,
         Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
         SearchPhaseContext context
@@ -132,20 +132,37 @@ final class DfsQueryPhase extends SearchPhase {
 
     private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
         SearchSourceBuilder source = request.source();
-        if (source == null || source.knnSearch() == null) {
+        if (source == null || source.knnSearch().isEmpty()) {
             return request;
         }
 
         List<ScoreDoc> scoreDocs = new ArrayList<>();
-        for (ScoreDoc scoreDoc : knnResults.scoreDocs()) {
-            if (scoreDoc.shardIndex == request.shardRequestIndex()) {
-                scoreDocs.add(scoreDoc);
+        for (DfsKnnResults dfsKnnResults : knnResults) {
+            for (ScoreDoc scoreDoc : dfsKnnResults.scoreDocs()) {
+                if (scoreDoc.shardIndex == request.shardRequestIndex()) {
+                    scoreDocs.add(scoreDoc);
+                }
             }
         }
         scoreDocs.sort(Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
+        // It is possible that the different results refer to the same doc.
+        for (int i = 0; i < scoreDocs.size() - 1; i++) {
+            ScoreDoc scoreDoc = scoreDocs.get(i);
+            int j = i + 1;
+            for (; j < scoreDocs.size(); j++) {
+                ScoreDoc otherScoreDoc = scoreDocs.get(j);
+                if (otherScoreDoc.doc != scoreDoc.doc) {
+                    break;
+                }
+                scoreDoc.score += otherScoreDoc.score;
+            }
+            if (j > i + 1) {
+                scoreDocs.subList(i + 1, j).clear();
+            }
+        }
         KnnScoreDocQueryBuilder knnQuery = new KnnScoreDocQueryBuilder(scoreDocs.toArray(new ScoreDoc[0]));
 
-        SearchSourceBuilder newSource = source.shallowCopy().knnSearch(null);
+        SearchSourceBuilder newSource = source.shallowCopy().knnSearch(List.of());
         if (source.query() == null) {
             newSource.query(knnQuery);
         } else {

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

@@ -87,7 +87,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
     protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> results, SearchPhaseContext context) {
         final List<DfsSearchResult> dfsSearchResults = results.getAtomicArray().asList();
         final AggregatedDfs aggregatedDfs = SearchPhaseController.aggregateDfs(dfsSearchResults);
-        final DfsKnnResults mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
+        final List<DfsKnnResults> mergedKnnResults = SearchPhaseController.mergeKnnResults(getRequest(), dfsSearchResults);
 
         return new DfsQueryPhase(
             dfsSearchResults,

+ 20 - 13
server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

@@ -46,7 +46,6 @@ import org.elasticsearch.search.query.QuerySearchResult;
 import org.elasticsearch.search.suggest.Suggest;
 import org.elasticsearch.search.suggest.Suggest.Suggestion;
 import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
-import org.elasticsearch.search.vectors.KnnSearchBuilder;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -128,26 +127,34 @@ public final class SearchPhaseController {
         return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
     }
 
-    public static DfsKnnResults mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
+    public static List<DfsKnnResults> mergeKnnResults(SearchRequest request, List<DfsSearchResult> dfsSearchResults) {
         if (request.hasKnnSearch() == false) {
             return null;
         }
 
-        List<TopDocs> topDocs = new ArrayList<>();
+        List<List<TopDocs>> topDocsLists = new ArrayList<>(request.source().knnSearch().size());
+        for (int i = 0; i < request.source().knnSearch().size(); i++) {
+            topDocsLists.add(new ArrayList<>());
+        }
+
         for (DfsSearchResult dfsSearchResult : dfsSearchResults) {
             if (dfsSearchResult.knnResults() != null) {
-                ScoreDoc[] scoreDocs = dfsSearchResult.knnResults().scoreDocs();
-                TotalHits totalHits = new TotalHits(scoreDocs.length, Relation.EQUAL_TO);
-
-                TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
-                setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
-                topDocs.add(shardTopDocs);
+                for (int i = 0; i < dfsSearchResult.knnResults().size(); i++) {
+                    DfsKnnResults knnResults = dfsSearchResult.knnResults().get(i);
+                    ScoreDoc[] scoreDocs = knnResults.scoreDocs();
+                    TotalHits totalHits = new TotalHits(scoreDocs.length, Relation.EQUAL_TO);
+                    TopDocs shardTopDocs = new TopDocs(totalHits, scoreDocs);
+                    setShardIndex(shardTopDocs, dfsSearchResult.getShardIndex());
+                    topDocsLists.get(i).add(shardTopDocs);
+                }
             }
         }
-
-        KnnSearchBuilder knnSearch = request.source().knnSearch();
-        TopDocs mergedTopDocs = TopDocs.merge(knnSearch.k(), topDocs.toArray(new TopDocs[0]));
-        return new DfsKnnResults(mergedTopDocs.scoreDocs);
+        List<DfsKnnResults> mergedResults = new ArrayList<>(request.source().knnSearch().size());
+        for (int i = 0; i < request.source().knnSearch().size(); i++) {
+            TopDocs mergedTopDocs = TopDocs.merge(request.source().knnSearch().get(i).k(), topDocsLists.get(i).toArray(new TopDocs[0]));
+            mergedResults.add(new DfsKnnResults(mergedTopDocs.scoreDocs));
+        }
+        return mergedResults;
     }
 
     /**

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

@@ -743,7 +743,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
      * @return true if the request contains kNN search
      */
     public boolean hasKnnSearch() {
-        return source != null && source.knnSearch() != null;
+        return source != null && source.knnSearch().isEmpty() == false;
     }
 
     public int resolveTrackTotalHitsUpTo() {

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchRequestBuilder.java

@@ -177,7 +177,7 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
      * Defines a kNN search. If a query is also provided, the kNN hits
      * are combined with the query hits.
      */
-    public SearchRequestBuilder setKnnSearch(KnnSearchBuilder knnSearch) {
+    public SearchRequestBuilder setKnnSearch(List<KnnSearchBuilder> knnSearch) {
         sourceBuilder().knnSearch(knnSearch);
         return this;
     }

+ 49 - 18
server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java

@@ -50,6 +50,7 @@ import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentFragment;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParseException;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentType;
 
@@ -130,7 +131,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
 
     private QueryBuilder postQueryBuilder;
 
-    private KnnSearchBuilder knnSearch;
+    private List<KnnSearchBuilder> knnSearch = new ArrayList<>();
 
     private int from = -1;
 
@@ -249,7 +250,12 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             runtimeMappings = in.readMap();
         }
         if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
-            knnSearch = in.readOptionalWriteable(KnnSearchBuilder::new);
+            if (in.getVersion().before(Version.V_8_7_0)) {
+                KnnSearchBuilder searchBuilder = in.readOptionalWriteable(KnnSearchBuilder::new);
+                knnSearch = searchBuilder != null ? List.of(searchBuilder) : List.of();
+            } else {
+                knnSearch = in.readList(KnnSearchBuilder::new);
+            }
         }
     }
 
@@ -319,7 +325,18 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             }
         }
         if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
-            out.writeOptionalWriteable(knnSearch);
+            if (out.getVersion().before(Version.V_8_7_0)) {
+                if (knnSearch.size() > 1) {
+                    throw new IllegalArgumentException(
+                        "Versions before 8.7.0 don't support multiple [knn] search clauses and search was sent to ["
+                            + out.getVersion()
+                            + "]"
+                    );
+                }
+                out.writeOptionalWriteable(knnSearch.isEmpty() ? null : knnSearch.get(0));
+            } else {
+                out.writeCollection(knnSearch);
+            }
         }
     }
 
@@ -361,16 +378,16 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
      * Defines a kNN search. If a query is also provided, the kNN hits
      * are combined with the query hits.
      */
-    public SearchSourceBuilder knnSearch(KnnSearchBuilder knnSearch) {
-        this.knnSearch = knnSearch;
+    public SearchSourceBuilder knnSearch(List<KnnSearchBuilder> knnSearch) {
+        this.knnSearch = Objects.requireNonNull(knnSearch);
         return this;
     }
 
     /**
      * An optional kNN search definition.
      */
-    public KnnSearchBuilder knnSearch() {
-        return knnSearch;
+    public List<KnnSearchBuilder> knnSearch() {
+        return Collections.unmodifiableList(knnSearch);
     }
 
     /**
@@ -986,7 +1003,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
      * @return true if the source only has suggest
      */
     public boolean isSuggestOnly() {
-        return suggestBuilder != null && queryBuilder == null && knnSearch == null && aggregations == null;
+        return suggestBuilder != null && queryBuilder == null && knnSearch.isEmpty() && aggregations == null;
     }
 
     /**
@@ -1039,10 +1056,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
         if (this.postQueryBuilder != null) {
             postQueryBuilder = this.postQueryBuilder.rewrite(context);
         }
-        KnnSearchBuilder knnSearch = null;
-        if (this.knnSearch != null) {
-            knnSearch = this.knnSearch.rewrite(context);
-        }
+        List<KnnSearchBuilder> knnSearch = Rewriteable.rewrite(this.knnSearch, context);
         AggregatorFactories.Builder aggregations = null;
         if (this.aggregations != null) {
             aggregations = this.aggregations.rewrite(context);
@@ -1092,7 +1106,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
     private SearchSourceBuilder shallowCopy(
         QueryBuilder queryBuilder,
         QueryBuilder postQueryBuilder,
-        KnnSearchBuilder knnSearch,
+        List<KnnSearchBuilder> knnSearch,
         AggregatorFactories.Builder aggregations,
         SliceBuilder slice,
         List<SortBuilder<?>> sorts,
@@ -1243,7 +1257,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
                     postQueryBuilder = parseTopLevelQuery(parser, searchUsage::trackQueryUsage);
                     searchUsage.trackSectionUsage(POST_FILTER_FIELD.getPreferredName());
                 } else if (KNN_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
-                    knnSearch = KnnSearchBuilder.fromXContent(parser);
+                    knnSearch = List.of(KnnSearchBuilder.fromXContent(parser));
                     searchUsage.trackSectionUsage(KNN_FIELD.getPreferredName());
                 } else if (_SOURCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                     fetchSourceContext = FetchSourceContext.fromXContent(parser);
@@ -1420,6 +1434,19 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
                 } else if (SEARCH_AFTER.match(currentFieldName, parser.getDeprecationHandler())) {
                     searchAfterBuilder = SearchAfterBuilder.fromXContent(parser);
                     searchUsage.trackSectionUsage(SEARCH_AFTER.getPreferredName());
+                } else if (KNN_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
+                    knnSearch = new ArrayList<>();
+                    while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
+                        if (token == XContentParser.Token.START_OBJECT) {
+                            knnSearch.add(KnnSearchBuilder.fromXContent(parser));
+                        } else {
+                            throw new XContentParseException(
+                                parser.getTokenLocation(),
+                                "malformed knn format, within the knn search array only objects are allowed; found " + token
+                            );
+                        }
+                    }
+                    searchUsage.trackSectionUsage(KNN_FIELD.getPreferredName());
                 } else {
                     throw new ParsingException(
                         parser.getTokenLocation(),
@@ -1469,10 +1496,14 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
             builder.field(POST_FILTER_FIELD.getPreferredName(), postQueryBuilder);
         }
 
-        if (knnSearch != null) {
-            builder.startObject(KNN_FIELD.getPreferredName());
-            knnSearch.toXContent(builder, params);
-            builder.endObject();
+        if (knnSearch.isEmpty() == false) {
+            builder.startArray(KNN_FIELD.getPreferredName());
+            for (KnnSearchBuilder knnSearchBuilder : knnSearch) {
+                builder.startObject();
+                knnSearchBuilder.toXContent(builder, params);
+                builder.endObject();
+            }
+            builder.endArray();
         }
 
         if (minScore != null) {

+ 29 - 19
server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

@@ -23,12 +23,14 @@ import org.elasticsearch.search.profile.dfs.DfsProfiler;
 import org.elasticsearch.search.profile.dfs.DfsTimingType;
 import org.elasticsearch.search.profile.query.CollectorResult;
 import org.elasticsearch.search.profile.query.InternalProfileCollector;
+import org.elasticsearch.search.profile.query.QueryProfiler;
 import org.elasticsearch.search.rescore.RescoreContext;
 import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
 import org.elasticsearch.tasks.TaskCancelledException;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -156,34 +158,42 @@ public class DfsPhase {
 
     private void executeKnnVectorQuery(SearchContext context) throws IOException {
         SearchSourceBuilder source = context.request().source();
-        if (source == null || source.knnSearch() == null) {
+        if (source == null || source.knnSearch().isEmpty()) {
             return;
         }
 
         SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
-        KnnSearchBuilder knnSearch = context.request().source().knnSearch();
-        KnnVectorQueryBuilder knnVectorQueryBuilder = knnSearch.toQueryBuilder();
+        List<KnnSearchBuilder> knnSearch = context.request().source().knnSearch();
+        List<KnnVectorQueryBuilder> knnVectorQueryBuilders = knnSearch.stream().map(KnnSearchBuilder::toQueryBuilder).toList();
 
         if (context.request().getAliasFilter().getQueryBuilder() != null) {
-            knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
+            for (KnnVectorQueryBuilder knnVectorQueryBuilder : knnVectorQueryBuilders) {
+                knnVectorQueryBuilder.addFilterQuery(context.request().getAliasFilter().getQueryBuilder());
+            }
         }
-
-        Query query = searchExecutionContext.toQuery(knnVectorQueryBuilder).query();
-        TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.k(), Integer.MAX_VALUE);
-        Collector collector = topScoreDocCollector;
-
+        List<DfsKnnResults> knnResults = new ArrayList<>(knnVectorQueryBuilders.size());
+        for (int i = 0; i < knnSearch.size(); i++) {
+            Query knnQuery = searchExecutionContext.toQuery(knnVectorQueryBuilders.get(i)).query();
+            TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(knnSearch.get(i).k(), Integer.MAX_VALUE);
+            Collector collector = topScoreDocCollector;
+            if (context.getProfilers() != null) {
+                InternalProfileCollector ipc = new InternalProfileCollector(
+                    topScoreDocCollector,
+                    CollectorResult.REASON_SEARCH_TOP_HITS,
+                    List.of()
+                );
+                QueryProfiler knnProfiler = context.getProfilers().getDfsProfiler().addQueryProfiler(ipc);
+                collector = ipc;
+                // Set the current searcher profiler to gather query profiling information for gathering top K docs
+                context.searcher().setProfiler(knnProfiler);
+            }
+            context.searcher().search(knnQuery, collector);
+            knnResults.add(new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs));
+        }
+        // Set profiler back after running KNN searches
         if (context.getProfilers() != null) {
-            InternalProfileCollector ipc = new InternalProfileCollector(
-                topScoreDocCollector,
-                CollectorResult.REASON_SEARCH_TOP_HITS,
-                List.of()
-            );
-            context.getProfilers().getDfsProfiler().setCollector(ipc);
-            collector = ipc;
+            context.searcher().setProfiler(context.getProfilers().getCurrentQueryProfiler());
         }
-
-        context.searcher().search(query, collector);
-        DfsKnnResults knnResults = new DfsKnnResults(topScoreDocCollector.topDocs().scoreDocs);
         context.dfsResult().knnResults(knnResults);
     }
 }

+ 22 - 5
server/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java

@@ -23,6 +23,7 @@ import org.elasticsearch.search.profile.SearchProfileDfsPhaseResult;
 
 import java.io.IOException;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 public class DfsSearchResult extends SearchPhaseResult {
@@ -32,7 +33,7 @@ public class DfsSearchResult extends SearchPhaseResult {
     private Term[] terms;
     private TermStatistics[] termStatistics;
     private Map<String, CollectionStatistics> fieldStatistics = new HashMap<>();
-    private DfsKnnResults knnResults;
+    private List<DfsKnnResults> knnResults;
     private int maxDoc;
     private SearchProfileDfsPhaseResult searchProfileDfsPhaseResult;
 
@@ -56,7 +57,12 @@ public class DfsSearchResult extends SearchPhaseResult {
             setShardSearchRequest(in.readOptionalWriteable(ShardSearchRequest::new));
         }
         if (in.getVersion().onOrAfter(Version.V_8_4_0)) {
-            knnResults = in.readOptionalWriteable(DfsKnnResults::new);
+            if (in.getVersion().onOrAfter(Version.V_8_7_0)) {
+                knnResults = in.readOptionalList(DfsKnnResults::new);
+            } else {
+                DfsKnnResults results = in.readOptionalWriteable(DfsKnnResults::new);
+                knnResults = results != null ? List.of(results) : List.of();
+            }
         }
         if (in.getVersion().onOrAfter(Version.V_8_6_0)) {
             searchProfileDfsPhaseResult = in.readOptionalWriteable(SearchProfileDfsPhaseResult::new);
@@ -89,7 +95,7 @@ public class DfsSearchResult extends SearchPhaseResult {
         return this;
     }
 
-    public DfsSearchResult knnResults(DfsKnnResults knnResults) {
+    public DfsSearchResult knnResults(List<DfsKnnResults> knnResults) {
         this.knnResults = knnResults;
         return this;
     }
@@ -111,7 +117,7 @@ public class DfsSearchResult extends SearchPhaseResult {
         return fieldStatistics;
     }
 
-    public DfsKnnResults knnResults() {
+    public List<DfsKnnResults> knnResults() {
         return knnResults;
     }
 
@@ -133,7 +139,18 @@ public class DfsSearchResult extends SearchPhaseResult {
             out.writeOptionalWriteable(getShardSearchRequest());
         }
         if (out.getVersion().onOrAfter(Version.V_8_4_0)) {
-            out.writeOptionalWriteable(knnResults);
+            if (out.getVersion().onOrAfter(Version.V_8_7_0)) {
+                out.writeOptionalCollection(knnResults);
+            } else {
+                if (knnResults != null && knnResults.size() > 1) {
+                    throw new IllegalArgumentException(
+                        "Versions before 8.7.0 don't support multiple [knn] search clauses and search was sent to ["
+                            + out.getVersion()
+                            + "]"
+                    );
+                }
+                out.writeOptionalWriteable(knnResults == null || knnResults.isEmpty() ? null : knnResults.get(0));
+            }
         }
         if (out.getVersion().onOrAfter(Version.V_8_6_0)) {
             out.writeOptionalWriteable(searchProfileDfsPhaseResult);

+ 1 - 1
server/src/main/java/org/elasticsearch/search/profile/Profilers.java

@@ -66,7 +66,7 @@ public final class Profilers {
      */
     public DfsProfiler getDfsProfiler() {
         if (dfsProfiler == null) {
-            dfsProfiler = new DfsProfiler(getCurrentQueryProfiler());
+            dfsProfiler = new DfsProfiler();
         }
 
         return dfsProfiler;

+ 48 - 8
server/src/main/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResult.java

@@ -8,10 +8,12 @@
 
 package org.elasticsearch.search.profile;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.search.profile.query.CollectorResult;
 import org.elasticsearch.search.profile.query.QueryProfileShardResult;
 import org.elasticsearch.xcontent.InstantiatingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
@@ -21,6 +23,8 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 
 import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -28,23 +32,35 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
 public class SearchProfileDfsPhaseResult implements Writeable, ToXContentObject {
 
     private final ProfileResult dfsShardResult;
-    private final QueryProfileShardResult queryProfileShardResult;
+    private final List<QueryProfileShardResult> queryProfileShardResult;
 
     @ParserConstructor
-    public SearchProfileDfsPhaseResult(@Nullable ProfileResult dfsShardResult, @Nullable QueryProfileShardResult queryProfileShardResult) {
+    public SearchProfileDfsPhaseResult(
+        @Nullable ProfileResult dfsShardResult,
+        @Nullable List<QueryProfileShardResult> queryProfileShardResult
+    ) {
         this.dfsShardResult = dfsShardResult;
         this.queryProfileShardResult = queryProfileShardResult;
     }
 
     public SearchProfileDfsPhaseResult(StreamInput in) throws IOException {
         dfsShardResult = in.readOptionalWriteable(ProfileResult::new);
-        queryProfileShardResult = in.readOptionalWriteable(QueryProfileShardResult::new);
+        if (in.getVersion().onOrAfter(Version.V_8_7_0)) {
+            queryProfileShardResult = in.readOptionalList(QueryProfileShardResult::new);
+        } else {
+            QueryProfileShardResult singleResult = in.readOptionalWriteable(QueryProfileShardResult::new);
+            queryProfileShardResult = singleResult != null ? List.of(singleResult) : null;
+        }
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeOptionalWriteable(dfsShardResult);
-        out.writeOptionalWriteable(queryProfileShardResult);
+        if (out.getVersion().onOrAfter(Version.V_8_7_0)) {
+            out.writeOptionalCollection(queryProfileShardResult);
+        } else {
+            out.writeOptionalWriteable(combineQueryProfileShardResults());
+        }
     }
 
     private static final ParseField STATISTICS = new ParseField("statistics");
@@ -58,7 +74,7 @@ public class SearchProfileDfsPhaseResult implements Writeable, ToXContentObject
             SearchProfileDfsPhaseResult.class
         );
         parser.declareObject(optionalConstructorArg(), (p, c) -> ProfileResult.fromXContent(p), STATISTICS);
-        parser.declareObject(optionalConstructorArg(), (p, c) -> QueryProfileShardResult.fromXContent(p), KNN);
+        parser.declareObjectArray(optionalConstructorArg(), (p, c) -> QueryProfileShardResult.fromXContent(p), KNN);
         PARSER = parser.build();
     }
 
@@ -74,8 +90,11 @@ public class SearchProfileDfsPhaseResult implements Writeable, ToXContentObject
             dfsShardResult.toXContent(builder, params);
         }
         if (queryProfileShardResult != null) {
-            builder.field(KNN.getPreferredName());
-            queryProfileShardResult.toXContent(builder, params);
+            builder.startArray(KNN.getPreferredName());
+            for (QueryProfileShardResult qpsr : queryProfileShardResult) {
+                qpsr.toXContent(builder, params);
+            }
+            builder.endArray();
         }
         builder.endObject();
         return builder;
@@ -108,7 +127,28 @@ public class SearchProfileDfsPhaseResult implements Writeable, ToXContentObject
         return dfsShardResult;
     }
 
-    public QueryProfileShardResult getQueryProfileShardResult() {
+    public List<QueryProfileShardResult> getQueryProfileShardResult() {
         return queryProfileShardResult;
     }
+
+    QueryProfileShardResult combineQueryProfileShardResults() {
+        if (queryProfileShardResult == null) {
+            return null;
+        }
+        List<CollectorResult> subCollectorResults = new ArrayList<>(queryProfileShardResult.size());
+        long totalRewriteTime = 0;
+        long totalCollectionTime = 0;
+        List<ProfileResult> profileResults = new ArrayList<>();
+        for (QueryProfileShardResult queryProfiler : queryProfileShardResult) {
+            totalRewriteTime += queryProfiler.getRewriteTime();
+            profileResults.addAll(queryProfiler.getQueryResults());
+            subCollectorResults.add(queryProfiler.getCollectorResult());
+            totalCollectionTime += queryProfiler.getCollectorResult().getTime();
+        }
+        return new QueryProfileShardResult(
+            profileResults,
+            totalRewriteTime,
+            new CollectorResult("KnnQueryCollector", CollectorResult.REASON_SEARCH_MULTI, totalCollectionTime, subCollectorResults)
+        );
+    }
 }

+ 18 - 7
server/src/main/java/org/elasticsearch/search/profile/dfs/DfsProfiler.java

@@ -15,6 +15,7 @@ import org.elasticsearch.search.profile.query.InternalProfileCollector;
 import org.elasticsearch.search.profile.query.QueryProfileShardResult;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
+import java.util.ArrayList;
 import java.util.List;
 
 /**
@@ -28,12 +29,11 @@ public class DfsProfiler extends AbstractProfileBreakdown<DfsTimingType> {
     private long startTime;
     private long totalTime;
 
-    private final QueryProfiler queryProfiler;
+    private final List<QueryProfiler> knnQueryProfilers = new ArrayList<>();
     private boolean collectorSet = false;
 
-    public DfsProfiler(QueryProfiler queryProfiler) {
+    public DfsProfiler() {
         super(DfsTimingType.class);
-        this.queryProfiler = queryProfiler;
     }
 
     public void start() {
@@ -52,9 +52,12 @@ public class DfsProfiler extends AbstractProfileBreakdown<DfsTimingType> {
         getTimer(dfsTimingType).stop();
     }
 
-    public void setCollector(InternalProfileCollector collector) {
+    public QueryProfiler addQueryProfiler(InternalProfileCollector collector) {
+        QueryProfiler queryProfiler = new QueryProfiler();
         queryProfiler.setCollector(collector);
+        knnQueryProfilers.add(queryProfiler);
         collectorSet = true;
+        return queryProfiler;
     }
 
     public SearchProfileDfsPhaseResult buildDfsPhaseResults() {
@@ -66,9 +69,17 @@ public class DfsProfiler extends AbstractProfileBreakdown<DfsTimingType> {
             totalTime,
             List.of()
         );
-        QueryProfileShardResult queryProfileShardResult = collectorSet
-            ? new QueryProfileShardResult(queryProfiler.getTree(), queryProfiler.getRewriteTime(), queryProfiler.getCollector())
-            : null;
+        final List<QueryProfileShardResult> queryProfileShardResult;
+        if (collectorSet) {
+            queryProfileShardResult = new ArrayList<>(knnQueryProfilers.size());
+            for (QueryProfiler queryProfiler : knnQueryProfilers) {
+                queryProfileShardResult.add(
+                    new QueryProfileShardResult(queryProfiler.getTree(), queryProfiler.getRewriteTime(), queryProfiler.getCollector())
+                );
+            }
+        } else {
+            queryProfileShardResult = null;
+        }
         return new SearchProfileDfsPhaseResult(dfsProfileResult, queryProfileShardResult);
     }
 }

+ 142 - 4
server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java

@@ -13,6 +13,8 @@ import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.indices.TermsLookup;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.aggregations.AggregationBuilders;
+import org.elasticsearch.search.aggregations.metrics.InternalStats;
 import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
 import org.elasticsearch.test.ESSingleNodeTestCase;
@@ -20,8 +22,11 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 
 import java.io.IOException;
+import java.util.List;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
 
 public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
     private static final int VECTOR_DIMENSION = 10;
@@ -56,7 +61,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         float[] queryVector = randomVector();
         KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
         SearchResponse response = client().prepareSearch("index")
-            .setKnnSearch(knnSearch)
+            .setKnnSearch(List.of(knnSearch))
             .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
             .addFetchField("*")
             .setSize(10)
@@ -101,7 +106,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
             QueryBuilders.termsQuery("field", "second")
         );
-        SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).addFetchField("*").setSize(10).get();
+        SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).addFetchField("*").setSize(10).get();
 
         assertHitCount(response, 5);
         assertEquals(5, response.getHits().getHits().length);
@@ -144,12 +149,145 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).addFilterQuery(
             QueryBuilders.termsLookupQuery("field", new TermsLookup("index", "lookup-doc", "other-field"))
         );
-        SearchResponse response = client().prepareSearch("index").setKnnSearch(knnSearch).setSize(10).get();
+        SearchResponse response = client().prepareSearch("index").setKnnSearch(List.of(knnSearch)).setSize(10).get();
 
         assertHitCount(response, 5);
         assertEquals(5, response.getHits().getHits().length);
     }
 
+    public void testMultiKnnClauses() throws IOException {
+        // This tests the recall from vectors being searched in different docs
+        int numShards = 1 + randomInt(3);
+        Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
+
+        XContentBuilder builder = XContentFactory.jsonBuilder()
+            .startObject()
+            .startObject("properties")
+            .startObject("vector")
+            .field("type", "dense_vector")
+            .field("dims", VECTOR_DIMENSION)
+            .field("index", true)
+            .field("similarity", "l2_norm")
+            .endObject()
+            .startObject("vector_2")
+            .field("type", "dense_vector")
+            .field("dims", VECTOR_DIMENSION)
+            .field("index", true)
+            .field("similarity", "l2_norm")
+            .endObject()
+            .startObject("text")
+            .field("type", "text")
+            .endObject()
+            .startObject("number")
+            .field("type", "long")
+            .endObject()
+            .endObject()
+            .endObject();
+        createIndex("index", indexSettings, builder);
+
+        for (int doc = 0; doc < 10; doc++) {
+            client().prepareIndex("index").setSource("vector", randomVector(), "text", "hello world", "number", 1).get();
+            client().prepareIndex("index").setSource("vector_2", randomVector(), "text", "hello world", "number", 2).get();
+            client().prepareIndex("index").setSource("text", "goodnight world", "number", 3).get();
+        }
+        client().admin().indices().prepareRefresh("index").get();
+
+        float[] queryVector = randomVector();
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50).boost(5.0f);
+        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50).boost(10.0f);
+        SearchResponse response = client().prepareSearch("index")
+            .setKnnSearch(List.of(knnSearch, knnSearch2))
+            .setQuery(QueryBuilders.matchQuery("text", "goodnight"))
+            .addFetchField("*")
+            .setSize(10)
+            .addAggregation(AggregationBuilders.stats("stats").field("number"))
+            .get();
+
+        // The total hits is k plus the number of text matches
+        assertHitCount(response, 20);
+        assertEquals(10, response.getHits().getHits().length);
+        InternalStats agg = response.getAggregations().get("stats");
+        assertThat(agg.getCount(), equalTo(20L));
+        assertThat(agg.getMax(), equalTo(3.0));
+        assertThat(agg.getMin(), equalTo(1.0));
+        assertThat(agg.getAvg(), equalTo(2.25));
+        assertThat(agg.getSum(), equalTo(45.0));
+
+        // Because of the boost, vector_2 results should appear first
+        assertNotNull(response.getHits().getAt(0).field("vector_2"));
+    }
+
+    public void testMultiKnnClausesSameDoc() throws IOException {
+        int numShards = 1 + randomInt(3);
+        Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
+
+        XContentBuilder builder = XContentFactory.jsonBuilder()
+            .startObject()
+            .startObject("properties")
+            .startObject("vector")
+            .field("type", "dense_vector")
+            .field("dims", VECTOR_DIMENSION)
+            .field("index", true)
+            .field("similarity", "l2_norm")
+            .endObject()
+            .startObject("vector_2")
+            .field("type", "dense_vector")
+            .field("dims", VECTOR_DIMENSION)
+            .field("index", true)
+            .field("similarity", "l2_norm")
+            .endObject()
+            .startObject("number")
+            .field("type", "long")
+            .endObject()
+            .endObject()
+            .endObject();
+        createIndex("index", indexSettings, builder);
+
+        for (int doc = 0; doc < 10; doc++) {
+            // Make them have hte same vector. This will allow us to test the recall is the same but scores take into account both fields
+            float[] vector = randomVector();
+            client().prepareIndex("index").setSource("vector", vector, "vector_2", vector, "number", doc).get();
+        }
+        client().admin().indices().prepareRefresh("index").get();
+
+        float[] queryVector = randomVector();
+        // Having the same query vector and same docs should mean our KNN scores are linearly combined if the same doc is matched
+        KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 5, 50);
+        KnnSearchBuilder knnSearch2 = new KnnSearchBuilder("vector_2", queryVector, 5, 50);
+        SearchResponse responseOneKnn = client().prepareSearch("index")
+            .setKnnSearch(List.of(knnSearch))
+            .addFetchField("*")
+            .setSize(10)
+            .addAggregation(AggregationBuilders.stats("stats").field("number"))
+            .get();
+        SearchResponse responseBothKnn = client().prepareSearch("index")
+            .setKnnSearch(List.of(knnSearch, knnSearch2))
+            .addFetchField("*")
+            .setSize(10)
+            .addAggregation(AggregationBuilders.stats("stats").field("number"))
+            .get();
+
+        // The total hits is k matched docs
+        assertHitCount(responseOneKnn, 5);
+        assertHitCount(responseBothKnn, 5);
+        assertEquals(5, responseOneKnn.getHits().getHits().length);
+        assertEquals(5, responseBothKnn.getHits().getHits().length);
+
+        for (int i = 0; i < responseOneKnn.getHits().getHits().length; i++) {
+            SearchHit oneHit = responseOneKnn.getHits().getHits()[i];
+            SearchHit bothHit = responseBothKnn.getHits().getHits()[i];
+            assertThat(bothHit.getId(), equalTo(oneHit.getId()));
+            assertThat(bothHit.getScore(), greaterThan(oneHit.getScore()));
+        }
+        InternalStats oneAgg = responseOneKnn.getAggregations().get("stats");
+        InternalStats bothAgg = responseBothKnn.getAggregations().get("stats");
+        assertThat(bothAgg.getCount(), equalTo(oneAgg.getCount()));
+        assertThat(bothAgg.getAvg(), equalTo(oneAgg.getAvg()));
+        assertThat(bothAgg.getMax(), equalTo(oneAgg.getMax()));
+        assertThat(bothAgg.getSum(), equalTo(oneAgg.getSum()));
+        assertThat(bothAgg.getMin(), equalTo(oneAgg.getMin()));
+    }
+
     public void testKnnFilteredAlias() throws IOException {
         int numShards = 1 + randomInt(3);
         Settings indexSettings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numShards).build();
@@ -184,7 +322,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
 
         float[] queryVector = randomVector();
         KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", queryVector, 10, 50);
-        SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(knnSearch).setSize(10).get();
+        SearchResponse response = client().prepareSearch("test-alias").setKnnSearch(List.of(knnSearch)).setSize(10).get();
 
         assertHitCount(response, expectedHits);
         assertEquals(expectedHits, response.getHits().getHits().length);

+ 38 - 1
server/src/test/java/org/elasticsearch/action/search/SearchRequestTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.search.Scroll;
 import org.elasticsearch.search.builder.PointInTimeBuilder;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.rescore.QueryRescorerBuilder;
+import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.VersionUtils;
@@ -85,6 +86,38 @@ public class SearchRequestTests extends AbstractSearchTestCase {
         assertNotSame(deserializedRequest, searchRequest);
     }
 
+    public void testSerializationMultiKNN() throws Exception {
+        SearchRequest searchRequest = createSearchRequest();
+        if (searchRequest.source() == null) {
+            searchRequest.source(new SearchSourceBuilder());
+        }
+        searchRequest.source()
+            .knnSearch(
+                List.of(
+                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10),
+                    new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 4, 12, 41 }, 3, 5)
+                )
+            );
+        expectThrows(
+            IllegalArgumentException.class,
+            () -> copyWriteable(
+                searchRequest,
+                namedWriteableRegistry,
+                SearchRequest::new,
+                VersionUtils.randomVersionBetween(random(), Version.V_8_4_0, Version.V_8_6_0)
+            )
+        );
+
+        searchRequest.source().knnSearch(List.of(new KnnSearchBuilder(randomAlphaOfLength(10), new float[] { 1, 2 }, 5, 10)));
+        // Shouldn't throw because its just one KNN request
+        copyWriteable(
+            searchRequest,
+            namedWriteableRegistry,
+            SearchRequest::new,
+            VersionUtils.randomVersionBetween(random(), Version.V_8_4_0, Version.V_8_6_0)
+        );
+    }
+
     public void testRandomVersionSerialization() throws IOException {
         SearchRequest searchRequest = createSearchRequest();
         Version version = VersionUtils.randomVersion(random());
@@ -93,9 +126,13 @@ public class SearchRequestTests extends AbstractSearchTestCase {
             searchRequest.source().runtimeMappings(emptyMap());
         }
         if (version.before(Version.V_8_4_0)) {
-            // Versionse before 8.4.0 don't support force_synthetic_source
+            // Versions before 8.4.0 don't support force_synthetic_source
             searchRequest.setForceSyntheticSource(false);
         }
+        if (version.before(Version.V_8_7_0) && searchRequest.hasKnnSearch() && searchRequest.source().knnSearch().size() > 1) {
+            // Versions before 8.7.0 don't support more than one KNN clause
+            searchRequest.source().knnSearch(List.of(searchRequest.source().knnSearch().get(0)));
+        }
         SearchRequest deserializedRequest = copyWriteable(searchRequest, namedWriteableRegistry, SearchRequest::new, version);
         assertEquals(searchRequest.isCcsMinimizeRoundtrips(), deserializedRequest.isCcsMinimizeRoundtrips());
         assertEquals(searchRequest.getLocalClusterAlias(), deserializedRequest.getLocalClusterAlias());

+ 3 - 3
server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java

@@ -1051,7 +1051,7 @@ public class TransportSearchActionTests extends ESTestCase {
             SearchSourceBuilder source = searchRequest.source();
             if (source != null) {
                 source.pointInTimeBuilder(null);
-                source.knnSearch(null);
+                source.knnSearch(List.of());
                 CollapseBuilder collapse = source.collapse();
                 if (collapse != null) {
                     collapse.setInnerHits(Collections.emptyList());
@@ -1065,7 +1065,7 @@ public class TransportSearchActionTests extends ESTestCase {
         {
             SearchRequest searchRequest = new SearchRequest();
             SearchSourceBuilder source = new SearchSourceBuilder();
-            source.knnSearch(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50));
+            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50)));
             searchRequest.source(source);
 
             searchRequest.setCcsMinimizeRoundtrips(true);
@@ -1080,7 +1080,7 @@ public class TransportSearchActionTests extends ESTestCase {
             // If the search includes kNN, we should always use DFS_QUERY_THEN_FETCH
             SearchRequest searchRequest = new SearchRequest();
             SearchSourceBuilder source = new SearchSourceBuilder();
-            source.knnSearch(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50));
+            source.knnSearch(List.of(new KnnSearchBuilder("field", new float[] { 1, 2, 3 }, 10, 50)));
             searchRequest.source(source);
 
             TransportSearchAction.adjustSearchType(searchRequest, randomBoolean());

+ 1 - 1
server/src/test/java/org/elasticsearch/rest/action/search/RestSearchActionTests.java

@@ -104,7 +104,7 @@ public class RestSearchActionTests extends RestActionTestCase {
 
             SearchRequest searchRequest = new SearchRequest();
             KnnSearchBuilder knnSearch = new KnnSearchBuilder("vector", new float[] { 1, 1, 1 }, 10, 100);
-            searchRequest.source(new SearchSourceBuilder().knnSearch(knnSearch));
+            searchRequest.source(new SearchSourceBuilder().knnSearch(List.of(knnSearch)));
 
             Exception ex = expectThrows(
                 IllegalArgumentException.class,

+ 2 - 1
server/src/test/java/org/elasticsearch/search/builder/SearchSourceBuilderTests.java

@@ -58,6 +58,7 @@ import org.elasticsearch.xcontent.json.JsonXContent;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 
@@ -739,7 +740,7 @@ public class SearchSourceBuilderTests extends AbstractSearchTestCase {
         searchSourceBuilder.fetchField("field");
         // these are not correct runtime mappings but they are counted compared to empty object
         searchSourceBuilder.runtimeMappings(Collections.singletonMap("field", "keyword"));
-        searchSourceBuilder.knnSearch(new KnnSearchBuilder("field", new float[] {}, 2, 5));
+        searchSourceBuilder.knnSearch(List.of(new KnnSearchBuilder("field", new float[] {}, 2, 5)));
         searchSourceBuilder.pointInTimeBuilder(new PointInTimeBuilder("pitid"));
         searchSourceBuilder.docValueField("field");
         searchSourceBuilder.storedField("field");

+ 10 - 3
server/src/test/java/org/elasticsearch/search/internal/ShardSearchRequestTests.java

@@ -29,6 +29,7 @@ import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.InvalidAliasNameException;
 import org.elasticsearch.search.AbstractSearchTestCase;
 import org.elasticsearch.search.SearchSortValuesAndFormatsTests;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.test.VersionUtils;
 import org.elasticsearch.xcontent.DeprecationHandler;
 import org.elasticsearch.xcontent.ToXContent;
@@ -38,6 +39,8 @@ import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.List;
+import java.util.Optional;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery;
@@ -220,9 +223,13 @@ public class ShardSearchRequestTests extends AbstractSearchTestCase {
         int iterations = between(0, 5);
         // New version
         for (int i = 0; i < iterations; i++) {
-            Version version = request.isForceSyntheticSource()
-                ? VersionUtils.randomVersionBetween(random(), Version.V_8_4_0, Version.CURRENT)
-                : VersionUtils.randomCompatibleVersion(random(), Version.CURRENT);
+            Version version = VersionUtils.randomCompatibleVersion(random(), Version.CURRENT);
+            if (request.isForceSyntheticSource()) {
+                version = VersionUtils.randomVersionBetween(random(), Version.V_8_4_0, Version.CURRENT);
+            }
+            if (Optional.ofNullable(request.source()).map(SearchSourceBuilder::knnSearch).map(List::size).orElse(0) > 1) {
+                version = VersionUtils.randomVersionBetween(random(), Version.V_8_7_0, Version.CURRENT);
+            }
             request = copyWriteable(request, namedWriteableRegistry, ShardSearchRequest::new, version);
             channelVersion = Version.min(channelVersion, version);
             assertThat(request.getChannelVersion(), equalTo(channelVersion));

+ 30 - 1
server/src/test/java/org/elasticsearch/search/profile/SearchProfileDfsPhaseResultTests.java

@@ -9,18 +9,25 @@
 package org.elasticsearch.search.profile;
 
 import org.elasticsearch.common.io.stream.Writeable.Reader;
+import org.elasticsearch.search.profile.query.CollectorResult;
+import org.elasticsearch.search.profile.query.QueryProfileShardResult;
 import org.elasticsearch.search.profile.query.QueryProfileShardResultTests;
 import org.elasticsearch.test.AbstractXContentSerializingTestCase;
 import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
+import java.util.List;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
 
 public class SearchProfileDfsPhaseResultTests extends AbstractXContentSerializingTestCase<SearchProfileDfsPhaseResult> {
 
     static SearchProfileDfsPhaseResult createTestItem() {
         return new SearchProfileDfsPhaseResult(
             randomBoolean() ? null : ProfileResultTests.createTestItem(1),
-            randomBoolean() ? null : QueryProfileShardResultTests.createTestItem()
+            randomBoolean() ? null : randomList(1, 10, QueryProfileShardResultTests::createTestItem)
         );
     }
 
@@ -38,4 +45,26 @@ public class SearchProfileDfsPhaseResultTests extends AbstractXContentSerializin
     protected SearchProfileDfsPhaseResult doParseInstance(XContentParser parser) throws IOException {
         return SearchProfileDfsPhaseResult.fromXContent(parser);
     }
+
+    public void testCombineQueryProfileShardResults() {
+        assertThat(new SearchProfileDfsPhaseResult(null, null).combineQueryProfileShardResults(), is(nullValue()));
+
+        List<QueryProfileShardResult> resultList = randomList(5, 5, QueryProfileShardResultTests::createTestItem);
+
+        SearchProfileDfsPhaseResult result = new SearchProfileDfsPhaseResult(null, resultList);
+        QueryProfileShardResult queryProfileShardResult = result.combineQueryProfileShardResults();
+        assertThat(
+            queryProfileShardResult.getRewriteTime(),
+            equalTo(resultList.stream().mapToLong(QueryProfileShardResult::getRewriteTime).sum())
+        );
+        assertThat(
+            queryProfileShardResult.getCollectorResult().getTime(),
+            equalTo(resultList.stream().map(QueryProfileShardResult::getCollectorResult).mapToLong(CollectorResult::getTime).sum())
+        );
+        assertThat(queryProfileShardResult.getCollectorResult().getProfiledChildren().size(), equalTo(resultList.size()));
+        assertThat(
+            queryProfileShardResult.getQueryResults().size(),
+            equalTo((int) resultList.stream().mapToLong(q -> q.getQueryResults().size()).sum())
+        );
+    }
 }

+ 13 - 8
test/framework/src/main/java/org/elasticsearch/search/RandomSearchRequestGenerator.java

@@ -248,15 +248,20 @@ public class RandomSearchRequestGenerator {
         }
 
         if (randomBoolean()) {
-            String field = randomAlphaOfLength(6);
-            int dim = randomIntBetween(2, 30);
-            float[] vector = new float[dim];
-            for (int i = 0; i < vector.length; i++) {
-                vector[i] = randomFloat();
+            int numKClauses = randomIntBetween(1, 5);
+            List<KnnSearchBuilder> knnSearchBuilders = new ArrayList<>(numKClauses);
+            for (int j = 0; j < numKClauses; j++) {
+                String field = randomAlphaOfLength(6);
+                int dim = randomIntBetween(2, 30);
+                float[] vector = new float[dim];
+                for (int i = 0; i < vector.length; i++) {
+                    vector[i] = randomFloat();
+                }
+                int k = randomIntBetween(1, 100);
+                int numCands = randomIntBetween(k, 1000);
+                knnSearchBuilders.add(new KnnSearchBuilder(field, vector, k, numCands));
             }
-            int k = randomIntBetween(1, 100);
-            int numCands = randomIntBetween(k, 1000);
-            builder.knnSearch(new KnnSearchBuilder(field, vector, k, numCands));
+            builder.knnSearch(knnSearchBuilders);
         }
 
         if (randomBoolean()) {

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportSemanticSearchAction.java

@@ -104,7 +104,7 @@ public class TransportSemanticSearchAction extends HandledTransportAction<Semant
 
         SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
         sourceBuilder.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_ACCURATE);
-        sourceBuilder.knnSearch(knnSearchBuilder);
+        sourceBuilder.knnSearch(List.of(knnSearchBuilder));
         if (request.getSize() != -1) {
             sourceBuilder.size(request.getSize());
         }