浏览代码

Support k parameter for knn query (#110233)

Introduce an optional k param for knn query

If k is not set, knn query has the previous behaviour:
- `num_candidates` docs  is collected from each shard. This `num_candidates` docs
are used for combining with results with other queries and aggregations on each shard.
- docs from all shards are merged to produce the top global `size` results

If k is set, the behaviour instead is following:
- `k` docs is collected from each shard. This `k` docs are used for
combining results with other queries and aggregations on each shard.
- similarly, docs from all shards are merged to produce the top global `size`
results.

Having `k` param makes it more intuitive for users to address their needs.
They also don't need to care and can skip `num_candidates` param for this query
as it is of more internal details to tune how knn search operates.

Closes #108473
Mayya Sharipova 1 年之前
父节点
当前提交
405e39660b
共有 29 个文件被更改,包括 523 次插入116 次删除
  1. 6 0
      docs/changelog/110233.yaml
  2. 20 17
      docs/reference/query-dsl/knn-query.asciidoc
  3. 1 1
      docs/reference/rest-api/common-parms.asciidoc
  4. 1 1
      modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java
  5. 0 9
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml
  6. 262 0
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/190_knn_query-with-k-param.yml
  7. 1 0
      server/src/main/java/module-info.java
  8. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  9. 22 11
      server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
  10. 22 0
      server/src/main/java/org/elasticsearch/search/SearchFeatures.java
  11. 12 3
      server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java
  12. 12 3
      server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java
  13. 6 4
      server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java
  14. 6 3
      server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java
  15. 1 1
      server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java
  16. 1 1
      server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java
  17. 61 17
      server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java
  18. 1 0
      server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification
  19. 1 1
      server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java
  20. 10 10
      server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java
  21. 11 11
      server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
  22. 1 0
      server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java
  23. 53 13
      server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
  24. 2 2
      server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java
  25. 2 2
      server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java
  26. 2 1
      server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java
  27. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
  28. 1 1
      x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java
  29. 3 3
      x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java

+ 6 - 0
docs/changelog/110233.yaml

@@ -0,0 +1,6 @@
+pr: 110233
+summary: Support k parameter for knn query
+area: Vector Search
+type: enhancement
+issues:
+ - 108473

+ 20 - 17
docs/reference/query-dsl/knn-query.asciidoc

@@ -50,7 +50,8 @@ POST my-image-index/_bulk?refresh=true
 ----
 //TEST[continued]
 
-. Run the search using the `knn` query, asking for the top 3 nearest vectors.
+. Run the search using the `knn` query, asking for the top 10 nearest vectors
+from each shard, and then combine shard results to get the top 3 global results.
 +
 [source,console]
 ----
@@ -61,18 +62,13 @@ POST my-image-index/_search
     "knn": {
       "field": "image-vector",
       "query_vector": [-5, 9, -12],
-      "num_candidates": 10
+      "k": 10
     }
   }
 }
 ----
 //TEST[continued]
 
-NOTE: `knn` query doesn't have a separate `k` parameter. `k` is defined by
-`size` parameter of a search request similar to other queries. `knn` query
-collects `num_candidates` results from each shard, then merges them to get
-the top `size` results.
-
 
 [[knn-query-top-level-parameters]]
 ==== Top-level parameters for `knn`
@@ -99,14 +95,21 @@ Either this or `query_vector_builder` must be provided.
 include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector-builder]
 --
 
+`k`::
++
+--
+(Optional, integer) The number of nearest neighbors to return from each shard.
+{es} collects `k` results from each shard, then merges them to find the global top results.
+This value must be less than or equal to `num_candidates`. Defaults to `num_candidates`.
+--
 
 `num_candidates`::
 +
 --
-(Optional, integer) The number of nearest neighbor candidates to consider per shard.
-Cannot exceed 10,000. {es} collects `num_candidates` results from each shard, then
-merges them to find the top results. Increasing `num_candidates` tends to improve the
-accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`.
+(Optional, integer) The number of nearest neighbor candidates to consider per shard
+while doing knn search. Cannot exceed 10,000. Increasing `num_candidates` tends to
+improve the accuracy of the final results.
+Defaults to `1.5 * k` if `k` is set, or `1.5 * size` if `k` is not set.
 --
 
 `filter`::
@@ -160,7 +163,7 @@ Also filters from <<filter-alias,aliases>> are applied as pre-filters.
 
 All other filters found in the Query DSL tree are applied as post-filters.
 For example, `knn` query finds the top 3 documents with the nearest vectors
-(num_candidates=3), which are combined with  `term` filter, that is
+(k=3), which are combined with  `term` filter, that is
 post-filtered. The final set of documents will contain only a single document
 that passes the post-filter.
 
@@ -176,7 +179,7 @@ POST my-image-index/_search
         "knn": {
           "field": "image-vector",
           "query_vector": [-5, 9, -12],
-          "num_candidates": 3
+          "k": 3
         }
       },
       "filter" : {
@@ -217,7 +220,7 @@ POST my-image-index/_search
           "knn": {
             "field": "image-vector",
             "query_vector": [-5, 9, -12],
-            "num_candidates": 10,
+            "k": 10,
             "boost": 2
           }
         }
@@ -267,8 +270,8 @@ A sample query can look like below:
 
 [[knn-query-aggregations]]
 ==== Knn query with aggregations
-`knn` query calculates aggregations on `num_candidates` from each shard.
+`knn` query calculates aggregations on top `k` documents from each shard.
 Thus, the final results from aggregations contain
-`num_candidates * number_of_shards` documents. This is different from
+`k * number_of_shards` documents. This is different from
 the <<knn-search,top level knn section>> where aggregations are
-calculated on the global top k nearest documents.
+calculated on the global top `k` nearest documents.

+ 1 - 1
docs/reference/rest-api/common-parms.asciidoc

@@ -594,7 +594,7 @@ end::knn-filter[]
 
 tag::knn-k[]
 Number of nearest neighbors to return as top hits. This value must be less than
-`num_candidates`. Defaults to `size`.
+or equal to `num_candidates`. Defaults to `size`.
 end::knn-k[]
 
 tag::knn-num-candidates[]

+ 1 - 1
modules/percolator/src/internalClusterTest/java/org/elasticsearch/percolator/PercolatorQuerySearchIT.java

@@ -1358,7 +1358,7 @@ public class PercolatorQuerySearchIT extends ESIntegTestCase {
             """);
         indicesAdmin().prepareCreate("index1").setMapping(mappings).get();
         ensureGreen();
-        QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, null);
+        QueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder("my_vector", new float[] { 1, 1, 1, 1, 1 }, 10, 10, null);
 
         IndexRequestBuilder indexRequestBuilder = prepareIndex("index1").setId("knn_query1")
             .setSource(jsonBuilder().startObject().field("my_query", knnVectorQueryBuilder).endObject());

+ 0 - 9
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/140_knn_query_with_other_queries.yml

@@ -26,15 +26,6 @@ setup:
               my_name:
                 type: keyword
                 store: true
-          aliases:
-            my_alias:
-              filter:
-                term:
-                  my_name: v2
-            my_alias1:
-              filter:
-                term:
-                  my_name: v1
 
   - do:
       bulk:

+ 262 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/190_knn_query-with-k-param.yml

@@ -0,0 +1,262 @@
+# test how knn query interact with other queries
+setup:
+  - requires:
+      cluster_features: "search.vectors.k_param_supported"
+      reason: 'k param for knn as query is required'
+      test_runner_features: close_to
+
+  - do:
+      indices.create:
+        index: my_index
+        body:
+          settings:
+            number_of_shards: 1
+          mappings:
+            dynamic: false
+            properties:
+              my_vector:
+                type: dense_vector
+                dims: 4
+                index : true
+                similarity : l2_norm
+                index_options:
+                  type: hnsw
+                  m: 16
+                  ef_construction: 200
+              my_name:
+                type: keyword
+                store: true
+
+  - do:
+      bulk:
+        refresh: true
+        index: my_index
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"my_vector": [1, 1, 1, 1], "my_name": "v1"}'
+          - '{"index": {"_id": "2"}}'
+          - '{"my_vector": [1, 1, 1, 2], "my_name": "v2"}'
+          - '{"index": {"_id": "3"}}'
+          - '{"my_vector": [1, 1, 1, 3], "my_name": "v1"}'
+          - '{"index": {"_id": "4"}}'
+          - '{"my_vector": [1, 1, 1, 4], "my_name": "v2"}'
+          - '{"index": {"_id": "5"}}'
+          - '{"my_vector": [1, 1, 1, 5], "my_name": "v1"}'
+          - '{"index": {"_id": "6"}}'
+          - '{"my_vector": [1, 1, 1, 6], "my_name": "v2"}'
+          - '{"index": {"_id": "7"}}'
+          - '{"my_vector": [1, 1, 1, 7], "my_name": "v1"}'
+          - '{"index": {"_id": "8"}}'
+          - '{"my_vector": [1, 1, 1, 8], "my_name": "v2"}'
+          - '{"index": {"_id": "9"}}'
+          - '{"my_vector": [1, 1, 1, 9], "my_name": "v1"}'
+          - '{"index": {"_id": "10"}}'
+          - '{"my_vector": [1, 1, 1, 10], "my_name": "v2"}'
+
+---
+"Simple knn query with k param":
+  - do:
+      search:
+        index: my_index
+        body:
+          query:
+            knn:
+              field: my_vector
+              query_vector: [1, 1, 1, 1]
+              k: 5
+
+  - match: { hits.total.value: 5 } # collector sees k docs
+  - length: {hits.hits: 5} # k docs retrieved
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.3._id: "4" }
+  - match: { hits.hits.4._id: "5" }
+
+  - do:
+      search:
+        index: my_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector
+              query_vector: [ 1, 1, 1, 1 ]
+              k: 5
+
+  - match: { hits.total.value: 5 } # collector sees k docs
+  - length: { hits.hits: 3 } # size docs retrieved
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "3" }
+
+  - do:
+      search:
+        index: my_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector
+              query_vector: [ 1, 1, 1, 1 ]
+              k: 5
+              num_candidates: 10
+
+  - match: { hits.total.value: 5 } # collector sees k docs
+  - length: { hits.hits: 3 } # size docs retrieved
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "3" }
+
+---
+"Knn query within the standard retriever":
+  - do:
+      search:
+        index: my_index
+        body:
+          retriever:
+            standard:
+              filter:
+                bool:
+                  must:
+                    term:
+                      my_name: "v1"
+              query:
+                knn:
+                  field: my_vector
+                  query_vector: [ 1, 1, 1, 1 ]
+                  k: 10
+  - match: { hits.total.value: 5 } # docs that pass post-filter
+  - length: { hits.hits: 5 }
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.1._id: "3" }
+  - match: { hits.hits.2._id: "5" }
+  - match: { hits.hits.3._id: "7" }
+  - match: { hits.hits.4._id: "9" }
+
+---
+"Incorrect k param":
+  - do:
+      catch: bad_request
+      search:
+        index: my_index
+        body:
+          query:
+            knn:
+              field: my_vector
+              query_vector: [ 1, 1, 1, 1 ]
+              k: 5
+              num_candidates: 3
+  - match: { status: 400 }
+  - match: { error.type: "x_content_parse_exception" }
+  - match: { error.caused_by.type: "illegal_argument_exception" }
+  - match: { error.caused_by.reason: "[num_candidates] cannot be less than [k]" }
+
+  - do:
+      catch: bad_request
+      search:
+        index: my_index
+        body:
+          query:
+            knn:
+              field: my_vector
+              query_vector: [ 1, 1, 1, 1 ]
+              k: 0
+  - match: { status: 400 }
+  - match: { error.type: "x_content_parse_exception" }
+  - match: { error.caused_by.type: "illegal_argument_exception" }
+  - match: { error.caused_by.reason: "[k] must be greater than 0" }
+
+---
+"Function score query with knn query with k param":
+  # find top 5 knn docs, then boost docs with name v1 by 10 and docs with name v2 by 100
+  - do:
+      search:
+        index: my_index
+        body:
+          size: 3
+          fields: [ my_name ]
+          query:
+            function_score:
+              query:
+                knn:
+                  field: my_vector
+                  query_vector: [ 1, 1, 1, 1 ]
+                  k : 5
+              functions:
+                - filter: { match: { my_name: v1 } }
+                  weight: 10
+                - filter: { match: { my_name: v2 } }
+                  weight: 100
+              boost_mode: multiply
+
+  - match: { hits.total.value: 5 } # collector sees k docs
+  - length: { hits.hits: 3 }
+  - match: { hits.hits.0._id: "2" }
+  - match: { hits.hits.0.fields.my_name.0: v2 }
+  - close_to: { hits.hits.0._score: { value: 50.0, error: 0.001 } }
+  - match: { hits.hits.1._id: "1" }
+  - match: { hits.hits.1.fields.my_name.0: v1 }
+  - close_to: { hits.hits.1._score: { value: 10.0, error: 0.001 } }
+  - match: { hits.hits.2._id: "4" }
+  - match: { hits.hits.2.fields.my_name.0: v2 }
+  - close_to: { hits.hits.2._score: { value: 10.0, error: 0.001 } }
+
+---
+"dis_max query with knn query":
+  - do:
+      search:
+        index: my_index
+        body:
+          size: 10
+          fields: [ my_name ]
+          query:
+            dis_max:
+              queries:
+                - knn: { field: my_vector, query_vector: [ 1, 1, 1, 1 ], k: 5, num_candidates: 10 }
+                - match: { my_name: v2 }
+              tie_breaker: 0.8
+
+  - match: { hits.total.value: 8 } # 5 knn results + extra results from match query
+  - match: { hits.hits.0._id: "2" }
+  - match: { hits.hits.0.fields.my_name.0: v2 }
+  - match: { hits.hits.1._id: "1" }
+  - match: { hits.hits.1.fields.my_name.0: v1 }
+  - match: { hits.hits.2._id: "4" }
+  - match: { hits.hits.2.fields.my_name.0: v2 }
+  - match: { hits.hits.3._id: "6" }
+  - match: { hits.hits.3.fields.my_name.0: v2 }
+  - match: { hits.hits.4._id: "8" }
+  - match: { hits.hits.4.fields.my_name.0: v2 }
+  - match: { hits.hits.5._id: "10" }
+  - match: { hits.hits.5.fields.my_name.0: v2 }
+  - match: { hits.hits.6._id: "3" }
+  - match: { hits.hits.6.fields.my_name.0: v1 }
+  - match: { hits.hits.7._id: "5" }
+  - match: { hits.hits.7.fields.my_name.0: v1 }
+
+---
+"Aggregations with collected number of docs depends on k param":
+  - do:
+      search:
+        index: my_index
+        body:
+          size: 2
+          query:
+            knn:
+              field: my_vector
+              query_vector: [1, 1, 1, 1]
+              k: 5 # collect 5 results from each shard
+          aggs:
+            my_agg:
+              terms:
+                field: my_name
+                order:
+                  _key: asc
+
+  - length: {hits.hits: 2}
+  - match: {hits.total.value: 5}
+  - match: {aggregations.my_agg.buckets.0.key: 'v1'}
+  - match: {aggregations.my_agg.buckets.1.key: 'v2'}
+  - match: {aggregations.my_agg.buckets.0.doc_count: 3}
+  - match: {aggregations.my_agg.buckets.1.doc_count: 2}

+ 1 - 0
server/src/main/java/module-info.java

@@ -431,6 +431,7 @@ module org.elasticsearch.server {
             org.elasticsearch.indices.IndicesFeatures,
             org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures,
             org.elasticsearch.index.mapper.MapperFeatures,
+            org.elasticsearch.search.SearchFeatures,
             org.elasticsearch.script.ScriptFeatures,
             org.elasticsearch.search.retriever.RetrieversFeatures,
             org.elasticsearch.reservedstate.service.FileSettingsFeatures;

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -204,6 +204,7 @@ public class TransportVersions {
     public static final TransportVersion EVENT_INGESTED_RANGE_IN_CLUSTER_STATE = def(8_695_00_0);
     public static final TransportVersion ESQL_ADD_AGGREGATE_TYPE = def(8_696_00_0);
     public static final TransportVersion SECURITY_MIGRATIONS_MIGRATION_NEEDED_ADDED = def(8_697_00_0);
+    public static final TransportVersion K_FOR_KNN_QUERY_ADDED = def(8_698_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 22 - 11
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -1749,12 +1749,20 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return new DenseVectorQuery.Floats(queryVector, name());
         }
 
-        Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) {
-            return createKnnQuery(VectorData.fromFloats(queryVector), numCands, filter, similarityThreshold, parentFilter);
+        Query createKnnQuery(
+            float[] queryVector,
+            Integer k,
+            int numCands,
+            Query filter,
+            Float similarityThreshold,
+            BitSetProducer parentFilter
+        ) {
+            return createKnnQuery(VectorData.fromFloats(queryVector), k, numCands, filter, similarityThreshold, parentFilter);
         }
 
         public Query createKnnQuery(
             VectorData queryVector,
+            Integer k,
             int numCands,
             Query filter,
             Float similarityThreshold,
@@ -1766,14 +1774,15 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 );
             }
             return switch (getElementType()) {
-                case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
-                case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
-                case BIT -> createKnnBitQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
+                case BYTE -> createKnnByteQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
+                case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), k, numCands, filter, similarityThreshold, parentFilter);
+                case BIT -> createKnnBitQuery(queryVector.asByteVector(), k, numCands, filter, similarityThreshold, parentFilter);
             };
         }
 
         private Query createKnnBitQuery(
             byte[] queryVector,
+            Integer k,
             int numCands,
             Query filter,
             Float similarityThreshold,
@@ -1781,8 +1790,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
         ) {
             elementType.checkDimensions(dims, queryVector.length);
             Query knnQuery = parentFilter != null
-                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
-                : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
+                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
+                : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,
@@ -1795,6 +1804,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         private Query createKnnByteQuery(
             byte[] queryVector,
+            Integer k,
             int numCands,
             Query filter,
             Float similarityThreshold,
@@ -1807,8 +1817,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
             Query knnQuery = parentFilter != null
-                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
-                : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
+                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
+                : new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter);
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,
@@ -1821,6 +1831,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         private Query createKnnFloatQuery(
             float[] queryVector,
+            Integer k,
             int numCands,
             Query filter,
             Float similarityThreshold,
@@ -1842,8 +1853,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
             }
             Query knnQuery = parentFilter != null
-                ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
-                : new ESKnnFloatVectorQuery(name(), queryVector, numCands, filter);
+                ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter)
+                : new ESKnnFloatVectorQuery(name(), queryVector, k, numCands, filter);
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,

+ 22 - 0
server/src/main/java/org/elasticsearch/search/SearchFeatures.java

@@ -0,0 +1,22 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search;
+
+import org.elasticsearch.features.FeatureSpecification;
+import org.elasticsearch.features.NodeFeature;
+import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
+
+import java.util.Set;
+
+public final class SearchFeatures implements FeatureSpecification {
+    @Override
+    public Set<NodeFeature> getFeatures() {
+        return Set.of(KnnVectorQueryBuilder.K_PARAM_SUPPORTED);
+    }
+}

+ 12 - 3
server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenByteKnnVectorQuery.java

@@ -15,15 +15,24 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
 public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements ProfilingQuery {
+    private final Integer kParam;
     private long vectorOpsCount;
 
-    public ESDiversifyingChildrenByteKnnVectorQuery(String field, byte[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
-        super(field, query, childFilter, k, parentsFilter);
+    public ESDiversifyingChildrenByteKnnVectorQuery(
+        String field,
+        byte[] query,
+        Query childFilter,
+        Integer k,
+        int numCands,
+        BitSetProducer parentsFilter
+    ) {
+        super(field, query, childFilter, numCands, parentsFilter);
+        this.kParam = k;
     }
 
     @Override
     protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
-        TopDocs topK = super.mergeLeafResults(perLeafResults);
+        TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
         vectorOpsCount = topK.totalHits.value;
         return topK;
     }

+ 12 - 3
server/src/main/java/org/elasticsearch/search/vectors/ESDiversifyingChildrenFloatKnnVectorQuery.java

@@ -15,15 +15,24 @@ import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
 public class ESDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery implements ProfilingQuery {
+    private final Integer kParam;
     private long vectorOpsCount;
 
-    public ESDiversifyingChildrenFloatKnnVectorQuery(String field, float[] query, Query childFilter, int k, BitSetProducer parentsFilter) {
-        super(field, query, childFilter, k, parentsFilter);
+    public ESDiversifyingChildrenFloatKnnVectorQuery(
+        String field,
+        float[] query,
+        Query childFilter,
+        Integer k,
+        int numCands,
+        BitSetProducer parentsFilter
+    ) {
+        super(field, query, childFilter, numCands, parentsFilter);
+        this.kParam = k;
     }
 
     @Override
     protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
-        TopDocs topK = super.mergeLeafResults(perLeafResults);
+        TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
         vectorOpsCount = topK.totalHits.value;
         return topK;
     }

+ 6 - 4
server/src/main/java/org/elasticsearch/search/vectors/ESKnnByteVectorQuery.java

@@ -14,16 +14,18 @@ import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
 public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements ProfilingQuery {
-
+    private final Integer kParam;
     private long vectorOpsCount;
 
-    public ESKnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
-        super(field, target, k, filter);
+    public ESKnnByteVectorQuery(String field, byte[] target, Integer k, int numCands, Query filter) {
+        super(field, target, numCands, filter);
+        this.kParam = k;
     }
 
     @Override
     protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
-        TopDocs topK = super.mergeLeafResults(perLeafResults);
+        // if k param is set, we get only top k results from each shard
+        TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
         vectorOpsCount = topK.totalHits.value;
         return topK;
     }

+ 6 - 3
server/src/main/java/org/elasticsearch/search/vectors/ESKnnFloatVectorQuery.java

@@ -14,15 +14,18 @@ import org.apache.lucene.search.TopDocs;
 import org.elasticsearch.search.profile.query.QueryProfiler;
 
 public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements ProfilingQuery {
+    private final Integer kParam;
     private long vectorOpsCount;
 
-    public ESKnnFloatVectorQuery(String field, float[] target, int k, Query filter) {
-        super(field, target, k, filter);
+    public ESKnnFloatVectorQuery(String field, float[] target, Integer k, int numCands, Query filter) {
+        super(field, target, numCands, filter);
+        this.kParam = k;
     }
 
     @Override
     protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
-        TopDocs topK = super.mergeLeafResults(perLeafResults);
+        // if k param is set, we get only top k results from each shard
+        TopDocs topK = kParam == null ? super.mergeLeafResults(perLeafResults) : TopDocs.merge(kParam, perLeafResults);
         vectorOpsCount = topK.totalHits.value;
         return topK;
     }

+ 1 - 1
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

@@ -398,7 +398,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         if (queryVectorBuilder != null) {
             throw new IllegalArgumentException("missing rewrite");
         }
-        return new KnnVectorQueryBuilder(field, queryVector, numCands, similarity).boost(boost)
+        return new KnnVectorQueryBuilder(field, queryVector, null, numCands, similarity).boost(boost)
             .queryName(queryName)
             .addFilterQueries(filterQueries);
     }

+ 1 - 1
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchRequestParser.java

@@ -255,7 +255,7 @@ public class KnnSearchRequestParser {
             if (numCands > NUM_CANDS_LIMIT) {
                 throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
             }
-            return new KnnVectorQueryBuilder(field, queryVector, numCands, null);
+            return new KnnVectorQueryBuilder(field, queryVector, null, numCands, null);
         }
 
         @Override

+ 61 - 17
server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

@@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.search.Queries;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.NestedObjectMapper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -52,11 +53,14 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
  * {@link org.apache.lucene.search.KnnByteVectorQuery}.
  */
 public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
+    public static final NodeFeature K_PARAM_SUPPORTED = new NodeFeature("search.vectors.k_param_supported");
+
     public static final String NAME = "knn";
     private static final int NUM_CANDS_LIMIT = 10_000;
     private static final float NUM_CANDS_MULTIPLICATIVE_FACTOR = 1.5f;
 
     public static final ParseField FIELD_FIELD = new ParseField("field");
+    public static final ParseField K_FIELD = new ParseField("k");
     public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates");
     public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector");
     public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity");
@@ -69,10 +73,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         args -> new KnnVectorQueryBuilder(
             (String) args[0],
             (VectorData) args[1],
-            (QueryVectorBuilder) args[4],
+            (QueryVectorBuilder) args[5],
             null,
             (Integer) args[2],
-            (Float) args[3]
+            (Integer) args[3],
+            (Float) args[4]
         )
     );
 
@@ -84,6 +89,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             QUERY_VECTOR_FIELD,
             ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
         );
+        PARSER.declareInt(optionalConstructorArg(), K_FIELD);
         PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
         PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
         PARSER.declareNamedObject(
@@ -106,26 +112,33 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
 
     private final String fieldName;
     private final VectorData queryVector;
+    private final Integer k;
     private Integer numCands;
     private final List<QueryBuilder> filterQueries = new ArrayList<>();
     private final Float vectorSimilarity;
     private final QueryVectorBuilder queryVectorBuilder;
     private final Supplier<float[]> queryVectorSupplier;
 
-    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
-        this(fieldName, VectorData.fromFloats(queryVector), null, null, numCands, vectorSimilarity);
+    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
+        this(fieldName, VectorData.fromFloats(queryVector), null, null, k, numCands, vectorSimilarity);
     }
 
-    protected KnnVectorQueryBuilder(String fieldName, QueryVectorBuilder queryVectorBuilder, Integer numCands, Float vectorSimilarity) {
-        this(fieldName, null, queryVectorBuilder, null, numCands, vectorSimilarity);
+    protected KnnVectorQueryBuilder(
+        String fieldName,
+        QueryVectorBuilder queryVectorBuilder,
+        Integer k,
+        Integer numCands,
+        Float vectorSimilarity
+    ) {
+        this(fieldName, null, queryVectorBuilder, null, k, numCands, vectorSimilarity);
     }
 
-    public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) {
-        this(fieldName, VectorData.fromBytes(queryVector), null, null, numCands, vectorSimilarity);
+    public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
+        this(fieldName, VectorData.fromBytes(queryVector), null, null, k, numCands, vectorSimilarity);
     }
 
-    public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) {
-        this(fieldName, queryVector, null, null, numCands, vectorSimilarity);
+    public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer k, Integer numCands, Float vectorSimilarity) {
+        this(fieldName, queryVector, null, null, k, numCands, vectorSimilarity);
     }
 
     private KnnVectorQueryBuilder(
@@ -133,12 +146,21 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         VectorData queryVector,
         QueryVectorBuilder queryVectorBuilder,
         Supplier<float[]> queryVectorSupplier,
+        Integer k,
         Integer numCands,
         Float vectorSimilarity
     ) {
+        if (k != null && k < 1) {
+            throw new IllegalArgumentException("[" + K_FIELD.getPreferredName() + "] must be greater than 0");
+        }
         if (numCands != null && numCands > NUM_CANDS_LIMIT) {
             throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
         }
+        if (k != null && numCands != null && numCands < k) {
+            throw new IllegalArgumentException(
+                "[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot be less than [" + K_FIELD.getPreferredName() + "]"
+            );
+        }
         if (queryVector == null && queryVectorBuilder == null) {
             throw new IllegalArgumentException(
                 format(
@@ -158,6 +180,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         }
         this.fieldName = fieldName;
         this.queryVector = queryVector;
+        this.k = k;
         this.numCands = numCands;
         this.vectorSimilarity = vectorSimilarity;
         this.queryVectorBuilder = queryVectorBuilder;
@@ -167,6 +190,11 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     public KnnVectorQueryBuilder(StreamInput in) throws IOException {
         super(in);
         this.fieldName = in.readString();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.K_FOR_KNN_QUERY_ADDED)) {
+            this.k = in.readOptionalVInt();
+        } else {
+            this.k = null;
+        }
         if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
             this.numCands = in.readOptionalVInt();
         } else {
@@ -214,6 +242,10 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         return vectorSimilarity;
     }
 
+    public Integer k() {
+        return k;
+    }
+
     public Integer numCands() {
         return numCands;
     }
@@ -245,6 +277,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             throw new IllegalStateException("missing a rewriteAndFetch?");
         }
         out.writeString(fieldName);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.K_FOR_KNN_QUERY_ADDED)) {
+            out.writeOptionalVInt(k);
+        }
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
             out.writeOptionalVInt(numCands);
         } else {
@@ -302,6 +337,9 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         if (queryVector != null) {
             builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
         }
+        if (k != null) {
+            builder.field(K_FIELD.getPreferredName(), k);
+        }
         if (numCands != null) {
             builder.field(NUM_CANDS_FIELD.getPreferredName(), numCands);
         }
@@ -335,7 +373,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             if (queryVectorSupplier.get() == null) {
                 return this;
             }
-            return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), numCands, vectorSimilarity).boost(boost)
+            return new KnnVectorQueryBuilder(fieldName, queryVectorSupplier.get(), k, numCands, vectorSimilarity).boost(boost)
                 .queryName(queryName)
                 .addFilterQueries(filterQueries);
         }
@@ -357,7 +395,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 }
                 ll.onResponse(null);
             })));
-            return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, numCands, vectorSimilarity).boost(
+            return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, toSet::get, k, numCands, vectorSimilarity).boost(
                 boost
             ).queryName(queryName).addFilterQueries(filterQueries);
         }
@@ -377,7 +415,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
             rewrittenQueries.add(rewrittenQuery);
         }
         if (changed) {
-            return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, numCands, vectorSimilarity)
+            return new KnnVectorQueryBuilder(fieldName, queryVector, queryVectorBuilder, queryVectorSupplier, k, numCands, vectorSimilarity)
                 .boost(boost)
                 .queryName(queryName)
                 .addFilterQueries(rewrittenQueries);
@@ -388,7 +426,12 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     @Override
     protected Query doToQuery(SearchExecutionContext context) throws IOException {
         MappedFieldType fieldType = context.getFieldType(fieldName);
-        int requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
+        int requestSize;
+        if (k != null) {
+            requestSize = k;
+        } else {
+            requestSize = context.requestSize() == null || context.requestSize() < 0 ? DEFAULT_SIZE : context.requestSize();
+        }
         int adjustedNumCands = numCands == null
             ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * requestSize, NUM_CANDS_LIMIT))
             : numCands;
@@ -446,20 +489,21 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 // Now join the filterQuery & parentFilter to provide the matching blocks of children
                 filterQuery = new ToChildBlockJoinQuery(filterQuery, parentBitSet);
             }
-            return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
+            return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, parentBitSet);
         }
-        return vectorFieldType.createKnnQuery(queryVector, adjustedNumCands, filterQuery, vectorSimilarity, null);
+        return vectorFieldType.createKnnQuery(queryVector, k, adjustedNumCands, filterQuery, vectorSimilarity, null);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
+        return Objects.hash(fieldName, Objects.hashCode(queryVector), k, numCands, filterQueries, vectorSimilarity, queryVectorBuilder);
     }
 
     @Override
     protected boolean doEquals(KnnVectorQueryBuilder other) {
         return Objects.equals(fieldName, other.fieldName)
             && Objects.equals(queryVector, other.queryVector)
+            && Objects.equals(k, other.k)
             && Objects.equals(numCands, other.numCands)
             && Objects.equals(filterQueries, other.filterQueries)
             && Objects.equals(vectorSimilarity, other.vectorSimilarity)

+ 1 - 0
server/src/main/resources/META-INF/services/org.elasticsearch.features.FeatureSpecification

@@ -14,6 +14,7 @@ org.elasticsearch.rest.RestFeatures
 org.elasticsearch.indices.IndicesFeatures
 org.elasticsearch.action.admin.cluster.allocation.AllocationStatsFeatures
 org.elasticsearch.index.mapper.MapperFeatures
+org.elasticsearch.search.SearchFeatures
 org.elasticsearch.search.retriever.RetrieversFeatures
 org.elasticsearch.script.ScriptFeatures
 org.elasticsearch.reservedstate.service.FileSettingsFeatures

+ 1 - 1
server/src/test/java/org/elasticsearch/action/search/KnnSearchSingleNodeTests.java

@@ -416,7 +416,7 @@ public class KnnSearchSingleNodeTests extends ESSingleNodeTestCase {
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = randomVector();
         assertResponse(
-            client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, 5, null)).setSize(2),
+            client().prepareSearch("index1", "index2").setQuery(new KnnVectorQueryBuilder("vector", queryVector, null, 5, null)).setSize(2),
             response -> {
                 // The total hits is num_cands * num_shards, since the query gathers num_cands hits from each shard
                 assertHitCount(response, 5 * 2);

+ 10 - 10
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

@@ -1126,7 +1126,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         Exception e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 128, 0, 0 }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 128, 0, 0 }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1135,7 +1135,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0f, -129.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0f, -129.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1144,7 +1144,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0.5f, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0.0f, 0.5f, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1153,7 +1153,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, 0.0f, -0.25f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, 0.0f, -0.25f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1162,13 +1162,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(e.getMessage(), containsString("element_type [byte] vectors do not support NaN values but found [NaN] at dim [0];"));
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1177,7 +1177,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1203,13 +1203,13 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         Exception e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.NaN, 0f, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(e.getMessage(), containsString("element_type [float] vectors do not support NaN values but found [NaN] at dim [0];"));
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { Float.POSITIVE_INFINITY, 0f, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),
@@ -1218,7 +1218,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, null, null, null)
+            () -> denseVectorFieldType.createKnnQuery(new float[] { 0, Float.NEGATIVE_INFINITY, 0.0f }, 3, 3, null, null, null)
         );
         assertThat(
             e.getMessage(),

+ 11 - 11
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -165,7 +165,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             for (int i = 0; i < dims; i++) {
                 queryVector[i] = randomFloat();
             }
-            Query query = field.createKnnQuery(queryVector, 10, null, null, producer);
+            Query query = field.createKnnQuery(queryVector, 10, 10, null, null, producer);
             assertThat(query, instanceOf(DiversifyingChildrenFloatKnnVectorQuery.class));
         }
         {
@@ -186,11 +186,11 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 floatQueryVector[i] = queryVector[i];
             }
             VectorData vectorData = new VectorData(null, queryVector);
-            Query query = field.createKnnQuery(vectorData, 10, null, null, producer);
+            Query query = field.createKnnQuery(vectorData, 10, 10, null, null, producer);
             assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
 
             vectorData = new VectorData(floatQueryVector, null);
-            query = field.createKnnQuery(vectorData, 10, null, null, producer);
+            query = field.createKnnQuery(vectorData, 10, 10, null, null, producer);
             assertThat(query, instanceOf(DiversifyingChildrenByteKnnVectorQuery.class));
         }
     }
@@ -251,7 +251,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }, 10, null, null, null)
+            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }, 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
 
@@ -267,7 +267,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }, 10, null, null, null)
+            () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f, 0.0f }, 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
 
@@ -283,7 +283,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f, 0.0f }, 10, null, null, null)
+            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f, 0.0f }, 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
     }
@@ -304,7 +304,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             for (int i = 0; i < 4096; i++) {
                 queryVector[i] = randomFloat();
             }
-            Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, null, null, null);
+            Query query = fieldWith4096dims.createKnnQuery(queryVector, 10, 10, null, null, null);
             assertThat(query, instanceOf(KnnFloatVectorQuery.class));
         }
 
@@ -324,7 +324,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 queryVector[i] = randomByte();
             }
             VectorData vectorData = new VectorData(null, queryVector);
-            Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, null, null, null);
+            Query query = fieldWith4096dims.createKnnQuery(vectorData, 10, 10, null, null, null);
             assertThat(query, instanceOf(KnnByteVectorQuery.class));
         }
     }
@@ -342,7 +342,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         IllegalArgumentException e = expectThrows(
             IllegalArgumentException.class,
-            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null, null)
+            () -> unindexedField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("to perform knn search on field [f], its mapping must have [index] set to [true]"));
 
@@ -358,13 +358,13 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null, null, null)
+            () -> cosineField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
 
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, null, null, null)
+            () -> cosineField.createKnnQuery(new VectorData(null, new byte[] { 0, 0, 0 }), 10, 10, null, null, null)
         );
         assertThat(e.getMessage(), containsString("The [cosine] similarity does not support vectors with zero magnitude."));
     }

+ 1 - 0
server/src/test/java/org/elasticsearch/index/query/NestedQueryBuilderTests.java

@@ -267,6 +267,7 @@ public class NestedQueryBuilderTests extends AbstractQueryTestCase<NestedQueryBu
         KnnVectorQueryBuilder innerQueryBuilder = new KnnVectorQueryBuilder(
             "nested1." + VECTOR_FIELD,
             new float[] { 1.0f, 2.0f, 3.0f },
+            null,
             1,
             null
         );

+ 53 - 13
server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

@@ -52,7 +52,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     abstract DenseVectorFieldMapper.ElementType elementType();
 
-    abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity);
+    abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity);
 
     @Override
     protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
@@ -82,8 +82,9 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
     @Override
     protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
         String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD;
-        int numCands = randomIntBetween(DEFAULT_SIZE, 1000);
-        KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, numCands, randomBoolean() ? null : randomFloat());
+        Integer k = randomBoolean() ? null : randomIntBetween(1, 100);
+        int numCands = randomIntBetween(k == null ? DEFAULT_SIZE : k + 20, 1000);
+        KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, k, numCands, randomFloat());
 
         if (randomBoolean()) {
             List<QueryBuilder> filters = new ArrayList<>();
@@ -125,12 +126,14 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
             case BYTE, BIT -> new ESKnnByteVectorQuery(
                 VECTOR_FIELD,
                 queryBuilder.queryVector().asByteVector(),
+                queryBuilder.k(),
                 queryBuilder.numCands(),
                 filterQuery
             );
             case FLOAT -> new ESKnnFloatVectorQuery(
                 VECTOR_FIELD,
                 queryBuilder.queryVector().asFloatVector(),
+                queryBuilder.k(),
                 queryBuilder.numCands(),
                 filterQuery
             );
@@ -143,7 +146,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     public void testWrongDimension() {
         SearchExecutionContext context = createSearchExecutionContext();
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 10, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f }, 5, 10, null);
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
         assertThat(
             e.getMessage(),
@@ -153,7 +156,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     public void testNonexistentField() {
         SearchExecutionContext context = createSearchExecutionContext();
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 10, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null);
         IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context));
         assertThat(e.getMessage(), containsString("field [nonexistent] does not exist in the mapping"));
     }
@@ -163,6 +166,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(
             AbstractBuilderTestCase.KEYWORD_FIELD_NAME,
             new float[] { 1.0f, 1.0f, 1.0f },
+            5,
             10,
             null
         );
@@ -170,9 +174,19 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         assertThat(e.getMessage(), containsString("[knn] queries are only supported on [dense_vector] fields"));
     }
 
+    public void testNumCandsLessThanK() {
+        int k = 5;
+        int numCands = 3;
+        IllegalArgumentException e = expectThrows(
+            IllegalArgumentException.class,
+            () -> new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 1.0f, 1.0f }, k, numCands, null)
+        );
+        assertThat(e.getMessage(), containsString("[num_candidates] cannot be less than [k]"));
+    }
+
     @Override
     public void testValidOutput() {
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 10, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, null, 10, null);
         String expected = """
             {
               "knn" : {
@@ -186,6 +200,22 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
               }
             }""";
         assertEquals(expected, query.toString());
+
+        KnnVectorQueryBuilder query2 = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, 5, 10, null);
+        String expected2 = """
+            {
+              "knn" : {
+                "field" : "vector",
+                "query_vector" : [
+                  1.0,
+                  2.0,
+                  3.0
+                ],
+                "k" : 5,
+                "num_candidates" : 10
+              }
+            }""";
+        assertEquals(expected2, query2.toString());
     }
 
     @Override
@@ -193,7 +223,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         SearchExecutionContext context = createSearchExecutionContext();
         context.setAllowUnmappedFields(true);
         TermQueryBuilder termQuery = new TermQueryBuilder("unmapped_field", 42);
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(VECTOR_FIELD, new float[] { 1.0f, 2.0f, 3.0f }, VECTOR_DIMENSION, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder(
+            VECTOR_FIELD,
+            new float[] { 1.0f, 2.0f, 3.0f },
+            VECTOR_DIMENSION,
+            null,
+            null
+        );
         query.addFilterQuery(termQuery);
 
         IllegalStateException e = expectThrows(IllegalStateException.class, () -> query.toQuery(context));
@@ -206,7 +242,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
     public void testBWCVersionSerializationFilters() throws IOException {
         KnnVectorQueryBuilder query = createTestQueryBuilder();
         VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
-        KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null)
+        KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null)
             .queryName(query.queryName())
             .boost(query.boost());
         TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
@@ -220,7 +256,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
     public void testBWCVersionSerializationSimilarity() throws IOException {
         KnnVectorQueryBuilder query = createTestQueryBuilder();
         VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
-        KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null)
+        KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, null, query.numCands(), null)
             .queryName(query.queryName())
             .boost(query.boost())
             .addFilterQueries(query.filterQueries());
@@ -236,10 +272,13 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         );
         Float similarity = differentQueryVersion.before(TransportVersions.V_8_8_0) ? null : query.getVectorSimilarity();
         VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
-        KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), similarity)
-            .queryName(query.queryName())
-            .boost(query.boost())
-            .addFilterQueries(query.filterQueries());
+        KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder(
+            query.getFieldName(),
+            vectorData,
+            null,
+            query.numCands(),
+            similarity
+        ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
         assertBWCSerialization(query, queryOlderVersion, differentQueryVersion);
     }
 
@@ -266,6 +305,7 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         KnnVectorQueryBuilder knnVectorQueryBuilder = new KnnVectorQueryBuilder(
             "field",
             new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(expectedArray),
+            null,
             5,
             1f
         );

+ 2 - 2
server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java

@@ -17,11 +17,11 @@ public class KnnByteVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilde
     }
 
     @Override
-    protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) {
+    protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) {
         byte[] vector = new byte[VECTOR_DIMENSION];
         for (int i = 0; i < vector.length; i++) {
             vector[i] = randomByte();
         }
-        return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity);
+        return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity);
     }
 }

+ 2 - 2
server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java

@@ -17,11 +17,11 @@ public class KnnFloatVectorQueryBuilderTests extends AbstractKnnVectorQueryBuild
     }
 
     @Override
-    KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) {
+    KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, Integer k, int numCands, Float similarity) {
         float[] vector = new float[VECTOR_DIMENSION];
         for (int i = 0; i < vector.length; i++) {
             vector[i] = randomFloat();
         }
-        return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity);
+        return new KnnVectorQueryBuilder(fieldName, vector, k, numCands, similarity);
     }
 }

+ 2 - 1
server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

@@ -166,7 +166,8 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
             builder.addFilterQuery(filter);
         }
 
-        QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, numCands, similarity).addFilterQueries(filterQueries).boost(boost);
+        QueryBuilder expected = new KnnVectorQueryBuilder(field, vector, null, numCands, similarity).addFilterQueries(filterQueries)
+            .boost(boost);
         assertEquals(expected, builder.toQueryBuilder());
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

@@ -427,7 +427,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
                             );
                         }
 
-                        yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null);
+                        yield new KnnVectorQueryBuilder(inferenceResultsFieldName, inference, null, null, null);
                     }
                     default -> throw new IllegalStateException(
                         "Field ["

+ 1 - 1
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/DocumentLevelSecurityTests.java

@@ -884,7 +884,7 @@ public class DocumentLevelSecurityTests extends SecurityIntegTestCase {
         // Since there's no kNN search action at the transport layer, we just emulate
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 50, 50, null);
 
         if (randomBoolean()) {
             query.addFilterQuery(new WildcardQueryBuilder("other", "value*"));

+ 3 - 3
x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/integration/FieldLevelSecurityTests.java

@@ -441,7 +441,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
         // Since there's no kNN search action at the transport layer, we just emulate
         // how the action works (it builds a kNN query under the hood)
         float[] queryVector = new float[] { 0.0f, 0.0f, 0.0f };
-        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, null);
+        KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null);
 
         // user1 has access to vector field, so the query should match with the document:
         assertResponse(
@@ -475,7 +475,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
             }
         );
         // user1 can access field1, so the filtered query should match with the document:
-        KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, null).addFilterQuery(
+        KnnVectorQueryBuilder filterQuery1 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery(
             QueryBuilders.matchQuery("field1", "value1")
         );
         assertHitCount(
@@ -486,7 +486,7 @@ public class FieldLevelSecurityTests extends SecurityIntegTestCase {
         );
 
         // user1 cannot access field2, so the filtered query should not match with the document:
-        KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, null).addFilterQuery(
+        KnnVectorQueryBuilder filterQuery2 = new KnnVectorQueryBuilder("vector", queryVector, 10, 10, null).addFilterQuery(
             QueryBuilders.matchQuery("field2", "value2")
         );
         assertHitCount(